In [None]:
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import os
from ast import literal_eval
from dotenv import load_dotenv 

In [25]:
# loading variables from .env file
load_dotenv("../../private_data/.env") 

# PARENT gets us to the root of the project
PARENT = "./../../"

FOLDER_TABLE = PARENT + os.getenv("FOLDER_TABLE")
FILE_FABRITIUS_DATA = PARENT + os.getenv("FILE_FABRITIUS_DATA")
FILE_FABRITIUS_DATA_FILTERED = PARENT + os.getenv("FILE_FABRITIUS_DATA_FILTERED")
FILE_FABRITIUS_DATA_FILTERED_DOWNLOADED = PARENT + os.getenv("FILE_FABRITIUS_DATA_FILTERED_DOWNLOADED")
FOLDER_FIGURES = PARENT + os.getenv("FOLDER_FIGURES")
IMAGES_FOLDER = PARENT + os.getenv("IMAGES_FOLDER")

DB_INPUT_ARTPIECES = PARENT + os.getenv("DB_INPUT_ARTPIECES")
DB_INPUT_ARTISTS = PARENT + os.getenv("DB_INPUT_ARTISTS")

BENCHMARK_1 = PARENT + os.getenv("BENCHMARK_1")

FILE_SUBJECTMATTERS_PARSED = PARENT + os.getenv("FILE_SUBJECTMATTERS_PARSED")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [26]:
# Get the artworks data
ARTWORKS = pd.read_csv(DB_INPUT_ARTPIECES)

In [27]:
# Methods to get an image from a recordID
def fixPath(path):
    return path.replace(".././", "../")

recordID_to_imageLowResFilename = {}
for index, row in ARTWORKS.iterrows():
    recordID = row["recordID"]
    path = row["imageLowResFilename"]

    # Fix the imageLowResFilename
    path = fixPath(IMAGES_FOLDER + path[1:])

    path = path.replace("internet", "Internet")
    path = path.replace("Mod", "mod")
    path = path.replace("MOD", "mod")
    path = path.replace("Old", "old")
    path = path.replace("Stefaan", "stefaan")
    path = path.replace("Art-Foto", "art-foto")
    path = path.replace("\\", "/")

    recordID_to_imageLowResFilename[recordID] = path

# Test it
for recordID, path in recordID_to_imageLowResFilename.items():
    if not os.path.exists(path):
        print(f"Path does not exist: {path}")

In [28]:
def v_literal_eval(val):
    try:
        return literal_eval(val)
    except (ValueError, SyntaxError):
        return np.nan
DATA = pd.read_csv(BENCHMARK_1, converters={'additional_info_fr': v_literal_eval, 'additional_info_en': v_literal_eval, 'additional_info_nl': v_literal_eval})
unique_focus = DATA['focus'].unique()
example_df = None
for focus in unique_focus:
    focus_data = DATA[DATA['focus'] == focus]
    if example_df is None:
        example_df = focus_data.sample(1)
    else:
        example_df = pd.concat([example_df, focus_data.sample(1)], ignore_index=True)

example_df.drop(columns=['tokenized_length_fr', 'tokenized_length_en', 'tokenized_length_nl'], inplace=True)
example_df

Unnamed: 0,recordID,category,focus,caption_fr,caption_en,caption_nl,additional_info_fr,additional_info_en,additional_info_nl
0,8291,Tableau,colors,"Une vieille femme sans abri au coin d'une rue,...",An old homeless woman at the corner of a stree...,Een oude dakloze vrouw op de hoek van een stra...,[Couleurs sombres],[Dark colors],[Donkere kleuren]
1,4996,Tableau,content,Un portrait d'une femme en robe avec un chapea...,"A portrait of a woman in a dress with a hat, a...",Een portret van een vrouw in een jurk met een ...,,,
2,10795,Tableau,emotion,Ferme le long d'un ruisseau où des vaches boie...,"Closes along a stream where cows drink, trees,...","Sluit langs een beekje waar koeien drinken, bo...","[Émotion neutre, Émotion joie]","[Neutral emotion, Joy emotion]","[Emotie neutraal, Vreugde-emotie]"
3,1820,Tableau,luminosity,Un femme avec un voile et un baton prie en reg...,A woman with a veil and a baton prays looking ...,Een vrouw met een sluier en een stokje bidt na...,[Luminosité sombre],[Dark luminosity],[Donkere helderheid]


