## ***Load Dataset and Split it into training Validation and Test***

In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
Collecting filelock (from datasets)
  Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.66.3 (from datasets)
  Downloading tqdm-4.66.5-py3-none-any.whl.metadata (57 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9

In [2]:
from datasets import load_from_disk, DatasetDict


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = load_from_disk("./Dataset/full_meme_cap_ocr")

def get_ocr_text_list(split):
    # Join the "labels" text within each dictionary in "extracted_text", or use "" if labels are None
    ocr_texts = [
        " ".join(label for label in item.get("<OCR_WITH_REGION>", {}).get("labels", []) if label is not None)
        if item.get("<OCR_WITH_REGION>") and item["<OCR_WITH_REGION>"]["labels"] is not None
        else ""
        for item in split["extracted_text"]
    ]
    return ocr_texts


# Add a column with ocr extracted text
dataset = dataset.add_column("OCR_text", get_ocr_text_list(dataset))

In [4]:
dataset[160]

{'category': 'memes',
 'img_captions': ['Two men in brown vests are standing outside.'],
 'meme_captions': ['Meme poster will say Kanye is a klan member from a sketch rather than himself.',
  'Meme poster will tell the kids Kanye was just as much of a white supremacist as a klan member.',
  'The poster is making fun of the black man and saying they are going to tell their children in the future that he was Kanye.',
  'Poster vows to present Ye as a racist in the future.'],
 'title': 'The Ye Reicht',
 'url': 'https://farm66.staticflickr.com/65535/52761419771_382d97602b.png',
 'img_fname': 'memes_zap0cv.png',
 'metaphors': [{'meaning': 'Dave Chapelle Klan Character and white klan member',
   'metaphor': 'Two men'}],
 'post_id': 'zap0cv',
 'extracted_text': {'<OCR_WITH_REGION>': {'labels': ['</s>Gonna tell future generations this was',
    'Kanye'],
   'quad_boxes': [[18.04800033569336,
     0.10649999976158142,
     640.89599609375,
     0.10649999976158142,
     640.89599609375,
     0.

## ***Preprocess contextual meme_caption and OCR_text columns***


In [4]:
!pip install nltk

import nltk
# Ensure NLTK resources are available
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('punkt_tab')



[nltk_data] Downloading package punkt to
[nltk_data]     /home/kareem.elzeky/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/kareem.elzeky/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/kareem.elzeky/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/kareem.elzeky/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/kareem.elzeky/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [5]:
import re
import nltk
from nltk.corpus import stopwords


# Define stop words
stop_words = set(stopwords.words('english'))

def clean_text(text):

    # Replace <s> with an empty string
    text = text.replace('</s>', '')  

    # Remove special characters
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    # Convert to lowercase
    text = text.lower()

   
    # Tokenize
    words = nltk.word_tokenize(text)
    # Remove stop words
    words = [word for word in words if word not in stop_words]
    return ' '.join(words)

# Apply cleaning to the relevant columns
def preprocess(dataset):
    dataset = dataset.map(lambda x: {
        'cleaned_meme_captions': clean_text(" ".join(x['meme_captions'])),
        'cleaned_OCR': clean_text(x['OCR_text'])
    })
    return dataset

# Preprocess the dataset
dataset = preprocess(dataset)  

In [6]:
dataset

Dataset({
    features: ['category', 'img_captions', 'meme_captions', 'title', 'url', 'img_fname', 'metaphors', 'post_id', 'extracted_text', 'OCR_text', 'cleaned_meme_captions', 'cleaned_OCR'],
    num_rows: 6382
})

In [7]:
len(dataset)

6382

In [8]:
# Clean the dataset to remove any rows with empty meme captions or OCR text
dataset = dataset.filter(lambda x: x['cleaned_meme_captions'] != '' and x['cleaned_OCR'] != '')

In [9]:
len(dataset)

5754

## ***Split the dataset with clean contextual captions and OCR texts and save the splits***

In [10]:
# Split the dataset
train_val_test = dataset.train_test_split(test_size=0.2, seed=42)  # 80% train, 20% val+test
train_set = train_val_test["train"]
val_test = train_val_test["test"]

# Further split val+test to get validation and test sets
val_test_split = val_test.train_test_split(test_size=0.5, seed=42)  # 50% of val+test each
val_set = val_test_split["train"]
test_set = val_test_split["test"]

# Now you have train, validation, and test sets
print(f"Train size: {len(train_set)}")
print(f"Validation size: {len(val_set)}")
print(f"Test size: {len(test_set)}")

Train size: 4603
Validation size: 575
Test size: 576


In [11]:
dataset_dict = DatasetDict({
    "train": train_set,
    "validation": val_set,
    "test": test_set
})

In [20]:

dataset_dict.save_to_disk("./Dataset/meme_cap_splits_with_ocr_text")

print("Dataset with splits and new 'OCR_text' columns saved.")

Saving the dataset (0/1 shards):   0%|          | 0/5105 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 5105/5105 [00:00<00:00, 13578.07 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 638/638 [00:00<00:00, 15717.58 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 639/639 [00:00<00:00, 15415.36 examples/s]

Dataset with splits and new 'OCR_text and embeddings' columns saved.





## ***Load the dataset splits with cleaned contextual captions and OCR texts***

In [1]:
from datasets import load_from_disk

dataset_dict = load_from_disk("./Dataset/meme_cap_splits_with_ocr_text_and_embeddings")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_set = dataset_dict["train"]
val_set = dataset_dict["validation"]
test_set = dataset_dict["test"]

## Create triplets dataset

In [4]:

# Check if CUDA is available
if torch.cuda.is_available():
    print(f"CUDA is available. Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available.")

CUDA is available. Number of GPUs: 1
Current CUDA device: 0
CUDA device name: Quadro RTX 6000


In [17]:
import os
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel




# Step 1: Create Triplets
def create_triplets(dataset, k=5):
    triplets = []
    all_OCR_texts = dataset['cleaned_OCR']  # Extract all OCR texts for random selection
    
    for idx in range(len(dataset)):
        anchor = dataset[idx]['cleaned_meme_captions']
        positive = dataset[idx]['cleaned_OCR']
        
        # Ensure that the negatives are selected from different entries
        negatives = set()
        while len(negatives) < k:
            negative = random.choice(all_OCR_texts)
            if negative != positive and negative != anchor:  # Ensure the negative is not the same as positive or anchor
                negatives.add(negative)
        
        # Create triplets
        for negative in negatives:
            triplets.append([anchor, positive, negative])
    
    return triplets

# Generate triplets for training and validation
triplet_data_train = create_triplets(train_set, k=1)
triplet_data_val = create_triplets(val_set, k=1)

In [18]:
triplet_data_train[0]
tiplet_data_test = create_triplets(test_set, k=1)

In [19]:
from datasets import Dataset, DatasetDict

# Define column names
columns = ["anchor", "positive", "negative"]




# Transform data into dictionaries
train_dict = {col: [row[i] for row in triplet_data_train] for i, col in enumerate(columns)}
val_dict = {col: [row[i] for row in triplet_data_val] for i, col in enumerate(columns)}
test_dict = {col: [row[i] for row in tiplet_data_test] for i, col in enumerate(columns)}

# Create individual datasets using the datasets module directly
train_dataset = Dataset.from_dict(train_dict)
val_dataset = Dataset.from_dict(val_dict)
test_dataset = Dataset.from_dict(test_dict)

# Combine into a DatasetDict
dataset_dict = DatasetDict({
    "train": train_dataset,
    "validation": val_dataset,
    "test": test_dataset
})



In [20]:
dataset_dict.save_to_disk("./Dataset/triplet_datasets")

Saving the dataset (1/1 shards): 100%|██████████| 46030/46030 [00:00<00:00, 1252896.37 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 5750/5750 [00:00<00:00, 256806.88 examples/s]


Saving the dataset (1/1 shards): 100%|██████████| 5760/5760 [00:00<00:00, 616180.14 examples/s]


## ***Load triplet datasets***

In [1]:
from datasets import load_from_disk

triplets_dict = load_from_disk("./Dataset/triplet_datasets/")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
triplets_dict

DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 4603
    })
    validation: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 575
    })
    test: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 576
    })
})

In [3]:
train_dataset = triplets_dict["train"]
eval_dataset = triplets_dict["validation"]
test_dataset = triplets_dict["test"]

In [4]:
len(eval_dataset)

575

## ***Evaluating using Recall@k and MRR***

In [None]:
test_set['cleaned_meme_captions']
test_set['cleaned_OCR']

In [13]:
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def evaluate_top_k(captions, ocr_texts, model_path, k=5):
    """
    Evaluates the model's performance in retrieving OCR text given captions using top-k retrieval.
    
    Parameters:
    captions (list of str): List of caption texts.
    ocr_texts (list of str): List of OCR texts corresponding to each caption.
    model_path (str): Path to the fine-tuned SentenceTransformer model.
    k (int): Number of top candidates to consider for retrieval.
    
    Returns:
    dict: A dictionary containing Recall@k, MRR, and MAP@k.
    """
    # Load the fine-tuned model
    model = SentenceTransformer(model_path)

    # Encode captions and OCR texts
    caption_embeddings = model.encode(captions, convert_to_tensor=True).cpu().numpy()
    ocr_embeddings = model.encode(ocr_texts, convert_to_tensor=True).cpu().numpy()

    # Initialize counters for metrics
    top_k_hits = 0  # For Recall@k
    reciprocal_ranks = []
 

    # Evaluate each caption
    for idx, caption_embedding in enumerate(caption_embeddings):
        # Compute cosine similarity with all OCR embeddings
        similarities = cosine_similarity([caption_embedding], ocr_embeddings).flatten()

        # Get indices of the top-k highest similarity scores
        top_k_indices = np.argsort(similarities)[-k:][::-1]  # Sort in descending order

        # Check if the correct OCR is in the top-k
        if idx in top_k_indices:
            top_k_hits += 1  # Count for Recall@k

            # Compute reciprocal rank for MRR
            rank = np.where(top_k_indices == idx)[0][0] + 1  # 1-based rank
            reciprocal_ranks.append(1 / rank)

            
        else:
            reciprocal_ranks.append(0)  # No relevant item in top-k

    # Calculate Recall@k, MRR, and MAP@k
    recall_at_k = top_k_hits / len(captions)
    mrr = np.mean(reciprocal_ranks)

    return {
        "recall@k": recall_at_k,
        "mrr": mrr,
    }


# # Example lists of captions and OCR texts
# captions = test_set['cleaned_meme_captions'] # List of captions
# ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
# model_path = "models/mpnet-base-caption-ocr-triplet/final2"



# metrics = evaluate_top_k(captions, ocr_texts, model_path, k=5)

# print(f"Recall@5: {metrics['recall@k'] * 100:.2f}%")
# print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
# print(f"Mean Average Precision at 5 (MAP@5): {metrics['map@k']:.4f}")


### Evaluate Vanilla Sentence Transformer

In [14]:

captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "all-mpnet-base-v2"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=1)

