In [1]:
import transformers as T
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from torchmetrics import SpearmanCorrCoef, Accuracy, F1Score
device = "cuda:0" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 有些中文的標點符號在tokenizer編碼以後會變成[UNK]，所以將其換成英文標點
token_replacement = [
    ["：" , ":"],
    ["，" , ","],
    ["“" , "\""],
    ["”" , "\""],
    ["？" , "?"],
    ["……" , "..."],
    ["！" , "!"]
]

In [3]:
#model = MultiLabelModel().to(device)
tokenizer = T.BertTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="./cache/")

In [11]:
model = MultiLabelModel().to(device)

2024-11-20 20:08:20.001043: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-20 20:08:20.001177: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-20 20:08:20.231453: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-20 20:08:20.713873: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
class SemevalDataset(Dataset):
    def __init__(self, split="train") -> None:
        super().__init__()
        assert split in ["train", "validation", "test"]
        self.data = load_dataset(
            "sem_eval_2014_task_1", split=split, cache_dir="./cache/"
        ).to_list()

    def __getitem__(self, index):
        d = self.data[index]
        # 把中文標點替換掉
        for k in ["premise", "hypothesis"]:
            for tok in token_replacement:
                d[k] = d[k].replace(tok[0], tok[1])
        return d

    def __len__(self):
        return len(self.data)

data_sample = SemevalDataset(split="train").data[:3]
print(f"Dataset example: \n{data_sample[0]} \n{data_sample[1]} \n{data_sample[2]}")

Dataset example: 
{'sentence_pair_id': 1, 'premise': 'A group of kids is playing in a yard and an old man is standing in the background', 'hypothesis': 'A group of boys in a yard is playing and a man is standing in the background', 'relatedness_score': 4.5, 'entailment_judgment': 0} 
{'sentence_pair_id': 2, 'premise': 'A group of children is playing in the house and there is no man standing in the background', 'hypothesis': 'A group of kids is playing in a yard and an old man is standing in the background', 'relatedness_score': 3.200000047683716, 'entailment_judgment': 0} 
{'sentence_pair_id': 3, 'premise': 'The young boys are playing outdoors and the man is smiling nearby', 'hypothesis': 'The kids are playing outdoors near a man with a smile', 'relatedness_score': 4.699999809265137, 'entailment_judgment': 1}


In [5]:
# Define the hyperparameters
lr = 3e-5
epochs = 6
train_batch_size = 8
validation_batch_size = 8

In [6]:
# TODO1: Create batched data for DataLoader
# `collate_fn` is a function that defines how the data batch should be packed.
# This function will be called in the DataLoader to pack the data batch.

def collate_fn(batch):
    # TODO1-1: Implement the collate_fn function
    # Write your code here
    # The input parameter is a data batch (tuple), and this function packs it into tensors.
    # Use tokenizer to pack tokenize and pack the data and its corresponding labels.
    # Return the data batch and labels for each sub-task.
    premises = [item['premise'] for item in batch]
    hypotheses = [item['hypothesis'] for item in batch]
    labels_relatedness = torch.tensor([item['relatedness_score'] for item in batch], dtype=torch.float32)
    labels_entailment = torch.tensor([item['entailment_judgment'] for item in batch], dtype=torch.long)
    
    encoding = tokenizer(premises, hypotheses, return_tensors='pt', padding=True, truncation=True, max_length=128)
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels_relatedness': labels_relatedness,
        'labels_entailment': labels_entailment
    }

# TODO1-2: Define your DataLoader
train_dataset = SemevalDataset(split="train")
validation_dataset = SemevalDataset(split="validation")

dl_train = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)# Write your code here
dl_validation = DataLoader(validation_dataset, batch_size=validation_batch_size, shuffle=False, collate_fn=collate_fn)# Write your code here

In [7]:
# check the first batch:
for batch in dl_train:
    print("Input IDs:", batch['input_ids'])
    print("Attention Mask:", batch['attention_mask'])
    print("Relatedness Labels:", batch['labels_relatedness'])
    print("Entailment Labels:", batch['labels_entailment'])
    break