In [29]:
TESTING = False
if TESTING:
    unique_recordIDs = list(DATA['recordID'].unique())
    k = 5
    first_k_recordIDs = unique_recordIDs[:k]
    # Filter DATA to only include the first k recordIDs
    filtered_DATA = DATA[DATA['recordID'].isin(first_k_recordIDs)]
    DATA = filtered_DATA
DATA

Unnamed: 0,recordID,category,focus,caption_fr,caption_en,caption_nl,additional_info_fr,additional_info_en,additional_info_nl,tokenized_length_fr,tokenized_length_en,tokenized_length_nl
0,10869,Dessin,colors,Un prêtre prononce un sermon les mains écartée...,A priest delivers a sermon with his hands apar...,Een priester geeft een preek met zijn handen u...,"[Bicolore, Couleur rouge]","[Bicolor, Red color]","[Bicolor, Rode kleur]",33,22,38
1,6146,Tableau,colors,"Un homme, contre une charette, une femme assis...","A man, against a charette, a woman sitting on ...","Een man, tegen een charette, een vrouw die op ...","[Couleurs vives, Couleur verte]","[Bright colors, Green color]","[Heldere kleuren, Groene kleur]",45,34,47
2,959,Tableau,colors,Un portrait d'une femme avec un habit bleu et ...,A portrait of a woman with a blue and white dr...,Een portret van een vrouw met een blauwe en wi...,[Couleur neutre],[Neutral color],[Neutrale kleur],40,31,47
3,579,Tableau,colors,Portrait d'une jeune femme avec des cheveux no...,"Portrait of a young woman with black hair, whi...","Portret van een jonge vrouw met zwart haar, wi...",[Couleur noir et blanc],[Black and white color],[Zwart-wit kleur],34,24,39
4,10351,Dessin,colors,Un dessin d'un personnage couché avec un habit...,A drawing of a character lying with a large dr...,Een tekening van een personage liggend met een...,[Couleur noir et blanc],[Black and white color],[Zwart-wit kleur],25,18,26
...,...,...,...,...,...,...,...,...,...,...,...,...
1811,1488,Tableau,luminosity,"Une place avec une foule vêtue de noir, une fe...","A place with a crowd dressed in black, a woman...","Een plek met een menigte gekleed in het zwart,...",[Luminosité lumineuse],[Bright luminosity],[Heldere helderheid],61,40,65
1812,1724,Tableau,luminosity,"Un foule qui fait la fête en extérieur, vêteme...","A crowd partying outside, colorful clothes, vi...","Een menigte feesten buiten, kleurrijke kleren,...",[Luminosité neutre],[Neutral luminosity],[Neutrale helderheid],38,17,32
1813,7059,Dessin,luminosity,Un texte entouré d'enfants anges avec des femm...,A text surrounded by children angels with nake...,Een tekst omringd door kinderen engelen met na...,[Luminosité lumineuse],[Bright luminosity],[Heldere helderheid],34,20,39
1814,8081,Tableau,luminosity,Un homme torse nu et un homme avec un habit bl...,A man with a naked torso and a man with a blue...,Een man met een naakte romp en een man met een...,[Luminosité neutre],[Neutral luminosity],[Neutrale helderheid],54,42,56


In [30]:
assert type(example_df["additional_info_en"].values[0]) == list

# Get the model

In [31]:
base_model_february_finetuned = "openai/clip-vit-large-patch14" # art-base
base_model_march_finetuned = "openai/clip-vit-large-patch14" # art-base

base_model_mini = "openai/clip-vit-base-patch32" # art-mini
base_model_base = "openai/clip-vit-large-patch14" # art-base
base_model_large = "openai/clip-vit-large-patch14-336" # art-large

basic_mini = "openai/clip-vit-base-patch32"
basic_base = "openai/clip-vit-large-patch14"
basic_large = "openai/clip-vit-large-patch14-336"

In [32]:
#model_name = "basic-large" # OK
#model_name = "art-large" # OK

#model_name = "art-base" # OK
#model_name = "basic-base" # OK
#model_name = "february_finetuned" # OK
#model_name = "march_finetuned" # OK

#model_name = "art-mini" # OK
#model_name = "basic-mini" # OK
model_name = "art-base-TextFT"

print(f"Running benchmark on: {model_name}")

Running benchmark on: art-base-TextFT


In [33]:
# Create folder to export the results
RESULT_FOLDER = "../benchmarks/benchmark_1"
os.makedirs(RESULT_FOLDER, exist_ok=True)

