In [1]:
import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["HF_HOME"] = "D:/huggingface"

if not os.path.exists("D:/huggingface"):
    os.makedirs("D:/huggingface")

import json
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import clip
from transformers import CLIPProcessor, CLIPModel

In [4]:
import os
base_dir = "./indo-fashion-dataset/versions/15/"
json_path = os.path.join(base_dir, "train_data.json")

with open(json_path, 'r') as f:
    input_data = []
    for line in f:
        obj = json.loads(line)
        input_data.append(obj)
        
test_path = os.path.join(base_dir, "test_data.json")

with open(test_path, 'r') as f:
    test_data = []
    for line in f:
        obj = json.loads(line)
        test_data.append(obj)

val_path = os.path.join(base_dir, "val_data.json")

with open(test_path, 'r') as f:
    val_data = []
    for line in f:
        obj = json.loads(line)
        val_data.append(obj)

In [3]:
class CustomDataset:
    def __init__(self, data, image_dir, processor):
        # Filter out the problem data
        problem_ids = ["55852.jpeg", "68567.jpeg", "78873.png", "71885.png"]
        self.data = [item for item in data if not any(pid in item["image_path"] for pid in problem_ids)]
        self.image_dir = image_dir
        self.processor = processor
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.data[idx]["image_path"])

        try:
            image_path = os.path.join(self.image_dir, self.data[idx]["image_path"])
            image = Image.open(image_path).convert("RGB")
            
            processed_image = self.processor(images=image, return_tensors="pt")["pixel_values"][0]
            processed_text = self.processor.tokenizer(
                text=self.data[idx]["product_title"],
                padding="max_length",
                max_length=77,
                truncation=True,
                return_tensors="pt"
            )["input_ids"][0]
            
            return processed_image, processed_text
            
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

In [6]:
base_dir = "./indo-fashion-dataset/versions/15/"
json_path = os.path.join(base_dir, "train_data.json")
image_dir = os.path.join(base_dir, "")

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

dataset = CustomDataset(input_data, image_dir, processor)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
val_dataset = CustomDataset(val_data, image_dir, processor)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

test_dataset = CustomDataset(test_data, image_dir, processor)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)



In [7]:
from tqdm import tqdm

def convert_models_to_fp32(model):
    for param in model.parameters():
        param.data = param.data.float()
        if param.grad is not None:
            param.grad.data = param.grad.data.float()

In [8]:
def freeze_vision(model, freeze_text_encoder=False, freeze_vision_layers = 6):
    if freeze_text_encoder:
        for layer in model.text_model.encoder.layers[:2]:
            for param in layer.parameters():
                param.requires_grad = False


    for name, param in model.vision_model.named_parameters():
        if "encoder.layers" in name:  
            layer_idx = int(name.split(".")[2])  
            if layer_idx < freeze_vision_layers:  
                param.requires_grad = False
        else:
            continue

In [9]:
model_freeze_vision = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_freeze_vision = model_freeze_vision.to(device)

freeze_vision(model_freeze_vision, freeze_text_encoder=False, freeze_vision_layers = 6)

for name, param in model_freeze_vision.named_parameters():
    print(f"{name}: {'Frozen' if not param.requires_grad else 'Trainable'}")

logit_scale: Trainable
text_model.embeddings.token_embedding.weight: Trainable
text_model.embeddings.position_embedding.weight: Trainable
text_model.encoder.layers.0.self_attn.k_proj.weight: Trainable
text_model.encoder.layers.0.self_attn.k_proj.bias: Trainable
text_model.encoder.layers.0.self_attn.v_proj.weight: Trainable
text_model.encoder.layers.0.self_attn.v_proj.bias: Trainable
text_model.encoder.layers.0.self_attn.q_proj.weight: Trainable
text_model.encoder.layers.0.self_attn.q_proj.bias: Trainable
text_model.encoder.layers.0.self_attn.out_proj.weight: Trainable
text_model.encoder.layers.0.self_attn.out_proj.bias: Trainable
text_model.encoder.layers.0.layer_norm1.weight: Trainable
text_model.encoder.layers.0.layer_norm1.bias: Trainable
text_model.encoder.layers.0.mlp.fc1.weight: Trainable
text_model.encoder.layers.0.mlp.fc1.bias: Trainable
text_model.encoder.layers.0.mlp.fc2.weight: Trainable
text_model.encoder.layers.0.mlp.fc2.bias: Trainable
text_model.encoder.layers.0.layer_no

