In [55]:
import random
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from torch.nn.utils import clip_grad_norm_

from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [4]:
model_name = "intfloat/multilingual-e5-small"
model = SentenceTransformer(model_name)


In [17]:
class LinearAdapter(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        # Create a linear layer of size input_dim in both ends. This will match our original embedding 
        self.linear: nn.Linear = nn.Linear(input_dim, input_dim) 

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

In [103]:
class TripletDataset(Dataset):
    def __init__(self, data: pd.DataFrame, base_model: SentenceTransformer):
        self.data = data
        self.base_model = base_model

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        query = item['anchor']
        positive = item['positive']
        negative = item['negative']
        
        query_emb = self.base_model.encode(query, convert_to_tensor=True)
        positive_emb = self.base_model.encode(positive, convert_to_tensor=True)
        negative_emb = self.base_model.encode(negative, convert_to_tensor=True)
        
        return query_emb, positive_emb, negative_emb

In [104]:
df_train = pd.read_json("data/triplet_data_train.json")
df_test = pd.read_json("data/triplet_data_test.json")

In [105]:
dataset_train = TripletDataset(df_train, model)
dataset_test = TripletDataset(df_test, model)

In [37]:
def get_linear_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int) -> LRScheduler:
    def lr_lambda(current_step: int) -> float:
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)


In [129]:
def train(
    base_model: SentenceTransformer,
    dataset_train: TripletDataset,
    dataset_test: TripletDataset,
    epochs: int,
    batch_size: int,
    learning_rate: float,
    warmup_steps: int,
    max_grad_norm: float,
    margin: int,
    save_every_epoch: int
) -> LinearAdapter:
    
    device = "cpu"
    if torch.mps.is_available():
        device = "mps"
    if torch.cuda.is_available():
        device = "cuda"
    
    adapter = LinearAdapter(base_model.get_sentence_embedding_dimension()).to(device)

    triplet_loss = nn.TripletMarginLoss()
    optimizer = AdamW(adapter.parameters(), lr=learning_rate)
    
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    
    dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
    
    total_steps = len(dataloader_train) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
    
    with tqdm(total=epochs, desc="Training") as pbar:
        for epoch in range(epochs):
            total_loss_train = 0
            total_loss_test = 0
            
            for batch in dataloader_train:
                query_emb, positive_emb, negative_emb = [x.to(device) for x in batch]
                
                # Forward pass
                adapted_query_emb = adapter(query_emb)
                
                # Compute loss
                train_loss = triplet_loss(adapted_query_emb, positive_emb, negative_emb)
                
                # Backward pass and optimization
                optimizer.zero_grad()
                train_loss.backward()
                
                # Gradient clipping
                clip_grad_norm_(adapter.parameters(), max_grad_norm)
                
                optimizer.step()
                scheduler.step()
                
                total_loss_train += train_loss.item()
                
            for batch in dataloader_train:
                query_emb, positive_emb, negative_emb = [x.to(device) for x in batch]
                
                # Forward pass
                adapted_query_emb = adapter(query_emb)
                
                # Compute loss
                test_loss = triplet_loss(adapted_query_emb, positive_emb, negative_emb)
                
                total_loss_test += test_loss.item()
            pbar.update(1)
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {total_loss_train/len(dataloader_train):.4f}, Test Loss: {total_loss_test/len(dataloader_test):.4f}")
            if (epoch + 1) % save_every_epoch == 0:
                save_dict = {
                    'adapter_state_dict': adapter.state_dict(),
                    'adapter_kwargs': {
                        'epochs': (epoch + 1),
                        'batch_size': batch_size,
                        'learning_rate': learning_rate,
                        'warmup_steps': warmup_steps,
                        'max_grad_norm': max_grad_norm,
                        'margin': margin,
                    }
                }
                torch.save(save_dict, f"models/adapter_{epoch}.pth")

    save_dict = {
        'adapter_state_dict': adapter.state_dict(),
        'adapter_kwargs': {
            'epochs': (epoch + 1),
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'warmup_steps': warmup_steps,
            'max_grad_norm': max_grad_norm,
            'margin': margin,
        }
    }
    torch.save(save_dict, f"models/adapter_{epoch}_final.pth")
    return adapter
    

In [None]:
adapter_kwargs = {
    'epochs': 40,
    'batch_size': 128,
    'learning_rate': 2e-4,
    'warmup_steps': 60,
    'max_grad_norm': 1.0,
    'margin': 1.0,
    'save_every_epoch': 5
}
trained_adapter = train(model, dataset_train, dataset_test, **adapter_kwargs)


Training:   2%|█▎                                                | 1/40 [04:45<3:05:27, 285.33s/it]

Epoch 1/40, Train Loss: 0.9687, Test Loss: 3.5269


Training:   5%|██▌                                               | 2/40 [12:55<4:17:01, 405.82s/it]

Epoch 2/40, Train Loss: 0.8301, Test Loss: 2.9346


Training:   8%|███▊                                              | 3/40 [17:32<3:34:04, 347.15s/it]

Epoch 3/40, Train Loss: 0.7361, Test Loss: 2.8066


Training:  10%|█████                                             | 4/40 [22:08<3:11:20, 318.89s/it]

Epoch 4/40, Train Loss: 0.7209, Test Loss: 2.7830


Training:  12%|██████▎                                           | 5/40 [26:44<2:57:03, 303.52s/it]

Epoch 5/40, Train Loss: 0.7185, Test Loss: 2.7741


Training:  15%|███████▌                                          | 6/40 [31:20<2:46:40, 294.14s/it]

