In [3]:
import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["HF_HOME"] = "D:/huggingface"

if not os.path.exists("D:/huggingface"):
    os.makedirs("D:/huggingface")

In [4]:
import json
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import clip
from transformers import CLIPProcessor, CLIPModel

In [5]:
print(torch.__version__)
torch.cuda.is_available()

2.2.2+cu118


True

In [9]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("validmodel/indo-fashion-dataset")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/validmodel/indo-fashion-dataset?dataset_version_number=15...


100%|██████████| 2.69G/2.69G [02:34<00:00, 18.7MB/s]

Extracting files...





Path to dataset files: C:\Users\25472\.cache\kagglehub\datasets\validmodel\indo-fashion-dataset\versions\15


In [6]:
import os
base_dir = "./indo-fashion-dataset/versions/15/"
json_path = os.path.join(base_dir, "train_data.json")

with open(json_path, 'r') as f:
    input_data = []
    for line in f:
        obj = json.loads(line)
        input_data.append(obj)

In [41]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [31]:
class CustomDataset:
    def __init__(self, data, image_dir, processor):
        # Filter out the problem data
        problem_ids = ["55852.jpeg", "68567.jpeg", "78873.png", "71885.png"]
        self.data = [item for item in data if not any(pid in item["image_path"] for pid in problem_ids)]
        self.image_dir = image_dir
        self.processor = processor
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.data[idx]["image_path"])
        try:
            image_path = os.path.join(self.image_dir, self.data[idx]["image_path"])
            image = Image.open(image_path).convert("RGB")
            
            processed_image = self.processor(images=image, return_tensors="pt")["pixel_values"][0]
            processed_text = self.processor.tokenizer(
                text=self.data[idx]["product_title"],
                padding="max_length",
                max_length=77,
                truncation=True,
                return_tensors="pt"
            )["input_ids"][0]
            
            return processed_image, processed_text
            
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))
        
        #return processed["pixel_values"][0], processed["input_ids"][0]

In [32]:
base_dir = "./indo-fashion-dataset/versions/15/"
json_path = os.path.join(base_dir, "train_data.json")
image_dir = os.path.join(base_dir, "")

dataset = CustomDataset(input_data, image_dir, processor)

In [33]:
import os
base_dir = "./indo-fashion-dataset/versions/15/"
test_path = os.path.join(base_dir, "test_data.json")

with open(test_path, 'r') as f:
    test_data = []
    for line in f:
        obj = json.loads(line)
        test_data.append(obj)

val_path = os.path.join(base_dir, "val_data.json")

with open(test_path, 'r') as f:
    val_data = []
    for line in f:
        obj = json.loads(line)
        val_data.append(obj)

In [34]:
print(len(dataset))
print(len(test_data))
print(len(val_data))

91162
7500
7500


In [35]:
for idx in range(5):
    sample = dataset[idx]
    print(f"Test {idx}:", type(sample))

Test 0: <class 'tuple'>
Test 1: <class 'tuple'>
Test 2: <class 'tuple'>
Test 3: <class 'tuple'>
Test 4: <class 'tuple'>


In [15]:
# Test the dataloader
try:
    print("Testing first sample loading...")
    single_sample = dataset[0]
    print("Single sample loaded successfully")
    print("Image shape:", single_sample[0].shape)
    print("Text shape:", single_sample[1].shape)
except Exception as e:
    print("Error loading single sample:", e)

try:
    print("\nTesting batch loading...")
    batch = next(iter(dataloader))
    print("Batch loaded successfully")
except Exception as e:
    print("Error loading batch:", e)

Testing first sample loading...
Single sample loaded successfully
Image shape: torch.Size([3, 224, 224])
Text shape: torch.Size([77])

Testing batch loading...
Batch loaded successfully


In [36]:
# dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0)
from sklearn.model_selection import train_test_split

# train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
val_dataset = CustomDataset(val_data, image_dir, processor)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

test_dataset = CustomDataset(test_data, image_dir, processor)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [15]:
# 测试 DataLoader
for images, texts in train_loader:
    print("Image batch shape:", images.shape)
    print("Text batch shape:", texts.shape)
    break

Image batch shape: torch.Size([32, 3, 224, 224])
Text batch shape: torch.Size([32, 77])


In [16]:
from tqdm import tqdm

In [17]:
def convert_models_to_fp32(model):
    for param in model.parameters():
        param.data = param.data.float()
        if param.grad is not None:
            param.grad.data = param.grad.data.float()

