In [None]:
import random
from typing import Callable, Iterable
import numpy as np
import pandas as pd
import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from transformers import BertTokenizerFast, BertModel
from transformers.optimization import get_linear_schedule_with_warmup
from scipy import stats
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
seed = 10

random.seed(seed)
np.random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

dataset = datasets.load_dataset("csv", data_files= 'paired_annotations.csv')
dataset['train'][0]
student_lengths, teacher_legnths = [], []

for data in tqdm(dataset['train']):
    #if random.random() > 0.9:
    student_lengths.append(len(data["student_text"]))
    teacher_legnths.append(len(data["teacher_text"]))

print(np.mean(student_lengths)); print(np.mean(teacher_legnths))

Generating train split: 0 examples [00:00, ? examples/s]

100%|██████████| 2348/2348 [00:00<00:00, 16510.46it/s]

45.49574105621806
158.80706984667802





In [None]:
%%time

MAX_LENGTH = 128

tokenized_premises = tokenizer([data["student_text"] for data in dataset['train']],
                               max_length=MAX_LENGTH, padding="max_length",
                               truncation=True, verbose=True)

tokenized_hypothesis = tokenizer([data["teacher_text"] for data in dataset['train']],
                                 max_length=MAX_LENGTH, padding="max_length",
                                 truncation=True, verbose=True)

CPU times: user 1.15 s, sys: 203 ms, total: 1.35 s
Wall time: 615 ms


In [None]:
class Dataset(Dataset):
    def __init__(self, premise_tokens: dict, hypothesis_tokens: dict, labels: Iterable[str]):
        self.premise_tokens = premise_tokens
        self.hypothesis_tokens = hypothesis_tokens
        self.labels = labels
        self._init_data()

    def _init_data(self) -> None:
        self.data = []
        for pt_ids, pt_am, ht_ids, ht_am, label in zip(
            self.premise_tokens["input_ids"], self.premise_tokens["attention_mask"],
            self.hypothesis_tokens["input_ids"], self.hypothesis_tokens["attention_mask"],
            self.labels
        ):
            data = {}
            data["premise_input_ids"] = torch.tensor(pt_ids, dtype=torch.long)
            data["premise_attention_mask"] = torch.tensor(pt_am, dtype=torch.long)
            data["hypothesis_input_ids"] = torch.tensor(ht_ids, dtype=torch.long)
            data["hypothesis_attention_mask"] = torch.tensor(ht_am, dtype=torch.long)
            data["label"] = torch.tensor(label, dtype=torch.long)
            self.data.append(data)

    def __getitem__(self, ix: int) -> dict[str, torch.tensor]:
        return self.data[ix]

    def __len__(self) -> int:
        return len(self.data)

dataset = Dataset(tokenized_premises, tokenized_hypothesis,
                           (data["focusing_question"] for data in dataset['train']))

In [None]:
train_ratio = 0.80
n_total = len(snli_dataset)
n_train = int(n_total * train_ratio)
n_val = n_total - n_train

train_dataset, val_dataset = random_split(snli_dataset, [n_train, n_val])

batch_size = 16  # mentioned in the paper
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizerFast
from datasets import Dataset

class Sbert(nn.Module):
    def __init__(self, max_length: int = 128, num_classes: int = 1):
        super().__init__()
        self.max_length = max_length
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
        # Change the output size to match the number of classes
        self.linear = nn.Linear(self.bert_model.config.hidden_size * 3, num_classes)

    def forward(self, data: Dataset) -> torch.tensor:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        premise_input_ids = data["premise_input_ids"].to(device)
        premise_attention_mask = data["premise_attention_mask"].to(device)
        hypothesis_input_ids = data["hypothesis_input_ids"].to(device)
        hypothesis_attention_mask = data["hypothesis_attention_mask"].to(device)

        out_premise = self.bert_model(premise_input_ids, premise_attention_mask)
        out_hypothesis = self.bert_model(hypothesis_input_ids, hypothesis_attention_mask)
        premise_embeds = out_premise.last_hidden_state
        hypothesis_embeds = out_hypothesis.last_hidden_state

        pooled_premise_embeds = mean_pool(premise_embeds, premise_attention_mask)
        pooled_hypotheses_embeds = mean_pool(hypothesis_embeds, hypothesis_attention_mask)

        # u, v, u*v, |u-v|
        """
        embeds = torch.cat([pooled_premise_embeds, pooled_hypotheses_embeds,
                            pooled_premise_embeds * pooled_hypotheses_embeds,
                            torch.abs(pooled_premise_embeds - pooled_hypotheses_embeds)],
                           dim=-1)

        # u, v, |u-v|
        """
        embeds = torch.cat([pooled_premise_embeds, pooled_hypotheses_embeds,
                            torch.abs(pooled_premise_embeds - pooled_hypotheses_embeds)],
                           dim=-1)
        """
        # u, v, u*v
        embeds = torch.cat([pooled_premise_embeds, pooled_hypotheses_embeds,
                            pooled_premise_embeds * pooled_hypotheses_embeds],
                           dim=-1)
        """

        return self.linear(embeds)

