In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor
import torch.nn.functional as F
from itertools import combinations
from tqdm import tqdm

In [2]:
try:

    root = "/content/drive/MyDrive/MASTER_THESIS/"

    from google.colab import drive
    drive.mount('/content/drive')

    df = pd.read_csv(root + "full_dataset_moondream_captions.csv")
    N = len(df)
    BATCH_SIZE = 256
    model_path = root + "MODELS/"

    print("Running on Google Colab")

except:
    print("Running on local machine")

    df = pd.read_csv('../../../private_data/CSV/full_dataset_moondream_captions.csv')
    N = 500
    BATCH_SIZE = 8
    model_path = "../../../private_data/MODELS/"
#

df["EN"] = df["EN"].apply(lambda x: x.split("..")[0])
df["FR"] = df["FR"].apply(lambda x: x.split("..")[0])
df["NL"] = df["NL"].apply(lambda x: x.split("..")[0])   

df["EN"] = df["EN"].apply(lambda x: x.split(". .")[0])
df["FR"] = df["FR"].apply(lambda x: x.split(". .")[0])
df["NL"] = df["NL"].apply(lambda x: x.split(". .")[0])   

df = df.iloc[:N]
mean_length_EN = df["EN"].apply(lambda x: len(x.split())).mean()
mean_length_FR = df["FR"].apply(lambda x: len(x.split())).mean()
print(f"Mean length of EN captions: {mean_length_EN}")
print(f"Mean length of FR captions: {mean_length_FR}")
df

Running on local machine
Mean length of EN captions: 38.63
Mean length of FR captions: 38.618


Unnamed: 0,recordID,task,EN,FR,NL
0,64,caption,A religious scene features a central figure o...,Une scène religieuse présente une figure centr...,Een religieuze scène toont een centrale figuur...
1,64,What objects do you see ?,"In the image, there are two people on a cross...","À l'image, il y a deux personnes sur une croix...",In het beeld zijn er twee mensen aan het kruis...
2,64,What colors do you see ?,The image features a painting with a predomin...,L'image présente une peinture avec un schéma d...,De afbeelding is voorzien van een schilderij m...
3,64,Is this image bright or dark ?,The image is dark.,L'image est sombre.,Het beeld is donker.
4,64,What emotion do you feel when looking at this ...,"When looking at this image, I feel a sense of...","En regardant cette image, je ressens un sentim...","Als ik naar dit beeld kijk, voel ik een gevoel..."
...,...,...,...,...,...
495,348,caption,A nude woman in a red dress and a nude man in...,Une femme nue dans une robe rouge et un homme ...,Een naakte vrouw in een rode jurk en een naakt...
496,348,What objects do you see ?,"In the image, there are two men depicted in a...","Dans l'image, il y a deux hommes représentés d...",In het beeld zijn er twee mannen afgebeeld in ...
497,348,What colors do you see ?,The image features a painting with a man and ...,L'image comporte une peinture avec un homme et...,De afbeelding is voorzien van een schilderij m...
498,348,Is this image bright or dark ?,The image is bright.,L'image est lumineuse.,Het beeld is helder.


In [3]:
base_model_base = "openai/clip-vit-large-patch14"
processor = CLIPProcessor.from_pretrained(base_model_base)
model = CLIPModel.from_pretrained(base_model_base)
tokenizer = CLIPTokenizer.from_pretrained(base_model_base)
model_weights_path = model_path + "art-base.pt"
model.load_state_dict(torch.load(model_weights_path, weights_only=True))

<All keys matched successfully>

In [4]:
def getNumberOfParamsWithGrad(model):
    num_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params += param.numel()
    return num_params

nb_params = getNumberOfParamsWithGrad(model)
print(f"Number of parameters with grad: {nb_params}")

# Freeze vision encoder and other non-text layers
for name, param in model.named_parameters():
    if ("text_model" not in name) and ("logit_scale" not in name):
        param.requires_grad = False

nb_params = getNumberOfParamsWithGrad(model)
print(f"Number of parameters with grad after freezing: {nb_params}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.train()

device

Number of parameters with grad: 427616513
Number of parameters with grad after freezing: 123060481


device(type='cuda')

In [5]:
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6)

In [6]:
def compute_embeddings(dataloader, silent=False):
    embeddings = []
    with torch.no_grad():
        if silent:
            iterator = dataloader
        else:
            iterator = tqdm(dataloader, desc="Computing embeddings")

        for batch in iterator:
            inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)
            outputs = model.get_text_features(**inputs)
            # Normalize the embeddings
            outputs = F.normalize(outputs, p=2, dim=-1)
            embeddings.append(outputs.cpu())
    return torch.cat(embeddings)

In [7]:
# Compute the embeddings of the anchors
class TextDataset(Dataset):
    def __init__(self, df, column):
        self.texts = df[column].tolist()

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

dataset_anchors = TextDataset(df, 'EN')
dataloader_anchors = DataLoader(dataset_anchors, batch_size=BATCH_SIZE, shuffle=True)
len(dataset_anchors), len(dataloader_anchors)

A = compute_embeddings(dataloader_anchors)
A.shape

class BoatsDataset(Dataset):
    def __init__(self, anchors_embeddings, df, columns):
        self.boats = [df[col].tolist() for col in columns]
        self.size = len(self.boats[0])
        self.anchors_embeddings = anchors_embeddings

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return {
            "anchors_embeddings": self.anchors_embeddings[idx],
            "boats": [boat[idx] for boat in self.boats],
        }

