In [1]:
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer

In [2]:
device = 'cuda'

# Task 1

In [112]:
model = SentenceTransformer("all-MiniLM-L6-v2")
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

In [4]:
sentences = [
    "Her brother is a king.",
    "His sister is a queen.",
    "They can't dance."
]

In [11]:
embeddings = model.encode(sentences)
embeddings

array([[ 1.3426487e-01,  6.8651125e-02, -7.1860326e-05, ...,
         5.4440424e-02, -8.0796704e-04,  4.2172940e-03],
       [ 1.3438945e-01,  6.8506449e-02, -2.4290488e-04, ...,
         5.4613810e-02, -6.0740544e-04,  3.9655287e-03],
       [ 1.3048099e-01,  7.0870683e-02, -3.6374270e-03, ...,
         5.7554800e-02,  8.8193323e-03,  2.7406658e-03]],
      shape=(3, 384), dtype=float32)

In [6]:
model.similarity(embeddings, embeddings)

tensor([[1.0000, 0.8068, 0.1882],
        [0.8068, 1.0000, 0.1774],
        [0.1882, 0.1774, 1.0000]])

## Alternative (Unfinished) Implementation

# Task 2

In [219]:
class ClassificationHead(nn.Module):
    def __init__(self, n):
        super(ClassificationHead, self).__init__()
        self.linear = nn.Linear(384, n)

    def forward(self, x):
        out = self.linear(x)
        return out

class SentenceTransformerWithHeads(nn.Module):
    def __init__(self, num_classes_a, num_classes_b):
        super(SentenceTransformerWithHeads, self).__init__()
        self.st = SentenceTransformer("all-MiniLM-L6-v2")
        self.num_classes_a = num_classes_a
        self.num_classes_b = num_classes_b
        self.head_a = ClassificationHead(n=num_classes_a)
        self.head_b = ClassificationHead(n=num_classes_b)

    def forward(self, x, head):
        features = torch.tensor(self.st.encode(x)).to(device)
        if head == 'A':
            out = self.head_a(features)
        elif head == 'B':
            out = self.head_b(features)
        else:
            raise ValueError("Please specify 'A' or 'B' for head")
        return out

## Task 4

In [227]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from torcheval.metrics import MulticlassAccuracy, MulticlassF1Score, MulticlassPrecision, MulticlassRecall
from itertools import chain

In [221]:
train_dataset_a = load_dataset("fancyzhx/ag_news", split="train").with_format("torch")
test_dataset_a = load_dataset("fancyzhx/ag_news", split="test").with_format("torch")

In [222]:
batch_size_a = 2000
train_dataloader_a = DataLoader(train_dataset_a, batch_size=batch_size_a)
test_dataloader_a = DataLoader(test_dataset_a, batch_size=batch_size_a)
len(train_dataloader_a)

60

In [223]:
train_dataset_b = load_dataset("zeroshot/twitter-financial-news-sentiment", split="train").with_format("torch")
test_dataset_b = load_dataset("zeroshot/twitter-financial-news-sentiment", split="validation").with_format("torch")

In [224]:
batch_size_b = 160
train_dataloader_b = DataLoader(train_dataset_b, batch_size=batch_size_b)
test_dataloader_b = DataLoader(test_dataset_b, batch_size=batch_size_b)
len(train_dataloader_b)

60

In [238]:
def train(dataloader_a, dataloader_b, model, loss_fn, optimizer):
    num_batches = len(dataloader_a)
    model.train()
    for i, (batch_a, batch_b) in enumerate(zip(dataloader_a, dataloader_b)):
        X_a = batch_a['text']
        y_a = batch_a['label'].to(device)
        X_b = batch_b['text']
        y_b = batch_b['label'].to(device)
        
        # Compute prediction error
        pred_a = model(X_a, 'A')
        loss_a = loss_fn(pred_a, y_a)
        pred_b = model(X_b, 'B')
        loss_b = loss_fn(pred_b, y_b)
        loss = loss_a + loss_b

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 5 == 4:
            loss = loss.item()
            print(f"loss: {loss:>7f}  [{i+1:>5d}/{num_batches:>5d}]")