def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
    in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)

    return pool



In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizerFast
from datasets import Dataset

class Sbert(nn.Module):
    def __init__(self, max_length: int = 128, num_classes: int = 1):
        super().__init__()
        self.max_length = max_length
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
        # Change the output size to match the number of classes
        self.linear = nn.Linear(self.bert_model.config.hidden_size * 3, num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, data: Dataset) -> torch.tensor:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        premise_input_ids = data["premise_input_ids"].to(device)
        premise_attention_mask = data["premise_attention_mask"].to(device)
        hypothesis_input_ids = data["hypothesis_input_ids"].to(device)
        hypothesis_attention_mask = data["hypothesis_attention_mask"].to(device)

        out_premise = self.bert_model(premise_input_ids, premise_attention_mask)
        out_hypothesis = self.bert_model(hypothesis_input_ids, hypothesis_attention_mask)
        premise_embeds = out_premise.last_hidden_state
        hypothesis_embeds = out_hypothesis.last_hidden_state

        pooled_premise_embeds = mean_pool(premise_embeds, premise_attention_mask)
        pooled_hypotheses_embeds = mean_pool(hypothesis_embeds, hypothesis_attention_mask)

        # u, v, |u-v|

        embeds = torch.cat([pooled_premise_embeds, pooled_hypotheses_embeds,
                            torch.abs(pooled_premise_embeds - pooled_hypotheses_embeds)],
                           dim=-1)
        output = self.linear(embeds)
        output = self.sigmoid(output)  # Apply sigmoid activation
        binary_output = torch.round(output)  # Round to the nearest integer (0 or 1)

        return binary_output

def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
    in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)

    return pool



In [None]:
model = Sbert()
#  optimizer, lr, num_warmup steps have been picked from the paper
#optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5, weight_decay=0.01)

total_steps = len(train_dataset) // batch_size
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps - warmup_steps)

#loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = torch.nn.BCEWithLogitsLoss()

In [None]:
model = model.to(device)

In [None]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from typing import Callable
import torch
import numpy as np
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_train_step_fn(
    model: torch.nn.Module, optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LambdaLR, loss_fn: torch.nn.BCEWithLogitsLoss
) -> Callable[[torch.tensor, torch.tensor], tuple[float, np.ndarray, np.ndarray]]:

    def train_step_fn(x: torch.tensor, y: torch.tensor) -> tuple[float, np.ndarray, np.ndarray]:
        model.train()
        yhat = model(x)
        loss = loss_fn(yhat.squeeze(), y.float())
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        #print(yhat)
        # Convert predictions to labels
        #preds = torch.round(torch.sigmoid(yhat)).cpu().numpy()
        preds = torch.round(torch.sigmoid(yhat)).detach().cpu().numpy()
        #print(preds)
        return loss.item(), preds, y.cpu().numpy()

    return train_step_fn

def get_val_step_fn(
    model: torch.nn.Module, loss_fn: torch.nn.BCEWithLogitsLoss
) -> Callable[[torch.tensor, torch.tensor], tuple[float, np.ndarray, np.ndarray]]:

    def val_step_fn(x: torch.tensor, y: torch.tensor) -> tuple[float, np.ndarray, np.ndarray]:
        model.eval()
        yhat = model(x)
        loss = loss_fn(yhat.squeeze(), y.float())

        # Convert predictions to labels
        #preds = torch.round(torch.sigmoid(yhat)).cpu().numpy()
        preds = torch.round(torch.sigmoid(yhat)).detach().cpu().numpy()
        return loss.item(), preds, y.cpu().numpy()

    return val_step_fn