In [57]:
class EarlyStopping:
   def __init__(self, patience=5, min_delta=0, verbose=True):
       self.patience = patience
       self.min_delta = min_delta
       self.counter = 0
       self.best_loss = None
       self.early_stop = False
       self.verbose = verbose

   def __call__(self, val_loss):
       if self.best_loss is None:
           self.best_loss = val_loss
       elif val_loss > self.best_loss - self.min_delta:
           self.counter += 1
           if self.verbose:
               print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
           if self.counter >= self.patience:
               self.early_stop = True
       else:
           self.best_loss = val_loss
           self.counter = 0

In [43]:
num_epochs = 20
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=1000,  
    gamma=0.75 
)

early_stopping = EarlyStopping(patience=5)

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/clip_training')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    num_batches = 0
    pbar = tqdm(train_loader,total=len(train_loader))

    for batch in pbar:
        optimizer.zero_grad()

        images, texts = batch
        images = images.to(device)
        texts = texts.to(device)
        
        outputs = model(input_ids=texts, pixel_values=images)
        # ground_truth = torch.arange(len(images), device=device)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        batch_loss = (loss_img(outputs.logits_per_image, ground_truth) + 
                     loss_txt(outputs.logits_per_text, ground_truth))/2
        
        batch_loss.backward()

        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # total_loss.backward()
        if device == "cpu":
            optimizer.step()
            scheduler.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            scheduler.step()
        
        # current_lr = scheduler.get_last_lr()[0]
        train_loss += batch_loss.item()
        num_batches += 1
        
        avg_loss = train_loss / num_batches        
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
   
    model.eval()
    current_lr = scheduler.get_last_lr()[0]
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            images, texts = [x.to(device) for x in batch]
            outputs = model(input_ids=texts, pixel_values=images)
            ground_truth = torch.arange(len(images), device=device)
            loss = (loss_img(outputs.logits_per_image, ground_truth) + 
                    loss_txt(outputs.logits_per_text, ground_truth))/2

            val_loss += loss.item()

    val_avg_loss = val_loss / len(val_loader)
    print(f"Val Loss: {val_avg_loss:.4f}, Current LR: {current_lr:.6f}")

    early_stopping(val_avg_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

    writer.add_scalars('Loss', {
        'train': train_loss/len(train_loader),
        'val': val_avg_loss,
        'Learning Rate': current_lr
    }, epoch)

writer.close()

total_loss = avg_loss
save_path = "./checkpoints/clip_checkpoint.pt"
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': total_loss,
}, save_path)

Epoch 1/20, Avg Loss: 0.9022: 100%|██████████| 2849/2849 [25:36<00:00,  1.85it/s]


Val Loss: 1.8269, Current LR: 0.000028


Epoch 2/20, Avg Loss: 0.4580: 100%|██████████| 2849/2849 [25:02<00:00,  1.90it/s]


Val Loss: 1.5725, Current LR: 0.000012


Epoch 3/20, Avg Loss: 0.2898: 100%|██████████| 2849/2849 [24:22<00:00,  1.95it/s]


Val Loss: 1.4427, Current LR: 0.000005


Epoch 4/20, Avg Loss: 0.2038: 100%|██████████| 2849/2849 [24:30<00:00,  1.94it/s]


Val Loss: 1.4406, Current LR: 0.000002


Epoch 5/20, Avg Loss: 0.1644: 100%|██████████| 2849/2849 [24:26<00:00,  1.94it/s]


Val Loss: 1.4868, Current LR: 0.000001


Epoch 6/20, Avg Loss: 0.1671:   1%|          | 29/2849 [00:15<24:58,  1.88it/s]


KeyboardInterrupt: 

In [45]:
test_dataset = CustomDataset(test_data, image_dir, processor)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [46]:
def evaluate_model(model, test_loader, device):
    """
    Calculate image-text and text-image retrieval accuracy（R@1, R@5, R@10）
    """
    model.eval()
    all_image_features = []
    all_text_features = []
    
    with torch.no_grad():
        for images, texts in tqdm(test_loader, desc="Extracting features"):
            images = images.to(device)
            texts = texts.to(device)
            
            outputs = model(input_ids=texts, pixel_values=images)
            image_features = outputs.image_embeds
            text_features = outputs.text_embeds
            
            all_image_features.append(image_features)
            all_text_features.append(text_features)
    
    image_features = torch.cat(all_image_features)
    text_features = torch.cat(all_text_features)
    
    similarity = (image_features @ text_features.T)
    
    metrics = {}
    for k in [1, 5, 10]:
        i2t_recall = compute_recall_at_k(similarity, k)
        t2i_recall = compute_recall_at_k(similarity.T, k)
        metrics[f'I2T_R@{k}'] = i2t_recall
        metrics[f'T2I_R@{k}'] = t2i_recall
    
    return metrics