def test(dataloader_a, dataloader_b, model, loss_fn):
    num_batches = len(dataloader_a)
    model.eval()
    test_loss = 0
    
    accuracy = MulticlassAccuracy()
    f1_a = MulticlassF1Score(num_classes=model.num_classes_a, average=None)
    precision_a = MulticlassPrecision(num_classes=model.num_classes_a, average=None)
    recall_a = MulticlassRecall(num_classes=model.num_classes_a, average=None)

    f1_b = MulticlassF1Score(num_classes=model.num_classes_b, average=None)
    precision_b = MulticlassPrecision(num_classes=model.num_classes_b, average=None)
    recall_b = MulticlassRecall(num_classes=model.num_classes_b, average=None)
    
    with torch.no_grad():
        for batch_a, batch_b in zip(dataloader_a, dataloader_b):
            X_a = batch_a['text']
            y_a = batch_a['label'].to(device)
            pred_a = model(X_a, 'A')

            test_loss += loss_fn(pred_a, y_a).item()
            accuracy.update(pred_a, y_a)
            f1_a.update(pred_a, y_a)
            precision_a.update(pred_a, y_a)
            recall_a.update(pred_a, y_a)

            X_b = batch_b['text']
            y_b = batch_b['label'].to(device)
            pred_b = model(X_b, 'B')

            test_loss += loss_fn(pred_b, y_b).item()
            accuracy.update(pred_b, y_b)
            f1_b.update(pred_b, y_b)
            precision_b.update(pred_b, y_b)
            recall_b.update(pred_b, y_b)
            
    test_loss /= num_batches
    loss_str = f"Avg loss: {test_loss:>8f}"
    accuracy_str = f"Accuracy: {(100*accuracy.compute()):>0.1f}%"

    precision_str_a = f"Precision (A): {precision_a.compute()}"
    recall_str_a = f"Recall (A): {recall_a.compute()}"
    f1_str_a = f"F1-Score (A): {f1_a.compute()}"

    precision_str_b = f"Precision (B): {precision_b.compute()}"
    recall_str_b = f"Recall (B): {recall_b.compute()}"
    f1_str_b = f"F1-Score (B): {f1_b.compute()}"

    print(f"""Test Error:
    {loss_str}
    {accuracy_str}
    
    {precision_str_a}
    {recall_str_a}
    {f1_str_a}
    
    {precision_str_b}
    {recall_str_b}
    {f1_str_b}
    """)

In [241]:
model = SentenceTransformerWithHeads(num_classes_a=4, num_classes_b=3).to(device)
model.train()

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader_a, train_dataloader_b, model, loss_fn, optimizer)
    test(test_dataloader_a, test_dataloader_b, model, loss_fn)

Epoch 1
-------------------------------
loss: 2.488513  [    5/   60]
loss: 2.463017  [   10/   60]
loss: 2.446870  [   15/   60]
loss: 2.419731  [   20/   60]
loss: 2.386283  [   25/   60]
loss: 2.335083  [   30/   60]
loss: 2.325567  [   35/   60]
loss: 2.417871  [   40/   60]
loss: 2.217272  [   45/   60]
loss: 2.449257  [   50/   60]
loss: 2.140202  [   55/   60]
loss: 2.204949  [   60/   60]


        [1]]) classes have zero instances in both the predictions and the ground truth labels. Precision is still logged as zero.


Test Error:
    Avg loss: 2.179438
    Accuracy: 82.6%

    Precision (A): tensor([0.8182, 0.9154, 0.7548, 0.8764])
    Recall (A): tensor([0.8621, 0.9400, 0.8637, 0.6795])
    F1-Score (A): tensor([0.8396, 0.9276, 0.8056, 0.7655])

    Precision (B): tensor([0.0000, 0.0000, 0.7063])
    Recall (B): tensor([0., 0., 1.])
    F1-Score (B): tensor([0.0000, 0.0000, 0.8278])
    
Epoch 2
-------------------------------
loss: 2.165391  [    5/   60]
loss: 2.073323  [   10/   60]
loss: 2.057193  [   15/   60]
loss: 2.170320  [   20/   60]
loss: 2.102693  [   25/   60]
loss: 2.010205  [   30/   60]
loss: 2.080139  [   35/   60]
loss: 2.327700  [   40/   60]
loss: 1.900751  [   45/   60]
loss: 2.367872  [   50/   60]
loss: 1.828688  [   55/   60]
loss: 1.963374  [   60/   60]


        [1]]) classes have zero instances in both the predictions and the ground truth labels. Precision is still logged as zero.


Test Error:
    Avg loss: 1.942279
    Accuracy: 84.2%

    Precision (A): tensor([0.8604, 0.9137, 0.7802, 0.8650])
    Recall (A): tensor([0.8532, 0.9584, 0.8595, 0.7421])
    F1-Score (A): tensor([0.8568, 0.9355, 0.8179, 0.7989])

    Precision (B): tensor([0.0000, 0.0000, 0.7063])
    Recall (B): tensor([0., 0., 1.])
    F1-Score (B): tensor([0.0000, 0.0000, 0.8278])
    
Epoch 3
-------------------------------
loss: 1.941443  [    5/   60]
loss: 1.790649  [   10/   60]
loss: 1.777463  [   15/   60]
loss: 1.987257  [   20/   60]
loss: 1.893212  [   25/   60]
loss: 1.769475  [   30/   60]
loss: 1.895535  [   35/   60]
loss: 2.253718  [   40/   60]
loss: 1.669129  [   45/   60]
loss: 2.305464  [   50/   60]
loss: 1.606261  [   55/   60]
loss: 1.785047  [   60/   60]


        [1]]) classes have zero instances in both the predictions and the ground truth labels. Precision is still logged as zero.