print(f"Recall@1: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")




Recall@1: 54.34%
Mean Reciprocal Rank (MRR): 0.5434


In [21]:
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "all-mpnet-base-v2"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=5)

print(f"Recall@5: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")




Recall@5: 71.35%
Mean Reciprocal Rank (MRR): 0.6064


In [23]:
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "all-mpnet-base-v2"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=10)

print(f"Recall@10: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")




Recall@10: 76.39%
Mean Reciprocal Rank (MRR): 0.6132


## ***Training sentence transformer model using Triplet Loss***

In [4]:
from sentence_transformers.evaluation import TripletEvaluator


from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses
)

from sentence_transformers.evaluation import TripletEvaluator


import torch

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")



model = SentenceTransformer("all-mpnet-base-v2").to(device)


loss = losses.TripletLoss(model=model)


dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="caption-ocr-dev",
)
dev_evaluator(model)

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/caption-ocr-triplet",
    # Optional training parameters:
    num_train_epochs=5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-7,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    # run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# 8. Save the trained model
model.save_pretrained("models/mpnet-base-caption-ocr-triplet/final2")


Using device: cuda




Step,Training Loss,Validation Loss,Caption-ocr-dev Cosine Accuracy,Caption-ocr-dev Dot Accuracy,Caption-ocr-dev Manhattan Accuracy,Caption-ocr-dev Euclidean Accuracy,Caption-ocr-dev Max Accuracy
100,4.7384,4.707828,0.958261,0.041739,0.961739,0.958261,0.961739
200,4.721,4.694885,0.958261,0.041739,0.963478,0.958261,0.963478
300,4.7148,4.684834,0.963478,0.036522,0.961739,0.963478,0.963478
400,4.7036,4.677135,0.966957,0.033043,0.961739,0.966957,0.966957
500,4.6977,4.671768,0.966957,0.033043,0.961739,0.966957,0.966957
600,4.6983,4.668466,0.966957,0.033043,0.961739,0.966957,0.966957
700,4.692,4.667116,0.966957,0.033043,0.961739,0.966957,0.966957


                                                                             