def mini_batch(
    dataloader: DataLoader,
    step_fn: Callable[[torch.tensor, torch.tensor], tuple[float, np.ndarray, np.ndarray]],
    is_training: bool = True
) -> tuple[float, float, float, float]:

    mini_batch_losses = []
    all_preds = []
    all_labels = []

    if is_training:
        print("\nTraining ...")
    else:
        print("\nValidating ...")
    n_steps = len(dataloader)
    for i, data in enumerate(dataloader):
        x, y = data, data["label"].to(device)

        loss, preds, labels = step_fn(x, y)
        mini_batch_losses.append(loss)
        all_preds.extend(preds)
        all_labels.extend(labels)

        if i % (batch_size * 100) == 0:
            print(f"step {i:>5}/{n_steps}, loss = {loss: .5f}")

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return np.mean(mini_batch_losses), mini_batch_losses, accuracy, precision, recall, f1, all_preds, all_labels

n_epochs = 10  # mentioned in the paper

train_step_fn = get_train_step_fn(model, optimizer, scheduler, loss_fn)
val_step_fn = get_val_step_fn(model, loss_fn)

# Initialize lists to store metrics
train_losses, train_mini_batch_losses, train_accuracies, train_precisions, train_recalls, train_f1s = [], [], [], [], [], []
val_losses, val_mini_batch_losses, val_accuracies, val_precisions, val_recalls, val_f1s, val_trues, val_preds = [], [], [], [], [], [], [], []

for epoch in range(1, n_epochs + 1):
    # Training
    train_loss,  _train_mini_batch_losses, train_accuracy, train_precision, train_recall, train_f1, _, _= mini_batch(train_dataloader, train_step_fn)
    train_losses.append(train_loss)
    train_mini_batch_losses += _train_mini_batch_losses
    train_accuracies.append(train_accuracy)
    train_precisions.append(train_precision)
    train_recalls.append(train_recall)
    train_f1s.append(train_f1)

    # Validation
    with torch.no_grad():
        val_loss, _val_mini_batch_losses, val_accuracy, val_precision, val_recall, val_f1, val_pred, val_true = mini_batch(val_dataloader, val_step_fn, is_training=False)
        val_losses.append(val_loss)
        val_mini_batch_losses += _val_mini_batch_losses
        val_accuracies.append(val_accuracy)
        val_precisions.append(val_precision)
        val_recalls.append(val_recall)
        val_f1s.append(val_f1)
        val_trues.append(val_true)
        val_preds.append(val_pred)

    # Optionally, print epoch results
    print(f'Epoch {epoch}/{n_epochs} - '
          f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}, Train F1: {train_f1:.4f} - '
          f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}')


In [None]:
import matplotlib
import matplotlib.pyplot as plt

window_size = 32

train_mb_running_loss = []
for i in range(len(train_mini_batch_losses)-window_size):
    train_mb_running_loss.append(np.mean(train_mini_batch_losses[i:i+window_size]))

val_mb_running_loss = []
for i in range(len(val_mini_batch_losses)-window_size):
    val_mb_running_loss.append(np.mean(val_mini_batch_losses[i:i+window_size]))

fix, ax = plt.subplots(figsize=(6, 3))
ax.plot(range(len(train_mb_running_loss)), train_mb_running_loss);

In [None]:
def encode(
    input_texts: list[str], tokenizer: BertTokenizerFast, model: BertModel, device: str = "cpu"
) -> torch.tensor:

    model.eval()
    tokenized_texts = tokenizer(input_texts, max_length=MAX_LENGTH,
                                padding='max_length', truncation=True, return_tensors="pt")
    token_embeds = model(tokenized_texts["input_ids"].to(device),
                         tokenized_texts["attention_mask"].to(device)).last_hidden_state
    pooled_embeds = mean_pool(token_embeds, tokenized_texts["attention_mask"].to(device))
    return pooled_embeds