Input IDs: tensor([[  101,  1996,  2158,  2003,  8783,  3347, 17327,  2015,   102,  1996,
          2158,  2003, 13845,  3347, 17327,  2015,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1037,  2502,  3869,  2003,  9105,  1037,  2158,   102,  1037,
          2158,  2003,  9105,  1037,  3869,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1037,  2450,  2003,  2559, 16484,  2012,  1037,  2158,   102,
          1037,  2158,  2003,  2108, 16484,  2246,  2012,  2011,  1037,  2450,
           102,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1037,  2711,  2003,  7367,  6299,  2075,  1996, 11756,  1997,
          2019,  4064,  5898,  9573,  2007,  1037,  4690,   102,  1037,  2158,
          2003,  6276,  1037,  9573,  2007,  1037,  4690,   102],
        [  101,  1996,  2158,  2003,  2652,  1996,  2858,   102,  1996,  2711,
          2003

In [None]:
# TODO2: Construct your model
class MultiLabelModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Write your code here
        # Define what modules you will use in the model
        super(MultiLabelModel, self).__init__()
        # bert-base-uncased
        self.bert = T.BertModel.from_pretrained("google-bert/bert-base-uncased", cache_dir="./cache/")
        # two layers
        self.regressor = torch.nn.Linear(self.bert.config.hidden_size, 1)  # reg.
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 3)  # cls 3.
    def forward(self, **kwargs):
        # Write your code here
        # Forward pass
        # Use BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # BERT pooling
        
        # output
        relatedness_score = self.regressor(pooled_output).squeeze(-1)  # minus dimension
        entailment_logits = self.classifier(pooled_output)
        
        return relatedness_score, entailment_logits

In [12]:
# TODO3: Define your optimizer and loss function

# TODO3-1: Define your Optimizer
optimizer = AdamW(model.parameters(), lr=lr)# Write your code here

# TODO3-2: Define your loss functions (you should have two)
# Write your code here
loss_fn_relatedness = torch.nn.MSELoss()  # reg. loss
loss_fn_entailment = torch.nn.CrossEntropyLoss()  # cls. loss

# scoring functions
spc = SpearmanCorrCoef().to(device)
acc = Accuracy(task="multiclass", num_classes=3).to(device)
f1 = F1Score(task="multiclass", num_classes=3, average='macro').to(device)



In [None]:
for ep in range(epochs):
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training epoch [{ep+1}/{epochs}]")
    model.train()
    # TODO4: Write the training loop
    # Write your code here
    # train your model
    # clear gradient
    # forward pass
    # compute loss
    # back-propagation
    # model optimization
    for batch in pbar:
        optimizer.zero_grad()  # remove grad.
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_relatedness = batch['labels_relatedness'].to(device)
        labels_entailment = batch['labels_entailment'].to(device)

        # Forward pass
        relatedness_score, entailment_logits = model(input_ids=input_ids, attention_mask=attention_mask)

        # Compute losses
        loss_relatedness = loss_fn_relatedness(relatedness_score, labels_relatedness)
        loss_entailment = loss_fn_entailment(entailment_logits, labels_entailment)
        loss = loss_relatedness + loss_entailment

        # Backward pass and optimization
        loss.backward()
        optimizer.step()


    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation epoch [{ep+1}/{epochs}]")
    model.eval()
    # TODO5: Write the evaluation loop
    # Write your code here
    total_loss_relatedness = 0
    total_loss_entailment = 0
    total_spc = 0
    total_acc = 0
    total_f1 = 0
    num_batches = 0
    with torch.no_grad():
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_relatedness = batch['labels_relatedness'].to(device)
            labels_entailment = batch['labels_entailment'].to(device)

            # Forward pass
            relatedness_score, entailment_logits = model(input_ids=input_ids, attention_mask=attention_mask)

            # Compute losses
            loss_relatedness = loss_fn_relatedness(relatedness_score, labels_relatedness)
            loss_entailment = loss_fn_entailment(entailment_logits, labels_entailment)
            total_loss_relatedness += loss_relatedness.item()
            total_loss_entailment += loss_entailment.item()

            # Compute metrics
            total_spc += spc(relatedness_score, labels_relatedness).item()
            total_acc += acc(entailment_logits, labels_entailment).item()
            total_f1 += f1(entailment_logits, labels_entailment).item()
            num_batches += 1
    # Evaluate your model
    avg_loss_relatedness = total_loss_relatedness / num_batches
    avg_loss_entailment = total_loss_entailment / num_batches
    avg_spc = total_spc / num_batches
    avg_acc = total_acc / num_batches
    avg_f1 = total_f1 / num_batches

    print(f"Validation Results - Epoch [{ep+1}/{epochs}]:")
    print(f"  Loss (Relatedness): {avg_loss_relatedness:.4f}")
    print(f"  Loss (Entailment): {avg_loss_entailment:.4f}")
    print(f"  Spearman Correlation: {avg_spc:.4f}")
    print(f"  Accuracy: {avg_acc:.4f}")
    print(f"  F1 Score: {avg_f1:.4f}")
    # Output all the evaluation scores (SpearmanCorrCoef, Accuracy, F1Score)
    torch.save(model, f'./saved_models/ep{ep}.ckpt')

Training epoch [1/6]: 100%|██████████| 563/563 [00:14<00:00, 38.46it/s]
Validation epoch [1/6]: 100%|██████████| 63/63 [00:01<00:00, 56.76it/s]


Validation Results - Epoch [1/6]:
  Loss (Relatedness): 0.2830
  Loss (Entailment): 0.5626
  Spearman Correlation: 0.7630
  Accuracy: 0.8611
  F1 Score: 0.8200


Training epoch [2/6]: 100%|██████████| 563/563 [00:14<00:00, 38.35it/s]
Validation epoch [2/6]: 100%|██████████| 63/63 [00:01<00:00, 54.24it/s]


Validation Results - Epoch [2/6]:
  Loss (Relatedness): 0.2839
  Loss (Entailment): 0.5263
  Spearman Correlation: 0.7526
  Accuracy: 0.8631
  F1 Score: 0.8362


Training epoch [3/6]: 100%|██████████| 563/563 [00:14<00:00, 38.50it/s]
Validation epoch [3/6]: 100%|██████████| 63/63 [00:01<00:00, 44.83it/s]


Validation Results - Epoch [3/6]:
  Loss (Relatedness): 0.2989
  Loss (Entailment): 0.6112
  Spearman Correlation: 0.7591
  Accuracy: 0.8611
  F1 Score: 0.8199


Training epoch [4/6]: 100%|██████████| 563/563 [00:14<00:00, 38.43it/s]
Validation epoch [4/6]: 100%|██████████| 63/63 [00:01<00:00, 49.75it/s]


Validation Results - Epoch [4/6]:
  Loss (Relatedness): 0.2756
  Loss (Entailment): 0.5931
  Spearman Correlation: 0.7497
  Accuracy: 0.8571
  F1 Score: 0.8011


Training epoch [5/6]: 100%|██████████| 563/563 [00:14<00:00, 38.47it/s]
Validation epoch [5/6]: 100%|██████████| 63/63 [00:01<00:00, 47.77it/s]


Validation Results - Epoch [5/6]:
  Loss (Relatedness): 0.2850
  Loss (Entailment): 0.5570
  Spearman Correlation: 0.7305
  Accuracy: 0.8611
  F1 Score: 0.8065


Training epoch [6/6]: 100%|██████████| 563/563 [00:14<00:00, 38.41it/s]
Validation epoch [6/6]: 100%|██████████| 63/63 [00:01<00:00, 46.00it/s]


Validation Results - Epoch [6/6]:
  Loss (Relatedness): 0.2998
  Loss (Entailment): 0.5605
  Spearman Correlation: 0.7314
  Accuracy: 0.8671
  F1 Score: 0.8183


For test set predictions, you can write perform evaluation simlar to #TODO5.

In [13]:
test_dataset = SemevalDataset(split="test")
dl_test = DataLoader(test_dataset, batch_size=validation_batch_size, shuffle=False, collate_fn=collate_fn)

In [31]:
pbar = tqdm(dl_test)
pbar.set_description(f"Test set predictions")
model.eval()
total_loss_relatedness = 0
total_loss_entailment = 0
total_spc = 0
total_acc = 0
total_f1 = 0
num_batches = 0
with torch.no_grad():
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_relatedness = batch['labels_relatedness'].to(device)
        labels_entailment = batch['labels_entailment'].to(device)

        # Forward pass
        relatedness_score, entailment_logits = model(input_ids=input_ids, attention_mask=attention_mask)

        # Compute losses
        loss_relatedness = loss_fn_relatedness(relatedness_score, labels_relatedness)
        loss_entailment = loss_fn_entailment(entailment_logits, labels_entailment)
        total_loss_relatedness += loss_relatedness.item()
        total_loss_entailment += loss_entailment.item()

        # Compute metrics
        total_spc += spc(relatedness_score, labels_relatedness).item()
        total_acc += acc(entailment_logits, labels_entailment).item()
        total_f1 += f1(entailment_logits, labels_entailment).item()
        num_batches += 1

# Output test set evaluation scores
avg_loss_relatedness = total_loss_relatedness / num_batches
avg_loss_entailment = total_loss_entailment / num_batches
avg_spc = total_spc / num_batches
avg_acc = total_acc / num_batches
avg_f1 = total_f1 / num_batches

print(f"Test Set Results:")
print(f"  Loss (Relatedness): {avg_loss_relatedness:.4f}")
print(f"  Loss (Entailment): {avg_loss_entailment:.4f}")
print(f"  Spearman Correlation: {avg_spc:.4f}")
print(f"  Accuracy: {avg_acc:.4f}")
print(f"  F1 Score: {avg_f1:.4f}")

Test set predictions: 100%|██████████| 616/616 [00:16<00:00, 36.71it/s]

Test Set Results:
  Loss (Relatedness): 0.2934
  Loss (Entailment): 0.6168
  Spearman Correlation: 0.7144
  Accuracy: 0.8610
  F1 Score: 0.8100





Compared with models trained separately on each of the sub-task, does multi-output learning improve the performance?

In [14]:
class RelatednessModel(torch.nn.Module):
    def __init__(self):
        super(RelatednessModel, self).__init__()
        self.bert = T.BertModel.from_pretrained("google-bert/bert-base-uncased", cache_dir="./cache/")
        self.regressor = torch.nn.Linear(self.bert.config.hidden_size, 1)  # Regression output

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        relatedness_score = self.regressor(pooled_output).squeeze(-1)
        return relatedness_score

In [15]:
relatedness_model = RelatednessModel().to(device)
optimizer_relatedness = AdamW(relatedness_model.parameters(), lr=lr)
loss_fn_relatedness = torch.nn.MSELoss()
spc = SpearmanCorrCoef().to(device)

In [16]:
for ep in range(epochs):
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training Relatedness Model epoch [{ep+1}/{epochs}]")
    relatedness_model.train()
    for batch in pbar:
        optimizer_relatedness.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_relatedness = batch['labels_relatedness'].to(device)

        # Forward pass
        relatedness_score = relatedness_model(input_ids=input_ids, attention_mask=attention_mask)

        # Compute loss
        loss_relatedness = loss_fn_relatedness(relatedness_score, labels_relatedness)
        loss_relatedness.backward()
        optimizer_relatedness.step()

    # Validation loop
    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation Relatedness Model epoch [{ep+1}/{epochs}]")
    relatedness_model.eval()
    total_loss_relatedness = 0
    total_spc = 0
    num_batches = 0
    with torch.no_grad():
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_relatedness = batch['labels_relatedness'].to(device)

            # Forward pass
            relatedness_score = relatedness_model(input_ids=input_ids, attention_mask=attention_mask)

            # Compute loss
            loss_relatedness = loss_fn_relatedness(relatedness_score, labels_relatedness)
            total_loss_relatedness += loss_relatedness.item()

            # Compute metric
            total_spc += spc(relatedness_score, labels_relatedness).item()
            num_batches += 1

    avg_loss_relatedness = total_loss_relatedness / num_batches
    avg_spc = total_spc / num_batches
    print(f"Validation Relatedness Model - Epoch [{ep+1}/{epochs}]:")
    print(f"  Loss (Relatedness): {avg_loss_relatedness:.4f}")
    print(f"  Spearman Correlation: {avg_spc:.4f}")

Training Relatedness Model epoch [1/6]: 100%|██████████| 563/563 [00:18<00:00, 29.75it/s]
Validation Relatedness Model epoch [1/6]: 100%|██████████| 63/63 [00:01<00:00, 45.00it/s]


Validation Relatedness Model - Epoch [1/6]:
  Loss (Relatedness): 0.3727
  Spearman Correlation: 0.7411


Training Relatedness Model epoch [2/6]: 100%|██████████| 563/563 [00:15<00:00, 35.88it/s]
Validation Relatedness Model epoch [2/6]: 100%|██████████| 63/63 [00:00<00:00, 102.48it/s]


Validation Relatedness Model - Epoch [2/6]:
  Loss (Relatedness): 0.4256
  Spearman Correlation: 0.7421


Training Relatedness Model epoch [3/6]: 100%|██████████| 563/563 [00:15<00:00, 35.85it/s]
Validation Relatedness Model epoch [3/6]: 100%|██████████| 63/63 [00:00<00:00, 93.61it/s]


Validation Relatedness Model - Epoch [3/6]:
  Loss (Relatedness): 0.2848
  Spearman Correlation: 0.7463


Training Relatedness Model epoch [4/6]: 100%|██████████| 563/563 [00:15<00:00, 36.09it/s]
Validation Relatedness Model epoch [4/6]: 100%|██████████| 63/63 [00:00<00:00, 86.69it/s]


Validation Relatedness Model - Epoch [4/6]:
  Loss (Relatedness): 0.3410
  Spearman Correlation: 0.7098


Training Relatedness Model epoch [5/6]: 100%|██████████| 563/563 [00:15<00:00, 35.89it/s]
Validation Relatedness Model epoch [5/6]: 100%|██████████| 63/63 [00:00<00:00, 80.75it/s]


Validation Relatedness Model - Epoch [5/6]:
  Loss (Relatedness): 0.2582
  Spearman Correlation: 0.7678


Training Relatedness Model epoch [6/6]: 100%|██████████| 563/563 [00:15<00:00, 36.05it/s]
Validation Relatedness Model epoch [6/6]: 100%|██████████| 63/63 [00:00<00:00, 75.76it/s]

Validation Relatedness Model - Epoch [6/6]:
  Loss (Relatedness): 0.2934
  Spearman Correlation: 0.7469





In [17]:
class EntailmentModel(torch.nn.Module):
    def __init__(self):
        super(EntailmentModel, self).__init__()
        self.bert = T.BertModel.from_pretrained("google-bert/bert-base-uncased", cache_dir="./cache/")
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 3)  # Classification output

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        entailment_logits = self.classifier(pooled_output)
        return entailment_logits

In [18]:
entailment_model = EntailmentModel().to(device)
optimizer_entailment = AdamW(entailment_model.parameters(), lr=lr)
loss_fn_entailment = torch.nn.CrossEntropyLoss()
acc = Accuracy(task="multiclass", num_classes=3).to(device)

In [19]:
for ep in range(epochs):
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training Entailment Model epoch [{ep+1}/{epochs}]")
    entailment_model.train()
    for batch in pbar:
        optimizer_entailment.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_entailment = batch['labels_entailment'].to(device)

        # Forward pass
        entailment_logits = entailment_model(input_ids=input_ids, attention_mask=attention_mask)

        # Compute loss
        loss_entailment = loss_fn_entailment(entailment_logits, labels_entailment)
        loss_entailment.backward()
        optimizer_entailment.step()

    # Validation loop
    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation Entailment Model epoch [{ep+1}/{epochs}]")
    entailment_model.eval()
    total_loss_entailment = 0
    total_acc = 0
    num_batches = 0
    with torch.no_grad():
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_entailment = batch['labels_entailment'].to(device)

            # Forward pass
            entailment_logits = entailment_model(input_ids=input_ids, attention_mask=attention_mask)

            # Compute loss
            loss_entailment = loss_fn_entailment(entailment_logits, labels_entailment)
            total_loss_entailment += loss_entailment.item()

            # Compute metric
            total_acc += acc(entailment_logits, labels_entailment).item()
            num_batches += 1

    avg_loss_entailment = total_loss_entailment / num_batches
    avg_acc = total_acc / num_batches
    print(f"Validation Entailment Model - Epoch [{ep+1}/{epochs}]:")
    print(f"  Loss (Entailment): {avg_loss_entailment:.4f}")
    print(f"  Accuracy: {avg_acc:.4f}")

Training Entailment Model epoch [1/6]: 100%|██████████| 563/563 [00:14<00:00, 38.39it/s]
Validation Entailment Model epoch [1/6]: 100%|██████████| 63/63 [00:00<00:00, 102.05it/s]


Validation Entailment Model - Epoch [1/6]:
  Loss (Entailment): 0.4016
  Accuracy: 0.8333


Training Entailment Model epoch [2/6]: 100%|██████████| 563/563 [00:14<00:00, 38.83it/s]
Validation Entailment Model epoch [2/6]: 100%|██████████| 63/63 [00:00<00:00, 126.97it/s]


Validation Entailment Model - Epoch [2/6]:
  Loss (Entailment): 0.3718
  Accuracy: 0.8452


Training Entailment Model epoch [3/6]: 100%|██████████| 563/563 [00:14<00:00, 38.00it/s]
Validation Entailment Model epoch [3/6]: 100%|██████████| 63/63 [00:00<00:00, 126.01it/s]


Validation Entailment Model - Epoch [3/6]:
  Loss (Entailment): 0.4364
  Accuracy: 0.8571


Training Entailment Model epoch [4/6]: 100%|██████████| 563/563 [00:15<00:00, 35.79it/s]
Validation Entailment Model epoch [4/6]: 100%|██████████| 63/63 [00:00<00:00, 126.65it/s]


Validation Entailment Model - Epoch [4/6]:
  Loss (Entailment): 0.4823
  Accuracy: 0.8492


Training Entailment Model epoch [5/6]: 100%|██████████| 563/563 [00:15<00:00, 36.02it/s]
Validation Entailment Model epoch [5/6]: 100%|██████████| 63/63 [00:00<00:00, 125.82it/s]


Validation Entailment Model - Epoch [5/6]:
  Loss (Entailment): 0.4726
  Accuracy: 0.8512


Training Entailment Model epoch [6/6]: 100%|██████████| 563/563 [00:15<00:00, 35.81it/s]
Validation Entailment Model epoch [6/6]: 100%|██████████| 63/63 [00:00<00:00, 126.11it/s]

Validation Entailment Model - Epoch [6/6]:
  Loss (Entailment): 0.5327
  Accuracy: 0.8591





Why does your model fail to correctly predict some data points? Please provide an error analysis.

In [30]:
import pandas as pd

In [31]:
for ep in range(epochs):
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training epoch [{ep+1}/{epochs}]")
    model.train()
    # TODO4: Write the training loop
    # Write your code here
    # train your model
    # clear gradient
    # forward pass
    # compute loss
    # back-propagation
    # model optimization
    for batch in pbar:
        optimizer.zero_grad()  # remove grad.
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_relatedness = batch['labels_relatedness'].to(device)
        labels_entailment = batch['labels_entailment'].to(device)

        # Forward pass
        relatedness_score, entailment_logits = model(input_ids=input_ids, attention_mask=attention_mask)

        # Compute losses
        loss_relatedness = loss_fn_relatedness(relatedness_score, labels_relatedness)
        loss_entailment = loss_fn_entailment(entailment_logits, labels_entailment)
        loss = loss_relatedness + loss_entailment

        # Backward pass and optimization
        loss.backward()
        optimizer.step()


    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation epoch [{ep+1}/{epochs}]")
    model.eval()
    # TODO5: Write the evaluation loop
    # Write your code here
    total_loss_relatedness = 0
    total_loss_entailment = 0
    total_spc = 0
    total_acc = 0
    total_f1 = 0
    num_batches = 0
    error_analysis = []
    with torch.no_grad():
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_relatedness = batch['labels_relatedness'].to(device)
            labels_entailment = batch['labels_entailment'].to(device)

            # Forward pass
            relatedness_score, entailment_logits = model(input_ids=input_ids, attention_mask=attention_mask)

            # Compute losses
            loss_relatedness = loss_fn_relatedness(relatedness_score, labels_relatedness)
            loss_entailment = loss_fn_entailment(entailment_logits, labels_entailment)
            total_loss_relatedness += loss_relatedness.item()
            total_loss_entailment += loss_entailment.item()

            # Compute metrics
            total_spc += spc(relatedness_score, labels_relatedness).item()
            total_acc += acc(entailment_logits, labels_entailment).item()
            total_f1 += f1(entailment_logits, labels_entailment).item()
            num_batches += 1

            # Error analysis
            predicted_labels = torch.argmax(entailment_logits, dim=1)
            for i in range(len(labels_entailment)):
                if predicted_labels[i] != labels_entailment[i]:
                    error_analysis.append({
                        'premise': batch['input_ids'][i].cpu().numpy().tolist(),
                        'true_label': labels_entailment[i].item(),
                        'predicted_label': predicted_labels[i].item(),
                        'relatedness_true': labels_relatedness[i].item(),
                        'relatedness_predicted': relatedness_score[i].item()
                    })
    # Evaluate your model
    avg_loss_relatedness = total_loss_relatedness / num_batches
    avg_loss_entailment = total_loss_entailment / num_batches
    avg_spc = total_spc / num_batches
    avg_acc = total_acc / num_batches
    avg_f1 = total_f1 / num_batches

    print(f"Validation Results - Epoch [{ep+1}/{epochs}]:")
    print(f"  Loss (Relatedness): {avg_loss_relatedness:.4f}")
    print(f"  Loss (Entailment): {avg_loss_entailment:.4f}")
    print(f"  Spearman Correlation: {avg_spc:.4f}")
    print(f"  Accuracy: {avg_acc:.4f}")
    print(f"  F1 Score: {avg_f1:.4f}")
    # Output all the evaluation scores (SpearmanCorrCoef, Accuracy, F1Score)

    if len(error_analysis) > 0:
        df_error_analysis = pd.DataFrame(error_analysis)
        print("Top 5 Misclassified Examples:")
        print(df_error_analysis.head(5))

    torch.save(model, f'./saved_models/ep{ep}.ckpt')

Training epoch [1/6]: 100%|██████████| 563/563 [00:14<00:00, 38.89it/s]
Validation epoch [1/6]: 100%|██████████| 63/63 [00:01<00:00, 53.40it/s]


Validation Results - Epoch [1/6]:
  Loss (Relatedness): 0.2549
  Loss (Entailment): 0.6533
  Spearman Correlation: 0.7757
  Accuracy: 0.8631
  F1 Score: 0.8198
Top 5 Misclassified Examples:
                                             premise  true_label  \
0  [101, 1037, 2829, 1998, 2317, 3899, 2003, 2770...           0   
1  [101, 1037, 2177, 1997, 10158, 2024, 13039, 20...           0   
2  [101, 1996, 2450, 4147, 3165, 6471, 1010, 5061...           0   
3  [101, 1037, 2450, 2003, 2635, 2125, 1037, 1196...           2   
4  [101, 1037, 2711, 2003, 8218, 1037, 2600, 2007...           0   

   predicted_label  relatedness_true  relatedness_predicted  
0                1               4.4               4.458434  
1                1               4.1               4.612533  
2                2               3.3               3.549941  
3                0               3.5               3.627561  
4                1               4.6               4.393768  


Training epoch [2/6]: 100%|██████████| 563/563 [00:14<00:00, 38.62it/s]
Validation epoch [2/6]: 100%|██████████| 63/63 [00:01<00:00, 51.13it/s]


Validation Results - Epoch [2/6]:
  Loss (Relatedness): 0.2437
  Loss (Entailment): 0.5948
  Spearman Correlation: 0.7891
  Accuracy: 0.8690
  F1 Score: 0.8318
Top 5 Misclassified Examples:
                                             premise  true_label  \
0  [101, 1037, 2611, 1999, 2317, 2003, 5613, 102,...           1   
1  [101, 1037, 2829, 1998, 2317, 3899, 2003, 2770...           0   
2  [101, 1037, 2177, 1997, 10158, 2024, 13039, 20...           0   
3  [101, 1996, 2450, 4147, 3165, 6471, 1010, 5061...           0   
4  [101, 1037, 2450, 2003, 2635, 2125, 1037, 1196...           2   

   predicted_label  relatedness_true  relatedness_predicted  
0                0               4.9               4.188742  
1                1               4.4               4.442097  
2                1               4.1               4.769308  
3                2               3.3               3.699707  
4                0               3.5               3.553298  


Training epoch [3/6]: 100%|██████████| 563/563 [00:14<00:00, 38.65it/s]
Validation epoch [3/6]: 100%|██████████| 63/63 [00:01<00:00, 48.86it/s]


Validation Results - Epoch [3/6]:
  Loss (Relatedness): 0.2756
  Loss (Entailment): 0.6907
  Spearman Correlation: 0.7663
  Accuracy: 0.8571
  F1 Score: 0.8224
Top 5 Misclassified Examples:
                                             premise  true_label  \
0  [101, 2274, 2336, 2024, 3061, 1999, 2392, 1997...           0   
1  [101, 2619, 2003, 2006, 1037, 2304, 1998, 2317...           0   
2  [101, 1037, 2611, 1999, 2317, 2003, 5613, 102,...           1   
3  [101, 1037, 2829, 1998, 2317, 3899, 2003, 2770...           0   
4  [101, 1037, 2177, 1997, 10158, 2024, 13039, 20...           0   

   predicted_label  relatedness_true  relatedness_predicted  
0                1             4.200               4.648133  
1                1             3.165               4.323132  
2                0             4.900               4.160884  
3                1             4.400               4.435025  
4                1             4.100               4.855351  


Training epoch [4/6]: 100%|██████████| 563/563 [00:14<00:00, 38.56it/s]
Validation epoch [4/6]: 100%|██████████| 63/63 [00:01<00:00, 46.92it/s]


Validation Results - Epoch [4/6]:
  Loss (Relatedness): 0.2622
  Loss (Entailment): 0.5936
  Spearman Correlation: 0.7748
  Accuracy: 0.8690
  F1 Score: 0.8251
Top 5 Misclassified Examples:
                                             premise  true_label  \
0  [101, 2274, 2336, 2024, 3061, 1999, 2392, 1997...           0   
1  [101, 1037, 2611, 1999, 2317, 2003, 5613, 102,...           1   
2  [101, 1037, 2177, 1997, 10158, 2024, 13039, 20...           0   
3  [101, 1996, 2450, 4147, 3165, 6471, 1010, 5061...           0   
4  [101, 1037, 2450, 2003, 2635, 2125, 1037, 1196...           2   

   predicted_label  relatedness_true  relatedness_predicted  
0                1               4.2               4.829823  
1                0               4.9               4.128891  
2                1               4.1               4.762861  
3                2               3.3               3.472594  
4                0               3.5               3.205946  


Training epoch [5/6]: 100%|██████████| 563/563 [00:14<00:00, 38.53it/s]
Validation epoch [5/6]: 100%|██████████| 63/63 [00:01<00:00, 39.53it/s]


Validation Results - Epoch [5/6]:
  Loss (Relatedness): 0.2754
  Loss (Entailment): 0.8113
  Spearman Correlation: 0.7683
  Accuracy: 0.8433
  F1 Score: 0.8101
Top 5 Misclassified Examples:
                                             premise  true_label  \
0  [101, 2176, 2336, 2024, 2725, 2067, 10609, 510...           0   
1  [101, 2274, 2336, 2024, 3061, 1999, 2392, 1997...           0   
2  [101, 2619, 2003, 2006, 1037, 2304, 1998, 2317...           0   
3  [101, 2019, 2214, 1010, 2327, 3238, 2450, 2003...           0   
4  [101, 1037, 2829, 1998, 2317, 3899, 2003, 2770...           0   

   predicted_label  relatedness_true  relatedness_predicted  
0                1             3.800               4.167154  
1                1             4.200               4.488229  
2                1             3.165               4.124984  
3                1             3.800               4.538031  
4                1             4.400               4.389619  


Training epoch [6/6]: 100%|██████████| 563/563 [00:14<00:00, 38.51it/s]
Validation epoch [6/6]: 100%|██████████| 63/63 [00:01<00:00, 43.47it/s]


Validation Results - Epoch [6/6]:
  Loss (Relatedness): 0.2754
  Loss (Entailment): 0.5778
  Spearman Correlation: 0.7653
  Accuracy: 0.8790
  F1 Score: 0.8335
Top 5 Misclassified Examples:
                                             premise  true_label  \
0  [101, 2274, 2336, 2024, 3061, 1999, 2392, 1997...           0   
1  [101, 2048, 6077, 2024, 2652, 2011, 1037, 3392...           1   
2  [101, 1037, 2611, 1999, 2317, 2003, 5613, 102,...           1   
3  [101, 1037, 2177, 1997, 10158, 2024, 13039, 20...           0   
4  [101, 1996, 2450, 4147, 3165, 6471, 1010, 5061...           0   

   predicted_label  relatedness_true  relatedness_predicted  
0                1               4.2               4.497460  
1                0               4.6               4.215297  
2                0               4.9               4.093045  
3                1               4.1               4.594283  
4                2               3.3               3.182624  