In [5]:
# (Optional) Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="mpnet-caption-ocr-test",
)
test_evaluator(model)

{'mpnet-caption-ocr-test_cosine_accuracy': 0.9618055555555556,
 'mpnet-caption-ocr-test_dot_accuracy': 0.03819444444444445,
 'mpnet-caption-ocr-test_manhattan_accuracy': 0.953125,
 'mpnet-caption-ocr-test_euclidean_accuracy': 0.9618055555555556,
 'mpnet-caption-ocr-test_max_accuracy': 0.9618055555555556}

### Evaluate the model fine-tuned with Triplet Loss

In [13]:
# Example lists of captions and OCR texts
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "models/mpnet-base-caption-ocr-triplet/final2"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=1)

print(f"Recall@1: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")

Recall@1: 56.77%
Mean Reciprocal Rank (MRR): 0.5677


In [11]:
# Example lists of captions and OCR texts
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "models/mpnet-base-caption-ocr-triplet/final2"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=5)

print(f"Recall@5: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")

Recall@5: 76.04%
Mean Reciprocal Rank (MRR): 0.6437


In [12]:
# Example lists of captions and OCR texts
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "models/mpnet-base-caption-ocr-triplet/final2"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=10)

print(f"Recall@10: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")

Recall@10: 79.86%
Mean Reciprocal Rank (MRR): 0.6487


