In [1]:
import os

import pandas as pd
from PIL import Image

import torch
import torch.nn.parallel
import torch.utils.data
from torch.nn.utils.rnn import pad_sequence
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer
from concurrent.futures import ThreadPoolExecutor

from transformers import CLIPProcessor, CLIPModel, EvalPrediction

  from .autonotebook import tqdm as notebook_tqdm


### Load dataset

In [2]:
df = pd.read_csv('./dataset/GLAMI-1M-train-dataset/GLAMI-1M-train.csv')

In [3]:
df.head()

Unnamed: 0,item_id,image_id,geo,name,description,category,category_name,label_source
0,0,0,cz,Casio Collection MTP-1259D-7BEF,,645,mens-watches,custom-tag
1,3,3,cz,Casio Collection MTP-1303D-1AVEF,,645,mens-watches,custom-tag
2,5,5,cz,Hot Diamonds Náušnice Eternity Interlocking St...,,468,womens-earrings,custom-tag
3,7,7,cz,Shepherd Papuče ANTON,Shepherd Papuče ANTON Béžová. K dispozici v pá...,40328,mens-boots,custom-tag
4,8,8,cz,Vestis Dámský župan Milano 3130 - dle obrázku - S,Dámský župan Milano Vestis s proužky. Župan s ...,437,womens-bathrobes,custom-tag


In [4]:
def preprocess_category_name(category_name: str):
    category_name = category_name.replace("-", " ")
    category_name = category_name.replace("womens", "women's").replace("women s", "women's")
    category_name = category_name.replace("mens", "men's").replace("men s", "men's")
    category_name = category_name.replace("boys", "boy's").replace("boy s", "boy's")
    category_name = category_name.replace("girls", "girl's").replace("girl s", "girl's")
    category_name = category_name.replace("and", "or")
    category_name = category_name.replace("t shirts", "t-shirts")
    return category_name

def preprocess_category_names_in_batch(category_names):
    with ThreadPoolExecutor() as executor:
        results = list(executor.map(preprocess_category_name, category_names))
    return results

In [5]:
df["processed_category_name"] = None
batch_size = 1000
for start in range(0, len(df), batch_size):
    end = min(start + batch_size, len(df))
    batch_category_names = df["category_name"][start:end]
    batch_processed_names = preprocess_category_names_in_batch(batch_category_names)
    df.iloc[start:end, df.columns.get_loc("processed_category_name")] = batch_processed_names
# Drop the original category name column
df = df.reset_index().rename(columns={'index': 'id'})
df = df.drop(columns=["category_name"])[["processed_category_name", "id", "image_id"]]

In [6]:
df.head()

Unnamed: 0,processed_category_name,id,image_id
0,men's watches,0,0
1,men's watches,1,3
2,women's earrings,2,5
3,men's boots,3,7
4,women's bathrobes,4,8


In [7]:
# Aggregate the data by the `processed category name` and get the count of each category
category_counts = df.groupby("processed_category_name").agg(
    count=("id", "size"),
    ids=("id", lambda x: list(x))
).reset_index()

In [8]:
category_counts

Unnamed: 0,processed_category_name,count,ids
0,baby accessories,1200,"[6059, 8460, 8461, 8811, 10357, 12939, 12940, ..."
1,baby clothing,27896,"[14957, 14958, 14963, 14964, 14977, 14984, 149..."
2,baby shoes,14577,"[37415, 37942, 38129, 38884, 38941, 40618, 406..."
3,bathroom,2318,"[841, 842, 2138, 2329, 2330, 2331, 2332, 2335,..."
4,bed linen,5254,"[361, 362, 363, 388, 389, 676, 1582, 1738, 230..."
...,...,...,...
186,women's undershirts,773,"[4867, 11589, 12103, 12198, 12521, 37824, 3817..."
187,women's vests,3149,"[9589, 12081, 12883, 18566, 21051, 21508, 2175..."
188,women's wallets,5442,"[107, 109, 139, 345, 401, 412, 430, 435, 692, ..."
189,women's watches,3062,"[3691, 4070, 4156, 4160, 4163, 4166, 4167, 416..."


### Load original CLIP model

In [9]:
# Load the CLIP model and processor from Hugging Face
model_id = "openai/clip-vit-base-patch32"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Set `device_map` to `cuda` to ignore the error `CUBLAS_STATUS_EXECUTION_FAILED` on multi-GPU machines
# See: https://forums.developer.nvidia.com/t/cublas-status-execution-failed/27370
model = CLIPModel.from_pretrained(model_id, device_map=device)
processor = CLIPProcessor.from_pretrained(model_id, clean_up_tokenization_spaces=True)

### Custom dataset class

In [10]:
class FashionDataset(Dataset):
    def __init__(self, dataframe, image_dir, processor):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.processor = processor

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Get the processed category name (text) and image path
        category = self.dataframe.iloc[idx]['processed_category_name']
        image_id = self.dataframe.iloc[idx]['image_id']
        image_path = os.path.join(self.image_dir, f"{image_id}.jpg")

        # Open the image
        image = Image.open(image_path).resize((224, 224)).convert("RGB")

        query = f'a photo of the {category}'

        # Process the text and image into tensors
        inputs = self.processor(text=query, images=image, return_tensors="pt", padding=True)

        # Remove the batch dimension from inputs as DataLoader adds it
        inputs = {key: val.squeeze(0) for key, val in inputs.items()}

        return inputs

In [11]:
image_dir = './dataset/GLAMI-1M-train-dataset/images'
dataset = FashionDataset(df, image_dir, processor)

# Split the dataset
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

