<a href="https://colab.research.google.com/github/ghostwalkin/Multimodal-finetuning-ImgText-pair/blob/main/Finetune_CLIP_Amazon_products_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets sentence-transformers huggingface-hub --quiet

# Data Download

In [2]:
from datasets import load_dataset,Dataset

In [None]:
ds=load_dataset("dparijat/amazon-image-title-triplet-250k-cleaned")

In [None]:
ds.cleanup_cache_files

# model download and freeze params


In [None]:
from sentence_transformers import SentenceTransformer

model_name="sentence-transformers/clip-ViT-L-14"
model=SentenceTransformer(model_name)

In [None]:
trainable_params=["projection"]
for name,param in model.named_parameters():
  if any(x in name for x in trainable_params):
    param.requires_grad=True
  else:
    param.requires_grad=False

In [None]:
for name,param in model.named_parameters():
  if param.requires_grad:
    print(name,param,param.shape)

# evals

*   Recall@1: from the predictions w.r.t image is the embeddings for title is in top 1 predictions



In [None]:
import torch
from torch import Tensor,nn
import torch.nn.functional as F

from sentence_transformers.evaluation import SentenceEvaluator


In [None]:
def calculate_recall_at_k(image_list:list, text_list: list, k=1):
    assert len(image_list) == len(text_list)
    image_embeddings = model.encode(image_list, batch_size=32, show_progress_bar=True, convert_to_tensor=True)
    text_embeddings = model.encode(text_list, batch_size=32, show_progress_bar=True, convert_to_tensor=True)

    correct = 0
    for i in range(len(image_list)):
        similarity = F.cosine_similarity(image_embeddings[i].unsqueeze(0), text_embeddings, dim=1)
        _, top_indices = torch.topk(similarity, k=k)
        if i in top_indices:
            correct += 1
    return correct / len(image_list)

recall_at_1 = calculate_recall_at_k(ds["train"]["anchor"], ds["train"]["positive"], k=1)
print(f"Recall@1 for train: {recall_at_1}")

In [None]:
recall_val=calculate_recall_at_k(ds["valid"]["anchor"], ds["valid"]["positive"], k=1)
print(f"Recall@1 for val: {recall_val}")

In [None]:
recall_test=calculate_recall_at_k(ds["test"]["anchor"], ds["test"]["positive"], k=1)
print(f"Recall@1 for test: {recall_test}")

In [None]:
from typing import List, Dict
from sentence_transformers.evaluation import SentenceEvaluator
class ImageTextRetrievalEvaluator(SentenceEvaluator):
    def __init__(
        self,
        images: List,
        texts: List[str],
        name: str = '',
        k: int = 1,
        batch_size: int = 32,
        show_progress_bar: bool = False
    ):
        self.images = images
        self.texts = texts
        self.name = name
        self.k = k
        self.batch_size = batch_size
        self.show_progress_bar = show_progress_bar

    def __call__(self,
        model: SentenceTransformer,
        output_path: str = None,
        epoch: int = -1,
        steps: int = -1) -> Dict[str, float]:

        # Get embeddings for all images
        # Get embeddings for all images in batches
        img_embeddings=model.encode(self.images,
            batch_size=self.batch_size,
            show_progress_bar=self.show_progress_bar,
            convert_to_tensor=True

        )
        # Get embeddings for all texts in batches
        txt_embeddings=model.encode(self.texts,
            batch_size=self.batch_size,
            show_progress_bar=self.show_progress_bar,
            convert_to_tensor=True

        )


        correct = 0
        for i in range(img_embeddings.shape[0]):
            similarity = F.cosine_similarity(img_embeddings[i].unsqueeze(0), txt_embeddings, dim=1)
            _, top_indices = torch.topk(similarity, k=1)
            if i in top_indices:
                correct += 1

        recall_at_k=correct/img_embeddings.shape[0]


        return {f'{self.name}_Recall@{self.k}': recall_at_k}



def create_recall_evaluator(set_name, k=1):
    """
        Create triplet evaluator for "train", "valid", or "test" split
    """

    return ImageTextRetrievalEvaluator(
        images=ds[f"{set_name}"]["anchor"],
        texts=ds[f"{set_name}"]["positive"],
        name=f"clip_score-{set_name}",
        k=k
    )

In [None]:
evaluator_recall_train = create_recall_evaluator("train", k=1)
evaluator_recall_valid = create_recall_evaluator("valid", k=1)


# training args and trainer

In [None]:
torch.cuda.empty_cache()

In [None]:
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.losses import MultipleNegativesRankingLoss,CachedMultipleNegativesRankingLoss

In [None]:
#@title args

#loss=MultipleNegativesRankingLoss(model)
loss=CachedMultipleNegativesRankingLoss(model)

num_epochs = 10
batch_size = 32
lr = 1e-5
finetuned_model_name = "clip-amazon-product-title-similarity-v4"

train_args = SentenceTransformerTrainingArguments(
    output_dir=f"models/{finetuned_model_name}",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    fp16=True,
    gradient_checkpointing=True,
    gradient_accumulation_steps=2,
    torch_empty_cache_steps=4,

    seed=42,
    lr_scheduler_type="linear",


    # Evaluation settings
    eval_strategy="epoch",
    eval_steps=1,
    logging_steps=1,
)

In [None]:
#@title trainer
import os
os.environ["TORCH_LOGS"] = "+dynamo"  # Enable dynamo logs
os.environ["TORCHDYNAMO_VERBOSE"] = "1"

trainer = SentenceTransformerTrainer(
    model=model,
    args=train_args,
    train_dataset=ds["train"],
    eval_dataset=ds["valid"],
    evaluator=[evaluator_recall_train, evaluator_recall_valid],
    loss=loss,
    callbacks=None
)

trainer.train()

In [None]:
model2=SentenceTransformer("/content/models/clip-amazon-product-title-similarity-v4/checkpoint-1870")

# evaluating finetuned model

In [None]:
def calculate_recall_at_1_new_model(sentencemodel,dataset_split, k=1):
    sentencemodel.eval()

    assert len(dataset_split["anchor"]) == len(dataset_split["positive"])
    image_embeddings = sentencemodel.encode(dataset_split["anchor"], batch_size=32, show_progress_bar=True, convert_to_tensor=True)
    text_embeddings = sentencemodel.encode(dataset_split["positive"], batch_size=32, show_progress_bar=True, convert_to_tensor=True)

    correct = 0
    for i in range(len(dataset_split["positive"])):
        similarity = F.cosine_similarity(image_embeddings[i].unsqueeze(0), text_embeddings, dim=1)
        _, top_indices = torch.topk(similarity, k=k)
        if i in top_indices:
            correct += 1
    return correct / len(dataset_split["positive"])

In [None]:
calculate_recall_at_1_new_model(model2,ds["test"],k=1)