Test Error:
    Avg loss: 1.767275
    Accuracy: 84.8%

    Precision (A): tensor([0.8728, 0.9150, 0.7964, 0.8567])
    Recall (A): tensor([0.8526, 0.9626, 0.8500, 0.7742])
    F1-Score (A): tensor([0.8626, 0.9382, 0.8223, 0.8134])

    Precision (B): tensor([0.0000, 0.0000, 0.7063])
    Recall (B): tensor([0., 0., 1.])
    F1-Score (B): tensor([0.0000, 0.0000, 0.8278])
    
Epoch 4
-------------------------------
loss: 1.778687  [    5/   60]
loss: 1.582820  [   10/   60]
loss: 1.573105  [   15/   60]
loss: 1.850649  [   20/   60]
loss: 1.735827  [   25/   60]
loss: 1.589973  [   30/   60]
loss: 1.754540  [   35/   60]
loss: 2.188859  [   40/   60]
loss: 1.498557  [   45/   60]
loss: 2.251021  [   50/   60]
loss: 1.445619  [   55/   60]
loss: 1.650118  [   60/   60]


        [1]]) classes have zero instances in both the predictions and the ground truth labels. Precision is still logged as zero.


Test Error:
    Avg loss: 1.635560
    Accuracy: 85.1%

    Precision (A): tensor([0.8795, 0.9173, 0.8022, 0.8522])
    Recall (A): tensor([0.8532, 0.9637, 0.8453, 0.7889])
    F1-Score (A): tensor([0.8662, 0.9399, 0.8232, 0.8193])

    Precision (B): tensor([0.0000, 0.0000, 0.7063])
    Recall (B): tensor([0., 0., 1.])
    F1-Score (B): tensor([0.0000, 0.0000, 0.8278])
    
Epoch 5
-------------------------------
loss: 1.657891  [    5/   60]
loss: 1.428973  [   10/   60]
loss: 1.421852  [   15/   60]
loss: 1.745885  [   20/   60]
loss: 1.615106  [   25/   60]
loss: 1.454816  [   30/   60]
loss: 1.644817  [   35/   60]
loss: 2.130281  [   40/   60]
loss: 1.371516  [   45/   60]
loss: 2.200944  [   50/   60]
loss: 1.327703  [   55/   60]
loss: 1.545658  [   60/   60]




Test Error:
    Avg loss: 1.534158
    Accuracy: 85.2%

    Precision (A): tensor([0.8809, 0.9202, 0.8058, 0.8490])
    Recall (A): tensor([0.8526, 0.9647, 0.8432, 0.7958])
    F1-Score (A): tensor([0.8665, 0.9419, 0.8241, 0.8215])

    Precision (B): tensor([0.0000, 1.0000, 0.7074])
    Recall (B): tensor([0.0000, 0.0090, 1.0000])
    F1-Score (B): tensor([0.0000, 0.0179, 0.8286])
    
Epoch 6
-------------------------------
loss: 1.566101  [    5/   60]
loss: 1.313316  [   10/   60]
loss: 1.307840  [   15/   60]
loss: 1.663432  [   20/   60]
loss: 1.520502  [   25/   60]
loss: 1.351629  [   30/   60]
loss: 1.557693  [   35/   60]
loss: 2.076861  [   40/   60]
loss: 1.275581  [   45/   60]
loss: 2.154396  [   50/   60]
loss: 1.239580  [   55/   60]
loss: 1.463140  [   60/   60]




Test Error:
    Avg loss: 1.454336
    Accuracy: 85.5%

    Precision (A): tensor([0.8857, 0.9240, 0.8108, 0.8477])
    Recall (A): tensor([0.8563, 0.9663, 0.8437, 0.8026])
    F1-Score (A): tensor([0.8708, 0.9447, 0.8269, 0.8245])

    Precision (B): tensor([0.0000, 1.0000, 0.7096])
    Recall (B): tensor([0.0000, 0.0270, 1.0000])
    F1-Score (B): tensor([0.0000, 0.0526, 0.8301])
    
Epoch 7
-------------------------------
loss: 1.494734  [    5/   60]
loss: 1.224703  [   10/   60]
loss: 1.220127  [   15/   60]
loss: 1.597078  [   20/   60]
loss: 1.444847  [   25/   60]
loss: 1.271567  [   30/   60]
loss: 1.487141  [   35/   60]
loss: 2.028066  [   40/   60]
loss: 1.202067  [   45/   60]
loss: 2.111342  [   50/   60]
loss: 1.172486  [   55/   60]
loss: 1.396800  [   60/   60]
Test Error:
    Avg loss: 1.390181
    Accuracy: 85.7%

    Precision (A): tensor([0.8873, 0.9258, 0.8122, 0.8476])
    Recall (A): tensor([0.8579, 0.9658, 0.8421, 0.8079])
    F1-Score (A): tensor([0.8724, 0.9