In [11]:
from earlystopping import EarlyStopping
from evaluate import evaluate_model

In [None]:
num_epochs = 15
# optimizer = torch.optim.Adam(model_freeze.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
optimizer = torch.optim.AdamW(model_freeze_vision.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.1)

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# scheduler = torch.optim.lr_scheduler.StepLR(
#     optimizer,
#     step_size=7500,  
#     gamma=0.75  
# )

total_steps = len(train_loader) * num_epochs
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=5e-5,
    total_steps=total_steps,
    pct_start=0.1,  
    final_div_factor=10  
)

early_stopping = EarlyStopping(patience=5)

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/clip_training')

for epoch in range(num_epochs):
    model_freeze_vision.train()
    train_loss = 0
    num_batches = 0
    pbar = tqdm(train_loader,total=len(train_loader))

    for batch in pbar:
        optimizer.zero_grad()

        images, texts = batch
        images = images.to(device)
        texts = texts.to(device)
        
        outputs = model_freeze_vision(input_ids=texts, pixel_values=images)
        # ground_truth = torch.arange(len(images), device=device)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        batch_loss = (loss_img(outputs.logits_per_image, ground_truth) + 
                     loss_txt(outputs.logits_per_text, ground_truth))/2
        
        batch_loss.backward()

        if device == "cpu":
            optimizer.step()
            scheduler.step()
        else:
            convert_models_to_fp32(model_freeze_vision)
            optimizer.step()
            scheduler.step()
        
        # current_lr = scheduler.get_last_lr()[0]
        train_loss += batch_loss.item()
        num_batches += 1
        
        avg_loss = train_loss / num_batches        
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
   
    model_freeze_vision.eval()
    current_lr = scheduler.get_last_lr()[0]
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            images, texts = [x.to(device) for x in batch]
            outputs = model_freeze_vision(input_ids=texts, pixel_values=images)
            ground_truth = torch.arange(len(images), device=device)
            loss = (loss_img(outputs.logits_per_image, ground_truth) + 
                    loss_txt(outputs.logits_per_text, ground_truth))/2

            val_loss += loss.item()

    val_avg_loss = val_loss / len(val_loader)
    print(f"Val Loss: {val_avg_loss:.4f}, Current LR: {current_lr:.6f}")

    early_stopping(val_avg_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

    writer.add_scalars('Loss', {
        'train': train_loss/len(train_loader),
        'val': val_avg_loss,
        'Learning Rate': current_lr
    }, epoch)

writer.close()

total_loss = avg_loss
save_path = "./checkpoints/clip_checkpoint_FreezeVision.pt"
torch.save({
    'epoch': epoch,
    'model_state_dict': model_freeze_vision.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': total_loss,
}, save_path)

  attn_output = torch.nn.functional.scaled_dot_product_attention(
Epoch 1/15, Avg Loss: 0.7448: 100%|██████████| 2849/2849 [22:16<00:00,  2.13it/s]


Val Loss: 2.0478, Current LR: 0.000038


Epoch 2/15, Avg Loss: 0.6993: 100%|██████████| 2849/2849 [21:57<00:00,  2.16it/s]


Val Loss: 1.8497, Current LR: 0.000050


Epoch 3/15, Avg Loss: 0.5586: 100%|██████████| 2849/2849 [22:00<00:00,  2.16it/s]


Val Loss: 1.7303, Current LR: 0.000048


Epoch 4/15, Avg Loss: 0.4737: 100%|██████████| 2849/2849 [24:00<00:00,  1.98it/s]


Val Loss: 1.6637, Current LR: 0.000046


Epoch 5/15, Avg Loss: 0.4024: 100%|██████████| 2849/2849 [22:03<00:00,  2.15it/s]


Val Loss: 1.6780, Current LR: 0.000042
EarlyStopping counter: 1 out of 5


Epoch 6/15, Avg Loss: 0.3456: 100%|██████████| 2849/2849 [22:08<00:00,  2.14it/s]


Val Loss: 1.6189, Current LR: 0.000038


Epoch 7/15, Avg Loss: 0.2914: 100%|██████████| 2849/2849 [22:12<00:00,  2.14it/s]


Val Loss: 1.5547, Current LR: 0.000032


Epoch 8/15, Avg Loss: 0.2482: 100%|██████████| 2849/2849 [22:11<00:00,  2.14it/s]


Val Loss: 1.4881, Current LR: 0.000027


Epoch 9/15, Avg Loss: 0.2042: 100%|██████████| 2849/2849 [22:09<00:00,  2.14it/s]


Val Loss: 1.4947, Current LR: 0.000021
EarlyStopping counter: 1 out of 5


Epoch 10/15, Avg Loss: 0.1712: 100%|██████████| 2849/2849 [22:10<00:00,  2.14it/s]


Val Loss: 1.4738, Current LR: 0.000015


Epoch 11/15, Avg Loss: 0.1437: 100%|██████████| 2849/2849 [22:10<00:00,  2.14it/s]


Val Loss: 1.5159, Current LR: 0.000010
EarlyStopping counter: 1 out of 5


Epoch 12/15, Avg Loss: 0.1205: 100%|██████████| 2849/2849 [22:23<00:00,  2.12it/s]


Val Loss: 1.5045, Current LR: 0.000006
EarlyStopping counter: 2 out of 5


Epoch 13/15, Avg Loss: 0.1062: 100%|██████████| 2849/2849 [22:42<00:00,  2.09it/s]


Val Loss: 1.5283, Current LR: 0.000003
EarlyStopping counter: 3 out of 5


Epoch 14/15, Avg Loss: 0.0964: 100%|██████████| 2849/2849 [22:09<00:00,  2.14it/s]


Val Loss: 1.5348, Current LR: 0.000001
EarlyStopping counter: 4 out of 5


Epoch 15/15, Avg Loss: 0.0945: 100%|██████████| 2849/2849 [22:06<00:00,  2.15it/s]


Val Loss: 1.5372, Current LR: 0.000000
EarlyStopping counter: 5 out of 5
Early stopping triggered


In [13]:
# freeze the first six layers of the vision encoder

metrics = evaluate_model(model_freeze_vision, test_loader, device)
print("Test Results:")
for k, v in metrics.items():
   print(f"{k}: {v:.4f}")

Extracting features: 100%|██████████| 469/469 [00:33<00:00, 14.09it/s]


Test Results:
I2T_R@1: 0.1924
T2I_R@1: 0.1873
I2T_R@5: 0.4407
T2I_R@5: 0.4397
I2T_R@10: 0.5607
T2I_R@10: 0.5613


In [None]:
# zero-shot prediction

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)
metrics = evaluate_model(model, test_loader, device)
print("Test Results:")
for k, v in metrics.items():
   print(f"{k}: {v:.4f}")

Extracting features: 100%|██████████| 469/469 [00:33<00:00, 13.89it/s]


Test Results:
I2T_R@1: 0.0351
T2I_R@1: 0.0220
I2T_R@5: 0.1085
T2I_R@5: 0.0707
I2T_R@10: 0.1557
T2I_R@10: 0.1117


In [17]:
# Finetune all the layers

unfreeze_path = "./checkpoints/clip_checkpoint_unfreeze.pt"
unfreeze_checkpoint = torch.load(unfreeze_path)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.load_state_dict(unfreeze_checkpoint['model_state_dict'])
model = model.to(device)

metrics = evaluate_model(model, test_loader, device)
print("Test Results:")
for k, v in metrics.items():
   print(f"{k}: {v:.4f}")

Extracting features: 100%|██████████| 469/469 [00:33<00:00, 14.04it/s]


Test Results:
I2T_R@1: 0.1652
T2I_R@1: 0.1548
I2T_R@5: 0.4007
T2I_R@5: 0.3861
I2T_R@10: 0.5171
T2I_R@10: 0.5160


In [18]:
# Freeze the backbone of both vision and text encoders

Mixfreeze_path = "./checkpoints/clip_checkpoint_MixFreeze.pt"
Mixfreeze_checkpoint = torch.load(Mixfreeze_path)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.load_state_dict(Mixfreeze_checkpoint['model_state_dict'])
model = model.to(device)

metrics = evaluate_model(model, test_loader, device)
print("Test Results:")
for k, v in metrics.items():
   print(f"{k}: {v:.4f}")

Extracting features: 100%|██████████| 469/469 [00:34<00:00, 13.55it/s]

Test Results:
I2T_R@1: 0.1659
T2I_R@1: 0.1599
I2T_R@5: 0.3944
T2I_R@5: 0.3887
I2T_R@10: 0.5117
T2I_R@10: 0.5053





In [38]:
def V2T_retrieval(model, processor, image_path, text_queries, k=5):
    """Retrieval text from image"""
    model.eval()

    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    image_features = model.get_image_features(inputs['pixel_values'].to(device))

    all_text_features = []
    for text_query in text_queries:
        inputs = processor.tokenizer(text=text_query, return_tensors="pt", truncation=True)
        # text_features = model.get_text_features(inputs['pixel_values'].to(device))
        text_features = model.get_text_features(**inputs.to(device))
        all_text_features.append(text_features)

    all_text_features = torch.cat(all_text_features)
    similarities = image_features @ all_text_features.T
    top_indices = similarities.topk(k).indices

    return [text_queries[idx] for idx in top_indices[0]]

In [43]:
# text_query = "Women's Digital Cotton Linen Blend Saree with Unstitched Blouse Piece(DigiPatta)"
# Women's Khadi Cotton Saree With Blouse Piece (UFO301119SOLD_KHDI_1_Grey)
text_queries = [
    "Women's Georgette Saree with Blouse Piece",
    "Women's Khadi Cotton Saree With Blouse Piece",
    "Georgette Strip Print Saree",
    "A photo of a blue saree.",
    "X SUNEET VARMA Women's Georgette and Pattern Mesh Ruffle Saree & Solid Blouse",
    "Women's Digital Cotton Linen Blend Saree with Unstitched Blouse Piece(DigiPatta)",
    "A photo of a blouse saree.",
    "Ilkal Resham Cotton Traditional Satin Border Tope Teni Pallu Saree",
    "Women's Saree with Solid Piece",
    "Women's Patola Style Art Silk Saree"
]
image_path = "./indo-fashion-dataset/versions/15/images/val/0.jpeg"

V2T_retrieval(model_freeze_vision, processor, image_path, text_queries, k=5)

["Women's Khadi Cotton Saree With Blouse Piece",
 "Women's Saree with Solid Piece",
 'A photo of a blouse saree.',
 'Ilkal Resham Cotton Traditional Satin Border Tope Teni Pallu Saree',
 "X SUNEET VARMA Women's Georgette and Pattern Mesh Ruffle Saree & Solid Blouse"]

In [50]:
def T2V_retrieval(model, processor, text_query, image_folder,k=10):
    """Retrieval image from text"""
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    inputs = processor.tokenizer(text=text_query, return_tensors="pt", truncation=True)
    text_features = model.get_text_features(**inputs.to(device))

    all_image_features = []
    all_image_paths = []
    for img_path in tqdm(glob.glob(f"{image_folder}/*.jpg")):
        try:
            image = Image.open(img_path).convert("RGB")
            inputs = processor(images=image, return_tensors="pt")
            image_features = model.get_image_features(inputs['pixel_values'].to(device))
            all_image_features.append(image_features)
            all_image_paths.append(img_path)
        except Exception as e:
            print(f"Error processing image {img_path}: {e}")
            continue

    all_image_features = torch.cat(all_image_features)
    similarities = text_features @ all_image_features.T
    top_indices = similarities.topk(k).indices

    return [all_image_paths[idx] for idx in top_indices[0]]


In [None]:
image_folder = "./indo-fashion-dataset/versions/15/images/val"
text_query = "Women's Khadi Cotton Saree With Blouse Piece"

T2V_retrieval(model_freeze_vision, processor, text_query, image_folder, k=10)