## ***Train using Multiple Negative Ranking Loss***

In [5]:
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator



import torch

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer("all-mpnet-base-v2").to(device)


# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-all-captions-triplet-8",
    # Optional training parameters:
    num_train_epochs=5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=8,
    logging_steps=100,
    load_best_model_at_end=True,  # Load the best model at the end of training
    metric_for_best_model="eval_all-captions-ocr-dev-8_max_accuracy",  # Choose the metric to monitor (e.g., "accuracy")
    greater_is_better=True,  # Set to False if lower values are better (e.g., for loss)

    run_name="mpnet-base-all-captions-ocr-triplet-8",  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="all-captions-ocr-dev-8",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()


# 8. Save the trained model
model.save_pretrained("models/mpnet-base-all-caption-ocr-triplet/final8")

Using device: cuda




Step,Training Loss,Validation Loss,All-captions-ocr-dev-8 Cosine Accuracy,All-captions-ocr-dev-8 Dot Accuracy,All-captions-ocr-dev-8 Manhattan Accuracy,All-captions-ocr-dev-8 Euclidean Accuracy,All-captions-ocr-dev-8 Max Accuracy
100,1.0525,0.81141,0.977391,0.022609,0.968696,0.977391,0.977391
200,0.8089,0.768663,0.975652,0.024348,0.973913,0.975652,0.975652
300,0.6442,0.760705,0.973913,0.026087,0.970435,0.973913,0.973913
400,0.5816,0.768794,0.975652,0.024348,0.966957,0.975652,0.975652
500,0.5143,0.760745,0.97913,0.02087,0.970435,0.97913,0.97913
600,0.4282,0.766439,0.972174,0.027826,0.968696,0.972174,0.972174
700,0.4196,0.766241,0.972174,0.027826,0.968696,0.972174,0.972174


                                                                             

In [6]:
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses
)

from sentence_transformers.evaluation import TripletEvaluator


model = SentenceTransformer("models/mpnet-base-all-caption-ocr-triplet/final8")
# (Optional) Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="mpnet-caption-ocr-test",
)
test_evaluator(model)

{'mpnet-caption-ocr-test_cosine_accuracy': 0.9652777777777778,
 'mpnet-caption-ocr-test_dot_accuracy': 0.034722222222222224,
 'mpnet-caption-ocr-test_manhattan_accuracy': 0.9670138888888888,
 'mpnet-caption-ocr-test_euclidean_accuracy': 0.9652777777777778,
 'mpnet-caption-ocr-test_max_accuracy': 0.9670138888888888}

### Evaluating the model fine-tuned with Multiple Negative Ranking Loss

In [14]:
# Example lists of captions and OCR texts
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "models/mpnet-base-all-caption-ocr-triplet/final8"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=1)

print(f"Recall@1: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")


Recall@1: 67.19%
Mean Reciprocal Rank (MRR): 0.6719


In [15]:
# Example lists of captions and OCR texts
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "models/mpnet-base-all-caption-ocr-triplet/final8"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=5)

print(f"Recall@5: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")


Recall@5: 79.69%
Mean Reciprocal Rank (MRR): 0.7198


In [16]:
# Example lists of captions and OCR texts
captions = test_set['cleaned_meme_captions'] # List of captions
ocr_texts = test_set['cleaned_OCR']  # List of OCR texts corresponding to captions
model_path = "models/mpnet-base-all-caption-ocr-triplet/final8"

metrics = evaluate_top_k(captions, ocr_texts, model_path, k=10)

print(f"Recall@10: {metrics['recall@k'] * 100:.2f}%")
print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")


Recall@10: 84.20%
Mean Reciprocal Rank (MRR): 0.7258
