In [1]:
from torch.optim import AdamW

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModel,
    get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split
import logging
import os
from tqdm import tqdm
import random

In [2]:
class ImprovedSimCSEDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128, use_hard_negatives=True):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.use_hard_negatives = use_hard_negatives

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])

        positive_text = self.augment_text(text)

        encoding_orig = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        encoding_pos = self.tokenizer(
            positive_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids_orig': encoding_orig['input_ids'].flatten(),
            'attention_mask_orig': encoding_orig['attention_mask'].flatten(),
            'input_ids_pos': encoding_pos['input_ids'].flatten(),
            'attention_mask_pos': encoding_pos['attention_mask'].flatten(),
            'text': text
        }

    def augment_text(self, text):
        return text

In [3]:
class SimCSEModel(nn.Module):
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2',
                 pooling='mean', projection_dim=None, temperature=0.05):
        super(SimCSEModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.pooling = pooling
        self.temperature = temperature
        self.hidden_size = self.encoder.config.hidden_size

        if projection_dim:
            self.projection = nn.Sequential(
                nn.Linear(self.hidden_size, projection_dim),
                nn.ReLU(),
                nn.Linear(projection_dim, projection_dim)
            )
            self.output_dim = projection_dim
        else:
            self.projection = None
            self.output_dim = self.hidden_size

        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.15)
        self.dropout3 = nn.Dropout(0.2)

    def forward(self, input_ids, attention_mask, return_embeddings=False, dropout_mask=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        if self.pooling == 'cls':
            embeddings = outputs.last_hidden_state[:, 0, :]
        elif self.pooling == 'mean':
            last_hidden = outputs.last_hidden_state
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
            sum_embeddings = torch.sum(last_hidden * mask_expanded, 1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            embeddings = sum_embeddings / sum_mask
        elif self.pooling == 'max':
            last_hidden = outputs.last_hidden_state
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
            last_hidden[mask_expanded == 0] = -1e9
            embeddings = torch.max(last_hidden, 1)[0]

        if self.projection:
            embeddings = self.projection(embeddings)

        if return_embeddings:
            return embeddings

        if dropout_mask is None:
            dropout_mask = torch.randint(0, 3, (embeddings.shape[0],))

        dropout_embeddings = embeddings.clone()
        for i, mask in enumerate(dropout_mask):
            if mask == 0:
                dropout_embeddings[i] = self.dropout1(embeddings[i])
            elif mask == 1:
                dropout_embeddings[i] = self.dropout2(embeddings[i])
            else:
                dropout_embeddings[i] = self.dropout3(embeddings[i])

        return dropout_embeddings

In [9]:
class SimCSEModel(nn.Module):
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2',
                 pooling='mean', projection_dim=None, temperature=0.05):
        super(SimCSEModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.pooling = pooling
        self.temperature = temperature
        self.hidden_size = self.encoder.config.hidden_size

        if projection_dim:
            self.projection = nn.Sequential(
                nn.Linear(self.hidden_size, projection_dim),
                nn.ReLU(),
                nn.Linear(projection_dim, projection_dim)
            )
            self.output_dim = projection_dim
        else:
            self.projection = None
            self.output_dim = self.hidden_size

        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.15)
        self.dropout3 = nn.Dropout(0.2)

    def forward(self, input_ids, attention_mask, return_embeddings=False, dropout_mask=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        if self.pooling == 'cls':
            embeddings = outputs.last_hidden_state[:, 0, :]
        elif self.pooling == 'mean':
            last_hidden = outputs.last_hidden_state
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
            sum_embeddings = torch.sum(last_hidden * mask_expanded, 1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            embeddings = sum_embeddings / sum_mask
        elif self.pooling == 'max':
            last_hidden = outputs.last_hidden_state
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
            last_hidden[mask_expanded == 0] = -1e9
            embeddings = torch.max(last_hidden, 1)[0]

        if self.projection:
            embeddings = self.projection(embeddings)

        if return_embeddings:
            return embeddings

        if dropout_mask is None:
            dropout_mask = torch.randint(0, 3, (embeddings.shape[0],))

        dropout_embeddings = embeddings.clone()
        for i, mask in enumerate(dropout_mask):
            if mask == 0:
                dropout_embeddings[i] = self.dropout1(embeddings[i])
            elif mask == 1:
                dropout_embeddings[i] = self.dropout2(embeddings[i])
            else:
                dropout_embeddings[i] = self.dropout3(embeddings[i])

        return dropout_embeddings

class SimCSETrainer:

    def __init__(self, model, tokenizer, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.temperature = model.temperature

    def compute_advanced_contrastive_loss(self, z1, z2, hard_negatives_weight=0.5):
        batch_size = z1.shape[0]

        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        sim_matrix = torch.matmul(z1, z2.T) / self.temperature

        labels = torch.arange(batch_size).to(self.device)

        loss_12 = F.cross_entropy(sim_matrix, labels)
        loss_21 = F.cross_entropy(sim_matrix.T, labels)
        loss = (loss_12 + loss_21) / 2

        if hard_negatives_weight > 0:
            mask = torch.eye(batch_size, device=self.device).bool()
            neg_sim = sim_matrix.masked_fill(mask, -float('inf'))

            k = min(5, batch_size - 1)
            hard_neg_sim, _ = torch.topk(neg_sim, k, dim=1)

            pos_sim = torch.diag(sim_matrix).unsqueeze(1)
            hard_neg_loss = -torch.log(
                torch.exp(pos_sim) / (torch.exp(pos_sim) + torch.exp(hard_neg_sim).sum(dim=1, keepdim=True))
            ).mean()

            loss = loss + hard_negatives_weight * hard_neg_loss

        return loss


    def train_epoch(self, dataloader, optimizer, scheduler, use_mixup=False):
        self.model.train()
        total_loss = 0

        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training")):
            input_ids_orig = batch['input_ids_orig'].to(self.device)
            attention_mask_orig = batch['attention_mask_orig'].to(self.device)
            input_ids_pos = batch['input_ids_pos'].to(self.device)
            attention_mask_pos = batch['attention_mask_pos'].to(self.device)

            batch_size = input_ids_orig.shape[0]

            dropout_mask1 = torch.randint(0, 3, (batch_size,))
            dropout_mask2 = torch.randint(0, 3, (batch_size,))

            z1 = self.model(input_ids_orig, attention_mask_orig, dropout_mask=dropout_mask1)
            z2 = self.model(input_ids_pos, attention_mask_pos, dropout_mask=dropout_mask2)

            if use_mixup and random.random() < 0.3:
                lam = np.random.beta(0.2, 0.2)
                indices = torch.randperm(batch_size)
                z1 = lam * z1 + (1 - lam) * z1[indices]
                z2 = lam * z2 + (1 - lam) * z2[indices]

            loss = self.compute_advanced_contrastive_loss(z1, z2)

            l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters())
            loss = loss + l2_reg
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        return total_loss / len(dataloader)

    def train(self, train_dataloader, val_dataloader, epochs=5, learning_rate=2e-5, warmup_ratio=0.1):

        encoder_param_ids = set(id(p) for p in self.model.encoder.parameters())
        encoder_params = []
        other_params = []

        for p in self.model.parameters():
            if id(p) in encoder_param_ids:
                encoder_params.append(p)
            else:
                other_params.append(p)

        optimizer = AdamW([
            {'params': encoder_params, 'lr': learning_rate * 0.1},
            {'params': other_params, 'lr': learning_rate}
        ], weight_decay=0.01)

        total_steps = len(train_dataloader) * epochs
        warmup_steps = int(warmup_ratio * total_steps)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        best_val_loss = float('inf')
        patience = 3
        patience_counter = 0

        for epoch in range(epochs):

            use_mixup = epoch >= 2
            train_loss = self.train_epoch(train_dataloader, optimizer, scheduler, use_mixup)

            val_loss = self.evaluate(val_dataloader)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.save_model("best_improved_simcse_model")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break

    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                input_ids_orig = batch['input_ids_orig'].to(self.device)
                attention_mask_orig = batch['attention_mask_orig'].to(self.device)
                input_ids_pos = batch['input_ids_pos'].to(self.device)
                attention_mask_pos = batch['attention_mask_pos'].to(self.device)
                z1 = self.model(input_ids_orig, attention_mask_orig, return_embeddings=True)
                z2 = self.model(input_ids_pos, attention_mask_pos, return_embeddings=True)

                loss = self.compute_advanced_contrastive_loss(z1, z2, hard_negatives_weight=0)
                total_loss += loss.item()

        return total_loss / len(dataloader)

    def generate_embeddings(self, texts, batch_size=32):
        self.model.eval()
        embeddings = []

        dataset = ImprovedSimCSEDataset(texts, self.tokenizer, use_hard_negatives=False)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Generating embeddings"):
                input_ids = batch['input_ids_orig'].to(self.device)
                attention_mask = batch['attention_mask_orig'].to(self.device)

                batch_embeddings_list = []
                for _ in range(3):
                    batch_embeddings = self.model(input_ids, attention_mask, return_embeddings=True)
                    batch_embeddings_list.append(batch_embeddings)

                final_embeddings = torch.stack(batch_embeddings_list).mean(dim=0)
                embeddings.append(final_embeddings.cpu().numpy())

        return np.vstack(embeddings)

    def save_model(self, path):
        os.makedirs(path, exist_ok=True)
        torch.save(self.model.state_dict(), os.path.join(path, "model.pt"))
        self.tokenizer.save_pretrained(path)

        config = {
            'pooling': self.model.pooling,
            'output_dim': self.model.output_dim,
            'temperature': self.temperature
        }
        torch.save(config, os.path.join(path, "config.pt"))

    def load_model(self, path):
        self.model.load_state_dict(torch.load(os.path.join(path, "model.pt")))

In [10]:
def load_data(csv_path):
    df = pd.read_csv(csv_path)

    texts = df['text'].dropna().astype(str).tolist()

    texts = [text.strip() for text in texts if text.strip()]
    return texts


In [11]:
def post_process_embeddings(embeddings, method='pca', n_components=256):
    from sklearn.decomposition import PCA
    from sklearn.preprocessing import StandardScaler

    if method == 'pca':
        scaler = StandardScaler()
        embeddings_scaled = scaler.fit_transform(embeddings)

        pca = PCA(n_components=n_components)
        embeddings_processed = pca.fit_transform(embeddings_scaled)

        return embeddings_processed, pca, scaler

    return embeddings, None, None

In [None]:
def main(csv_path, output_path):
    texts = load_data(csv_path)

    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
    model = SimCSEModel(
        model_name='sentence-transformers/all-MiniLM-L6-v2',
        pooling='mean',
        projection_dim=384,
        temperature=0.05
    )

    train_texts, val_texts = train_test_split(texts, test_size=0.2, random_state=42)

    train_dataset = ImprovedSimCSEDataset(train_texts, tokenizer)
    val_dataset = ImprovedSimCSEDataset(val_texts, tokenizer)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    trainer = SimCSETrainer(model, tokenizer)

    trainer.train(train_dataloader, val_dataloader, epochs=5)

    embeddings = trainer.generate_embeddings(texts)

    embeddings_processed, pca, scaler = post_process_embeddings(embeddings)

    np.save(output_path, embeddings_processed)

    return embeddings_processed, trainer

if __name__ == "__main__":
    csv_path = "bbc_encoded.csv"
    output_path = "simcse_embeddings.npy"

    embeddings, trainer = main(csv_path, output_path)

Training:   4%|▎         | 2/54 [00:30<12:50, 14.81s/it]