dataset = BoatsDataset(A, df, ['FR']) # , 'NL'
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

for batch in dataloader:
    anchors_embeddings = batch["anchors_embeddings"]
    boats = batch["boats"]
    print(f"Anchors embeddings shape: {anchors_embeddings.shape}")
    print(f"Boats shape: {len(boats)}")
    boat_example = boats[0]
    print(f"Boat example shape: {len(boat_example)}")
    break

len(dataset), len(dataloader)

Computing embeddings: 100%|██████████| 63/63 [00:01<00:00, 41.96it/s]

Anchors embeddings shape: torch.Size([8, 768])
Boats shape: 1
Boat example shape: 8





(500, 63)

In [8]:
def perfAnalysis():
    # Compute the mean cosine similarity between the anchor and boat embeddings
    def compute_mean_cosine_similarity(anchors_embeddings, boats):
        boat_embeddings = compute_embeddings(boats, silent=True)
        similarities = []
        for anchor_embedding in anchors_embeddings:
            similarity = F.cosine_similarity(anchor_embedding, boat_embeddings)
            similarities.append(similarity.mean().item())
        return sum(similarities) / len(similarities)
    
    # Compute the mean cosine similarity for each batch
    mean_similarities = []
    for batch in dataloader:
        anchors_embeddings = batch["anchors_embeddings"]
        boats = batch["boats"]
        mean_similarity = compute_mean_cosine_similarity(anchors_embeddings, boats)
        mean_similarities.append(mean_similarity)

    # Compute the overall mean cosine similarity
    overall_mean_similarity = sum(mean_similarities) / len(mean_similarities)

    print(f"Overall Mean Cosine Similarity: {overall_mean_similarity:.4f}") 

In [9]:
def anchor_cosine_loss(anchor_embeddings, text_features):
    loss = 1 - F.cosine_similarity(anchor_embeddings, text_features, dim=1)
    return loss.mean()

def anchor_cosine_margin_loss(anchor_embeddings, text_features, margin=0.3):
    cosine_sim = F.cosine_similarity(anchor_embeddings, text_features, dim=1)
    # Margin pushes similarity to be > margin
    loss = F.relu(margin - cosine_sim)
    return loss.mean()

N = 5

model.eval()
print("Computing initial performance analysis...")
perfAnalysis()

model.train()

# Training loop
for epoch in range(N):
    total_loss = 0
    tqdm_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{N}", total=len(dataloader))
    i = 0
    for batch in tqdm_bar:
        # batch: list of 32 elements, each is a list of 3 strings
        # Flatten: ["sent1_en", "sent1_fr", "sent1_nl", "sent2_en", ...]
        others_texts = batch["boats"]
        anchors_embeddings = batch["anchors_embeddings"]
    
        # Flatten others_texts
        # Clone anchors_embeddings for each boat
        anchors_embeddings_repeated = []
        for i in range(len(others_texts)):
            anchors_embeddings_repeated.append(anchors_embeddings)
        anchors_embeddings_repeated = torch.cat(anchors_embeddings_repeated, dim=0)

        boats = [text for sublist in others_texts for text in sublist]

        anchors_embeddings = anchors_embeddings_repeated.to(device)

        inputs = tokenizer(boats, padding=True, return_tensors="pt", truncation=True).to(device)
        text_features = model.get_text_features(**inputs)
        text_features = F.normalize(text_features, dim=1)

        # Compute cosine similarity loss for each group
        batch_loss = anchor_cosine_margin_loss(anchors_embeddings, text_features)
        batch_loss /= len(boats)

        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += batch_loss.item()
        i += 1 
        tqdm_bar.set_postfix(loss=total_loss)

    model.eval()
    perfAnalysis()
    model.train()

    # Print average loss for the epoch
    tqdm_bar.close()

    print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(dataloader):.4f}")

Computing initial performance analysis...
Overall Mean Cosine Similarity: 0.2044


Epoch 1/5: 100%|██████████| 63/63 [00:05<00:00, 11.59it/s, loss=0.208]


Overall Mean Cosine Similarity: 0.3930
Epoch 1 - Avg Loss: 0.0033


Epoch 2/5: 100%|██████████| 63/63 [00:05<00:00, 11.24it/s, loss=0.0067] 


Overall Mean Cosine Similarity: 0.4020
Epoch 2 - Avg Loss: 0.0001


Epoch 3/5: 100%|██████████| 63/63 [00:05<00:00, 11.48it/s, loss=0.00076] 


Overall Mean Cosine Similarity: 0.3973
Epoch 3 - Avg Loss: 0.0000


Epoch 4/5: 100%|██████████| 63/63 [00:05<00:00, 11.94it/s, loss=0.000575]


Overall Mean Cosine Similarity: 0.4019
Epoch 4 - Avg Loss: 0.0000


Epoch 5/5: 100%|██████████| 63/63 [00:05<00:00, 11.57it/s, loss=0.00105] 


Overall Mean Cosine Similarity: 0.4021
Epoch 5 - Avg Loss: 0.0000


In [None]:
# Save the model weights
model_weights_path = model_path + "art-base-TextFT.pt"
torch.save(model.state_dict(), model_weights_path)
#tokenizer.save_pretrained(model_weights_path)

: 