Epoch 6/40, Train Loss: 0.7161, Test Loss: 2.7710


Training:  18%|████████▊                                         | 7/40 [35:48<2:37:04, 285.59s/it]

Epoch 7/40, Train Loss: 0.7148, Test Loss: 2.7659


Training:  20%|██████████                                        | 8/40 [40:20<2:29:58, 281.20s/it]

Epoch 8/40, Train Loss: 0.7142, Test Loss: 2.7642


Training:  22%|███████████▎                                      | 9/40 [44:55<2:24:21, 279.39s/it]

Epoch 9/40, Train Loss: 0.7127, Test Loss: 2.7596


Training:  25%|████████████▎                                    | 10/40 [49:29<2:18:50, 277.69s/it]

Epoch 10/40, Train Loss: 0.7124, Test Loss: 2.7570


Training:  28%|█████████████▍                                   | 11/40 [53:57<2:12:49, 274.81s/it]

Epoch 11/40, Train Loss: 0.7120, Test Loss: 2.7563


Training:  30%|██████████████▋                                  | 12/40 [58:30<2:07:51, 273.98s/it]

Epoch 12/40, Train Loss: 0.7105, Test Loss: 2.7529


Training:  32%|███████████████▎                               | 13/40 [1:03:03<2:03:13, 273.82s/it]

Epoch 13/40, Train Loss: 0.7105, Test Loss: 2.7502


Training:  35%|████████████████▍                              | 14/40 [1:07:27<1:57:25, 270.96s/it]

Epoch 14/40, Train Loss: 0.7098, Test Loss: 2.7499


Training:  38%|█████████████████▋                             | 15/40 [1:11:53<1:52:14, 269.36s/it]

Epoch 15/40, Train Loss: 0.7090, Test Loss: 2.7451


Training:  40%|██████████████████▊                            | 16/40 [1:16:22<1:47:45, 269.38s/it]

Epoch 16/40, Train Loss: 0.7087, Test Loss: 2.7482


Training:  42%|███████████████████▉                           | 17/40 [1:20:52<1:43:16, 269.41s/it]

Epoch 17/40, Train Loss: 0.7078, Test Loss: 2.7432


Training:  45%|█████████████████████▏                         | 18/40 [1:25:15<1:38:08, 267.65s/it]

Epoch 18/40, Train Loss: 0.7072, Test Loss: 2.7410


Training:  48%|██████████████████████▎                        | 19/40 [1:29:43<1:33:40, 267.63s/it]

Epoch 19/40, Train Loss: 0.7073, Test Loss: 2.7385


Training:  50%|███████████████████████▌                       | 20/40 [1:34:14<1:29:34, 268.74s/it]

Epoch 20/40, Train Loss: 0.7072, Test Loss: 2.7378


Training:  52%|████████████████████████▋                      | 21/40 [1:38:39<1:24:45, 267.64s/it]

Epoch 21/40, Train Loss: 0.7064, Test Loss: 2.7370


Training:  55%|█████████████████████████▊                     | 22/40 [1:43:05<1:20:05, 266.99s/it]

Epoch 22/40, Train Loss: 0.7068, Test Loss: 2.7357


Training:  57%|███████████████████████████                    | 23/40 [1:47:34<1:15:50, 267.70s/it]

Epoch 23/40, Train Loss: 0.7062, Test Loss: 2.7330


Training:  60%|████████████████████████████▏                  | 24/40 [1:52:04<1:11:32, 268.26s/it]

Epoch 24/40, Train Loss: 0.7054, Test Loss: 2.7347


Training:  62%|█████████████████████████████▍                 | 25/40 [1:56:28<1:06:43, 266.92s/it]

Epoch 25/40, Train Loss: 0.7050, Test Loss: 2.7325


Training:  65%|██████████████████████████████▌                | 26/40 [2:00:55<1:02:19, 267.11s/it]

Epoch 26/40, Train Loss: 0.7049, Test Loss: 2.7306


Training:  68%|█████████████████████████████████                | 27/40 [2:05:27<58:09, 268.42s/it]

Epoch 27/40, Train Loss: 0.7047, Test Loss: 2.7308


Training:  70%|██████████████████████████████████▎              | 28/40 [2:09:52<53:28, 267.35s/it]

Epoch 28/40, Train Loss: 0.7050, Test Loss: 2.7290


In [None]:
def accuracy(anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> int:

    cos = torch.nn.CosineSimilarity(dim=-1)
    return 1 if cos(anchor, positive) > cos(anchor, negative) else 0

    
for epoch in range(0, 11, 5):
    if epoch == 0:
        accuracy_per = df_val.apply(
            lambda x: accuracy(
                torch.Tensor(x["embedded_anchor"]),
                torch.Tensor(x["embedded_positive"]),
                torch.Tensor(x["embedded_negative"])
            ), axis=1).sum() / df_val.shape[0]
    else:
        loaded_dict = torch.load(f"models/adapter_{epoch}.pth")
        adapter = LinearAdapter(model.get_sentence_embedding_dimension())  
        adapter.load_state_dict(loaded_dict['adapter_state_dict'])
        accuracy_per = df_val.apply(
            lambda x: accuracy(
                adapter(torch.Tensor(x["embedded_anchor"])),
                torch.Tensor(x["embedded_positive"]),
                torch.Tensor(x["embedded_negative"])
            ), axis=1).sum() / df_val.shape[0]

    print(f"Epoch: {epoch}, accuracy: {accuracy_per}")