In [34]:
root = "../../private_data/MODELS/"

In [None]:
if model_name == "february_finetuned":
  processor = CLIPProcessor.from_pretrained(base_model_base)
  model = CLIPModel.from_pretrained(base_model_base).to(device)
  model_weights_path = root + "2025-02-05 17_09_07_allFocus_5.pt"
  model.load_state_dict(torch.load(model_weights_path, weights_only=True))
  BATCH_SIZE = 2 # 32

elif model_name == "march_finetuned":
  processor = CLIPProcessor.from_pretrained(base_model_base)
  model = CLIPModel.from_pretrained(base_model_base).to(device)
  model_weights_path = root + "2025-03-29 16 59 53_allFocus_5.pt"
  model.load_state_dict(torch.load(model_weights_path, weights_only=True))
  BATCH_SIZE = 2 # 32

elif model_name == "art-mini":
  processor = CLIPProcessor.from_pretrained(base_model_mini)
  model = CLIPModel.from_pretrained(base_model_mini).to(device)
  model_weights_path = root + "art-mini.pt"
  model.load_state_dict(torch.load(model_weights_path, weights_only=True))
  BATCH_SIZE = 8 # 256

elif model_name == "art-base":
  processor = CLIPProcessor.from_pretrained(base_model_base)
  model = CLIPModel.from_pretrained(base_model_base).to(device)
  model_weights_path = root + "art-base.pt"
  model.load_state_dict(torch.load(model_weights_path, weights_only=True))
  BATCH_SIZE = 2 # 32

elif model_name == "art-large":
  processor = CLIPProcessor.from_pretrained(base_model_large)
  model = CLIPModel.from_pretrained(base_model_large).to(device)
  model_weights_path = root + "art-large.pt"
  model.load_state_dict(torch.load(model_weights_path, weights_only=True))
  BATCH_SIZE = 1 # 16

elif model_name == "basic-mini":
  processor = CLIPProcessor.from_pretrained(basic_mini)
  model = CLIPModel.from_pretrained(basic_mini).to(device)
  BATCH_SIZE = 8 # 256

elif model_name == "basic-base":
  processor = CLIPProcessor.from_pretrained(basic_base)
  model = CLIPModel.from_pretrained(basic_base).to(device)
  BATCH_SIZE = 2 # 32

elif model_name == "basic-large":
  processor = CLIPProcessor.from_pretrained(basic_large)
  model = CLIPModel.from_pretrained(basic_large).to(device)
  BATCH_SIZE = 1 # 16
else:
  processor = CLIPProcessor.from_pretrained(base_model_base)
  model = CLIPModel.from_pretrained(base_model_base).to(device)
  model_weights_path = root + f"{model_name}.pt"
  model.load_state_dict(torch.load(model_weights_path, weights_only=True))
  BATCH_SIZE = 2 # 32

print(f"Running benchmark on: {model_name} with batch size: {BATCH_SIZE}")

Running benchmark on: art-base-TextFT with batch size: 2


# Configure the benchmark

In [36]:
best_workers = 0
criterion = nn.CrossEntropyLoss()

# Create the -PROMPT and -MIXED datasets

In [37]:
artworks_recordIDs = sorted(list(DATA['recordID'].unique()))
print(len(artworks_recordIDs))

454


## Create a dataset to compute the embeddings of every image and every text

