# SBERT Fine Tuning

## Imports

In [1]:
import torch
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import random
import pandas as pd
import numpy as np

In [None]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

## Functions

In [3]:
def load_dataset(file_path, text_column, label_column):
    df = pd.read_csv(file_path)
    return df[text_column].tolist(), df[label_column].tolist()

In [None]:

def create_training_data(texts, labels, num_samples=10000):
    """
    Function to create training data with random positive and negative pairs for sentence embeddings.
    This function assumes that positive pairs are two samples from the same label and negative pairs 
    are samples from different labels.
    
    Args:
    texts (list): List of text samples (documents).
    labels (list): List of corresponding labels for the text samples.
    num_samples (int): The number of training examples to generate. Default is 10,000.
    
    Returns:
    train_examples (list): List of InputExample instances for training.
    """
    # Create a dictionary where the key is the label and the value is a list of texts for that label.
    label_to_texts = {}
    for text, label in zip(texts, labels):
        label_to_texts.setdefault(label, []).append(text)

    train_examples = []
    unique_labels = list(label_to_texts.keys())

    # Shuffle the text samples within each label to ensure randomness
    for label in unique_labels:
        random.shuffle(label_to_texts[label]) 

    # Generate pairs for training
    for _ in range(num_samples):
        # Randomly choose a positive label and a negative label
        pos_label = random.choice(unique_labels)
        neg_label = random.choice([l for l in unique_labels if l != pos_label])

        # Shuffle again before selecting to avoid repeated sampling patterns
        random.shuffle(label_to_texts[pos_label])
        random.shuffle(label_to_texts[neg_label])

        # Randomly select two samples from the positive label for a positive pair
        pos_pair = random.sample(label_to_texts[pos_label], 2)

        # Randomly select one sample from the negative label for a negative sample
        neg_sample = random.choice(label_to_texts[neg_label])

        # Append both a positive pair and a negative pair to the training examples
        train_examples.append(InputExample(texts=pos_pair, label=1))  # Positive pair (label=1)
        train_examples.append(InputExample(texts=[pos_pair[0], neg_sample], label=0))  # Negative pair (label=0)

    return train_examples

## Load in data

In [None]:
# Example usage
file_path = 'dataset.csv'  # Replace with your dataset file
text_column = 'text'  # Replace with your text column name
label_column = 'label'  # Replace with your class label column name
texts, labels = load_dataset(file_path, text_column, label_column)

# Clean data if needed


## Perform Finetuning

In [7]:
# Split into train and test sets
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2, random_state=SEED)


In [8]:
train_data = create_training_data(train_texts, train_labels)

In [9]:
# Load pre-trained SBERT model
model = SentenceTransformer('all-MiniLM-L6-v2')

In [None]:
# Define loss function (contrastive learning)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
train_loss = losses.CosineSimilarityLoss(model) # Used since we are focusing on an embedding space


In [None]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=1,  
    warmup_steps=100,
    weight_decay=0.01,
    optimizer_params={'lr': 2e-5},
    max_grad_norm=0.5
    #scheduler 5 epochs
)


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.1747
1000,0.131


In [None]:
# Save fine-tuned model
model.save('sbert_finetuned')