def compute_recall_at_k(similarity, k):
    batch_size = similarity.shape[0]
    indices = torch.topk(similarity, k, dim=1)[1]
    targets = torch.arange(batch_size).view(-1, 1).to(similarity.device)
    correct = (indices == targets)
    return (correct.sum(1) > 0).float().mean().item()

In [47]:
metrics = evaluate_model(model, test_loader, device)
print("Test Results:")
for k, v in metrics.items():
   print(f"{k}: {v:.4f}")

Extracting features: 100%|██████████| 469/469 [00:45<00:00, 10.28it/s]

Test Results:
I2T_R@1: 0.1652
T2I_R@1: 0.1548
I2T_R@5: 0.4007
T2I_R@5: 0.3861
I2T_R@10: 0.5171
T2I_R@10: 0.5160





In [50]:
def freeze_backbone(model, freeze_text_encoder=False):
    """
    :param freeze_text_encoder: whether to freeze the text encoder
    """
    if freeze_text_encoder:
        # Freeze the first two layers of the text encoder
        for layer in model.text_model.encoder.layers[:2]:
            for param in layer.parameters():
                param.requires_grad = False


    # Freeze the first eight layers of the vision encoder
    for name, param in model.vision_model.named_parameters():
        if "encoder.layers" in name:  
            layer_idx = int(name.split(".")[2])  
            if layer_idx < 8:  
                param.requires_grad = False
        else:
            continue

In [60]:
model_freeze = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_freeze = model_freeze.to(device)

freeze_backbone(model_freeze, freeze_text_encoder=True)

for name, param in model_freeze.named_parameters():
    print(f"{name}: {'Frozen' if not param.requires_grad else 'Trainable'}")

logit_scale: Trainable
text_model.embeddings.token_embedding.weight: Trainable
text_model.embeddings.position_embedding.weight: Trainable
text_model.encoder.layers.0.self_attn.k_proj.weight: Frozen
text_model.encoder.layers.0.self_attn.k_proj.bias: Frozen
text_model.encoder.layers.0.self_attn.v_proj.weight: Frozen
text_model.encoder.layers.0.self_attn.v_proj.bias: Frozen
text_model.encoder.layers.0.self_attn.q_proj.weight: Frozen
text_model.encoder.layers.0.self_attn.q_proj.bias: Frozen
text_model.encoder.layers.0.self_attn.out_proj.weight: Frozen
text_model.encoder.layers.0.self_attn.out_proj.bias: Frozen
text_model.encoder.layers.0.layer_norm1.weight: Frozen
text_model.encoder.layers.0.layer_norm1.bias: Frozen
text_model.encoder.layers.0.mlp.fc1.weight: Frozen
text_model.encoder.layers.0.mlp.fc1.bias: Frozen
text_model.encoder.layers.0.mlp.fc2.weight: Frozen
text_model.encoder.layers.0.mlp.fc2.bias: Frozen
text_model.encoder.layers.0.layer_norm2.weight: Frozen
text_model.encoder.laye

In [61]:
num_epochs = 15
# optimizer = torch.optim.Adam(model_freeze.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
optimizer = torch.optim.AdamW(model_freeze.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.25)

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=10000, 
    gamma=0.75  
)

early_stopping = EarlyStopping(patience=5)

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/clip_training')

