In [1]:
import torch
import csv
from torch import nn
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

import random

#### Load dataset

In [None]:
dataset = []
with open("../../data/quora/quora_duplicate_questions.tsv", encoding="utf-8") as f:
    data = csv.DictReader(f, delimiter="\t")
    for idx, row in enumerate(data):
        dataset.append(
            {
                "questions": [
                    {"id": row["qid1"], "text": row["question1"]},
                    {"id": row["qid2"], "text": row["question2"]},
                ],
                "is_duplicate": row["is_duplicate"] == "1",
            }
        )
dataset[0]

In [None]:
# Prepare triplet data from the dataset
sentences = [
    (item["questions"][0]["text"], item["questions"][1]["text"])
    for item in tqdm(dataset)
    if item["is_duplicate"]
]

# train test split
sentences, test_sentences = train_test_split(sentences, test_size=0.2, random_state=42)
sentences, val_sentences = train_test_split(sentences, test_size=0.2, random_state=42)

### Prepare Dataset and Dataloader

In [4]:
MODEL_PATH = "prajjwal1/bert-tiny"

In [5]:
class TripletDataset(Dataset):
    def __init__(self, sentences, tokenizer, max_length=128):
        self.sentences = sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        anchor, positive = self.sentences[idx]
        negative = self._get_negative_sample(anchor, positive)
        return (
            self.tokenizer(
                anchor,
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
            ),
            self.tokenizer(
                positive,
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
            ),
            self.tokenizer(
                negative,
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
            ),
        )

    def _get_negative_sample(self, anchor, positive):
        """Get a random sentence that is not the positive pair."""
        while True:
            negative = random.choice(self.sentences)[0]
            if negative != anchor and negative != positive:
                return negative

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [7]:
train_ds = TripletDataset(sentences, tokenizer)
val_ds = TripletDataset(val_sentences, tokenizer)
test_ds = TripletDataset(test_sentences, tokenizer)

In [8]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_ds = DataLoader(val_ds, batch_size=32, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)

#### Define the model

In [9]:
class TripletTransformer(nn.Module):
    def __init__(self, model_name):
        super(TripletTransformer, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, anchor, positive, negative):
        anchor_embeddings = self.encoder(**anchor).pooler_output
        positive_embeddings = self.encoder(**positive).pooler_output
        negative_embeddings = self.encoder(**negative).pooler_output

        return anchor_embeddings, positive_embeddings, negative_embeddings


class TripletTransformer(nn.Module):
    def __init__(self, model_name):
        super(TripletTransformer, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, anchor, positive, negative):
        def mean_pooling(model_output, attention_mask):
            token_embeddings = model_output.last_hidden_state
            # Perform mean pooling
            input_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            )
            return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
                input_mask_expanded.sum(1), min=1e-9
            )

        anchor_output = self.encoder(**anchor)
        anchor_embeddings = mean_pooling(anchor_output, anchor["attention_mask"])

        positive_output = self.encoder(**positive)
        positive_embeddings = mean_pooling(positive_output, positive["attention_mask"])

        negative_output = self.encoder(**negative)
        negative_embeddings = mean_pooling(negative_output, negative["attention_mask"])

        return anchor_embeddings, positive_embeddings, negative_embeddings

#### Instantiate Model and Define Loss & Optimizer

In [10]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.loss_fn = nn.TripletMarginLoss(margin=self.margin, p=2)

    def forward(self, anchor, positive, negative):
        return self.loss_fn(anchor, positive, negative)

In [None]:
model = TripletTransformer(model_name=MODEL_PATH)

loss_fn = TripletLoss(margin=1.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

#### Start training

In [12]:
def train_epoch(model, dataloader, loss_fn, optimzer, device):
    model.train()
    train_loss = 0
    for i, batch in enumerate(dataloader):
        anchor, positive, negative = batch

        anchor = {key: val.squeeze(1) for key, val in anchor.items()}
        positive = {key: val.squeeze(1) for key, val in positive.items()}
        negative = {key: val.squeeze(1) for key, val in negative.items()}

        anchor_embeddings, positive_embeddings, negative_embeddings = model(
            anchor, positive, negative
        )

        loss = loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    epoch_loss = train_loss / len(dataloader)
    return epoch_loss

In [13]:
def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            anchor, positive, negative = batch

            anchor = {key: val.squeeze(1) for key, val in anchor.items()}
            positive = {key: val.squeeze(1) for key, val in positive.items()}
            negative = {key: val.squeeze(1) for key, val in negative.items()}

            anchor_embeddings, positive_embeddings, negative_embeddings = model(
                anchor, positive, negative
            )

            loss = loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings)
            val_loss += loss.item()
    epoch_loss = val_loss / len(dataloader)
    return epoch_loss

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
n_epochs = 100
verbose = 2
train_losses = []
val_losses = []

for epoch in tqdm(range(n_epochs)):
    train_loss = train_epoch(model, train_dl, loss_fn, optimizer, device)
    val_loss, val_accuracy = evaluate(model, val_dl, loss_fn, device)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    if verbose and (epoch + 1) % verbose == 0:
        print(
            f"Epoch {epoch + 1}/{n_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f}"
        )

print("Training complete.")