### Configuration for Parameter-Efficient Fine-Tuning (PEFT)

In [12]:
# PEFT Configuration for Low-Rank Adaptation (LoRA)
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["k_proj", "v_proj", "q_proj", "out_proj"],
    base_model_name_or_path=model_id
)

In [13]:
# Count original model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

# Initialize LoRA model
lora_model = get_peft_model(model, peft_config)

# Count parameters after applying LoRA
lora_trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)

# Print the details
print(f"Total parameters in original model: {total_params}")
print(f"Trainable parameters in original model: {trainable_params}")
print(f"Trainable parameters after applying LoRA: {lora_trainable_params}")
print(f"Percentage of parameters being trained: {100 * lora_trainable_params / total_params:.2f}%")

Total parameters in original model: 151277313
Trainable parameters in original model: 151277313
Trainable parameters after applying LoRA: 1966080
Percentage of parameters being trained: 1.30%


### Configuration training arguments

In [14]:
# Define the training arguments
use_cpu = True if device.type == "cpu" else False
use_fp16 = True if torch_dtype == torch.float16 else False
print(f"Training on CPU: {use_cpu}")
print(f"Training with FP16: {use_fp16}")
training_args = TrainingArguments(
    output_dir="./results",             # Directory to save checkpoints and results
    per_device_train_batch_size=512,     # Batch size for training (adjust based on your GPU memory)
    per_device_eval_batch_size=512,      # Batch size for evaluation
    num_train_epochs=5,                 # Number of epochs to train for
    logging_dir="./logs",               # Directory for storing logs
    logging_steps=10,                   # Log after every 10 steps
    load_best_model_at_end=True,        # Load the best model at the end based on evaluation metrics
    save_total_limit=2,                 # Limit the number of saved checkpoints
    eval_strategy="epoch",              # Evaluate after every epoch
    save_strategy="epoch",              # Save checkpoint after every epoch
    learning_rate=5e-5,                 # Learning rate (adjust based on your model size and dataset)
    report_to="none",                   # Avoid reporting to external services like WandB, TensorBoard
    fp16=use_fp16,                      # Enable mixed-precision training (if your hardware supports it)
    gradient_accumulation_steps=2,      # Accumulate gradients over multiple steps before updating
    use_cpu=use_cpu                     # Use CPU for training
)

Training on CPU: False
Training with FP16: True


In [15]:
# Compute cosine similarity
def cosine_similarity(embeddings_1, embeddings_2):
    return torch.matmul(embeddings_1, embeddings_2.T)

# Compute retrieval accuracy
def compute_metrics(eval_pred: EvalPrediction):
    # Separate image and text embeddings from eval_pred
    image_embeddings, text_embeddings = eval_pred.predictions
    batch_size = image_embeddings.shape[0]

    # Normalize embeddings
    image_embeddings = torch.nn.functional.normalize(image_embeddings, dim=-1)
    text_embeddings = torch.nn.functional.normalize(text_embeddings, dim=-1)

    # Compute similarity matrix (cosine similarity)
    similarity_matrix = cosine_similarity(image_embeddings, text_embeddings)
    
    # Ground truth: diagonal contains correct pairs (as we use pairs from the dataset)
    ground_truth = torch.arange(batch_size)

    # Evaluate image-to-text and text-to-image retrieval
    image_to_text_accuracy = (similarity_matrix.argmax(dim=-1) == ground_truth).float().mean().item()
    text_to_image_accuracy = (similarity_matrix.argmax(dim=0) == ground_truth).float().mean().item()

    return {
        "image_to_text_accuracy": image_to_text_accuracy,
        "text_to_image_accuracy": text_to_image_accuracy
    }

def compute_loss(model, inputs, return_outputs=False):
    # Forward pass
    outputs = model(**inputs)
    
    # Extract logits for image and text
    logits_per_image = outputs.logits_per_image
    logits_per_text = outputs.logits_per_text
    
    # Calculate the contrastive loss (cross-entropy)
    loss_fct = torch.nn.CrossEntropyLoss()

    # Labels are the diagonal of the similarity matrix (0 to batch_size-1)
    labels = torch.arange(logits_per_image.size(0), device=logits_per_image.device)

    # Compute the contrastive loss (sum of image-to-text and text-to-image)
    loss_img_to_text = loss_fct(logits_per_image, labels)
    loss_text_to_img = loss_fct(logits_per_text, labels)
    loss = (loss_img_to_text + loss_text_to_img) / 2

    return (loss, outputs) if return_outputs else loss

def custom_collate_fn(batch):
    keys = batch[0].keys()
    collated = {}

    for key in keys:
        if key == 'pixel_values':
            # Assuming all images are already the same size
            collated[key] = torch.stack([item[key] for item in batch])
        elif key == 'input_ids' or key == 'attention_mask':
            # Pad sequences to the maximum length in the batch
            collated[key] = pad_sequence([item[key] for item in batch], batch_first=True, padding_value=0)
        else:
            # For any other keys, just stack them
            collated[key] = torch.stack([item[key] for item in batch])

    return collated

### Training

In [16]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return compute_loss(model, inputs, return_outputs)

In [17]:
trainer = CustomTrainer(
    model=lora_model,                       # LoRA model
    args=training_args,                     # Training arguments defined above
    train_dataset=train_dataset,            # Training dataset
    eval_dataset=val_dataset,               # Evaluation dataset
    compute_metrics=compute_metrics,        # Custom function to compute accuracy
    data_collator=custom_collate_fn         # Custom collate function to handle mixed data types
)

In [None]:
# Train the model
trainer.train()