for epoch in range(num_epochs):
    model_freeze.train()
    train_loss = 0
    num_batches = 0
    pbar = tqdm(train_loader,total=len(train_loader))

    for batch in pbar:
        optimizer.zero_grad()

        images, texts = batch
        images = images.to(device)
        texts = texts.to(device)
        
        outputs = model_freeze(input_ids=texts, pixel_values=images)
        # ground_truth = torch.arange(len(images), device=device)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        batch_loss = (loss_img(outputs.logits_per_image, ground_truth) + 
                     loss_txt(outputs.logits_per_text, ground_truth))/2
        
        batch_loss.backward()

        # torch.nn.utils.clip_grad_norm_(model_freeze.parameters(), max_norm=1.0)

        # total_loss.backward()
        if device == "cpu":
            optimizer.step()
            scheduler.step()
        else:
            convert_models_to_fp32(model_freeze)
            optimizer.step()
            scheduler.step()
        
        # current_lr = scheduler.get_last_lr()[0]
        train_loss += batch_loss.item()
        num_batches += 1
        
        avg_loss = train_loss / num_batches        
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
   
    model_freeze.eval()
    current_lr = scheduler.get_last_lr()[0]
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            images, texts = [x.to(device) for x in batch]
            outputs = model_freeze(input_ids=texts, pixel_values=images)
            ground_truth = torch.arange(len(images), device=device)
            loss = (loss_img(outputs.logits_per_image, ground_truth) + 
                    loss_txt(outputs.logits_per_text, ground_truth))/2

            val_loss += loss.item()

    val_avg_loss = val_loss / len(val_loader)
    print(f"Val Loss: {val_avg_loss:.4f}, Current LR: {current_lr:.6f}")

    early_stopping(val_avg_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

    writer.add_scalars('Loss', {
        'train': train_loss/len(train_loader),
        'val': val_avg_loss,
        'Learning Rate': current_lr
    }, epoch)

writer.close()

total_loss = avg_loss
save_path = "./checkpoints/clip_checkpoint_MixFreeze.pt"
torch.save({
    'epoch': epoch,
    'model_state_dict': model_freeze.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': total_loss,
}, save_path)

Epoch 1/15, Avg Loss: 0.8894: 100%|██████████| 2849/2849 [21:30<00:00,  2.21it/s]


Val Loss: 2.0474, Current LR: 0.000050


Epoch 2/15, Avg Loss: 0.6199: 100%|██████████| 2849/2849 [22:19<00:00,  2.13it/s]


Val Loss: 1.8842, Current LR: 0.000050


Epoch 3/15, Avg Loss: 0.5218: 100%|██████████| 2849/2849 [21:20<00:00,  2.23it/s]


Val Loss: 1.8388, Current LR: 0.000050


Epoch 4/15, Avg Loss: 0.4364: 100%|██████████| 2849/2849 [21:27<00:00,  2.21it/s]


Val Loss: 1.6039, Current LR: 0.000038


Epoch 5/15, Avg Loss: 0.3538: 100%|██████████| 2849/2849 [21:22<00:00,  2.22it/s]


Val Loss: 1.6639, Current LR: 0.000038
EarlyStopping counter: 1 out of 5


Epoch 6/15, Avg Loss: 0.3252: 100%|██████████| 2849/2849 [21:18<00:00,  2.23it/s]


Val Loss: 1.6248, Current LR: 0.000038
EarlyStopping counter: 2 out of 5


Epoch 7/15, Avg Loss: 0.2974: 100%|██████████| 2849/2849 [21:22<00:00,  2.22it/s]


Val Loss: 1.6318, Current LR: 0.000038
EarlyStopping counter: 3 out of 5


Epoch 8/15, Avg Loss: 0.2463: 100%|██████████| 2849/2849 [21:24<00:00,  2.22it/s]


Val Loss: 1.5952, Current LR: 0.000028


Epoch 9/15, Avg Loss: 0.2241: 100%|██████████| 2849/2849 [21:25<00:00,  2.22it/s]


Val Loss: 1.6175, Current LR: 0.000028
EarlyStopping counter: 1 out of 5


Epoch 10/15, Avg Loss: 0.2106: 100%|██████████| 2849/2849 [21:23<00:00,  2.22it/s]


Val Loss: 1.6358, Current LR: 0.000028
EarlyStopping counter: 2 out of 5


Epoch 11/15, Avg Loss: 0.1946: 100%|██████████| 2849/2849 [21:26<00:00,  2.22it/s]


Val Loss: 1.5876, Current LR: 0.000021


Epoch 12/15, Avg Loss: 0.1647: 100%|██████████| 2849/2849 [21:23<00:00,  2.22it/s]


Val Loss: 1.6168, Current LR: 0.000021
EarlyStopping counter: 1 out of 5


Epoch 13/15, Avg Loss: 0.1575: 100%|██████████| 2849/2849 [21:30<00:00,  2.21it/s]


Val Loss: 1.5993, Current LR: 0.000021
EarlyStopping counter: 2 out of 5


Epoch 14/15, Avg Loss: 0.1546: 100%|██████████| 2849/2849 [21:27<00:00,  2.21it/s]


Val Loss: 1.6163, Current LR: 0.000021
EarlyStopping counter: 3 out of 5


Epoch 15/15, Avg Loss: 0.1338: 100%|██████████| 2849/2849 [21:25<00:00,  2.22it/s]


Val Loss: 1.5699, Current LR: 0.000016


In [62]:
metrics = evaluate_model(model_freeze, test_loader, device)
print("Test Results:")
for k, v in metrics.items():
   print(f"{k}: {v:.4f}")

Extracting features: 100%|██████████| 469/469 [00:35<00:00, 13.18it/s]


Test Results:
I2T_R@1: 0.1659
T2I_R@1: 0.1599
I2T_R@5: 0.3944
T2I_R@5: 0.3887
I2T_R@10: 0.5117
T2I_R@10: 0.5053