In [38]:
class ArtworksImages(Dataset):
    def __init__(self, recordIDs):
        self.recordIDs = recordIDs

    def __len__(self):
        return len(self.recordIDs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        recordID = self.recordIDs[idx]
        path = recordID_to_imageLowResFilename[recordID]

        image = Image.open(path)

        return image

def ArtworksImagesBBuilder(images):
    inputs = processor(text=[""] * len(images), images=images, return_tensors="pt", padding=True, truncation=True)
    return inputs

artworks_dataset = ArtworksImages(artworks_recordIDs)
artworks_dataloader = DataLoader(
    artworks_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=best_workers,
    collate_fn=ArtworksImagesBBuilder,
    pin_memory=True,
)

def compute_images_embeddings(dataloader, model, device):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for sample in tqdm(dataloader, desc="Computing artworks embeddings", unit="batch"):
            input_ids = sample['input_ids'].to(device)
            attention_mask = sample['attention_mask'].to(device)
            pixel_values = sample['pixel_values'].to(device)

            # Compute image embeddings
            image_features = model.get_image_features(pixel_values=pixel_values)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            image_features = image_features.flatten(1)

            embeddings.append(image_features)

    embeddings = torch.cat(embeddings, dim=0)
    embeddings = embeddings.cpu().numpy()
    return embeddings

artworks_embeddings_np_array = compute_images_embeddings(artworks_dataloader, model, device)
artworks_embeddings = {}
for i, recordID in enumerate(artworks_recordIDs):
    artworks_embeddings[recordID] = artworks_embeddings_np_array[i]

Computing artworks embeddings: 100%|██████████| 227/227 [00:12<00:00, 18.15batch/s]


In [39]:
class PromptCaptions(Dataset):
    def __init__(self, recordIDs, df, lang, focus):
        assert lang in ['en', 'fr', 'nl'], "Language must be one of ['en', 'fr', 'nl']"
        assert focus in df.focus.unique(), f"Focus must be one of {df.focus.unique()}"

        self.recordIDs = recordIDs
        self.captions = []

        rows_with_focus = df[df['focus'] == focus]

        for recordID in recordIDs:
            row = rows_with_focus[rows_with_focus['recordID'] == recordID].iloc[0] # Should only be one row !
            focus = row['focus']
            caption = row[f"caption_{lang}"]
            additional_info = row[f"additional_info_{lang}"]

            if focus=="content":
                # additional_info is NaN, we just use the caption
                merged_caption = caption
            else:
                # additional_info is not NaN, we use the caption + additional_info
                merged_caption = caption + " " + ", ".join(additional_info)

            self.captions.append(merged_caption)

    def __len__(self):
        return len(self.recordIDs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        caption = self.captions[idx]
        return caption

def PromptCaptionsBBuilder(samples):
    inputs = processor(text=samples, return_tensors="pt", padding=True, truncation=True)
    return inputs

def compute_captions_embeddings(dataloader, model, device, tqdm_text):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for sample in tqdm(dataloader, desc=f"Computing captions embeddings: {tqdm_text}", unit="batch"):
            input_ids = sample['input_ids'].to(device)
            attention_mask = sample['attention_mask'].to(device)

            # Compute the embeddings
            text_features = model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            text_features = text_features.flatten(1)

            embeddings.append(text_features)

    embeddings = torch.cat(embeddings, dim=0)
    embeddings = embeddings.cpu().numpy()
    return embeddings

prompt_embeddings = {}
for lang in ['en', 'fr', 'nl']:
    prompt_embeddings[lang] = {}
    for focus in unique_focus:
        prompt_dataset = PromptCaptions(artworks_recordIDs, DATA, lang, focus)
        prompt_dataloader = DataLoader(
            prompt_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=best_workers,
            collate_fn=PromptCaptionsBBuilder,
            pin_memory=True,
        )

        tqdm_text = f"Language: {lang}, Focus: {focus}"
        raw_embeddings = compute_captions_embeddings(prompt_dataloader, model, device, tqdm_text)
        prompt_embeddings[lang][focus] = raw_embeddings

Computing captions embeddings: Language: en, Focus: colors: 100%|██████████| 227/227 [00:01<00:00, 117.37batch/s]
Computing captions embeddings: Language: en, Focus: content: 100%|██████████| 227/227 [00:01<00:00, 118.14batch/s]
Computing captions embeddings: Language: en, Focus: emotion: 100%|██████████| 227/227 [00:01<00:00, 114.47batch/s]
Computing captions embeddings: Language: en, Focus: luminosity: 100%|██████████| 227/227 [00:01<00:00, 120.74batch/s]
Computing captions embeddings: Language: fr, Focus: colors: 100%|██████████| 227/227 [00:01<00:00, 119.92batch/s]
Computing captions embeddings: Language: fr, Focus: content: 100%|██████████| 227/227 [00:01<00:00, 120.36batch/s]
Computing captions embeddings: Language: fr, Focus: emotion: 100%|██████████| 227/227 [00:01<00:00, 119.41batch/s]
Computing captions embeddings: Language: fr, Focus: luminosity: 100%|██████████| 227/227 [00:01<00:00, 118.45batch/s]
Computing captions embeddings: Language: nl, Focus: colors: 100%|██████████|

In [40]:
class AdditionalInformation(Dataset):
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.texts[idx]

def AdditionalInformationBBuilder(samples):
    inputs = processor(text=samples, return_tensors="pt", padding=True, truncation=True)
    return inputs

def compute_additional_infos_embeddings(dataloader, model, device):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for sample in tqdm(dataloader, desc="Computing additional informations embeddings", unit="batch"):
            input_ids = sample['input_ids'].to(device)
            attention_mask = sample['attention_mask'].to(device)

            # Compute the embeddings
            text_features = model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            text_features = text_features.flatten(1)

            embeddings.append(text_features)

    embeddings = torch.cat(embeddings, dim=0)
    embeddings = embeddings.cpu().numpy()
    return embeddings

unique_focus_except_content = list(unique_focus.copy())
unique_focus_except_content.remove("content")

additional_infos_embeddings = {}
for lang in ['en', 'fr', 'nl']:
    additional_infos_embeddings[lang] = {}
    for focus in unique_focus_except_content:

        rows_with_focus = DATA[DATA['focus'] == focus]
        columnName = f"additional_info_{lang}"

        unique_additional_infos = set()
        for i, row in rows_with_focus.iterrows():
          add_info = row[columnName]
          for value in add_info:
            unique_additional_infos.add(value)

        unique_additional_infos = sorted(list(unique_additional_infos))

        prompt_dataset = AdditionalInformation(unique_additional_infos)
        prompt_dataloader = DataLoader(
            prompt_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=best_workers,
            collate_fn=AdditionalInformationBBuilder,
            pin_memory=True,
        )

        raw_embeddings = compute_additional_infos_embeddings(prompt_dataloader, model, device)
        additional_infos_embeddings[lang][focus] = {}
        for i, unique_additional_info in enumerate(unique_additional_infos):
            additional_infos_embeddings[lang][focus][unique_additional_info] = raw_embeddings[i]

Computing additional informations embeddings: 100%|██████████| 4/4 [00:00<00:00, 114.29batch/s]
Computing additional informations embeddings: 100%|██████████| 4/4 [00:00<00:00, 112.69batch/s]
Computing additional informations embeddings: 100%|██████████| 2/2 [00:00<00:00, 111.12batch/s]
Computing additional informations embeddings: 100%|██████████| 4/4 [00:00<00:00, 115.94batch/s]
Computing additional informations embeddings: 100%|██████████| 4/4 [00:00<00:00, 101.27batch/s]
Computing additional informations embeddings: 100%|██████████| 2/2 [00:00<00:00, 114.27batch/s]
Computing additional informations embeddings: 100%|██████████| 4/4 [00:00<00:00, 117.65batch/s]
Computing additional informations embeddings: 100%|██████████| 4/4 [00:00<00:00, 117.64batch/s]
Computing additional informations embeddings: 100%|██████████| 2/2 [00:00<00:00, 121.21batch/s]


# Benchmarking methods

In [41]:
def get_average_position(cosine_similarities):
    positions = []
    average_position = 0
    for i in range(len(cosine_similarities)):
        sorted_indices = np.argsort(cosine_similarities[i])[::-1]
        rank = np.where(sorted_indices == i)[0][0] + 1
        sorted_indices = [int(p) for p in sorted_indices]
        positions.append(sorted_indices)
        average_position += rank
    average_position /= len(cosine_similarities)
    return average_position, positions

def get_MRR(cosine_similarities):
    mrr = 0
    for i in range(len(cosine_similarities)):
        sorted_indices = np.argsort(cosine_similarities[i])[::-1]
        rank = np.where(sorted_indices == i)[0][0] + 1
        mrr += 1 / rank
    mrr /= len(cosine_similarities)
    return mrr

def get_recall_at_k(cosine_similarities, k):
    recall_at_k = 0
    for i in range(len(cosine_similarities)):
        sorted_indices = np.argsort(cosine_similarities[i])[::-1]
        if i in sorted_indices[:k]:
            recall_at_k += 1
    recall_at_k /= len(cosine_similarities)
    return recall_at_k

def get_nDCG_at_k(cosine_similarities, k):
    nDCG_at_k = 0
    for i in range(len(cosine_similarities)):
        sorted_indices = np.argsort(cosine_similarities[i])[::-1]
        rank = np.where(sorted_indices == i)[0][0] + 1
        nDCG_at_k += 1 / np.log2(rank + 1) if rank <= k else 0
    nDCG_at_k /= len(cosine_similarities)
    return nDCG_at_k

def get_metrics_row(cosine_similarities):
    recalls_k = [1,3,5,10]
    nDCG_k = [1,3,5,10]

    average_position, positions = get_average_position(cosine_similarities)
    mrr = get_MRR(cosine_similarities)
    recalls = [get_recall_at_k(cosine_similarities, k) for k in recalls_k]
    nDCGs = [get_nDCG_at_k(cosine_similarities, k) for k in nDCG_k]

    metrics = [average_position, mrr] + recalls + nDCGs

    return metrics, positions

# -PROMPT variant
Steps:
1. For each lang *l*
2. For each focus *f*
3. Get all the artworks embeddings
4. Get all the textual embeddings of the tasks for lang *l* and focus *f*
5. Measure the cosine similarity between the artworks and the textual embeddings
6. Run the metrics methods

In [42]:
def run_prompt_benchmark():
    results = pd.DataFrame(columns=[
        "variant",
        "lang",
        "focus",
        "average_position",
        "mrr",
        "recall@1",
        "recall@3",
        "recall@5",
        "recall@10",
        "nDCG@1",
        "nDCG@3",
        "nDCG@5",
        "nDCG@10"
    ])

    runs = []
    for lang in ["en", "fr", "nl"]:
        for focus in unique_focus:
            runs.append((lang, focus))

    tqdm_bar = tqdm(total=len(runs), desc="Running -PROMPT benchmark", unit="run")

    positions_per_lang = {}

    for lang, focus in runs:
        # Get the artworks embeddings
        # Get the textual embeddings of the tasks
        textual_embeddings = prompt_embeddings[lang][focus]
        # Get the cosine similarities
        cosine_similarities = textual_embeddings @ artworks_embeddings_np_array.T
        # Measure the metrics
        metrics, positions = get_metrics_row(cosine_similarities)

        results.loc[len(results)] = ["-PROMPT", lang, focus] + metrics
        positions_per_lang[(lang, focus)] = positions

        tqdm_bar.update(1)
        tqdm_bar.set_description(f"Last result: {lang}-{focus} : MRR: {metrics[1]:.4f}")

    return results, positions_per_lang

# -MIXED variant
Steps:
1. For each lang *l*
2. For each focus *f* (except *f=content*)
3. Get all the artworks embeddings
4. For a task $t=(l,f,recordID)$, get the textual embedding of the caption $c$ and the textual embeddings of the $n$ additional information $a_{i \in \{1,2,\ldots,n\}}$
5. Mix $c$ and the $a$ embeddings
6. Normalize the mixed embedding
7. Measure the cosine similarity between the artworks and the mixed embeddings
8. Run the metrics methods

In [43]:
def run_mixed_benchmark():
    results = pd.DataFrame(columns=[
        "variant",
        "lang",
        "focus",
        "average_position",
        "mrr",
        "recall@1",
        "recall@3",
        "recall@5",
        "recall@10",
        "nDCG@1",
        "nDCG@3",
        "nDCG@5",
        "nDCG@10"
    ])

    runs = []
    for lang in ["en", "fr", "nl"]:
        for focus in unique_focus_except_content:
            runs.append((lang, focus))

    tqdm_bar = tqdm(total=len(runs), desc="Running -MIXED benchmark", unit="run")

    positions_per_lang = {}

    for lang, focus in runs:
        # Get the artworks embeddings

        # Get the mixed embeddings of the tasks
        rows_focus = DATA[DATA['focus'] == focus]
        mixed_embeddings = []

        caption_embeddings = prompt_embeddings[lang]["content"] # We always use the content captions as it is not merged with additional info :)
        for recordID in artworks_recordIDs:
            row_recordID = rows_focus[rows_focus['recordID'] == recordID].iloc[0] # Should only be one row !
            additional_info_values = row_recordID[f"additional_info_{lang}"]

            recordID_index = artworks_recordIDs.index(recordID)
            caption_embedding = caption_embeddings[recordID_index]

            all_embeddings_for_recordID = [caption_embedding]
            for additional_info in additional_info_values:
                additional_info_embdding = additional_infos_embeddings[lang][focus][additional_info]
                all_embeddings_for_recordID.append(additional_info_embdding)

            # Convert all_embeddings_for_recordID to torch
            all_embeddings_for_recordID = [torch.from_numpy(emb) for emb in all_embeddings_for_recordID]

            # Get the average of the embeddings
            mixed_embedding = torch.mean(torch.stack(all_embeddings_for_recordID), dim=0)
            # Normalize the embedding
            mixed_embedding = mixed_embedding / mixed_embedding.norm(dim=-1, keepdim=True)
            mixed_embeddings.append(mixed_embedding)

        mixed_embeddings = torch.stack(mixed_embeddings)
        mixed_embeddings = np.array(mixed_embeddings)

        # Get the cosine similarities
        cosine_similarities = mixed_embeddings @ artworks_embeddings_np_array.T
        # Measure the metrics
        metrics, positions = get_metrics_row(cosine_similarities)

        results.loc[len(results)] = ["-MIXED", lang, focus] + metrics
        positions_per_lang[(lang, focus)] = positions

        tqdm_bar.update(1)
        tqdm_bar.set_description(f"Last result: {lang}-{focus} : MRR: {metrics[1]:.4f}")

    return results, positions_per_lang

In [44]:
# Run the benchmarks !
import pickle

results_prompt, positions_prompt = run_prompt_benchmark()
results_mixed, positions_mixed = run_mixed_benchmark()
results = pd.concat([results_prompt, results_mixed], ignore_index=True)
results.to_csv(RESULT_FOLDER + f"/{model_name}_benchmark.csv", index=False)
# Save positions_prompt
with open(RESULT_FOLDER + f"/{model_name}_positions_prompt.pkl", "wb") as f:
    pickle.dump(positions_prompt, f)
with open(RESULT_FOLDER + f"/{model_name}_positions_mixed.pkl", "wb") as f:
    pickle.dump(positions_mixed, f)
results

Last result: nl-luminosity : MRR: 0.0488: 100%|██████████| 12/12 [00:00<00:00, 12.00run/s]
Last result: nl-luminosity : MRR: 0.0322: 100%|██████████| 9/9 [00:01<00:00,  4.81run/s]


Unnamed: 0,variant,lang,focus,average_position,mrr,recall@1,recall@3,recall@5,recall@10,nDCG@1,nDCG@3,nDCG@5,nDCG@10
0,-PROMPT,en,colors,11.422907,0.511141,0.356828,0.605727,0.696035,0.801762,0.356828,0.503772,0.540639,0.574943
1,-PROMPT,en,content,10.779736,0.506275,0.356828,0.605727,0.685022,0.797357,0.356828,0.500311,0.532724,0.569507
2,-PROMPT,en,emotion,15.38326,0.438448,0.290749,0.519824,0.627753,0.742291,0.290749,0.42259,0.46608,0.503758
3,-PROMPT,en,luminosity,15.030837,0.454326,0.312775,0.524229,0.618943,0.762115,0.312775,0.435517,0.474378,0.520435
4,-PROMPT,fr,colors,110.162996,0.082663,0.035242,0.077093,0.10793,0.151982,0.035242,0.059051,0.07156,0.085649
5,-PROMPT,fr,content,103.975771,0.091934,0.04185,0.0837,0.121145,0.169604,0.04185,0.066236,0.081397,0.096577
6,-PROMPT,fr,emotion,116.502203,0.070786,0.028634,0.063877,0.081498,0.136564,0.028634,0.048563,0.055959,0.073835
7,-PROMPT,fr,luminosity,117.070485,0.073006,0.030837,0.055066,0.094714,0.143172,0.030837,0.044393,0.060986,0.076548
8,-PROMPT,nl,colors,138.444934,0.0502,0.019824,0.037445,0.052863,0.090308,0.019824,0.030076,0.036331,0.048603
9,-PROMPT,nl,content,133.028634,0.056834,0.028634,0.039648,0.057269,0.092511,0.028634,0.034429,0.041632,0.053112


In [None]:
# Test load the pkl
with open(RESULT_FOLDER + f"/{model_name}_positions_prompt.pkl", "rb") as f:
    positions_prompt = pickle.load(f)

for key in positions_prompt.keys():
  lens = len(positions_prompt[key])
  lens_d2 = len(positions_prompt[key][0])
  print(f"{key}: {lens}x{lens_d2}")

('en', 'colors'): 454x454
('en', 'content'): 454x454
('en', 'emotion'): 454x454
('en', 'luminosity'): 454x454
('fr', 'colors'): 454x454
('fr', 'content'): 454x454
('fr', 'emotion'): 454x454
('fr', 'luminosity'): 454x454
('nl', 'colors'): 454x454
('nl', 'content'): 454x454
('nl', 'emotion'): 454x454
('nl', 'luminosity'): 454x454


: 