# 1. All-or-nothing Thinking

## 1. 모델 정의

### 1-1. 기본 준비

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
data1 = pd.read_csv('d01_preprocessed_revised.csv')
data2 = pd.read_csv('d02_preprocessed.csv')
data3 = pd.read_csv('d03_preprocessed.csv')

### 1-2. 데이터 증강 및 구조화

In [None]:
for idx, i in enumerate(data2['thought']):
    if type(i) != str:
        data2['thought'][idx] = ''

In [None]:
# has_distortion == 1 필터링

data1 = data1[data1['has_distortion'] == 1].reset_index(drop=True)
data2 = data2[data2['has_distortion'] == 1].reset_index(drop=True)
data3 = data3[data3['has_distortion'] == 1].reset_index(drop=True)

In [None]:
data1_1 = data1['situation']+' '+data1['thought']
data2_1 = data2['situation']
data3_1 = data3['situation']+' '+data3['thought']

In [None]:
data1_1.drop_duplicates(inplace = True)
data2_1.drop_duplicates(inplace = True)
data3_1.drop_duplicates(inplace = True)

In [None]:
def normalize_text(s):
    # Removing articles and punctuation, and standardizing whitespace
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


In [None]:
from transformers import BertTokenizer, BertModel, BertConfig

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.to(device)
bert_config = BertConfig.from_pretrained('bert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
# Embedding

def tokenize_and_pad(data, tokenizer, max_len=512):
    tokenized_data = []
    for text in data:
        encoded = tokenizer(normalize_text(text), return_tensors="pt", padding='max_length', truncation=True, max_length=max_len)
        tokenized_data.append(encoded)
    return tokenized_data

data1_1_encoded = tokenize_and_pad(data1_1, tokenizer)
data2_1_encoded = tokenize_and_pad(data2_1, tokenizer)
data3_1_encoded = tokenize_and_pad(data3_1, tokenizer)

In [None]:
data1.columns

Index(['situation', 'thought', 'reframe', 'has_distortion',
       'all-or-nothing thinking', 'comparing and despairing',
       'disqualifying the positive', 'emotional reasoning', 'fortune telling',
       'labeling', 'magnification', 'mind reading', 'overgeneralizing',
       'should statements', 'mental filter', 'personalization and blaming'],
      dtype='object')

In [None]:
# Add labels
data1_1_labels = list(data1['all-or-nothing thinking'][data1_1.index])
data2_1_labels = list(data2['all-or-nothing thinking'][data2_1.index])
data3_1_labels = list(data3['all-or-nothing thinking'][data3_1.index])

In [None]:
# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

In [None]:
class CustomDatasetWithLabels(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"input_ids": self.data[idx]['input_ids'].squeeze(),
                "attention_mask": self.data[idx]['attention_mask'].squeeze(),
                "y": self.labels[idx]}

In [None]:
dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
from imblearn.over_sampling import SMOTE
smote = SMOTE(sampling_strategy='auto', random_state=0)

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

### 1-3. 모델 평가 함수 정의

In [None]:
from sklearn.metrics import accuracy_score, f1_score

def evaluate(model, dataloader, device="cpu"):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            y = batch["y"].to(device)

            # Get embeddings from the BERT model
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

            logits = model(embeddings)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())

    acc = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average="macro", zero_division=0)

    return {"accuracy": acc, "f1_macro": f1_macro}

In [None]:
# Classifier that uses label embeddings to make predictions
class InnerProductClassifier(nn.Module):
    def __init__(self, input_dim, label_embeddings, trainable_label_emb=True):
        super().__init__()
        # Project input features into the same dimension as label embeddings
        self.proj = nn.Linear(input_dim, label_embeddings.size(1))

        if trainable_label_emb:
            # Label embeddings are trainable parameters
            self.label_emb = nn.Parameter(label_embeddings.clone())
        else:
            # Label embeddings are fixed (not updated during training)
            self.register_buffer("label_emb", label_embeddings.clone())

    def forward(self, x):
        # Project input feature vectors
        x_proj = self.proj(x)
        # Compute logits as similarity with each label embedding
        logits = torch.matmul(x_proj, self.label_emb.T)
        return logits

### 1-4. 모델 생성

In [None]:
# Instantiate the InnerProductClassifier model
model1 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model1.parameters(), lr=2e-4)

## 2. 모델 학습

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model1.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model1(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model1, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/58 [00:02<02:17,  2.41s/it]

Loss: 2.830264091491699


Epoch 1:   3%|▎         | 2/58 [00:04<01:49,  1.96s/it]

Loss: 1.9927647113800049


Epoch 1:   5%|▌         | 3/58 [00:05<01:41,  1.84s/it]

Loss: 0.8688162565231323


Epoch 1:   7%|▋         | 4/58 [00:07<01:36,  1.78s/it]

Loss: 0.634597659111023


Epoch 1:   9%|▊         | 5/58 [00:09<01:32,  1.75s/it]

Loss: 0.16858726739883423


Epoch 1:  10%|█         | 6/58 [00:10<01:30,  1.74s/it]

Loss: 0.6435409784317017


Epoch 1:  12%|█▏        | 7/58 [00:12<01:28,  1.73s/it]

Loss: 0.9940181970596313


Epoch 1:  14%|█▍        | 8/58 [00:14<01:26,  1.73s/it]

Loss: 1.941710114479065


Epoch 1:  16%|█▌        | 9/58 [00:16<01:24,  1.72s/it]

Loss: 2.6971497535705566


Epoch 1:  17%|█▋        | 10/58 [00:17<01:22,  1.72s/it]

Loss: 1.3925105333328247


Epoch 1:  19%|█▉        | 11/58 [00:19<01:20,  1.72s/it]

Loss: 1.647416591644287


Epoch 1:  21%|██        | 12/58 [00:21<01:19,  1.73s/it]

Loss: 1.1870920658111572


Epoch 1:  22%|██▏       | 13/58 [00:22<01:17,  1.73s/it]

Loss: 2.1355559825897217


Epoch 1:  24%|██▍       | 14/58 [00:24<01:16,  1.74s/it]

Loss: 0.6520906090736389


Epoch 1:  26%|██▌       | 15/58 [00:26<01:15,  1.74s/it]

Loss: 0.7609651684761047


Epoch 1:  28%|██▊       | 16/58 [00:28<01:13,  1.76s/it]

Loss: 0.8111122846603394


Epoch 1:  29%|██▉       | 17/58 [00:29<01:12,  1.76s/it]

Loss: 0.7627554535865784


Epoch 1:  31%|███       | 18/58 [00:31<01:10,  1.77s/it]

Loss: 2.2225475311279297


Epoch 1:  33%|███▎      | 19/58 [00:33<01:09,  1.78s/it]

Loss: 0.3839457333087921


Epoch 1:  34%|███▍      | 20/58 [00:35<01:07,  1.79s/it]

Loss: 0.43300679326057434


Epoch 1:  36%|███▌      | 21/58 [00:37<01:06,  1.80s/it]

Loss: 2.063870429992676


Epoch 1:  38%|███▊      | 22/58 [00:39<01:05,  1.81s/it]

Loss: 2.3545804023742676


Epoch 1:  40%|███▉      | 23/58 [00:40<01:03,  1.82s/it]

Loss: 1.430389165878296


Epoch 1:  41%|████▏     | 24/58 [00:42<01:02,  1.83s/it]

Loss: 0.281577467918396


Epoch 1:  43%|████▎     | 25/58 [00:44<01:00,  1.83s/it]

Loss: 0.9321359395980835


Epoch 1:  45%|████▍     | 26/58 [00:46<00:58,  1.84s/it]

Loss: 1.5726203918457031


Epoch 1:  47%|████▋     | 27/58 [00:48<00:57,  1.84s/it]

Loss: 0.45346951484680176


Epoch 1:  48%|████▊     | 28/58 [00:50<00:55,  1.85s/it]

Loss: 1.7021980285644531


Epoch 1:  50%|█████     | 29/58 [00:52<00:53,  1.86s/it]

Loss: 0.9140793085098267


Epoch 1:  52%|█████▏    | 30/58 [00:53<00:52,  1.87s/it]

Loss: 1.1286548376083374


Epoch 1:  53%|█████▎    | 31/58 [00:55<00:50,  1.88s/it]

Loss: 0.6385588645935059


Epoch 1:  55%|█████▌    | 32/58 [00:57<00:49,  1.89s/it]

Loss: 0.4879584312438965


Epoch 1:  57%|█████▋    | 33/58 [00:59<00:47,  1.90s/it]

Loss: 0.6499820947647095


Epoch 1:  59%|█████▊    | 34/58 [01:01<00:45,  1.91s/it]

Loss: 0.841675341129303


Epoch 1:  60%|██████    | 35/58 [01:03<00:44,  1.92s/it]

Loss: 1.075608491897583


Epoch 1:  62%|██████▏   | 36/58 [01:05<00:42,  1.93s/it]

Loss: 0.8388368487358093


Epoch 1:  64%|██████▍   | 37/58 [01:07<00:40,  1.94s/it]

Loss: 0.8300610780715942


Epoch 1:  66%|██████▌   | 38/58 [01:09<00:39,  1.95s/it]

Loss: 1.249332070350647


Epoch 1:  67%|██████▋   | 39/58 [01:11<00:37,  1.97s/it]

Loss: 1.648390769958496


Epoch 1:  69%|██████▉   | 40/58 [01:13<00:35,  1.98s/it]

Loss: 0.7646337151527405


Epoch 1:  71%|███████   | 41/58 [01:15<00:33,  1.99s/it]

Loss: 0.26428619027137756


Epoch 1:  72%|███████▏  | 42/58 [01:17<00:31,  2.00s/it]

Loss: 0.8611202239990234


Epoch 1:  74%|███████▍  | 43/58 [01:19<00:30,  2.01s/it]

Loss: 1.3136298656463623


Epoch 1:  76%|███████▌  | 44/58 [01:21<00:28,  2.02s/it]

Loss: 1.1398117542266846


Epoch 1:  78%|███████▊  | 45/58 [01:23<00:26,  2.03s/it]

Loss: 0.6010174751281738


Epoch 1:  79%|███████▉  | 46/58 [01:25<00:24,  2.03s/it]

Loss: 1.1665419340133667


Epoch 1:  81%|████████  | 47/58 [01:27<00:22,  2.04s/it]

Loss: 0.7025517225265503


Epoch 1:  83%|████████▎ | 48/58 [01:29<00:20,  2.03s/it]

Loss: 0.7653785943984985


Epoch 1:  84%|████████▍ | 49/58 [01:31<00:18,  2.03s/it]

Loss: 0.7736606597900391


Epoch 1:  86%|████████▌ | 50/58 [01:33<00:16,  2.02s/it]

Loss: 2.1147427558898926


Epoch 1:  88%|████████▊ | 51/58 [01:35<00:14,  2.01s/it]

Loss: 0.9152438640594482


Epoch 1:  90%|████████▉ | 52/58 [01:37<00:12,  2.00s/it]

Loss: 0.8367862701416016


Epoch 1:  91%|█████████▏| 53/58 [01:39<00:09,  1.99s/it]

Loss: 0.8434372544288635


Epoch 1:  93%|█████████▎| 54/58 [01:41<00:07,  1.98s/it]

Loss: 0.5564929246902466


Epoch 1:  95%|█████████▍| 55/58 [01:43<00:05,  1.97s/it]

Loss: 0.07778539508581161


Epoch 1:  97%|█████████▋| 56/58 [01:45<00:03,  1.96s/it]

Loss: 0.3517853021621704


Epoch 1:  98%|█████████▊| 57/58 [01:47<00:01,  1.95s/it]

Loss: 0.2321624457836151


Epoch 1: 100%|██████████| 58/58 [01:47<00:00,  1.86s/it]

Loss: 0.47925734519958496





Epoch 1 Validation Accuracy: 0.5467980295566502, F1-macro: 0.44482758620689655


Epoch 2:   2%|▏         | 1/58 [00:01<01:48,  1.90s/it]

Loss: 0.6087865233421326


Epoch 2:   3%|▎         | 2/58 [00:03<01:46,  1.90s/it]

Loss: 0.940849244594574


Epoch 2:   5%|▌         | 3/58 [00:05<01:44,  1.90s/it]

Loss: 0.13239699602127075


Epoch 2:   7%|▋         | 4/58 [00:07<01:42,  1.90s/it]

Loss: 0.14226102828979492


Epoch 2:   9%|▊         | 5/58 [00:09<01:40,  1.90s/it]

Loss: 0.49826139211654663


Epoch 2:  10%|█         | 6/58 [00:11<01:38,  1.90s/it]

Loss: 0.5722050070762634


Epoch 2:  12%|█▏        | 7/58 [00:13<01:36,  1.90s/it]

Loss: 1.5629236698150635


Epoch 2:  14%|█▍        | 8/58 [00:15<01:35,  1.90s/it]

Loss: 0.9002200365066528


Epoch 2:  16%|█▌        | 9/58 [00:17<01:33,  1.90s/it]

Loss: 1.5416090488433838


Epoch 2:  17%|█▋        | 10/58 [00:19<01:31,  1.90s/it]

Loss: 0.3077189326286316


Epoch 2:  19%|█▉        | 11/58 [00:20<01:29,  1.91s/it]

Loss: 0.8542036414146423


Epoch 2:  21%|██        | 12/58 [00:22<01:27,  1.91s/it]

Loss: 0.9119237065315247


Epoch 2:  22%|██▏       | 13/58 [00:24<01:25,  1.91s/it]

Loss: 0.21220095455646515


Epoch 2:  24%|██▍       | 14/58 [00:26<01:24,  1.91s/it]

Loss: 0.9929426312446594


Epoch 2:  26%|██▌       | 15/58 [00:28<01:22,  1.92s/it]

Loss: 0.7183569669723511


Epoch 2:  28%|██▊       | 16/58 [00:30<01:20,  1.93s/it]

Loss: 0.2893369197845459


Epoch 2:  29%|██▉       | 17/58 [00:32<01:19,  1.93s/it]

Loss: 0.4519590735435486


Epoch 2:  31%|███       | 18/58 [00:34<01:17,  1.94s/it]

Loss: 0.1386086791753769


Epoch 2:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 0.012537076137959957


Epoch 2:  34%|███▍      | 20/58 [00:38<01:13,  1.95s/it]

Loss: 0.1913217008113861


Epoch 2:  36%|███▌      | 21/58 [00:40<01:12,  1.95s/it]

Loss: 0.2611577808856964


Epoch 2:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 2.4403278827667236


Epoch 2:  40%|███▉      | 23/58 [00:44<01:08,  1.96s/it]

Loss: 1.4317342042922974


Epoch 2:  41%|████▏     | 24/58 [00:46<01:06,  1.96s/it]

Loss: 0.5999241471290588


Epoch 2:  43%|████▎     | 25/58 [00:48<01:04,  1.96s/it]

Loss: 0.963765025138855


Epoch 2:  45%|████▍     | 26/58 [00:50<01:02,  1.96s/it]

Loss: 0.9490680694580078


Epoch 2:  47%|████▋     | 27/58 [00:52<01:00,  1.97s/it]

Loss: 0.46202462911605835


Epoch 2:  48%|████▊     | 28/58 [00:54<00:59,  1.97s/it]

Loss: 0.8565297722816467


Epoch 2:  50%|█████     | 29/58 [00:56<00:57,  1.97s/it]

Loss: 1.0142476558685303


Epoch 2:  52%|█████▏    | 30/58 [00:58<00:55,  1.97s/it]

Loss: 0.8672146797180176


Epoch 2:  53%|█████▎    | 31/58 [00:59<00:53,  1.97s/it]

Loss: 0.09095658361911774


Epoch 2:  55%|█████▌    | 32/58 [01:01<00:51,  1.97s/it]

Loss: 0.3589514493942261


Epoch 2:  57%|█████▋    | 33/58 [01:03<00:49,  1.96s/it]

Loss: 0.5607534646987915


Epoch 2:  59%|█████▊    | 34/58 [01:05<00:47,  1.96s/it]

Loss: 0.7447068691253662


Epoch 2:  60%|██████    | 35/58 [01:07<00:44,  1.96s/it]

Loss: 0.07979481667280197


Epoch 2:  62%|██████▏   | 36/58 [01:09<00:42,  1.95s/it]

Loss: 0.09914685785770416


Epoch 2:  64%|██████▍   | 37/58 [01:11<00:41,  1.95s/it]

Loss: 0.7506805062294006


Epoch 2:  66%|██████▌   | 38/58 [01:13<00:39,  1.95s/it]

Loss: 0.6127263307571411


Epoch 2:  67%|██████▋   | 39/58 [01:15<00:37,  1.95s/it]

Loss: 1.4216351509094238


Epoch 2:  69%|██████▉   | 40/58 [01:17<00:35,  1.95s/it]

Loss: 0.23745892941951752


Epoch 2:  71%|███████   | 41/58 [01:19<00:33,  1.95s/it]

Loss: 1.2039824724197388


Epoch 2:  72%|███████▏  | 42/58 [01:21<00:31,  1.95s/it]

Loss: 0.7428565621376038


Epoch 2:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.8127214908599854


Epoch 2:  76%|███████▌  | 44/58 [01:25<00:27,  1.94s/it]

Loss: 0.7701293230056763


Epoch 2:  78%|███████▊  | 45/58 [01:27<00:25,  1.94s/it]

Loss: 0.5390831232070923


Epoch 2:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 0.44256094098091125


Epoch 2:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.229066863656044


Epoch 2:  83%|████████▎ | 48/58 [01:33<00:19,  1.94s/it]

Loss: 0.325749009847641


Epoch 2:  84%|████████▍ | 49/58 [01:34<00:17,  1.94s/it]

Loss: 0.2168164998292923


Epoch 2:  86%|████████▌ | 50/58 [01:36<00:15,  1.94s/it]

Loss: 0.5185225605964661


Epoch 2:  88%|████████▊ | 51/58 [01:38<00:13,  1.94s/it]

Loss: 0.5993419289588928


Epoch 2:  90%|████████▉ | 52/58 [01:40<00:11,  1.94s/it]

Loss: 0.5140336751937866


Epoch 2:  91%|█████████▏| 53/58 [01:42<00:09,  1.94s/it]

Loss: 0.5308155417442322


Epoch 2:  93%|█████████▎| 54/58 [01:44<00:07,  1.94s/it]

Loss: 1.044755458831787


Epoch 2:  95%|█████████▍| 55/58 [01:46<00:05,  1.94s/it]

Loss: 0.5961322784423828


Epoch 2:  97%|█████████▋| 56/58 [01:48<00:03,  1.94s/it]

Loss: 0.7299649715423584


Epoch 2:  98%|█████████▊| 57/58 [01:50<00:01,  1.94s/it]

Loss: 0.9596312046051025


Epoch 2: 100%|██████████| 58/58 [01:50<00:00,  1.91s/it]

Loss: 1.094288472813787e-05





Epoch 2 Validation Accuracy: 0.8669950738916257, F1-macro: 0.49867374005305043


Epoch 3:   2%|▏         | 1/58 [00:01<01:51,  1.95s/it]

Loss: 1.7990829944610596


Epoch 3:   3%|▎         | 2/58 [00:03<01:49,  1.95s/it]

Loss: 0.36273258924484253


Epoch 3:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 0.5947802066802979


Epoch 3:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 0.23688265681266785


Epoch 3:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 0.4223078489303589


Epoch 3:  10%|█         | 6/58 [00:11<01:41,  1.95s/it]

Loss: 0.2847039997577667


Epoch 3:  12%|█▏        | 7/58 [00:13<01:39,  1.96s/it]

Loss: 0.34971684217453003


Epoch 3:  14%|█▍        | 8/58 [00:15<01:37,  1.96s/it]

Loss: 0.4698375463485718


Epoch 3:  16%|█▌        | 9/58 [00:17<01:35,  1.96s/it]

Loss: 0.2767379581928253


Epoch 3:  17%|█▋        | 10/58 [00:19<01:33,  1.96s/it]

Loss: 0.2687254846096039


Epoch 3:  19%|█▉        | 11/58 [00:21<01:31,  1.96s/it]

Loss: 0.2723861336708069


Epoch 3:  21%|██        | 12/58 [00:23<01:30,  1.96s/it]

Loss: 0.3526747524738312


Epoch 3:  22%|██▏       | 13/58 [00:25<01:27,  1.96s/it]

Loss: 0.0464298278093338


Epoch 3:  24%|██▍       | 14/58 [00:27<01:26,  1.96s/it]

Loss: 0.5455335974693298


Epoch 3:  26%|██▌       | 15/58 [00:29<01:24,  1.96s/it]

Loss: 0.20612874627113342


Epoch 3:  28%|██▊       | 16/58 [00:31<01:22,  1.96s/it]

Loss: 0.19738931953907013


Epoch 3:  29%|██▉       | 17/58 [00:33<01:20,  1.96s/it]

Loss: 0.2842678725719452


Epoch 3:  31%|███       | 18/58 [00:35<01:18,  1.96s/it]

Loss: 0.6344041228294373


Epoch 3:  33%|███▎      | 19/58 [00:37<01:16,  1.96s/it]

Loss: 0.3306114971637726


Epoch 3:  34%|███▍      | 20/58 [00:39<01:14,  1.96s/it]

Loss: 0.6644992828369141


Epoch 3:  36%|███▌      | 21/58 [00:41<01:12,  1.96s/it]

Loss: 0.44163915514945984


Epoch 3:  38%|███▊      | 22/58 [00:43<01:10,  1.96s/it]

Loss: 0.5635349750518799


Epoch 3:  40%|███▉      | 23/58 [00:44<01:08,  1.96s/it]

Loss: 0.3812696933746338


Epoch 3:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.1454026699066162


Epoch 3:  43%|████▎     | 25/58 [00:48<01:04,  1.95s/it]

Loss: 0.7640582919120789


Epoch 3:  45%|████▍     | 26/58 [00:50<01:02,  1.95s/it]

Loss: 0.6182737350463867


Epoch 3:  47%|████▋     | 27/58 [00:52<01:00,  1.95s/it]

Loss: 0.16014796495437622


Epoch 3:  48%|████▊     | 28/58 [00:54<00:58,  1.95s/it]

Loss: 0.39838480949401855


Epoch 3:  50%|█████     | 29/58 [00:56<00:56,  1.95s/it]

Loss: 0.6023502349853516


Epoch 3:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.45042508840560913


Epoch 3:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 0.11234259605407715


Epoch 3:  55%|█████▌    | 32/58 [01:02<00:50,  1.95s/it]

Loss: 0.29172903299331665


Epoch 3:  57%|█████▋    | 33/58 [01:04<00:48,  1.95s/it]

Loss: 0.6437633037567139


Epoch 3:  59%|█████▊    | 34/58 [01:06<00:46,  1.95s/it]

Loss: 0.5463245511054993


Epoch 3:  60%|██████    | 35/58 [01:08<00:44,  1.95s/it]

Loss: 0.4369821548461914


Epoch 3:  62%|██████▏   | 36/58 [01:10<00:42,  1.95s/it]

Loss: 0.7372825741767883


Epoch 3:  64%|██████▍   | 37/58 [01:12<00:40,  1.95s/it]

Loss: 0.2659824788570404


Epoch 3:  66%|██████▌   | 38/58 [01:14<00:38,  1.95s/it]

Loss: 0.16780336201190948


Epoch 3:  67%|██████▋   | 39/58 [01:16<00:36,  1.95s/it]

Loss: 0.4568471312522888


Epoch 3:  69%|██████▉   | 40/58 [01:18<00:35,  1.95s/it]

Loss: 0.339406281709671


Epoch 3:  71%|███████   | 41/58 [01:20<00:33,  1.95s/it]

Loss: 0.45672744512557983


Epoch 3:  72%|███████▏  | 42/58 [01:22<00:31,  1.95s/it]

Loss: 0.38483792543411255


Epoch 3:  74%|███████▍  | 43/58 [01:23<00:29,  1.95s/it]

Loss: 1.1717464923858643


Epoch 3:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.2301466017961502


Epoch 3:  78%|███████▊  | 45/58 [01:27<00:25,  1.95s/it]

Loss: 0.6922109723091125


Epoch 3:  79%|███████▉  | 46/58 [01:29<00:23,  1.95s/it]

Loss: 0.4618515074253082


Epoch 3:  81%|████████  | 47/58 [01:31<00:21,  1.95s/it]

Loss: 0.8365387320518494


Epoch 3:  83%|████████▎ | 48/58 [01:33<00:19,  1.95s/it]

Loss: 0.2417476773262024


Epoch 3:  84%|████████▍ | 49/58 [01:35<00:17,  1.95s/it]

Loss: 0.3590749204158783


Epoch 3:  86%|████████▌ | 50/58 [01:37<00:15,  1.95s/it]

Loss: 0.3705160915851593


Epoch 3:  88%|████████▊ | 51/58 [01:39<00:13,  1.95s/it]

Loss: 0.209654301404953


Epoch 3:  90%|████████▉ | 52/58 [01:41<00:11,  1.95s/it]

Loss: 0.6098619699478149


Epoch 3:  91%|█████████▏| 53/58 [01:43<00:09,  1.95s/it]

Loss: 0.8444563150405884


Epoch 3:  93%|█████████▎| 54/58 [01:45<00:07,  1.95s/it]

Loss: 0.5620706081390381


Epoch 3:  95%|█████████▍| 55/58 [01:47<00:05,  1.95s/it]

Loss: 0.23062846064567566


Epoch 3:  97%|█████████▋| 56/58 [01:49<00:03,  1.95s/it]

Loss: 0.023850340396165848


Epoch 3:  98%|█████████▊| 57/58 [01:51<00:01,  1.94s/it]

Loss: 0.07293781638145447


Epoch 3: 100%|██████████| 58/58 [01:51<00:00,  1.92s/it]

Loss: 0.002071327529847622





Epoch 3 Validation Accuracy: 0.8620689655172413, F1-macro: 0.6282051282051282


Epoch 4:   2%|▏         | 1/58 [00:01<01:51,  1.95s/it]

Loss: 0.6596744060516357


Epoch 4:   3%|▎         | 2/58 [00:03<01:49,  1.95s/it]

Loss: 0.3572586178779602


Epoch 4:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 0.576952338218689


Epoch 4:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 0.425690233707428


Epoch 4:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 0.07587546110153198


Epoch 4:  10%|█         | 6/58 [00:11<01:41,  1.94s/it]

Loss: 0.332665354013443


Epoch 4:  12%|█▏        | 7/58 [00:13<01:39,  1.94s/it]

Loss: 0.4664117097854614


Epoch 4:  14%|█▍        | 8/58 [00:15<01:37,  1.94s/it]

Loss: 0.12197357416152954


Epoch 4:  16%|█▌        | 9/58 [00:17<01:35,  1.94s/it]

Loss: 0.16167576611042023


Epoch 4:  17%|█▋        | 10/58 [00:19<01:33,  1.94s/it]

Loss: 0.054793380200862885


Epoch 4:  19%|█▉        | 11/58 [00:21<01:31,  1.94s/it]

Loss: 0.023911423981189728


Epoch 4:  21%|██        | 12/58 [00:23<01:29,  1.94s/it]

Loss: 0.42930901050567627


Epoch 4:  22%|██▏       | 13/58 [00:25<01:27,  1.94s/it]

Loss: 0.31346023082733154


Epoch 4:  24%|██▍       | 14/58 [00:27<01:25,  1.94s/it]

Loss: 0.37900006771087646


Epoch 4:  26%|██▌       | 15/58 [00:29<01:23,  1.94s/it]

Loss: 0.29913461208343506


Epoch 4:  28%|██▊       | 16/58 [00:31<01:21,  1.94s/it]

Loss: 0.5402259230613708


Epoch 4:  29%|██▉       | 17/58 [00:33<01:19,  1.95s/it]

Loss: 0.23928020894527435


Epoch 4:  31%|███       | 18/58 [00:35<01:17,  1.94s/it]

Loss: 0.19006642699241638


Epoch 4:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 0.2302553951740265


Epoch 4:  34%|███▍      | 20/58 [00:38<01:13,  1.95s/it]

Loss: 0.17336896061897278


Epoch 4:  36%|███▌      | 21/58 [00:40<01:12,  1.95s/it]

Loss: 0.4544551968574524


Epoch 4:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 0.014247067272663116


Epoch 4:  40%|███▉      | 23/58 [00:44<01:08,  1.95s/it]

Loss: 0.42228177189826965


Epoch 4:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.23230774700641632


Epoch 4:  43%|████▎     | 25/58 [00:48<01:04,  1.95s/it]

Loss: 0.3399079144001007


Epoch 4:  45%|████▍     | 26/58 [00:50<01:02,  1.95s/it]

Loss: 0.30443277955055237


Epoch 4:  47%|████▋     | 27/58 [00:52<01:00,  1.95s/it]

Loss: 0.2557704746723175


Epoch 4:  48%|████▊     | 28/58 [00:54<00:58,  1.95s/it]

Loss: 0.4715060293674469


Epoch 4:  50%|█████     | 29/58 [00:56<00:56,  1.95s/it]

Loss: 0.2819488048553467


Epoch 4:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.3294944167137146


Epoch 4:  53%|█████▎    | 31/58 [01:00<00:52,  1.94s/it]

Loss: 0.18493422865867615


Epoch 4:  55%|█████▌    | 32/58 [01:02<00:50,  1.94s/it]

Loss: 0.11972981691360474


Epoch 4:  57%|█████▋    | 33/58 [01:04<00:48,  1.94s/it]

Loss: 0.03791134059429169


Epoch 4:  59%|█████▊    | 34/58 [01:06<00:46,  1.94s/it]

Loss: 0.04116737097501755


Epoch 4:  60%|██████    | 35/58 [01:08<00:44,  1.94s/it]

Loss: 0.1204739585518837


Epoch 4:  62%|██████▏   | 36/58 [01:10<00:42,  1.94s/it]

Loss: 0.42062586545944214


Epoch 4:  64%|██████▍   | 37/58 [01:11<00:40,  1.94s/it]

Loss: 0.1798158884048462


Epoch 4:  66%|██████▌   | 38/58 [01:13<00:38,  1.94s/it]

Loss: 0.3295385539531708


Epoch 4:  67%|██████▋   | 39/58 [01:15<00:36,  1.94s/it]

Loss: 0.5277460217475891


Epoch 4:  69%|██████▉   | 40/58 [01:17<00:34,  1.94s/it]

Loss: 0.9796425104141235


Epoch 4:  71%|███████   | 41/58 [01:19<00:33,  1.94s/it]

Loss: 0.19759728014469147


Epoch 4:  72%|███████▏  | 42/58 [01:21<00:31,  1.94s/it]

Loss: 1.0998557806015015


Epoch 4:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.5283557772636414


Epoch 4:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.6473812460899353


Epoch 4:  78%|███████▊  | 45/58 [01:27<00:25,  1.95s/it]

Loss: 0.3642162084579468


Epoch 4:  79%|███████▉  | 46/58 [01:29<00:23,  1.95s/it]

Loss: 0.5255529880523682


Epoch 4:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.23785671591758728


Epoch 4:  83%|████████▎ | 48/58 [01:33<00:19,  1.95s/it]

Loss: 0.5958970189094543


Epoch 4:  84%|████████▍ | 49/58 [01:35<00:17,  1.95s/it]

Loss: 0.5103540420532227


Epoch 4:  86%|████████▌ | 50/58 [01:37<00:15,  1.95s/it]

Loss: 0.541925311088562


Epoch 4:  88%|████████▊ | 51/58 [01:39<00:13,  1.95s/it]

Loss: 0.569609522819519


Epoch 4:  90%|████████▉ | 52/58 [01:41<00:11,  1.95s/it]

Loss: 0.1795065999031067


Epoch 4:  91%|█████████▏| 53/58 [01:43<00:09,  1.95s/it]

Loss: 0.14948029816150665


Epoch 4:  93%|█████████▎| 54/58 [01:45<00:07,  1.95s/it]

Loss: 0.16718313097953796


Epoch 4:  95%|█████████▍| 55/58 [01:46<00:05,  1.95s/it]

Loss: 0.5265937447547913


Epoch 4:  97%|█████████▋| 56/58 [01:48<00:03,  1.95s/it]

Loss: 0.4824461340904236


Epoch 4:  98%|█████████▊| 57/58 [01:50<00:01,  1.95s/it]

Loss: 0.3176005184650421


Epoch 4: 100%|██████████| 58/58 [01:51<00:00,  1.92s/it]

Loss: 0.0854668840765953





Epoch 4 Validation Accuracy: 0.8719211822660099, F1-macro: 0.6037537537537538


Epoch 5:   2%|▏         | 1/58 [00:01<01:50,  1.94s/it]

Loss: 0.322506844997406


Epoch 5:   3%|▎         | 2/58 [00:03<01:48,  1.95s/it]

Loss: 0.4663240909576416


Epoch 5:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 0.28578075766563416


Epoch 5:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 0.06809434294700623


Epoch 5:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 0.19628648459911346


Epoch 5:  10%|█         | 6/58 [00:11<01:41,  1.95s/it]

Loss: 0.342151939868927


Epoch 5:  12%|█▏        | 7/58 [00:13<01:39,  1.95s/it]

Loss: 0.36021071672439575


Epoch 5:  14%|█▍        | 8/58 [00:15<01:37,  1.95s/it]

Loss: 0.2741885781288147


Epoch 5:  16%|█▌        | 9/58 [00:17<01:35,  1.95s/it]

Loss: 0.056199219077825546


Epoch 5:  17%|█▋        | 10/58 [00:19<01:33,  1.95s/it]

Loss: 0.3711503744125366


Epoch 5:  19%|█▉        | 11/58 [00:21<01:31,  1.95s/it]

Loss: 0.34067419171333313


Epoch 5:  21%|██        | 12/58 [00:23<01:29,  1.95s/it]

Loss: 0.31590062379837036


Epoch 5:  22%|██▏       | 13/58 [00:25<01:27,  1.95s/it]

Loss: 0.2402966320514679


Epoch 5:  24%|██▍       | 14/58 [00:27<01:25,  1.95s/it]

Loss: 0.40416303277015686


Epoch 5:  26%|██▌       | 15/58 [00:29<01:23,  1.95s/it]

Loss: 0.2125750333070755


Epoch 5:  28%|██▊       | 16/58 [00:31<01:21,  1.95s/it]

Loss: 0.21096482872962952


Epoch 5:  29%|██▉       | 17/58 [00:33<01:19,  1.95s/it]

Loss: 0.48467037081718445


Epoch 5:  31%|███       | 18/58 [00:35<01:18,  1.95s/it]

Loss: 0.19765101373195648


Epoch 5:  33%|███▎      | 19/58 [00:37<01:16,  1.95s/it]

Loss: 0.7983646392822266


Epoch 5:  34%|███▍      | 20/58 [00:39<01:14,  1.95s/it]

Loss: 0.26221030950546265


Epoch 5:  36%|███▌      | 21/58 [00:40<01:12,  1.95s/it]

Loss: 0.8978036046028137


Epoch 5:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 0.3438029885292053


Epoch 5:  40%|███▉      | 23/58 [00:44<01:08,  1.95s/it]

Loss: 0.6625804305076599


Epoch 5:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.5059599876403809


Epoch 5:  43%|████▎     | 25/58 [00:48<01:04,  1.95s/it]

Loss: 0.9593782424926758


Epoch 5:  45%|████▍     | 26/58 [00:50<01:02,  1.96s/it]

Loss: 0.4196731448173523


Epoch 5:  47%|████▋     | 27/58 [00:52<01:00,  1.95s/it]

Loss: 0.2009747475385666


Epoch 5:  48%|████▊     | 28/58 [00:54<00:58,  1.96s/it]

Loss: 0.5306682586669922


Epoch 5:  50%|█████     | 29/58 [00:56<00:56,  1.96s/it]

Loss: 0.37773510813713074


Epoch 5:  52%|█████▏    | 30/58 [00:58<00:54,  1.96s/it]

Loss: 0.19184866547584534


Epoch 5:  53%|█████▎    | 31/58 [01:00<00:52,  1.96s/it]

Loss: 0.7040073871612549


Epoch 5:  55%|█████▌    | 32/58 [01:02<00:50,  1.96s/it]

Loss: 0.18164733052253723


Epoch 5:  57%|█████▋    | 33/58 [01:04<00:49,  1.96s/it]

Loss: 0.711841344833374


Epoch 5:  59%|█████▊    | 34/58 [01:06<00:47,  1.96s/it]

Loss: 0.25669264793395996


Epoch 5:  60%|██████    | 35/58 [01:08<00:45,  1.96s/it]

Loss: 0.17507657408714294


Epoch 5:  62%|██████▏   | 36/58 [01:10<00:43,  1.96s/it]

Loss: 0.5342174172401428


Epoch 5:  64%|██████▍   | 37/58 [01:12<00:41,  1.96s/it]

Loss: 0.49703913927078247


Epoch 5:  66%|██████▌   | 38/58 [01:14<00:39,  1.96s/it]

Loss: 0.25666430592536926


Epoch 5:  67%|██████▋   | 39/58 [01:16<00:37,  1.96s/it]

Loss: 0.349263459444046


Epoch 5:  69%|██████▉   | 40/58 [01:18<00:35,  1.95s/it]

Loss: 0.5499904751777649


Epoch 5:  71%|███████   | 41/58 [01:20<00:33,  1.95s/it]

Loss: 0.8622071146965027


Epoch 5:  72%|███████▏  | 42/58 [01:22<00:31,  1.95s/it]

Loss: 0.13941963016986847


Epoch 5:  74%|███████▍  | 43/58 [01:23<00:29,  1.95s/it]

Loss: 0.4592769742012024


Epoch 5:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.22441880404949188


Epoch 5:  78%|███████▊  | 45/58 [01:27<00:25,  1.95s/it]

Loss: 0.13354510068893433


Epoch 5:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 0.14880052208900452


Epoch 5:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.3626568615436554


Epoch 5:  83%|████████▎ | 48/58 [01:33<00:19,  1.94s/it]

Loss: 0.05750938504934311


Epoch 5:  84%|████████▍ | 49/58 [01:35<00:17,  1.94s/it]

Loss: 0.13646268844604492


Epoch 5:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.2593265771865845


Epoch 5:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.07670285552740097


Epoch 5:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.35883307456970215


Epoch 5:  91%|█████████▏| 53/58 [01:43<00:09,  1.94s/it]

Loss: 0.2517825961112976


Epoch 5:  93%|█████████▎| 54/58 [01:45<00:07,  1.94s/it]

Loss: 0.2530379295349121


Epoch 5:  95%|█████████▍| 55/58 [01:47<00:05,  1.94s/it]

Loss: 0.2950349450111389


Epoch 5:  97%|█████████▋| 56/58 [01:49<00:03,  1.94s/it]

Loss: 0.20865584909915924


Epoch 5:  98%|█████████▊| 57/58 [01:51<00:01,  1.93s/it]

Loss: 0.4091988205909729


Epoch 5: 100%|██████████| 58/58 [01:51<00:00,  1.92s/it]

Loss: 0.34048011898994446





Epoch 5 Validation Accuracy: 0.6009852216748769, F1-macro: 0.5435669673837613


Epoch 6:   2%|▏         | 1/58 [00:01<01:49,  1.93s/it]

Loss: 0.48448535799980164


Epoch 6:   3%|▎         | 2/58 [00:03<01:48,  1.93s/it]

Loss: 0.28761082887649536


Epoch 6:   5%|▌         | 3/58 [00:05<01:46,  1.94s/it]

Loss: 0.2751504182815552


Epoch 6:   7%|▋         | 4/58 [00:07<01:44,  1.94s/it]

Loss: 0.28533145785331726


Epoch 6:   9%|▊         | 5/58 [00:09<01:42,  1.93s/it]

Loss: 0.23519454896450043


Epoch 6:  10%|█         | 6/58 [00:11<01:40,  1.94s/it]

Loss: 1.0358198881149292


Epoch 6:  12%|█▏        | 7/58 [00:13<01:38,  1.94s/it]

Loss: 0.28808605670928955


Epoch 6:  14%|█▍        | 8/58 [00:15<01:36,  1.94s/it]

Loss: 0.5102371573448181


Epoch 6:  16%|█▌        | 9/58 [00:17<01:35,  1.94s/it]

Loss: 0.34379810094833374


Epoch 6:  17%|█▋        | 10/58 [00:19<01:33,  1.94s/it]

Loss: 0.5936598181724548


Epoch 6:  19%|█▉        | 11/58 [00:21<01:31,  1.94s/it]

Loss: 0.15738335251808167


Epoch 6:  21%|██        | 12/58 [00:23<01:29,  1.94s/it]

Loss: 0.3018108308315277


Epoch 6:  22%|██▏       | 13/58 [00:25<01:27,  1.94s/it]

Loss: 0.2643575668334961


Epoch 6:  24%|██▍       | 14/58 [00:27<01:25,  1.94s/it]

Loss: 0.32286715507507324


Epoch 6:  26%|██▌       | 15/58 [00:29<01:23,  1.94s/it]

Loss: 0.937076210975647


Epoch 6:  28%|██▊       | 16/58 [00:31<01:21,  1.95s/it]

Loss: 0.20195429027080536


Epoch 6:  29%|██▉       | 17/58 [00:32<01:19,  1.94s/it]

Loss: 0.5517659187316895


Epoch 6:  31%|███       | 18/58 [00:34<01:17,  1.95s/it]

Loss: 0.18104951083660126


Epoch 6:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 0.21776819229125977


Epoch 6:  34%|███▍      | 20/58 [00:38<01:13,  1.94s/it]

Loss: 0.9074355363845825


Epoch 6:  36%|███▌      | 21/58 [00:40<01:11,  1.94s/it]

Loss: 0.37539243698120117


Epoch 6:  38%|███▊      | 22/58 [00:42<01:09,  1.94s/it]

Loss: 0.6149697303771973


Epoch 6:  40%|███▉      | 23/58 [00:44<01:08,  1.94s/it]

Loss: 0.5144698023796082


Epoch 6:  41%|████▏     | 24/58 [00:46<01:06,  1.94s/it]

Loss: 0.4874318540096283


Epoch 6:  43%|████▎     | 25/58 [00:48<01:04,  1.94s/it]

Loss: 0.435801237821579


Epoch 6:  45%|████▍     | 26/58 [00:50<01:02,  1.94s/it]

Loss: 0.1164877712726593


Epoch 6:  47%|████▋     | 27/58 [00:52<01:00,  1.94s/it]

Loss: 0.194224551320076


Epoch 6:  48%|████▊     | 28/58 [00:54<00:58,  1.94s/it]

Loss: 0.3420568108558655


Epoch 6:  50%|█████     | 29/58 [00:56<00:56,  1.94s/it]

Loss: 0.1612992286682129


Epoch 6:  52%|█████▏    | 30/58 [00:58<00:54,  1.94s/it]

Loss: 0.6542655229568481


Epoch 6:  53%|█████▎    | 31/58 [01:00<00:52,  1.94s/it]

Loss: 0.3981376886367798


Epoch 6:  55%|█████▌    | 32/58 [01:02<00:50,  1.94s/it]

Loss: 0.44061294198036194


Epoch 6:  57%|█████▋    | 33/58 [01:04<00:48,  1.94s/it]

Loss: 0.813747763633728


Epoch 6:  59%|█████▊    | 34/58 [01:06<00:46,  1.94s/it]

Loss: 0.22033095359802246


Epoch 6:  60%|██████    | 35/58 [01:07<00:44,  1.94s/it]

Loss: 0.5526136159896851


Epoch 6:  62%|██████▏   | 36/58 [01:09<00:42,  1.94s/it]

Loss: 0.2175927460193634


Epoch 6:  64%|██████▍   | 37/58 [01:11<00:40,  1.95s/it]

Loss: 0.12556175887584686


Epoch 6:  66%|██████▌   | 38/58 [01:13<00:38,  1.95s/it]

Loss: 0.26442351937294006


Epoch 6:  67%|██████▋   | 39/58 [01:15<00:36,  1.95s/it]

Loss: 0.1925412118434906


Epoch 6:  69%|██████▉   | 40/58 [01:17<00:35,  1.95s/it]

Loss: 0.3312857151031494


Epoch 6:  71%|███████   | 41/58 [01:19<00:33,  1.95s/it]

Loss: 0.5843605995178223


Epoch 6:  72%|███████▏  | 42/58 [01:21<00:31,  1.95s/it]

Loss: 0.2821248173713684


Epoch 6:  74%|███████▍  | 43/58 [01:23<00:29,  1.95s/it]

Loss: 0.5827996134757996


Epoch 6:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.5327693819999695


Epoch 6:  78%|███████▊  | 45/58 [01:27<00:25,  1.95s/it]

Loss: 0.32114505767822266


Epoch 6:  79%|███████▉  | 46/58 [01:29<00:23,  1.95s/it]

Loss: 0.6212891936302185


Epoch 6:  81%|████████  | 47/58 [01:31<00:21,  1.95s/it]

Loss: 0.30685845017433167


Epoch 6:  83%|████████▎ | 48/58 [01:33<00:19,  1.95s/it]

Loss: 0.4519484341144562


Epoch 6:  84%|████████▍ | 49/58 [01:35<00:17,  1.95s/it]

Loss: 0.2688957154750824


Epoch 6:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.2149115949869156


Epoch 6:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.1261146068572998


Epoch 6:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.07282332330942154


Epoch 6:  91%|█████████▏| 53/58 [01:43<00:09,  1.94s/it]

Loss: 0.10702966153621674


Epoch 6:  93%|█████████▎| 54/58 [01:44<00:07,  1.94s/it]

Loss: 0.538499653339386


Epoch 6:  95%|█████████▍| 55/58 [01:46<00:05,  1.94s/it]

Loss: 0.2369641214609146


Epoch 6:  97%|█████████▋| 56/58 [01:48<00:03,  1.94s/it]

Loss: 0.38245096802711487


Epoch 6:  98%|█████████▊| 57/58 [01:50<00:01,  1.94s/it]

Loss: 0.3408973515033722


Epoch 6: 100%|██████████| 58/58 [01:51<00:00,  1.92s/it]

Loss: 1.4235196113586426





Epoch 6 Validation Accuracy: 0.8669950738916257, F1-macro: 0.7070393928056016


Epoch 7:   2%|▏         | 1/58 [00:01<01:50,  1.94s/it]

Loss: 0.5243138670921326


Epoch 7:   3%|▎         | 2/58 [00:03<01:49,  1.95s/it]

Loss: 0.40001773834228516


Epoch 7:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 0.3671640157699585


Epoch 7:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 0.5541096329689026


Epoch 7:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 0.28289374709129333


Epoch 7:  10%|█         | 6/58 [00:11<01:41,  1.95s/it]

Loss: 0.4250253438949585


Epoch 7:  12%|█▏        | 7/58 [00:13<01:39,  1.95s/it]

Loss: 0.3523319661617279


Epoch 7:  14%|█▍        | 8/58 [00:15<01:37,  1.94s/it]

Loss: 0.4181288778781891


Epoch 7:  16%|█▌        | 9/58 [00:17<01:35,  1.94s/it]

Loss: 0.29646971821784973


Epoch 7:  17%|█▋        | 10/58 [00:19<01:33,  1.95s/it]

Loss: 0.0025606867857277393


Epoch 7:  19%|█▉        | 11/58 [00:21<01:31,  1.95s/it]

Loss: 0.6827436089515686


Epoch 7:  21%|██        | 12/58 [00:23<01:29,  1.95s/it]

Loss: 0.28229841589927673


Epoch 7:  22%|██▏       | 13/58 [00:25<01:27,  1.95s/it]

Loss: 0.4513014853000641


Epoch 7:  24%|██▍       | 14/58 [00:27<01:25,  1.95s/it]

Loss: 0.4900665879249573


Epoch 7:  26%|██▌       | 15/58 [00:29<01:23,  1.95s/it]

Loss: 0.8260520696640015


Epoch 7:  28%|██▊       | 16/58 [00:31<01:21,  1.95s/it]

Loss: 0.38268372416496277


Epoch 7:  29%|██▉       | 17/58 [00:33<01:19,  1.95s/it]

Loss: 0.15284030139446259


Epoch 7:  31%|███       | 18/58 [00:35<01:17,  1.95s/it]

Loss: 0.2177354395389557


Epoch 7:  33%|███▎      | 19/58 [00:36<01:15,  1.95s/it]

Loss: 0.23687544465065002


Epoch 7:  34%|███▍      | 20/58 [00:38<01:13,  1.95s/it]

Loss: 0.36692339181900024


Epoch 7:  36%|███▌      | 21/58 [00:40<01:11,  1.95s/it]

Loss: 0.6378346681594849


Epoch 7:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 0.2504766881465912


Epoch 7:  40%|███▉      | 23/58 [00:44<01:08,  1.95s/it]

Loss: 0.31094905734062195


Epoch 7:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.10621372610330582


Epoch 7:  43%|████▎     | 25/58 [00:48<01:04,  1.94s/it]

Loss: 0.7336927056312561


Epoch 7:  45%|████▍     | 26/58 [00:50<01:02,  1.94s/it]

Loss: 0.8095418214797974


Epoch 7:  47%|████▋     | 27/58 [00:52<01:00,  1.94s/it]

Loss: 0.21876533329486847


Epoch 7:  48%|████▊     | 28/58 [00:54<00:58,  1.94s/it]

Loss: 0.577507495880127


Epoch 7:  50%|█████     | 29/58 [00:56<00:56,  1.94s/it]

Loss: 0.4156729280948639


Epoch 7:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.23371905088424683


Epoch 7:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 0.22281041741371155


Epoch 7:  55%|█████▌    | 32/58 [01:02<00:50,  1.94s/it]

Loss: 0.5523836016654968


Epoch 7:  57%|█████▋    | 33/58 [01:04<00:48,  1.94s/it]

Loss: 0.45493084192276


Epoch 7:  59%|█████▊    | 34/58 [01:06<00:46,  1.94s/it]

Loss: 0.26497742533683777


Epoch 7:  60%|██████    | 35/58 [01:08<00:44,  1.94s/it]

Loss: 0.46423256397247314


Epoch 7:  62%|██████▏   | 36/58 [01:10<00:42,  1.94s/it]

Loss: 0.271241694688797


Epoch 7:  64%|██████▍   | 37/58 [01:11<00:40,  1.94s/it]

Loss: 0.28592729568481445


Epoch 7:  66%|██████▌   | 38/58 [01:13<00:38,  1.94s/it]

Loss: 0.1065845713019371


Epoch 7:  67%|██████▋   | 39/58 [01:15<00:36,  1.94s/it]

Loss: 0.014329835772514343


Epoch 7:  69%|██████▉   | 40/58 [01:17<00:34,  1.94s/it]

Loss: 0.9129806756973267


Epoch 7:  71%|███████   | 41/58 [01:19<00:33,  1.94s/it]

Loss: 0.9427326321601868


Epoch 7:  72%|███████▏  | 42/58 [01:21<00:31,  1.94s/it]

Loss: 0.5972774028778076


Epoch 7:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.5051041841506958


Epoch 7:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.2527148425579071


Epoch 7:  78%|███████▊  | 45/58 [01:27<00:25,  1.94s/it]

Loss: 0.3164837658405304


Epoch 7:  79%|███████▉  | 46/58 [01:29<00:23,  1.95s/it]

Loss: 0.2547462284564972


Epoch 7:  81%|████████  | 47/58 [01:31<00:21,  1.95s/it]

Loss: 0.028261393308639526


Epoch 7:  83%|████████▎ | 48/58 [01:33<00:19,  1.95s/it]

Loss: 0.11204127222299576


Epoch 7:  84%|████████▍ | 49/58 [01:35<00:17,  1.95s/it]

Loss: 0.6391795873641968


Epoch 7:  86%|████████▌ | 50/58 [01:37<00:15,  1.95s/it]

Loss: 0.16935095191001892


Epoch 7:  88%|████████▊ | 51/58 [01:39<00:13,  1.95s/it]

Loss: 0.22897674143314362


Epoch 7:  90%|████████▉ | 52/58 [01:41<00:11,  1.95s/it]

Loss: 0.07813728600740433


Epoch 7:  91%|█████████▏| 53/58 [01:43<00:09,  1.95s/it]

Loss: 0.48140430450439453


Epoch 7:  93%|█████████▎| 54/58 [01:45<00:07,  1.95s/it]

Loss: 0.6325671672821045


Epoch 7:  95%|█████████▍| 55/58 [01:47<00:05,  1.95s/it]

Loss: 0.2869773507118225


Epoch 7:  97%|█████████▋| 56/58 [01:48<00:03,  1.94s/it]

Loss: 0.39603930711746216


Epoch 7:  98%|█████████▊| 57/58 [01:50<00:01,  1.94s/it]

Loss: 0.2854057848453522


Epoch 7: 100%|██████████| 58/58 [01:51<00:00,  1.92s/it]

Loss: 0.000992117915302515





Epoch 7 Validation Accuracy: 0.8522167487684729, F1-macro: 0.721765350877193


## 3. 모델 테스트

In [None]:
# Evaluate on the test set
test_metrics = evaluate(model1, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8725490196078431, F1-macro: 0.7225360954174513


In [None]:
distortion_types = [
    'all-or-nothing thinking', 'comparing and despairing',
    'disqualifying the positive', 'emotional reasoning', 'fortune telling',
    'labeling', 'magnification', 'mind reading', 'overgeneralizing',
    'should statements', 'mental filter', 'personalization and blaming'
]

In [None]:
results_df = pd.DataFrame({
    "distortion_type": distortion_types,
    "test_accuracy": [np.nan] * 12,
    "f1_macro": [np.nan] * 12
})

In [None]:
current_type = 'all-or-nothing thinking'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 2. Comparing and Despairing

In [None]:
# Add labels
data1_1_labels = list(data1['comparing and despairing'][data1_1.index])
# comparing and despairing은 data1에만 있음.

# Merging Data
data_encoded = data1_1_encoded
data_labels = data1_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model2 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model2.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model2.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model2(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model2, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:  12%|█▎        | 1/8 [00:01<00:13,  1.91s/it]

Loss: 3.1400747299194336


Epoch 1:  25%|██▌       | 2/8 [00:03<00:11,  1.93s/it]

Loss: 0.5828450918197632


Epoch 1:  38%|███▊      | 3/8 [00:05<00:09,  1.93s/it]

Loss: 0.18146011233329773


Epoch 1:  50%|█████     | 4/8 [00:07<00:07,  1.93s/it]

Loss: 0.46900343894958496


Epoch 1:  62%|██████▎   | 5/8 [00:09<00:05,  1.94s/it]

Loss: 0.5580494999885559


Epoch 1:  75%|███████▌  | 6/8 [00:11<00:03,  1.94s/it]

Loss: 0.4469476640224457


Epoch 1:  88%|████████▊ | 7/8 [00:13<00:01,  1.94s/it]

Loss: 1.2782621383666992


Epoch 1: 100%|██████████| 8/8 [00:13<00:00,  1.74s/it]

Loss: 0.0





Epoch 1 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 2:  12%|█▎        | 1/8 [00:01<00:13,  1.93s/it]

Loss: 0.6110848784446716


Epoch 2:  25%|██▌       | 2/8 [00:03<00:11,  1.94s/it]

Loss: 0.9117472767829895


Epoch 2:  38%|███▊      | 3/8 [00:05<00:09,  1.94s/it]

Loss: 1.0597869157791138


Epoch 2:  50%|█████     | 4/8 [00:07<00:07,  1.94s/it]

Loss: 2.23286660911981e-05


Epoch 2:  62%|██████▎   | 5/8 [00:09<00:05,  1.94s/it]

Loss: 0.13361521065235138


Epoch 2:  75%|███████▌  | 6/8 [00:11<00:03,  1.94s/it]

Loss: 0.43689674139022827


Epoch 2:  88%|████████▊ | 7/8 [00:13<00:01,  1.94s/it]

Loss: 0.5030040144920349


Epoch 2: 100%|██████████| 8/8 [00:13<00:00,  1.75s/it]

Loss: 0.7121627330780029





Epoch 2 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 3:  12%|█▎        | 1/8 [00:01<00:13,  1.93s/it]

Loss: 0.488451212644577


Epoch 3:  25%|██▌       | 2/8 [00:03<00:11,  1.94s/it]

Loss: 0.8204947113990784


Epoch 3:  38%|███▊      | 3/8 [00:05<00:09,  1.94s/it]

Loss: 0.0


Epoch 3:  50%|█████     | 4/8 [00:07<00:07,  1.94s/it]

Loss: 0.9301796555519104


Epoch 3:  62%|██████▎   | 5/8 [00:09<00:05,  1.94s/it]

Loss: 0.36463481187820435


Epoch 3:  75%|███████▌  | 6/8 [00:11<00:03,  1.94s/it]

Loss: 0.5457184314727783


Epoch 3:  88%|████████▊ | 7/8 [00:13<00:01,  1.94s/it]

Loss: 1.8626450382086546e-09


Epoch 3: 100%|██████████| 8/8 [00:13<00:00,  1.74s/it]

Loss: 0.0





Epoch 3 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 4:  12%|█▎        | 1/8 [00:01<00:13,  1.93s/it]

Loss: 0.3496825098991394


Epoch 4:  25%|██▌       | 2/8 [00:03<00:11,  1.94s/it]

Loss: 0.48968127369880676


Epoch 4:  38%|███▊      | 3/8 [00:05<00:09,  1.94s/it]

Loss: 0.38145846128463745


Epoch 4:  50%|█████     | 4/8 [00:07<00:07,  1.94s/it]

Loss: 0.4767157733440399


Epoch 4:  62%|██████▎   | 5/8 [00:09<00:05,  1.94s/it]

Loss: 0.22579634189605713


Epoch 4:  75%|███████▌  | 6/8 [00:11<00:03,  1.94s/it]

Loss: 0.2493652105331421


Epoch 4:  88%|████████▊ | 7/8 [00:13<00:01,  1.94s/it]

Loss: 0.21258723735809326


Epoch 4: 100%|██████████| 8/8 [00:13<00:00,  1.75s/it]

Loss: 0.5080733299255371





Epoch 4 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 5:  12%|█▎        | 1/8 [00:01<00:13,  1.93s/it]

Loss: 0.49594998359680176


Epoch 5:  25%|██▌       | 2/8 [00:03<00:11,  1.94s/it]

Loss: 4.824177608497848e-07


Epoch 5:  38%|███▊      | 3/8 [00:05<00:09,  1.94s/it]

Loss: 0.436614066362381


Epoch 5:  50%|█████     | 4/8 [00:07<00:07,  1.94s/it]

Loss: 0.9811663627624512


Epoch 5:  62%|██████▎   | 5/8 [00:09<00:05,  1.94s/it]

Loss: 0.8044140934944153


Epoch 5:  75%|███████▌  | 6/8 [00:11<00:03,  1.94s/it]

Loss: 0.0


Epoch 5:  88%|████████▊ | 7/8 [00:13<00:01,  1.94s/it]

Loss: 2.235172580355993e-08


Epoch 5: 100%|██████████| 8/8 [00:13<00:00,  1.74s/it]

Loss: 0.0





Epoch 5 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 6:  12%|█▎        | 1/8 [00:01<00:13,  1.93s/it]

Loss: 0.31534725427627563


Epoch 6:  25%|██▌       | 2/8 [00:03<00:11,  1.94s/it]

Loss: 0.4044690728187561


Epoch 6:  38%|███▊      | 3/8 [00:05<00:09,  1.94s/it]

Loss: 4.6937952902226243e-07


Epoch 6:  50%|█████     | 4/8 [00:07<00:07,  1.94s/it]

Loss: 0.6029065251350403


Epoch 6:  62%|██████▎   | 5/8 [00:09<00:05,  1.94s/it]

Loss: 0.3967027962207794


Epoch 6:  75%|███████▌  | 6/8 [00:11<00:03,  1.94s/it]

Loss: 5.960460569554016e-08


Epoch 6:  88%|████████▊ | 7/8 [00:13<00:01,  1.94s/it]

Loss: 0.276751309633255


Epoch 6: 100%|██████████| 8/8 [00:13<00:00,  1.74s/it]

Loss: 0.0029751500114798546





Epoch 6 Validation Accuracy: 0.9565217391304348, F1-macro: 0.4888888888888889


Epoch 7:  12%|█▎        | 1/8 [00:01<00:13,  1.93s/it]

Loss: 0.5555042028427124


Epoch 7:  25%|██▌       | 2/8 [00:03<00:11,  1.93s/it]

Loss: 0.24232031404972076


Epoch 7:  38%|███▊      | 3/8 [00:05<00:09,  1.94s/it]

Loss: 0.1432649940252304


Epoch 7:  50%|█████     | 4/8 [00:07<00:07,  1.94s/it]

Loss: 7.498712511733174e-05


Epoch 7:  62%|██████▎   | 5/8 [00:09<00:05,  1.94s/it]

Loss: 3.4720847907010466e-05


Epoch 7:  75%|███████▌  | 6/8 [00:11<00:03,  1.94s/it]

Loss: 0.004393632989376783


Epoch 7:  88%|████████▊ | 7/8 [00:13<00:01,  1.93s/it]

Loss: 0.0


Epoch 7: 100%|██████████| 8/8 [00:13<00:00,  1.74s/it]

Loss: 1.5492181777954102





Epoch 7 Validation Accuracy: 1.0, F1-macro: 1.0


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model2, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.92, F1-macro: 0.4791666666666667


In [None]:
current_type = 'comparing and despairing'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 3. Disqualifying the Positive

In [None]:
# Add labels
data1_1_labels = list(data1['disqualifying the positive'][data1_1.index])
data2_1_labels = list(data2['disqualifiying the positive'][data2_1.index]) # 자료에서 오타 있었음!
data3_1_labels = list(data3['disqualifying the positive'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model3 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model3.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model3.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model3(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model3, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/63 [00:01<01:58,  1.91s/it]

Loss: 2.279573917388916


Epoch 1:   3%|▎         | 2/63 [00:03<01:57,  1.93s/it]

Loss: 1.9934914112091064


Epoch 1:   5%|▍         | 3/63 [00:05<01:55,  1.93s/it]

Loss: 0.9882683753967285


Epoch 1:   6%|▋         | 4/63 [00:07<01:54,  1.93s/it]

Loss: 0.7294929027557373


Epoch 1:   8%|▊         | 5/63 [00:09<01:52,  1.94s/it]

Loss: 1.1850336790084839


Epoch 1:  10%|▉         | 6/63 [00:11<01:50,  1.94s/it]

Loss: 2.510990270820912e-05


Epoch 1:  11%|█         | 7/63 [00:13<01:48,  1.94s/it]

Loss: 0.0


Epoch 1:  13%|█▎        | 8/63 [00:15<01:46,  1.94s/it]

Loss: 0.07003438472747803


Epoch 1:  14%|█▍        | 9/63 [00:17<01:44,  1.94s/it]

Loss: 1.4715051651000977


Epoch 1:  16%|█▌        | 10/63 [00:19<01:42,  1.94s/it]

Loss: 1.158620834350586


Epoch 1:  17%|█▋        | 11/63 [00:21<01:41,  1.94s/it]

Loss: 0.09213083237409592


Epoch 1:  19%|█▉        | 12/63 [00:23<01:39,  1.94s/it]

Loss: 0.05929294973611832


Epoch 1:  21%|██        | 13/63 [00:25<01:37,  1.94s/it]

Loss: 5.103477815282531e-05


Epoch 1:  22%|██▏       | 14/63 [00:27<01:35,  1.94s/it]

Loss: 0.0


Epoch 1:  24%|██▍       | 15/63 [00:29<01:33,  1.94s/it]

Loss: 0.8095245957374573


Epoch 1:  25%|██▌       | 16/63 [00:31<01:31,  1.94s/it]

Loss: 0.0


Epoch 1:  27%|██▋       | 17/63 [00:32<01:29,  1.94s/it]

Loss: 0.0


Epoch 1:  29%|██▊       | 18/63 [00:34<01:27,  1.94s/it]

Loss: 0.0


Epoch 1:  30%|███       | 19/63 [00:36<01:25,  1.94s/it]

Loss: 0.0


Epoch 1:  32%|███▏      | 20/63 [00:38<01:23,  1.94s/it]

Loss: 1.014812005450949e-05


Epoch 1:  33%|███▎      | 21/63 [00:40<01:21,  1.94s/it]

Loss: 0.02174016647040844


Epoch 1:  35%|███▍      | 22/63 [00:42<01:19,  1.94s/it]

Loss: 1.3022358417510986


Epoch 1:  37%|███▋      | 23/63 [00:44<01:17,  1.94s/it]

Loss: 0.8460754156112671


Epoch 1:  38%|███▊      | 24/63 [00:46<01:15,  1.94s/it]

Loss: 0.0


Epoch 1:  40%|███▉      | 25/63 [00:48<01:13,  1.95s/it]

Loss: 0.4391113221645355


Epoch 1:  41%|████▏     | 26/63 [00:50<01:11,  1.95s/it]

Loss: 3.054717865325074e-07


Epoch 1:  43%|████▎     | 27/63 [00:52<01:10,  1.95s/it]

Loss: 0.09501364082098007


Epoch 1:  44%|████▍     | 28/63 [00:54<01:08,  1.94s/it]

Loss: 0.0006302495021373034


Epoch 1:  46%|████▌     | 29/63 [00:56<01:06,  1.94s/it]

Loss: 1.0127657333214302e-05


Epoch 1:  48%|████▊     | 30/63 [00:58<01:04,  1.94s/it]

Loss: 0.033527590334415436


Epoch 1:  49%|████▉     | 31/63 [01:00<01:02,  1.94s/it]

Loss: 0.00160736043471843


Epoch 1:  51%|█████     | 32/63 [01:02<01:00,  1.94s/it]

Loss: 0.4149492681026459


Epoch 1:  52%|█████▏    | 33/63 [01:04<00:58,  1.94s/it]

Loss: 0.4089888334274292


Epoch 1:  54%|█████▍    | 34/63 [01:06<00:56,  1.95s/it]

Loss: 0.23339830338954926


Epoch 1:  56%|█████▌    | 35/63 [01:07<00:54,  1.95s/it]

Loss: 1.7998856492340565e-05


Epoch 1:  57%|█████▋    | 36/63 [01:09<00:52,  1.95s/it]

Loss: 0.37299132347106934


Epoch 1:  59%|█████▊    | 37/63 [01:11<00:50,  1.95s/it]

Loss: 2.854172998922877e-05


Epoch 1:  60%|██████    | 38/63 [01:13<00:48,  1.95s/it]

Loss: 1.12391197681427


Epoch 1:  62%|██████▏   | 39/63 [01:15<00:46,  1.95s/it]

Loss: 0.6173432469367981


Epoch 1:  63%|██████▎   | 40/63 [01:17<00:44,  1.95s/it]

Loss: 2.2287616957328282e-05


Epoch 1:  65%|██████▌   | 41/63 [01:19<00:42,  1.95s/it]

Loss: 4.360880484455265e-05


Epoch 1:  67%|██████▋   | 42/63 [01:21<00:40,  1.95s/it]

Loss: 0.20442670583724976


Epoch 1:  68%|██████▊   | 43/63 [01:23<00:38,  1.94s/it]

Loss: 0.24034008383750916


Epoch 1:  70%|██████▉   | 44/63 [01:25<00:36,  1.95s/it]

Loss: 0.9293310046195984


Epoch 1:  71%|███████▏  | 45/63 [01:27<00:35,  1.94s/it]

Loss: 0.3996698260307312


Epoch 1:  73%|███████▎  | 46/63 [01:29<00:33,  1.95s/it]

Loss: 0.0007610290194861591


Epoch 1:  75%|███████▍  | 47/63 [01:31<00:31,  1.95s/it]

Loss: 0.24103306233882904


Epoch 1:  76%|███████▌  | 48/63 [01:33<00:29,  1.95s/it]

Loss: 0.6916902661323547


Epoch 1:  78%|███████▊  | 49/63 [01:35<00:27,  1.95s/it]

Loss: 0.41217559576034546


Epoch 1:  79%|███████▉  | 50/63 [01:37<00:25,  1.94s/it]

Loss: 0.26664477586746216


Epoch 1:  81%|████████  | 51/63 [01:39<00:23,  1.95s/it]

Loss: 0.06102760136127472


Epoch 1:  83%|████████▎ | 52/63 [01:41<00:21,  1.94s/it]

Loss: 0.1393040120601654


Epoch 1:  84%|████████▍ | 53/63 [01:43<00:19,  1.95s/it]

Loss: 0.002859368920326233


Epoch 1:  86%|████████▌ | 54/63 [01:44<00:17,  1.94s/it]

Loss: 0.8854067921638489


Epoch 1:  87%|████████▋ | 55/63 [01:46<00:15,  1.94s/it]

Loss: 4.392435585032217e-05


Epoch 1:  89%|████████▉ | 56/63 [01:48<00:13,  1.94s/it]

Loss: 0.06912898272275925


Epoch 1:  90%|█████████ | 57/63 [01:50<00:11,  1.94s/it]

Loss: 0.8465149402618408


Epoch 1:  92%|█████████▏| 58/63 [01:52<00:09,  1.94s/it]

Loss: 0.03875114396214485


Epoch 1:  94%|█████████▎| 59/63 [01:54<00:07,  1.94s/it]

Loss: 1.0058251120881323e-07


Epoch 1:  95%|█████████▌| 60/63 [01:56<00:05,  1.94s/it]

Loss: 3.7252860352054995e-08


Epoch 1:  97%|█████████▋| 61/63 [01:58<00:03,  1.94s/it]

Loss: 2.248051032438525e-06


Epoch 1:  98%|█████████▊| 62/63 [02:00<00:01,  1.94s/it]

Loss: 0.0


Epoch 1: 100%|██████████| 63/63 [02:01<00:00,  1.93s/it]

Loss: 8.514947857918287e-09





Epoch 1 Validation Accuracy: 0.9901477832512315, F1-macro: 0.4975247524752475


Epoch 2:   2%|▏         | 1/63 [00:01<02:00,  1.94s/it]

Loss: 0.0


Epoch 2:   3%|▎         | 2/63 [00:03<01:58,  1.94s/it]

Loss: 0.0014612249797210097


Epoch 2:   5%|▍         | 3/63 [00:05<01:56,  1.94s/it]

Loss: 1.062349557876587


Epoch 2:   6%|▋         | 4/63 [00:07<01:54,  1.94s/it]

Loss: 0.19811956584453583


Epoch 2:   8%|▊         | 5/63 [00:09<01:52,  1.94s/it]

Loss: 0.0


Epoch 2:  10%|▉         | 6/63 [00:11<01:50,  1.94s/it]

Loss: 0.0


Epoch 2:  11%|█         | 7/63 [00:13<01:48,  1.95s/it]

Loss: 0.8569377660751343


Epoch 2:  13%|█▎        | 8/63 [00:15<01:47,  1.95s/it]

Loss: 0.4434083104133606


Epoch 2:  14%|█▍        | 9/63 [00:17<01:45,  1.95s/it]

Loss: 0.3659966289997101


Epoch 2:  16%|█▌        | 10/63 [00:19<01:43,  1.95s/it]

Loss: 2.7620649234449957e-06


Epoch 2:  17%|█▋        | 11/63 [00:21<01:41,  1.95s/it]

Loss: 0.20610332489013672


Epoch 2:  19%|█▉        | 12/63 [00:23<01:39,  1.95s/it]

Loss: 4.3107989768031985e-05


Epoch 2:  21%|██        | 13/63 [00:25<01:37,  1.94s/it]

Loss: 0.3063558340072632


Epoch 2:  22%|██▏       | 14/63 [00:27<01:35,  1.94s/it]

Loss: 8.661197625770001e-07


Epoch 2:  24%|██▍       | 15/63 [00:29<01:33,  1.94s/it]

Loss: 6.295123966992833e-06


Epoch 2:  25%|██▌       | 16/63 [00:31<01:31,  1.94s/it]

Loss: 0.040243156254291534


Epoch 2:  27%|██▋       | 17/63 [00:33<01:29,  1.94s/it]

Loss: 0.4071715176105499


Epoch 2:  29%|██▊       | 18/63 [00:34<01:27,  1.94s/it]

Loss: 0.03198960795998573


Epoch 2:  30%|███       | 19/63 [00:36<01:25,  1.94s/it]

Loss: 9.015057003125548e-07


Epoch 2:  32%|███▏      | 20/63 [00:38<01:23,  1.94s/it]

Loss: 0.08188842236995697


Epoch 2:  33%|███▎      | 21/63 [00:40<01:21,  1.94s/it]

Loss: 0.1087769865989685


Epoch 2:  35%|███▍      | 22/63 [00:42<01:19,  1.94s/it]

Loss: 0.0004109180299565196


Epoch 2:  37%|███▋      | 23/63 [00:44<01:17,  1.94s/it]

Loss: 0.0001313836983172223


Epoch 2:  38%|███▊      | 24/63 [00:46<01:15,  1.94s/it]

Loss: 0.006130803842097521


Epoch 2:  40%|███▉      | 25/63 [00:48<01:13,  1.94s/it]

Loss: 0.0


Epoch 2:  41%|████▏     | 26/63 [00:50<01:11,  1.94s/it]

Loss: 4.880054120803834e-07


Epoch 2:  43%|████▎     | 27/63 [00:52<01:10,  1.94s/it]

Loss: 3.073334369219083e-07


Epoch 2:  44%|████▍     | 28/63 [00:54<01:08,  1.94s/it]

Loss: 0.7624544501304626


Epoch 2:  46%|████▌     | 29/63 [00:56<01:06,  1.94s/it]

Loss: 1.2340401411056519


Epoch 2:  48%|████▊     | 30/63 [00:58<01:04,  1.94s/it]

Loss: 0.19949497282505035


Epoch 2:  49%|████▉     | 31/63 [01:00<01:02,  1.94s/it]

Loss: 5.458459327201126e-06


Epoch 2:  51%|█████     | 32/63 [01:02<01:00,  1.94s/it]

Loss: 2.5614586775191128e-05


Epoch 2:  52%|█████▏    | 33/63 [01:04<00:58,  1.94s/it]

Loss: 0.4777096211910248


Epoch 2:  54%|█████▍    | 34/63 [01:06<00:56,  1.94s/it]

Loss: 0.16339486837387085


Epoch 2:  56%|█████▌    | 35/63 [01:08<00:54,  1.94s/it]

Loss: 0.5836958885192871


Epoch 2:  57%|█████▋    | 36/63 [01:10<00:52,  1.95s/it]

Loss: 0.00028728225152008235


Epoch 2:  59%|█████▊    | 37/63 [01:11<00:50,  1.95s/it]

Loss: 0.396247535943985


Epoch 2:  60%|██████    | 38/63 [01:13<00:48,  1.95s/it]

Loss: 0.0003907523932866752


Epoch 2:  62%|██████▏   | 39/63 [01:15<00:46,  1.95s/it]

Loss: 0.023643488064408302


Epoch 2:  63%|██████▎   | 40/63 [01:17<00:44,  1.95s/it]

Loss: 0.21503789722919464


Epoch 2:  65%|██████▌   | 41/63 [01:19<00:42,  1.95s/it]

Loss: 0.6611822843551636


Epoch 2:  67%|██████▋   | 42/63 [01:21<00:40,  1.95s/it]

Loss: 0.002382280072197318


Epoch 2:  68%|██████▊   | 43/63 [01:23<00:38,  1.95s/it]

Loss: 1.3511616998584941e-05


Epoch 2:  70%|██████▉   | 44/63 [01:25<00:36,  1.95s/it]

Loss: 0.003923224285244942


Epoch 2:  71%|███████▏  | 45/63 [01:27<00:35,  1.95s/it]

Loss: 0.09276648610830307


Epoch 2:  73%|███████▎  | 46/63 [01:29<00:33,  1.95s/it]

Loss: 0.0005158514832146466


Epoch 2:  75%|███████▍  | 47/63 [01:31<00:31,  1.94s/it]

Loss: 1.1734624649761827e-07


Epoch 2:  76%|███████▌  | 48/63 [01:33<00:29,  1.95s/it]

Loss: 2.048908953611317e-08


Epoch 2:  78%|███████▊  | 49/63 [01:35<00:27,  1.95s/it]

Loss: 0.6097781658172607


Epoch 2:  79%|███████▉  | 50/63 [01:37<00:25,  1.94s/it]

Loss: 4.898681709164521e-07


Epoch 2:  81%|████████  | 51/63 [01:39<00:23,  1.94s/it]

Loss: 0.02916601113975048


Epoch 2:  83%|████████▎ | 52/63 [01:41<00:21,  1.94s/it]

Loss: 0.40838301181793213


Epoch 2:  84%|████████▍ | 53/63 [01:43<00:19,  1.95s/it]

Loss: 0.059058837592601776


Epoch 2:  86%|████████▌ | 54/63 [01:45<00:17,  1.94s/it]

Loss: 0.1298559606075287


Epoch 2:  87%|████████▋ | 55/63 [01:46<00:15,  1.94s/it]

Loss: 0.2079148143529892


Epoch 2:  89%|████████▉ | 56/63 [01:48<00:13,  1.94s/it]

Loss: 0.4033694565296173


Epoch 2:  90%|█████████ | 57/63 [01:50<00:11,  1.95s/it]

Loss: 0.007315017748624086


Epoch 2:  92%|█████████▏| 58/63 [01:52<00:09,  1.95s/it]

Loss: 0.22855190932750702


Epoch 2:  94%|█████████▎| 59/63 [01:54<00:07,  1.94s/it]

Loss: 1.009523316497507e-06


Epoch 2:  95%|█████████▌| 60/63 [01:56<00:05,  1.94s/it]

Loss: 0.28359565138816833


Epoch 2:  97%|█████████▋| 61/63 [01:58<00:03,  1.94s/it]

Loss: 0.36501312255859375


Epoch 2:  98%|█████████▊| 62/63 [02:00<00:01,  1.94s/it]

Loss: 0.000924420019146055


Epoch 2: 100%|██████████| 63/63 [02:01<00:00,  1.93s/it]

Loss: 6.166868843138218e-06





Epoch 2 Validation Accuracy: 0.9901477832512315, F1-macro: 0.7475124378109452


Epoch 3:   2%|▏         | 1/63 [00:01<01:59,  1.93s/it]

Loss: 6.51924452199637e-08


Epoch 3:   3%|▎         | 2/63 [00:03<01:58,  1.94s/it]

Loss: 0.043341465294361115


Epoch 3:   5%|▍         | 3/63 [00:05<01:56,  1.94s/it]

Loss: 0.045802272856235504


Epoch 3:   6%|▋         | 4/63 [00:07<01:54,  1.94s/it]

Loss: 1.3038476254223497e-07


Epoch 3:   8%|▊         | 5/63 [00:09<01:52,  1.94s/it]

Loss: 4.99180941915256e-07


Epoch 3:  10%|▉         | 6/63 [00:11<01:50,  1.94s/it]

Loss: 2.4231137558672344e-06


Epoch 3:  11%|█         | 7/63 [00:13<01:48,  1.94s/it]

Loss: 0.010182376019656658


Epoch 3:  13%|█▎        | 8/63 [00:15<01:47,  1.95s/it]

Loss: 2.235172580355993e-08


Epoch 3:  14%|█▍        | 9/63 [00:17<01:45,  1.95s/it]

Loss: 0.6199249029159546


Epoch 3:  16%|█▌        | 10/63 [00:19<01:43,  1.95s/it]

Loss: 3.6321216612122953e-07


Epoch 3:  17%|█▋        | 11/63 [00:21<01:41,  1.95s/it]

Loss: 0.018349135294556618


Epoch 3:  19%|█▉        | 12/63 [00:23<01:39,  1.95s/it]

Loss: 0.08203952759504318


Epoch 3:  21%|██        | 13/63 [00:25<01:37,  1.95s/it]

Loss: 0.00286033283919096


Epoch 3:  22%|██▏       | 14/63 [00:27<01:35,  1.95s/it]

Loss: 0.00039356385241262615


Epoch 3:  24%|██▍       | 15/63 [00:29<01:33,  1.95s/it]

Loss: 6.127999654381711e-07


Epoch 3:  25%|██▌       | 16/63 [00:31<01:31,  1.94s/it]

Loss: 0.06068965420126915


Epoch 3:  27%|██▋       | 17/63 [00:33<01:29,  1.94s/it]

Loss: 0.00021470132924150676


Epoch 3:  29%|██▊       | 18/63 [00:35<01:27,  1.94s/it]

Loss: 0.04740242660045624


Epoch 3:  30%|███       | 19/63 [00:36<01:25,  1.94s/it]

Loss: 0.0009272423340007663


Epoch 3:  32%|███▏      | 20/63 [00:38<01:23,  1.94s/it]

Loss: 0.04331953823566437


Epoch 3:  33%|███▎      | 21/63 [00:40<01:21,  1.94s/it]

Loss: 7.703119626967236e-05


Epoch 3:  35%|███▍      | 22/63 [00:42<01:19,  1.94s/it]

Loss: 0.0001890139828901738


Epoch 3:  37%|███▋      | 23/63 [00:44<01:17,  1.94s/it]

Loss: 1.303851338008144e-08


Epoch 3:  38%|███▊      | 24/63 [00:46<01:15,  1.94s/it]

Loss: 0.24456565082073212


Epoch 3:  40%|███▉      | 25/63 [00:48<01:13,  1.94s/it]

Loss: 1.2386350363158272e-06


Epoch 3:  41%|████▏     | 26/63 [00:50<01:11,  1.94s/it]

Loss: 1.8626450382086546e-09


Epoch 3:  43%|████▎     | 27/63 [00:52<01:10,  1.94s/it]

Loss: 0.05846654251217842


Epoch 3:  44%|████▍     | 28/63 [00:54<01:08,  1.94s/it]

Loss: 0.23033228516578674


Epoch 3:  46%|████▌     | 29/63 [00:56<01:06,  1.94s/it]

Loss: 0.00166120077483356


Epoch 3:  48%|████▊     | 30/63 [00:58<01:04,  1.95s/it]

Loss: 0.15596112608909607


Epoch 3:  49%|████▉     | 31/63 [01:00<01:02,  1.94s/it]

Loss: 0.2046751081943512


Epoch 3:  51%|█████     | 32/63 [01:02<01:00,  1.94s/it]

Loss: 0.3822169303894043


Epoch 3:  52%|█████▏    | 33/63 [01:04<00:58,  1.94s/it]

Loss: 0.23341922461986542


Epoch 3:  54%|█████▍    | 34/63 [01:06<00:56,  1.94s/it]

Loss: 0.08277405798435211


Epoch 3:  56%|█████▌    | 35/63 [01:08<00:54,  1.94s/it]

Loss: 1.598099152033683e-05


Epoch 3:  57%|█████▋    | 36/63 [01:09<00:52,  1.94s/it]

Loss: 0.457139790058136


Epoch 3:  59%|█████▊    | 37/63 [01:11<00:50,  1.94s/it]

Loss: 0.010387225076556206


Epoch 3:  60%|██████    | 38/63 [01:13<00:48,  1.94s/it]

Loss: 0.0017357994802296162


Epoch 3:  62%|██████▏   | 39/63 [01:15<00:46,  1.94s/it]

Loss: 0.004292598459869623


Epoch 3:  63%|██████▎   | 40/63 [01:17<00:44,  1.94s/it]

Loss: 0.20336005091667175


Epoch 3:  65%|██████▌   | 41/63 [01:19<00:42,  1.94s/it]

Loss: 0.20504295825958252


Epoch 3:  67%|██████▋   | 42/63 [01:21<00:40,  1.94s/it]

Loss: 0.4200913608074188


Epoch 3:  68%|██████▊   | 43/63 [01:23<00:38,  1.94s/it]

Loss: 0.12835583090782166


Epoch 3:  70%|██████▉   | 44/63 [01:25<00:36,  1.95s/it]

Loss: 8.69843802320247e-07


Epoch 3:  71%|███████▏  | 45/63 [01:27<00:35,  1.95s/it]

Loss: 0.37039443850517273


Epoch 3:  73%|███████▎  | 46/63 [01:29<00:33,  1.94s/it]

Loss: 0.23504090309143066


Epoch 3:  75%|███████▍  | 47/63 [01:31<00:31,  1.94s/it]

Loss: 0.0


Epoch 3:  76%|███████▌  | 48/63 [01:33<00:29,  1.95s/it]

Loss: 0.00014100311091169715


Epoch 3:  78%|███████▊  | 49/63 [01:35<00:27,  1.95s/it]

Loss: 0.08944400399923325


Epoch 3:  79%|███████▉  | 50/63 [01:37<00:25,  1.94s/it]

Loss: 6.959497568459483e-06


Epoch 3:  81%|████████  | 51/63 [01:39<00:23,  1.94s/it]

Loss: 2.2351736461700966e-08


Epoch 3:  83%|████████▎ | 52/63 [01:41<00:21,  1.94s/it]

Loss: 0.36840009689331055


Epoch 3:  84%|████████▍ | 53/63 [01:43<00:19,  1.94s/it]

Loss: 0.00020743718778248876


Epoch 3:  86%|████████▌ | 54/63 [01:44<00:17,  1.94s/it]

Loss: 0.0708681046962738


Epoch 3:  87%|████████▋ | 55/63 [01:46<00:15,  1.94s/it]

Loss: 5.398556822910905e-05


Epoch 3:  89%|████████▉ | 56/63 [01:48<00:13,  1.94s/it]

Loss: 0.3173408508300781


Epoch 3:  90%|█████████ | 57/63 [01:50<00:11,  1.94s/it]

Loss: 0.05946078896522522


Epoch 3:  92%|█████████▏| 58/63 [01:52<00:09,  1.94s/it]

Loss: 0.0008507512975484133


Epoch 3:  94%|█████████▎| 59/63 [01:54<00:07,  1.94s/it]

Loss: 0.3338732123374939


Epoch 3:  95%|█████████▌| 60/63 [01:56<00:05,  1.94s/it]

Loss: 0.006155117880553007


Epoch 3:  97%|█████████▋| 61/63 [01:58<00:03,  1.94s/it]

Loss: 0.3341110050678253


Epoch 3:  98%|█████████▊| 62/63 [02:00<00:01,  1.94s/it]

Loss: 0.0017428480787202716


Epoch 3: 100%|██████████| 63/63 [02:01<00:00,  1.93s/it]

Loss: 0.00036370393354445696





Epoch 3 Validation Accuracy: 0.9901477832512315, F1-macro: 0.7475124378109452


Epoch 4:   2%|▏         | 1/63 [00:01<01:59,  1.92s/it]

Loss: 0.1328863948583603


Epoch 4:   3%|▎         | 2/63 [00:03<01:57,  1.93s/it]

Loss: 0.02263312041759491


Epoch 4:   5%|▍         | 3/63 [00:05<01:56,  1.94s/it]

Loss: 3.2968563346003066e-07


Epoch 4:   6%|▋         | 4/63 [00:07<01:54,  1.94s/it]

Loss: 0.00016220616817008704


Epoch 4:   8%|▊         | 5/63 [00:09<01:52,  1.94s/it]

Loss: 0.0008484727004542947


Epoch 4:  10%|▉         | 6/63 [00:11<01:50,  1.94s/it]

Loss: 2.402803147560917e-07


Epoch 4:  11%|█         | 7/63 [00:13<01:48,  1.94s/it]

Loss: 0.0009267063578590751


Epoch 4:  13%|█▎        | 8/63 [00:15<01:46,  1.94s/it]

Loss: 1.760353916324675e-05


Epoch 4:  14%|█▍        | 9/63 [00:17<01:44,  1.94s/it]

Loss: 1.0462817044754047e-05


Epoch 4:  16%|█▌        | 10/63 [00:19<01:42,  1.94s/it]

Loss: 0.23779897391796112


Epoch 4:  17%|█▋        | 11/63 [00:21<01:40,  1.94s/it]

Loss: 0.3682723045349121


Epoch 4:  19%|█▉        | 12/63 [00:23<01:38,  1.94s/it]

Loss: 0.00022394060215447098


Epoch 4:  21%|██        | 13/63 [00:25<01:37,  1.94s/it]

Loss: 0.002382562030106783


Epoch 4:  22%|██▏       | 14/63 [00:27<01:35,  1.94s/it]

Loss: 0.1875530481338501


Epoch 4:  24%|██▍       | 15/63 [00:29<01:33,  1.94s/it]

Loss: 0.00033171079121530056


Epoch 4:  25%|██▌       | 16/63 [00:31<01:31,  1.94s/it]

Loss: 0.01431120466440916


Epoch 4:  27%|██▋       | 17/63 [00:32<01:29,  1.94s/it]

Loss: 0.001230392255820334


Epoch 4:  29%|██▊       | 18/63 [00:34<01:27,  1.94s/it]

Loss: 0.06478267163038254


Epoch 4:  30%|███       | 19/63 [00:36<01:25,  1.94s/it]

Loss: 0.0490952730178833


Epoch 4:  32%|███▏      | 20/63 [00:38<01:23,  1.94s/it]

Loss: 0.018571874126791954


Epoch 4:  33%|███▎      | 21/63 [00:40<01:21,  1.94s/it]

Loss: 0.2202025055885315


Epoch 4:  35%|███▍      | 22/63 [00:42<01:19,  1.94s/it]

Loss: 0.00013983114331495017


Epoch 4:  37%|███▋      | 23/63 [00:44<01:17,  1.94s/it]

Loss: 0.02774432860314846


Epoch 4:  38%|███▊      | 24/63 [00:46<01:15,  1.94s/it]

Loss: 0.004963572137057781


Epoch 4:  40%|███▉      | 25/63 [00:48<01:13,  1.94s/it]

Loss: 0.28393375873565674


Epoch 4:  41%|████▏     | 26/63 [00:50<01:11,  1.94s/it]

Loss: 0.1762993335723877


Epoch 4:  43%|████▎     | 27/63 [00:52<01:09,  1.94s/it]

Loss: 1.258744941878831e-05


Epoch 4:  44%|████▍     | 28/63 [00:54<01:08,  1.94s/it]

Loss: 0.003095767227932811


Epoch 4:  46%|████▌     | 29/63 [00:56<01:06,  1.94s/it]

Loss: 0.10728926956653595


Epoch 4:  48%|████▊     | 30/63 [00:58<01:04,  1.94s/it]

Loss: 0.16426336765289307


Epoch 4:  49%|████▉     | 31/63 [01:00<01:02,  1.94s/it]

Loss: 0.0714806616306305


Epoch 4:  51%|█████     | 32/63 [01:02<01:00,  1.94s/it]

Loss: 0.006356927566230297


Epoch 4:  52%|█████▏    | 33/63 [01:04<00:58,  1.94s/it]

Loss: 0.32964009046554565


Epoch 4:  54%|█████▍    | 34/63 [01:05<00:56,  1.94s/it]

Loss: 5.178118271942367e-07


Epoch 4:  56%|█████▌    | 35/63 [01:07<00:54,  1.94s/it]

Loss: 4.8793690439197235e-06


Epoch 4:  57%|█████▋    | 36/63 [01:09<00:52,  1.94s/it]

Loss: 0.03392281383275986


Epoch 4:  59%|█████▊    | 37/63 [01:11<00:50,  1.94s/it]

Loss: 0.05103315785527229


Epoch 4:  60%|██████    | 38/63 [01:13<00:48,  1.94s/it]

Loss: 0.8295183181762695


Epoch 4:  62%|██████▏   | 39/63 [01:15<00:46,  1.94s/it]

Loss: 0.6633437275886536


Epoch 4:  63%|██████▎   | 40/63 [01:17<00:44,  1.94s/it]

Loss: 0.6224340200424194


Epoch 4:  65%|██████▌   | 41/63 [01:19<00:42,  1.94s/it]

Loss: 0.0


Epoch 4:  67%|██████▋   | 42/63 [01:21<00:40,  1.94s/it]

Loss: 5.3038114856462926e-05


Epoch 4:  68%|██████▊   | 43/63 [01:23<00:38,  1.94s/it]

Loss: 7.823093284287097e-08


Epoch 4:  70%|██████▉   | 44/63 [01:25<00:36,  1.94s/it]

Loss: 7.320025474655267e-07


Epoch 4:  71%|███████▏  | 45/63 [01:27<00:34,  1.94s/it]

Loss: 0.0


Epoch 4:  73%|███████▎  | 46/63 [01:29<00:32,  1.94s/it]

Loss: 1.0840219601959689e-06


Epoch 4:  75%|███████▍  | 47/63 [01:31<00:31,  1.94s/it]

Loss: 0.0


Epoch 4:  76%|███████▌  | 48/63 [01:33<00:29,  1.94s/it]

Loss: 4.172272838331992e-07


Epoch 4:  78%|███████▊  | 49/63 [01:35<00:27,  1.94s/it]

Loss: 4.470284977742267e-07


Epoch 4:  79%|███████▉  | 50/63 [01:37<00:25,  1.94s/it]

Loss: 9.951288666343316e-05


Epoch 4:  81%|████████  | 51/63 [01:38<00:23,  1.94s/it]

Loss: 0.12619104981422424


Epoch 4:  83%|████████▎ | 52/63 [01:40<00:21,  1.94s/it]

Loss: 1.303851249190302e-08


Epoch 4:  84%|████████▍ | 53/63 [01:42<00:19,  1.94s/it]

Loss: 0.14294663071632385


Epoch 4:  86%|████████▌ | 54/63 [01:44<00:17,  1.94s/it]

Loss: 0.013520574197173119


Epoch 4:  87%|████████▋ | 55/63 [01:46<00:15,  1.94s/it]

Loss: 0.018006615340709686


Epoch 4:  89%|████████▉ | 56/63 [01:48<00:13,  1.94s/it]

Loss: 0.0019553641323000193


Epoch 4:  90%|█████████ | 57/63 [01:50<00:11,  1.94s/it]

Loss: 0.20499825477600098


Epoch 4:  92%|█████████▏| 58/63 [01:52<00:09,  1.94s/it]

Loss: 0.39974749088287354


Epoch 4:  94%|█████████▎| 59/63 [01:54<00:07,  1.94s/it]

Loss: 0.3049946129322052


Epoch 4:  95%|█████████▌| 60/63 [01:56<00:05,  1.94s/it]

Loss: 0.0002429006708553061


Epoch 4:  97%|█████████▋| 61/63 [01:58<00:03,  1.94s/it]

Loss: 0.14985422790050507


Epoch 4:  98%|█████████▊| 62/63 [02:00<00:01,  1.94s/it]

Loss: 0.010752800852060318


Epoch 4: 100%|██████████| 63/63 [02:01<00:00,  1.93s/it]

Loss: 0.39446398615837097





Epoch 4 Validation Accuracy: 0.9950738916256158, F1-macro: 0.8987531172069826


Epoch 5:   2%|▏         | 1/63 [00:01<01:59,  1.92s/it]

Loss: 3.5390250729960826e-08


Epoch 5:   3%|▎         | 2/63 [00:03<01:58,  1.94s/it]

Loss: 4.0605189610687376e-07


Epoch 5:   5%|▍         | 3/63 [00:05<01:56,  1.94s/it]

Loss: 0.056942421942949295


Epoch 5:   6%|▋         | 4/63 [00:07<01:54,  1.94s/it]

Loss: 0.1580503135919571


Epoch 5:   8%|▊         | 5/63 [00:09<01:52,  1.94s/it]

Loss: 0.0


Epoch 5:  10%|▉         | 6/63 [00:11<01:50,  1.94s/it]

Loss: 0.34863364696502686


Epoch 5:  11%|█         | 7/63 [00:13<01:48,  1.94s/it]

Loss: 0.3143450915813446


Epoch 5:  13%|█▎        | 8/63 [00:15<01:46,  1.94s/it]

Loss: 6.0639929870376363e-05


Epoch 5:  14%|█▍        | 9/63 [00:17<01:44,  1.94s/it]

Loss: 2.5145649829028116e-07


Epoch 5:  16%|█▌        | 10/63 [00:19<01:42,  1.94s/it]

Loss: 0.0017314444994553924


Epoch 5:  17%|█▋        | 11/63 [00:21<01:40,  1.94s/it]

Loss: 2.6751160476123914e-05


Epoch 5:  19%|█▉        | 12/63 [00:23<01:38,  1.94s/it]

Loss: 0.009744218550622463


Epoch 5:  21%|██        | 13/63 [00:25<01:37,  1.94s/it]

Loss: 0.0002496626984793693


Epoch 5:  22%|██▏       | 14/63 [00:27<01:35,  1.94s/it]

Loss: 0.059299420565366745


Epoch 5:  24%|██▍       | 15/63 [00:29<01:33,  1.94s/it]

Loss: 5.3846448281547055e-05


Epoch 5:  25%|██▌       | 16/63 [00:31<01:31,  1.94s/it]

Loss: 0.0002464434946887195


Epoch 5:  27%|██▋       | 17/63 [00:32<01:29,  1.94s/it]

Loss: 0.3079722821712494


Epoch 5:  29%|██▊       | 18/63 [00:34<01:27,  1.94s/it]

Loss: 0.2396431863307953


Epoch 5:  30%|███       | 19/63 [00:36<01:25,  1.94s/it]

Loss: 0.00021310755982995033


Epoch 5:  32%|███▏      | 20/63 [00:38<01:23,  1.94s/it]

Loss: 3.7252898543727042e-09


Epoch 5:  33%|███▎      | 21/63 [00:40<01:21,  1.94s/it]

Loss: 1.0788440704345703e-05


Epoch 5:  35%|███▍      | 22/63 [00:42<01:19,  1.94s/it]

Loss: 0.0


Epoch 5:  37%|███▋      | 23/63 [00:44<01:17,  1.94s/it]

Loss: 0.10254450887441635


Epoch 5:  38%|███▊      | 24/63 [00:46<01:15,  1.94s/it]

Loss: 0.035844892263412476


Epoch 5:  40%|███▉      | 25/63 [00:48<01:13,  1.94s/it]

Loss: 0.45787614583969116


Epoch 5:  41%|████▏     | 26/63 [00:50<01:11,  1.94s/it]

Loss: 0.023085013031959534


Epoch 5:  43%|████▎     | 27/63 [00:52<01:09,  1.94s/it]

Loss: 0.5397155284881592


Epoch 5:  44%|████▍     | 28/63 [00:54<01:07,  1.94s/it]

Loss: 0.0


Epoch 5:  46%|████▌     | 29/63 [00:56<01:06,  1.94s/it]

Loss: 0.41287118196487427


Epoch 5:  48%|████▊     | 30/63 [00:58<01:04,  1.94s/it]

Loss: 0.32367420196533203


Epoch 5:  49%|████▉     | 31/63 [01:00<01:02,  1.94s/it]

Loss: 0.5209528207778931


Epoch 5:  51%|█████     | 32/63 [01:02<01:00,  1.94s/it]

Loss: 0.7388336658477783


Epoch 5:  52%|█████▏    | 33/63 [01:04<00:58,  1.94s/it]

Loss: 0.2802000641822815


Epoch 5:  54%|█████▍    | 34/63 [01:05<00:56,  1.94s/it]

Loss: 0.013011216185986996


Epoch 5:  56%|█████▌    | 35/63 [01:07<00:54,  1.94s/it]

Loss: 1.2945616617798805e-05


Epoch 5:  57%|█████▋    | 36/63 [01:09<00:52,  1.94s/it]

Loss: 0.22720275819301605


Epoch 5:  59%|█████▊    | 37/63 [01:11<00:50,  1.94s/it]

Loss: 0.0005582963931374252


Epoch 5:  60%|██████    | 38/63 [01:13<00:48,  1.94s/it]

Loss: 7.450578820566989e-09


Epoch 5:  62%|██████▏   | 39/63 [01:15<00:46,  1.94s/it]

Loss: 0.5981507897377014


Epoch 5:  63%|██████▎   | 40/63 [01:17<00:44,  1.94s/it]

Loss: 0.0


Epoch 5:  65%|██████▌   | 41/63 [01:19<00:42,  1.94s/it]

Loss: 0.0


Epoch 5:  67%|██████▋   | 42/63 [01:21<00:40,  1.94s/it]

Loss: 1.1175869119028903e-08


Epoch 5:  68%|██████▊   | 43/63 [01:23<00:38,  1.94s/it]

Loss: 1.8626450382086546e-09


Epoch 5:  70%|██████▉   | 44/63 [01:25<00:36,  1.94s/it]

Loss: 0.3398001492023468


Epoch 5:  71%|███████▏  | 45/63 [01:27<00:34,  1.94s/it]

Loss: 0.3347412347793579


Epoch 5:  73%|███████▎  | 46/63 [01:29<00:32,  1.94s/it]

Loss: 1.4528566794069775e-07


Epoch 5:  75%|███████▍  | 47/63 [01:31<00:31,  1.94s/it]

Loss: 0.09249922633171082


Epoch 5:  76%|███████▌  | 48/63 [01:33<00:29,  1.94s/it]

Loss: 0.30815792083740234


Epoch 5:  78%|███████▊  | 49/63 [01:35<00:27,  1.94s/it]

Loss: 0.03105083480477333


Epoch 5:  79%|███████▉  | 50/63 [01:37<00:25,  1.94s/it]

Loss: 0.013032570481300354


Epoch 5:  81%|████████  | 51/63 [01:38<00:23,  1.94s/it]

Loss: 0.001893134438432753


Epoch 5:  83%|████████▎ | 52/63 [01:40<00:21,  1.94s/it]

Loss: 0.2649923264980316


Epoch 5:  84%|████████▍ | 53/63 [01:42<00:19,  1.94s/it]

Loss: 1.8626450382086546e-09


Epoch 5:  86%|████████▌ | 54/63 [01:44<00:17,  1.94s/it]

Loss: 1.8933487808681093e-05


Epoch 5:  87%|████████▋ | 55/63 [01:46<00:15,  1.94s/it]

Loss: 0.0


Epoch 5:  89%|████████▉ | 56/63 [01:48<00:13,  1.94s/it]

Loss: 0.511905312538147


Epoch 5:  90%|█████████ | 57/63 [01:50<00:11,  1.94s/it]

Loss: 0.12028072029352188


Epoch 5:  92%|█████████▏| 58/63 [01:52<00:09,  1.94s/it]

Loss: 0.3893241286277771


Epoch 5:  94%|█████████▎| 59/63 [01:54<00:07,  1.94s/it]

Loss: 2.0515510186669417e-05


Epoch 5:  95%|█████████▌| 60/63 [01:56<00:05,  1.94s/it]

Loss: 0.0005685288924723864


Epoch 5:  97%|█████████▋| 61/63 [01:58<00:03,  1.94s/it]

Loss: 0.07934524118900299


Epoch 5:  98%|█████████▊| 62/63 [02:00<00:01,  1.94s/it]

Loss: 0.1342536211013794


Epoch 5: 100%|██████████| 63/63 [02:01<00:00,  1.93s/it]

Loss: 0.20039211213588715





Epoch 5 Validation Accuracy: 0.9802955665024631, F1-macro: 0.6616666666666666


Epoch 6:   2%|▏         | 1/63 [00:01<01:59,  1.93s/it]

Loss: 0.003769536269828677


Epoch 6:   3%|▎         | 2/63 [00:03<01:58,  1.94s/it]

Loss: 6.606241367990151e-05


Epoch 6:   5%|▍         | 3/63 [00:05<01:56,  1.94s/it]

Loss: 5.587935003603661e-09


Epoch 6:   6%|▋         | 4/63 [00:07<01:54,  1.94s/it]

Loss: 0.08978375792503357


Epoch 6:   8%|▊         | 5/63 [00:09<01:52,  1.94s/it]

Loss: 0.313484787940979


Epoch 6:  10%|▉         | 6/63 [00:11<01:50,  1.94s/it]

Loss: 1.9371400128420646e-07


Epoch 6:  11%|█         | 7/63 [00:13<01:48,  1.94s/it]

Loss: 3.6622855986934155e-05


Epoch 6:  13%|█▎        | 8/63 [00:15<01:46,  1.94s/it]

Loss: 1.3446257071336731e-05


Epoch 6:  14%|█▍        | 9/63 [00:17<01:44,  1.94s/it]

Loss: 0.3494965136051178


Epoch 6:  16%|█▌        | 10/63 [00:19<01:42,  1.94s/it]

Loss: 0.000802934227976948


Epoch 6:  17%|█▋        | 11/63 [00:21<01:40,  1.94s/it]

Loss: 0.09255653619766235


Epoch 6:  19%|█▉        | 12/63 [00:23<01:38,  1.94s/it]

Loss: 4.7881280806905124e-06


Epoch 6:  21%|██        | 13/63 [00:25<01:37,  1.94s/it]

Loss: 1.8626450382086546e-09


Epoch 6:  22%|██▏       | 14/63 [00:27<01:35,  1.94s/it]

Loss: 3.539021875553772e-08


Epoch 6:  24%|██▍       | 15/63 [00:29<01:33,  1.94s/it]

Loss: 6.912615390319843e-06


Epoch 6:  25%|██▌       | 16/63 [00:31<01:31,  1.94s/it]

Loss: 3.725290076417309e-09


Epoch 6:  27%|██▋       | 17/63 [00:32<01:29,  1.94s/it]

Loss: 0.0


Epoch 6:  29%|██▊       | 18/63 [00:34<01:27,  1.94s/it]

Loss: 0.29276058077812195


Epoch 6:  30%|███       | 19/63 [00:36<01:25,  1.94s/it]

Loss: 2.2351732908987287e-08


Epoch 6:  32%|███▏      | 20/63 [00:38<01:23,  1.94s/it]

Loss: 0.00018862372962757945


Epoch 6:  33%|███▎      | 21/63 [00:40<01:21,  1.94s/it]

Loss: 0.0006499801529571414


Epoch 6:  35%|███▍      | 22/63 [00:42<01:19,  1.94s/it]

Loss: 0.04868881031870842


Epoch 6:  37%|███▋      | 23/63 [00:44<01:17,  1.94s/it]

Loss: 8.754413727274368e-08


Epoch 6:  38%|███▊      | 24/63 [00:46<01:15,  1.94s/it]

Loss: 5.52176925339154e-06


Epoch 6:  40%|███▉      | 25/63 [00:48<01:13,  1.94s/it]

Loss: 5.587934559514451e-09


Epoch 6:  41%|████▏     | 26/63 [00:50<01:11,  1.94s/it]

Loss: 0.01970903016626835


Epoch 6:  43%|████▎     | 27/63 [00:52<01:09,  1.94s/it]

Loss: 6.33298071761601e-08


Epoch 6:  44%|████▍     | 28/63 [00:54<01:08,  1.94s/it]

Loss: 8.828695854390389e-07


Epoch 6:  46%|████▌     | 29/63 [00:56<01:06,  1.94s/it]

Loss: 0.0


Epoch 6:  48%|████▊     | 30/63 [00:58<01:04,  1.94s/it]

Loss: 0.48925673961639404


Epoch 6:  49%|████▉     | 31/63 [01:00<01:02,  1.94s/it]

Loss: 1.4714832730078342e-07


Epoch 6:  51%|█████     | 32/63 [01:02<01:00,  1.94s/it]

Loss: 0.552259087562561


Epoch 6:  52%|█████▏    | 33/63 [01:04<00:58,  1.94s/it]

Loss: 0.3490712642669678


Epoch 6:  54%|█████▍    | 34/63 [01:05<00:56,  1.94s/it]

Loss: 0.09265734255313873


Epoch 6:  56%|█████▌    | 35/63 [01:07<00:54,  1.94s/it]

Loss: 0.2837566137313843


Epoch 6:  57%|█████▋    | 36/63 [01:09<00:52,  1.94s/it]

Loss: 0.35340607166290283


Epoch 6:  59%|█████▊    | 37/63 [01:11<00:50,  1.94s/it]

Loss: 0.003298832569271326


Epoch 6:  60%|██████    | 38/63 [01:13<00:48,  1.94s/it]

Loss: 0.07315574586391449


Epoch 6:  62%|██████▏   | 39/63 [01:15<00:46,  1.94s/it]

Loss: 0.1141747236251831


Epoch 6:  63%|██████▎   | 40/63 [01:17<00:44,  1.94s/it]

Loss: 0.014501557685434818


Epoch 6:  65%|██████▌   | 41/63 [01:19<00:42,  1.94s/it]

Loss: 0.0


Epoch 6:  67%|██████▋   | 42/63 [01:21<00:40,  1.94s/it]

Loss: 0.0


Epoch 6:  68%|██████▊   | 43/63 [01:23<00:38,  1.94s/it]

Loss: 0.1314208060503006


Epoch 6:  70%|██████▉   | 44/63 [01:25<00:36,  1.94s/it]

Loss: 0.1431894302368164


Epoch 6:  71%|███████▏  | 45/63 [01:27<00:35,  1.95s/it]

Loss: 7.450578820566989e-09


Epoch 6:  73%|███████▎  | 46/63 [01:29<00:33,  1.95s/it]

Loss: 0.0


Epoch 6:  75%|███████▍  | 47/63 [01:31<00:31,  1.94s/it]

Loss: 0.00027436972595751286


Epoch 6:  76%|███████▌  | 48/63 [01:33<00:29,  1.94s/it]

Loss: 0.9670175909996033


Epoch 6:  78%|███████▊  | 49/63 [01:35<00:27,  1.94s/it]

Loss: 0.3493081331253052


Epoch 6:  79%|███████▉  | 50/63 [01:37<00:25,  1.94s/it]

Loss: 2.8402757834555814e-06


Epoch 6:  81%|████████  | 51/63 [01:39<00:23,  1.94s/it]

Loss: 6.128836503194179e-06


Epoch 6:  83%|████████▎ | 52/63 [01:40<00:21,  1.94s/it]

Loss: 6.554540595971048e-05


Epoch 6:  84%|████████▍ | 53/63 [01:42<00:19,  1.94s/it]

Loss: 0.061866164207458496


Epoch 6:  86%|████████▌ | 54/63 [01:44<00:17,  1.94s/it]

Loss: 0.0002963550796266645


Epoch 6:  87%|████████▋ | 55/63 [01:46<00:15,  1.94s/it]

Loss: 0.6538980603218079


Epoch 6:  89%|████████▉ | 56/63 [01:48<00:13,  1.94s/it]

Loss: 0.01625022105872631


Epoch 6:  90%|█████████ | 57/63 [01:50<00:11,  1.94s/it]

Loss: 0.26115548610687256


Epoch 6:  92%|█████████▏| 58/63 [01:52<00:09,  1.94s/it]

Loss: 0.11390005052089691


Epoch 6:  94%|█████████▎| 59/63 [01:54<00:07,  1.94s/it]

Loss: 0.00017126323655247688


Epoch 6:  95%|█████████▌| 60/63 [01:56<00:05,  1.94s/it]

Loss: 0.11519448459148407


Epoch 6:  97%|█████████▋| 61/63 [01:58<00:03,  1.94s/it]

Loss: 0.405337393283844


Epoch 6:  98%|█████████▊| 62/63 [02:00<00:01,  1.94s/it]

Loss: 0.0001750377705320716


Epoch 6: 100%|██████████| 63/63 [02:01<00:00,  1.93s/it]

Loss: 0.6663104891777039





Epoch 6 Validation Accuracy: 0.9802955665024631, F1-macro: 0.7449748743718593


Epoch 7:   2%|▏         | 1/63 [00:01<01:59,  1.93s/it]

Loss: 0.16480040550231934


Epoch 7:   3%|▎         | 2/63 [00:03<01:58,  1.94s/it]

Loss: 0.08362768590450287


Epoch 7:   5%|▍         | 3/63 [00:05<01:56,  1.94s/it]

Loss: 0.31595486402511597


Epoch 7:   6%|▋         | 4/63 [00:07<01:54,  1.94s/it]

Loss: 0.0010463126236572862


Epoch 7:   8%|▊         | 5/63 [00:09<01:52,  1.94s/it]

Loss: 0.10011029988527298


Epoch 7:  10%|▉         | 6/63 [00:11<01:50,  1.94s/it]

Loss: 8.269927320725401e-07


Epoch 7:  11%|█         | 7/63 [00:13<01:48,  1.94s/it]

Loss: 0.22416166961193085


Epoch 7:  13%|█▎        | 8/63 [00:15<01:46,  1.94s/it]

Loss: 0.001307699829339981


Epoch 7:  14%|█▍        | 9/63 [00:17<01:44,  1.94s/it]

Loss: 0.16050979495048523


Epoch 7:  16%|█▌        | 10/63 [00:19<01:43,  1.94s/it]

Loss: 0.006374908145517111


Epoch 7:  17%|█▋        | 11/63 [00:21<01:41,  1.94s/it]

Loss: 0.006763014942407608


Epoch 7:  19%|█▉        | 12/63 [00:23<01:39,  1.94s/it]

Loss: 0.0028798195999115705


Epoch 7:  21%|██        | 13/63 [00:25<01:37,  1.94s/it]

Loss: 0.002250553574413061


Epoch 7:  22%|██▏       | 14/63 [00:27<01:35,  1.94s/it]

Loss: 2.315301389899105e-05


Epoch 7:  24%|██▍       | 15/63 [00:29<01:33,  1.94s/it]

Loss: 7.636836585334095e-08


Epoch 7:  25%|██▌       | 16/63 [00:31<01:31,  1.94s/it]

Loss: 4.0991449168359395e-06


Epoch 7:  27%|██▋       | 17/63 [00:33<01:29,  1.94s/it]

Loss: 0.06347792595624924


Epoch 7:  29%|██▊       | 18/63 [00:34<01:27,  1.94s/it]

Loss: 0.040834806859493256


Epoch 7:  30%|███       | 19/63 [00:36<01:25,  1.94s/it]

Loss: 0.2732974886894226


Epoch 7:  32%|███▏      | 20/63 [00:38<01:23,  1.94s/it]

Loss: 1.2665935855693533e-07


Epoch 7:  33%|███▎      | 21/63 [00:40<01:21,  1.95s/it]

Loss: 1.1821463886008132e-05


Epoch 7:  35%|███▍      | 22/63 [00:42<01:19,  1.95s/it]

Loss: 2.0489082430685812e-08


Epoch 7:  37%|███▋      | 23/63 [00:44<01:17,  1.94s/it]

Loss: 0.20129041373729706


Epoch 7:  38%|███▊      | 24/63 [00:46<01:15,  1.94s/it]

Loss: 0.0


Epoch 7:  40%|███▉      | 25/63 [00:48<01:13,  1.94s/it]

Loss: 0.0


Epoch 7:  41%|████▏     | 26/63 [00:50<01:11,  1.94s/it]

Loss: 0.10482511669397354


Epoch 7:  43%|████▎     | 27/63 [00:52<01:09,  1.94s/it]

Loss: 1.9165605408488773e-05


Epoch 7:  44%|████▍     | 28/63 [00:54<01:08,  1.95s/it]

Loss: 0.12517088651657104


Epoch 7:  46%|████▌     | 29/63 [00:56<01:06,  1.95s/it]

Loss: 0.05296821892261505


Epoch 7:  48%|████▊     | 30/63 [00:58<01:04,  1.94s/it]

Loss: 7.506544989155373e-06


Epoch 7:  49%|████▉     | 31/63 [01:00<01:02,  1.94s/it]

Loss: 9.825590677792206e-05


Epoch 7:  51%|█████     | 32/63 [01:02<01:00,  1.94s/it]

Loss: 9.499477471308637e-08


Epoch 7:  52%|█████▏    | 33/63 [01:04<00:58,  1.94s/it]

Loss: 0.11497856676578522


Epoch 7:  54%|█████▍    | 34/63 [01:06<00:56,  1.94s/it]

Loss: 1.2389783478283789e-05


Epoch 7:  56%|█████▌    | 35/63 [01:08<00:54,  1.94s/it]

Loss: 0.09783017635345459


Epoch 7:  57%|█████▋    | 36/63 [01:09<00:52,  1.94s/it]

Loss: 0.06321142613887787


Epoch 7:  59%|█████▊    | 37/63 [01:11<00:50,  1.94s/it]

Loss: 0.13499200344085693


Epoch 7:  60%|██████    | 38/63 [01:13<00:48,  1.94s/it]

Loss: 0.06823862344026566


Epoch 7:  62%|██████▏   | 39/63 [01:15<00:46,  1.94s/it]

Loss: 0.21155354380607605


Epoch 7:  63%|██████▎   | 40/63 [01:17<00:44,  1.95s/it]

Loss: 0.0005198183353058994


Epoch 7:  65%|██████▌   | 41/63 [01:19<00:42,  1.94s/it]

Loss: 0.00012875058746431023


Epoch 7:  67%|██████▋   | 42/63 [01:21<00:40,  1.94s/it]

Loss: 8.009355667581985e-08


Epoch 7:  68%|██████▊   | 43/63 [01:23<00:38,  1.94s/it]

Loss: 0.0


Epoch 7:  70%|██████▉   | 44/63 [01:25<00:36,  1.94s/it]

Loss: 4.1908964476533583e-07


Epoch 7:  71%|███████▏  | 45/63 [01:27<00:35,  1.95s/it]

Loss: 4.135029598728579e-07


Epoch 7:  73%|███████▎  | 46/63 [01:29<00:33,  1.94s/it]

Loss: 3.7252898543727042e-09


Epoch 7:  75%|███████▍  | 47/63 [01:31<00:31,  1.94s/it]

Loss: 0.20199289917945862


Epoch 7:  76%|███████▌  | 48/63 [01:33<00:29,  1.95s/it]

Loss: 1.303851249190302e-08


Epoch 7:  78%|███████▊  | 49/63 [01:35<00:27,  1.95s/it]

Loss: 0.011368250474333763


Epoch 7:  79%|███████▉  | 50/63 [01:37<00:25,  1.94s/it]

Loss: 0.0


Epoch 7:  81%|████████  | 51/63 [01:39<00:23,  1.95s/it]

Loss: 0.2287771999835968


Epoch 7:  83%|████████▎ | 52/63 [01:41<00:21,  1.94s/it]

Loss: 0.006729947403073311


Epoch 7:  84%|████████▍ | 53/63 [01:43<00:19,  1.95s/it]

Loss: 9.411376959178597e-05


Epoch 7:  86%|████████▌ | 54/63 [01:44<00:17,  1.95s/it]

Loss: 0.07612982392311096


Epoch 7:  87%|████████▋ | 55/63 [01:46<00:15,  1.95s/it]

Loss: 0.008540182374417782


Epoch 7:  89%|████████▉ | 56/63 [01:48<00:13,  1.94s/it]

Loss: 3.9674108620602055e-07


Epoch 7:  90%|█████████ | 57/63 [01:50<00:11,  1.95s/it]

Loss: 1.1034137969545554e-05


Epoch 7:  92%|█████████▏| 58/63 [01:52<00:09,  1.95s/it]

Loss: 0.011228327639400959


Epoch 7:  94%|█████████▎| 59/63 [01:54<00:07,  1.95s/it]

Loss: 1.6763797461294416e-08


Epoch 7:  95%|█████████▌| 60/63 [01:56<00:05,  1.95s/it]

Loss: 0.0


Epoch 7:  97%|█████████▋| 61/63 [01:58<00:03,  1.95s/it]

Loss: 0.0


Epoch 7:  98%|█████████▊| 62/63 [02:00<00:01,  1.95s/it]

Loss: 0.0


Epoch 7: 100%|██████████| 63/63 [02:01<00:00,  1.93s/it]

Loss: 0.06817980110645294





Epoch 7 Validation Accuracy: 0.9901477832512315, F1-macro: 0.4975247524752475


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model3, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.9901960784313726, F1-macro: 0.4975369458128079


In [None]:
current_type = 'disqualifying the positive'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 4. Emotional Reasoning

In [None]:
# Add labels
data1_1_labels = list(data1['emotional reasoning'][data1_1.index])
data2_1_labels = list(data2['emotional reasoning'][data2_1.index])
data3_1_labels = list(data3['emotional reasoning'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model4 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model4.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model4.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model4(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model4, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/58 [00:01<01:49,  1.92s/it]

Loss: 8.413392066955566


Epoch 1:   3%|▎         | 2/58 [00:03<01:48,  1.93s/it]

Loss: 3.1930437088012695


Epoch 1:   5%|▌         | 3/58 [00:05<01:46,  1.94s/it]

Loss: 3.1411406993865967


Epoch 1:   7%|▋         | 4/58 [00:07<01:44,  1.94s/it]

Loss: 1.2990366220474243


Epoch 1:   9%|▊         | 5/58 [00:09<01:42,  1.94s/it]

Loss: 2.4656472206115723


Epoch 1:  10%|█         | 6/58 [00:11<01:41,  1.94s/it]

Loss: 0.48280978202819824


Epoch 1:  12%|█▏        | 7/58 [00:13<01:39,  1.94s/it]

Loss: 0.37598374485969543


Epoch 1:  14%|█▍        | 8/58 [00:15<01:37,  1.94s/it]

Loss: 1.3097928762435913


Epoch 1:  16%|█▌        | 9/58 [00:17<01:35,  1.94s/it]

Loss: 1.420471429824829


Epoch 1:  17%|█▋        | 10/58 [00:19<01:33,  1.94s/it]

Loss: 0.6681303977966309


Epoch 1:  19%|█▉        | 11/58 [00:21<01:31,  1.94s/it]

Loss: 1.4186376333236694


Epoch 1:  21%|██        | 12/58 [00:23<01:29,  1.94s/it]

Loss: 1.2196605205535889


Epoch 1:  22%|██▏       | 13/58 [00:25<01:27,  1.94s/it]

Loss: 1.2657285928726196


Epoch 1:  24%|██▍       | 14/58 [00:27<01:25,  1.94s/it]

Loss: 1.083047866821289


Epoch 1:  26%|██▌       | 15/58 [00:29<01:23,  1.94s/it]

Loss: 1.4697176218032837


Epoch 1:  28%|██▊       | 16/58 [00:31<01:21,  1.94s/it]

Loss: 1.1849303245544434


Epoch 1:  29%|██▉       | 17/58 [00:33<01:19,  1.94s/it]

Loss: 2.3181724548339844


Epoch 1:  31%|███       | 18/58 [00:34<01:17,  1.94s/it]

Loss: 1.1452821493148804


Epoch 1:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 1.8120415210723877


Epoch 1:  34%|███▍      | 20/58 [00:38<01:13,  1.94s/it]

Loss: 1.7439218759536743


Epoch 1:  36%|███▌      | 21/58 [00:40<01:11,  1.95s/it]

Loss: 2.062145471572876


Epoch 1:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 1.025383710861206


Epoch 1:  40%|███▉      | 23/58 [00:44<01:08,  1.95s/it]

Loss: 0.5554223656654358


Epoch 1:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 1.359886646270752


Epoch 1:  43%|████▎     | 25/58 [00:48<01:04,  1.95s/it]

Loss: 0.552059531211853


Epoch 1:  45%|████▍     | 26/58 [00:50<01:02,  1.95s/it]

Loss: 0.9021655917167664


Epoch 1:  47%|████▋     | 27/58 [00:52<01:00,  1.95s/it]

Loss: 0.40566954016685486


Epoch 1:  48%|████▊     | 28/58 [00:54<00:58,  1.95s/it]

Loss: 1.308524250984192


Epoch 1:  50%|█████     | 29/58 [00:56<00:56,  1.95s/it]

Loss: 0.2200939655303955


Epoch 1:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.7495399117469788


Epoch 1:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 1.4139596223831177


Epoch 1:  55%|█████▌    | 32/58 [01:02<00:50,  1.94s/it]

Loss: 0.686848521232605


Epoch 1:  57%|█████▋    | 33/58 [01:04<00:48,  1.94s/it]

Loss: 0.9572499990463257


Epoch 1:  59%|█████▊    | 34/58 [01:06<00:46,  1.94s/it]

Loss: 1.290059208869934


Epoch 1:  60%|██████    | 35/58 [01:08<00:44,  1.94s/it]

Loss: 0.44803833961486816


Epoch 1:  62%|██████▏   | 36/58 [01:09<00:42,  1.94s/it]

Loss: 1.2810202836990356


Epoch 1:  64%|██████▍   | 37/58 [01:11<00:40,  1.94s/it]

Loss: 0.6092582941055298


Epoch 1:  66%|██████▌   | 38/58 [01:13<00:38,  1.94s/it]

Loss: 0.3009883463382721


Epoch 1:  67%|██████▋   | 39/58 [01:15<00:36,  1.94s/it]

Loss: 0.4760441780090332


Epoch 1:  69%|██████▉   | 40/58 [01:17<00:34,  1.94s/it]

Loss: 0.09185276925563812


Epoch 1:  71%|███████   | 41/58 [01:19<00:33,  1.94s/it]

Loss: 0.610228955745697


Epoch 1:  72%|███████▏  | 42/58 [01:21<00:31,  1.94s/it]

Loss: 0.3669804036617279


Epoch 1:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.5872947573661804


Epoch 1:  76%|███████▌  | 44/58 [01:25<00:27,  1.94s/it]

Loss: 0.31974390149116516


Epoch 1:  78%|███████▊  | 45/58 [01:27<00:25,  1.95s/it]

Loss: 0.5155179500579834


Epoch 1:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 0.5525348782539368


Epoch 1:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.7280563116073608


Epoch 1:  83%|████████▎ | 48/58 [01:33<00:19,  1.94s/it]

Loss: 0.5718065500259399


Epoch 1:  84%|████████▍ | 49/58 [01:35<00:17,  1.94s/it]

Loss: 0.8709481954574585


Epoch 1:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.7284170985221863


Epoch 1:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.43347111344337463


Epoch 1:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.45160922408103943


Epoch 1:  91%|█████████▏| 53/58 [01:43<00:09,  1.94s/it]

Loss: 0.5752536654472351


Epoch 1:  93%|█████████▎| 54/58 [01:44<00:07,  1.94s/it]

Loss: 0.8396052122116089


Epoch 1:  95%|█████████▍| 55/58 [01:46<00:05,  1.94s/it]

Loss: 0.7413108348846436


Epoch 1:  97%|█████████▋| 56/58 [01:48<00:03,  1.95s/it]

Loss: 0.17626601457595825


Epoch 1: 100%|██████████| 58/58 [01:50<00:00,  1.91s/it]

Loss: 1.0869135856628418
Loss: 0.002871558303013444





Epoch 1 Validation Accuracy: 0.8719211822660099, F1-macro: 0.5013227513227513


Epoch 2:   2%|▏         | 1/58 [00:01<01:50,  1.95s/it]

Loss: 0.12222835421562195


Epoch 2:   3%|▎         | 2/58 [00:03<01:49,  1.95s/it]

Loss: 0.2614959180355072


Epoch 2:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 0.5054685473442078


Epoch 2:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 0.4800778031349182


Epoch 2:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 0.7346932888031006


Epoch 2:  10%|█         | 6/58 [00:11<01:41,  1.95s/it]

Loss: 0.8358803391456604


Epoch 2:  12%|█▏        | 7/58 [00:13<01:39,  1.94s/it]

Loss: 2.2139551639556885


Epoch 2:  14%|█▍        | 8/58 [00:15<01:37,  1.95s/it]

Loss: 1.264811396598816


Epoch 2:  16%|█▌        | 9/58 [00:17<01:35,  1.94s/it]

Loss: 0.2594364285469055


Epoch 2:  17%|█▋        | 10/58 [00:19<01:33,  1.95s/it]

Loss: 0.4723152220249176


Epoch 2:  19%|█▉        | 11/58 [00:21<01:31,  1.95s/it]

Loss: 0.2783021628856659


Epoch 2:  21%|██        | 12/58 [00:23<01:29,  1.94s/it]

Loss: 0.621973991394043


Epoch 2:  22%|██▏       | 13/58 [00:25<01:27,  1.94s/it]

Loss: 0.586865246295929


Epoch 2:  24%|██▍       | 14/58 [00:27<01:25,  1.94s/it]

Loss: 0.4588603973388672


Epoch 2:  26%|██▌       | 15/58 [00:29<01:23,  1.94s/it]

Loss: 0.30143678188323975


Epoch 2:  28%|██▊       | 16/58 [00:31<01:21,  1.94s/it]

Loss: 1.4400649070739746


Epoch 2:  29%|██▉       | 17/58 [00:33<01:19,  1.94s/it]

Loss: 0.3775812089443207


Epoch 2:  31%|███       | 18/58 [00:35<01:17,  1.94s/it]

Loss: 0.6877859830856323


Epoch 2:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 0.2245982587337494


Epoch 2:  34%|███▍      | 20/58 [00:38<01:13,  1.94s/it]

Loss: 0.6253248453140259


Epoch 2:  36%|███▌      | 21/58 [00:40<01:11,  1.94s/it]

Loss: 0.7991265058517456


Epoch 2:  38%|███▊      | 22/58 [00:42<01:10,  1.94s/it]

Loss: 0.6369913816452026


Epoch 2:  40%|███▉      | 23/58 [00:44<01:08,  1.94s/it]

Loss: 0.42036330699920654


Epoch 2:  41%|████▏     | 24/58 [00:46<01:06,  1.94s/it]

Loss: 0.964809238910675


Epoch 2:  43%|████▎     | 25/58 [00:48<01:04,  1.94s/it]

Loss: 0.5593526363372803


Epoch 2:  45%|████▍     | 26/58 [00:50<01:02,  1.94s/it]

Loss: 0.38833969831466675


Epoch 2:  47%|████▋     | 27/58 [00:52<01:00,  1.94s/it]

Loss: 0.232689768075943


Epoch 2:  48%|████▊     | 28/58 [00:54<00:58,  1.94s/it]

Loss: 0.23606887459754944


Epoch 2:  50%|█████     | 29/58 [00:56<00:56,  1.94s/it]

Loss: 0.7933674454689026


Epoch 2:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.6232742667198181


Epoch 2:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 0.8737708330154419


Epoch 2:  55%|█████▌    | 32/58 [01:02<00:50,  1.95s/it]

Loss: 0.11710032820701599


Epoch 2:  57%|█████▋    | 33/58 [01:04<00:48,  1.95s/it]

Loss: 0.8855361938476562


Epoch 2:  59%|█████▊    | 34/58 [01:06<00:46,  1.95s/it]

Loss: 0.7046105265617371


Epoch 2:  60%|██████    | 35/58 [01:08<00:44,  1.95s/it]

Loss: 0.5738207697868347


Epoch 2:  62%|██████▏   | 36/58 [01:10<00:42,  1.95s/it]

Loss: 0.19048510491847992


Epoch 2:  64%|██████▍   | 37/58 [01:11<00:40,  1.95s/it]

Loss: 0.5277894735336304


Epoch 2:  66%|██████▌   | 38/58 [01:13<00:38,  1.95s/it]

Loss: 0.7986167073249817


Epoch 2:  67%|██████▋   | 39/58 [01:15<00:36,  1.95s/it]

Loss: 0.463217169046402


Epoch 2:  69%|██████▉   | 40/58 [01:17<00:35,  1.95s/it]

Loss: 0.13377204537391663


Epoch 2:  71%|███████   | 41/58 [01:19<00:33,  1.95s/it]

Loss: 0.2510143518447876


Epoch 2:  72%|███████▏  | 42/58 [01:21<00:31,  1.95s/it]

Loss: 0.6173641681671143


Epoch 2:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 1.0147910118103027


Epoch 2:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.3537861406803131


Epoch 2:  78%|███████▊  | 45/58 [01:27<00:25,  1.94s/it]

Loss: 0.3216482102870941


Epoch 2:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 0.3372737765312195


Epoch 2:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.88414466381073


Epoch 2:  83%|████████▎ | 48/58 [01:33<00:19,  1.94s/it]

Loss: 0.2202981412410736


Epoch 2:  84%|████████▍ | 49/58 [01:35<00:17,  1.94s/it]

Loss: 0.537606954574585


Epoch 2:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.7257038354873657


Epoch 2:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.1974335014820099


Epoch 2:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.33108779788017273


Epoch 2:  91%|█████████▏| 53/58 [01:43<00:09,  1.95s/it]

Loss: 0.6306208968162537


Epoch 2:  93%|█████████▎| 54/58 [01:45<00:07,  1.95s/it]

Loss: 0.27065789699554443


Epoch 2:  95%|█████████▍| 55/58 [01:46<00:05,  1.94s/it]

Loss: 0.5733587741851807


Epoch 2:  97%|█████████▋| 56/58 [01:48<00:03,  1.94s/it]

Loss: 0.3078506886959076


Epoch 2: 100%|██████████| 58/58 [01:51<00:00,  1.91s/it]

Loss: 0.5347627401351929
Loss: 0.6401492357254028





Epoch 2 Validation Accuracy: 0.8768472906403941, F1-macro: 0.4671916010498688


Epoch 3:   2%|▏         | 1/58 [00:01<01:50,  1.95s/it]

Loss: 0.28310683369636536


Epoch 3:   3%|▎         | 2/58 [00:03<01:48,  1.95s/it]

Loss: 1.7539464235305786


Epoch 3:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 1.5730692148208618


Epoch 3:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 4.268716812133789


Epoch 3:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 2.7053346633911133


Epoch 3:  10%|█         | 6/58 [00:11<01:41,  1.95s/it]

Loss: 1.423915982246399


Epoch 3:  12%|█▏        | 7/58 [00:13<01:39,  1.95s/it]

Loss: 3.133221387863159


Epoch 3:  14%|█▍        | 8/58 [00:15<01:37,  1.95s/it]

Loss: 1.0926148891448975


Epoch 3:  16%|█▌        | 9/58 [00:17<01:35,  1.95s/it]

Loss: 0.7004075646400452


Epoch 3:  17%|█▋        | 10/58 [00:19<01:33,  1.95s/it]

Loss: 1.0192950963974


Epoch 3:  19%|█▉        | 11/58 [00:21<01:31,  1.95s/it]

Loss: 0.960250735282898


Epoch 3:  21%|██        | 12/58 [00:23<01:29,  1.95s/it]

Loss: 2.740022659301758


Epoch 3:  22%|██▏       | 13/58 [00:25<01:27,  1.95s/it]

Loss: 0.6373597383499146


Epoch 3:  24%|██▍       | 14/58 [00:27<01:25,  1.95s/it]

Loss: 0.6634739637374878


Epoch 3:  26%|██▌       | 15/58 [00:29<01:23,  1.95s/it]

Loss: 0.3451407849788666


Epoch 3:  28%|██▊       | 16/58 [00:31<01:21,  1.95s/it]

Loss: 1.909627079963684


Epoch 3:  29%|██▉       | 17/58 [00:33<01:19,  1.95s/it]

Loss: 2.857442855834961


Epoch 3:  31%|███       | 18/58 [00:35<01:17,  1.95s/it]

Loss: 2.8269574642181396


Epoch 3:  33%|███▎      | 19/58 [00:36<01:15,  1.95s/it]

Loss: 0.49219420552253723


Epoch 3:  34%|███▍      | 20/58 [00:38<01:13,  1.95s/it]

Loss: 1.1526122093200684


Epoch 3:  36%|███▌      | 21/58 [00:40<01:11,  1.95s/it]

Loss: 2.1343371868133545


Epoch 3:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 1.0169285535812378


Epoch 3:  40%|███▉      | 23/58 [00:44<01:08,  1.95s/it]

Loss: 1.6449155807495117


Epoch 3:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.9643580317497253


Epoch 3:  43%|████▎     | 25/58 [00:48<01:04,  1.95s/it]

Loss: 0.33789244294166565


Epoch 3:  45%|████▍     | 26/58 [00:50<01:02,  1.95s/it]

Loss: 0.7021580934524536


Epoch 3:  47%|████▋     | 27/58 [00:52<01:00,  1.95s/it]

Loss: 1.8489599227905273


Epoch 3:  48%|████▊     | 28/58 [00:54<00:58,  1.95s/it]

Loss: 0.19008564949035645


Epoch 3:  50%|█████     | 29/58 [00:56<00:56,  1.95s/it]

Loss: 0.40738001465797424


Epoch 3:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.009980481117963791


Epoch 3:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 1.025479793548584


Epoch 3:  55%|█████▌    | 32/58 [01:02<00:50,  1.95s/it]

Loss: 0.9210587739944458


Epoch 3:  57%|█████▋    | 33/58 [01:04<00:48,  1.95s/it]

Loss: 0.5929011106491089


Epoch 3:  59%|█████▊    | 34/58 [01:06<00:46,  1.95s/it]

Loss: 0.9085420966148376


Epoch 3:  60%|██████    | 35/58 [01:08<00:44,  1.95s/it]

Loss: 0.5781463384628296


Epoch 3:  62%|██████▏   | 36/58 [01:10<00:42,  1.95s/it]

Loss: 0.26293620467185974


Epoch 3:  64%|██████▍   | 37/58 [01:12<00:40,  1.95s/it]

Loss: 0.9329637289047241


Epoch 3:  66%|██████▌   | 38/58 [01:13<00:38,  1.95s/it]

Loss: 0.44653505086898804


Epoch 3:  67%|██████▋   | 39/58 [01:15<00:37,  1.95s/it]

Loss: 0.425820529460907


Epoch 3:  69%|██████▉   | 40/58 [01:17<00:35,  1.95s/it]

Loss: 1.0337172746658325


Epoch 3:  71%|███████   | 41/58 [01:19<00:33,  1.95s/it]

Loss: 0.29843512177467346


Epoch 3:  72%|███████▏  | 42/58 [01:21<00:31,  1.95s/it]

Loss: 0.28815126419067383


Epoch 3:  74%|███████▍  | 43/58 [01:23<00:29,  1.95s/it]

Loss: 0.3087536096572876


Epoch 3:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.4272306561470032


Epoch 3:  78%|███████▊  | 45/58 [01:27<00:25,  1.95s/it]

Loss: 0.7173044085502625


Epoch 3:  79%|███████▉  | 46/58 [01:29<00:23,  1.95s/it]

Loss: 0.3999018669128418


Epoch 3:  81%|████████  | 47/58 [01:31<00:21,  1.95s/it]

Loss: 0.5079784989356995


Epoch 3:  83%|████████▎ | 48/58 [01:33<00:19,  1.95s/it]

Loss: 0.7305182218551636


Epoch 3:  84%|████████▍ | 49/58 [01:35<00:17,  1.95s/it]

Loss: 0.026887495070695877


Epoch 3:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.6494048833847046


Epoch 3:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.3462386727333069


Epoch 3:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.9682185649871826


Epoch 3:  91%|█████████▏| 53/58 [01:43<00:09,  1.94s/it]

Loss: 0.23113563656806946


Epoch 3:  93%|█████████▎| 54/58 [01:45<00:07,  1.95s/it]

Loss: 1.0439471006393433


Epoch 3:  95%|█████████▍| 55/58 [01:47<00:05,  1.95s/it]

Loss: 0.6882972717285156


Epoch 3:  97%|█████████▋| 56/58 [01:49<00:03,  1.95s/it]

Loss: 0.28705692291259766


Epoch 3: 100%|██████████| 58/58 [01:51<00:00,  1.92s/it]

Loss: 0.3572129011154175
Loss: 0.0





Epoch 3 Validation Accuracy: 0.8768472906403941, F1-macro: 0.4671916010498688


Epoch 4:   2%|▏         | 1/58 [00:01<01:50,  1.95s/it]

Loss: 0.911142110824585


Epoch 4:   3%|▎         | 2/58 [00:03<01:49,  1.95s/it]

Loss: 0.8317936062812805


Epoch 4:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 5.0789385568350554e-05


Epoch 4:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 0.47328731417655945


Epoch 4:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 0.0458931103348732


Epoch 4:  10%|█         | 6/58 [00:11<01:41,  1.95s/it]

Loss: 0.2906569838523865


Epoch 4:  12%|█▏        | 7/58 [00:13<01:39,  1.95s/it]

Loss: 0.42604774236679077


Epoch 4:  14%|█▍        | 8/58 [00:15<01:37,  1.95s/it]

Loss: 0.19680923223495483


Epoch 4:  16%|█▌        | 9/58 [00:17<01:35,  1.95s/it]

Loss: 0.28795725107192993


Epoch 4:  17%|█▋        | 10/58 [00:19<01:33,  1.95s/it]

Loss: 0.0729219913482666


Epoch 4:  19%|█▉        | 11/58 [00:21<01:31,  1.95s/it]

Loss: 0.23938564956188202


Epoch 4:  21%|██        | 12/58 [00:23<01:29,  1.95s/it]

Loss: 0.4201454520225525


Epoch 4:  22%|██▏       | 13/58 [00:25<01:27,  1.95s/it]

Loss: 0.493030846118927


Epoch 4:  24%|██▍       | 14/58 [00:27<01:25,  1.95s/it]

Loss: 0.7879126071929932


Epoch 4:  26%|██▌       | 15/58 [00:29<01:23,  1.95s/it]

Loss: 0.08696748316287994


Epoch 4:  28%|██▊       | 16/58 [00:31<01:21,  1.95s/it]

Loss: 0.08326750248670578


Epoch 4:  29%|██▉       | 17/58 [00:33<01:19,  1.95s/it]

Loss: 0.11426384747028351


Epoch 4:  31%|███       | 18/58 [00:35<01:17,  1.95s/it]

Loss: 0.9695732593536377


Epoch 4:  33%|███▎      | 19/58 [00:36<01:15,  1.95s/it]

Loss: 0.14336441457271576


Epoch 4:  34%|███▍      | 20/58 [00:38<01:13,  1.95s/it]

Loss: 0.5556765198707581


Epoch 4:  36%|███▌      | 21/58 [00:40<01:12,  1.95s/it]

Loss: 0.5193219184875488


Epoch 4:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 0.42423129081726074


Epoch 4:  40%|███▉      | 23/58 [00:44<01:08,  1.95s/it]

Loss: 0.4633423089981079


Epoch 4:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.4592336416244507


Epoch 4:  43%|████▎     | 25/58 [00:48<01:04,  1.95s/it]

Loss: 0.6636835336685181


Epoch 4:  45%|████▍     | 26/58 [00:50<01:02,  1.95s/it]

Loss: 1.1977564096450806


Epoch 4:  47%|████▋     | 27/58 [00:52<01:00,  1.95s/it]

Loss: 0.2429528683423996


Epoch 4:  48%|████▊     | 28/58 [00:54<00:58,  1.95s/it]

Loss: 0.24807904660701752


Epoch 4:  50%|█████     | 29/58 [00:56<00:56,  1.95s/it]

Loss: 1.1281721591949463


Epoch 4:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.18112221360206604


Epoch 4:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 0.10741817206144333


Epoch 4:  55%|█████▌    | 32/58 [01:02<00:50,  1.95s/it]

Loss: 0.9772332906723022


Epoch 4:  57%|█████▋    | 33/58 [01:04<00:48,  1.95s/it]

Loss: 0.17214122414588928


Epoch 4:  59%|█████▊    | 34/58 [01:06<00:46,  1.95s/it]

Loss: 0.5516605973243713


Epoch 4:  60%|██████    | 35/58 [01:08<00:44,  1.95s/it]

Loss: 0.1956007480621338


Epoch 4:  62%|██████▏   | 36/58 [01:10<00:42,  1.95s/it]

Loss: 0.38120758533477783


Epoch 4:  64%|██████▍   | 37/58 [01:12<00:40,  1.94s/it]

Loss: 0.3822356164455414


Epoch 4:  66%|██████▌   | 38/58 [01:13<00:38,  1.94s/it]

Loss: 0.46816134452819824


Epoch 4:  67%|██████▋   | 39/58 [01:15<00:36,  1.94s/it]

Loss: 1.0984549522399902


Epoch 4:  69%|██████▉   | 40/58 [01:17<00:34,  1.94s/it]

Loss: 0.529030978679657


Epoch 4:  71%|███████   | 41/58 [01:19<00:33,  1.94s/it]

Loss: 0.1927288919687271


Epoch 4:  72%|███████▏  | 42/58 [01:21<00:31,  1.94s/it]

Loss: 0.37183281779289246


Epoch 4:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.7981616258621216


Epoch 4:  76%|███████▌  | 44/58 [01:25<00:27,  1.94s/it]

Loss: 0.7122855186462402


Epoch 4:  78%|███████▊  | 45/58 [01:27<00:25,  1.94s/it]

Loss: 0.20507583022117615


Epoch 4:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 0.1623198688030243


Epoch 4:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.7437862157821655


Epoch 4:  83%|████████▎ | 48/58 [01:33<00:19,  1.95s/it]

Loss: 0.43338748812675476


Epoch 4:  84%|████████▍ | 49/58 [01:35<00:17,  1.95s/it]

Loss: 0.5005561113357544


Epoch 4:  86%|████████▌ | 50/58 [01:37<00:15,  1.95s/it]

Loss: 0.7678862810134888


Epoch 4:  88%|████████▊ | 51/58 [01:39<00:13,  1.95s/it]

Loss: 0.4866968095302582


Epoch 4:  90%|████████▉ | 52/58 [01:41<00:11,  1.95s/it]

Loss: 0.3900945782661438


Epoch 4:  91%|█████████▏| 53/58 [01:43<00:09,  1.95s/it]

Loss: 0.8355628252029419


Epoch 4:  93%|█████████▎| 54/58 [01:45<00:07,  1.95s/it]

Loss: 0.7679635882377625


Epoch 4:  95%|█████████▍| 55/58 [01:47<00:05,  1.95s/it]

Loss: 0.4499484598636627


Epoch 4:  97%|█████████▋| 56/58 [01:49<00:03,  1.95s/it]

Loss: 0.2650010287761688


Epoch 4: 100%|██████████| 58/58 [01:51<00:00,  1.92s/it]

Loss: 0.7389612793922424
Loss: 0.0036187744699418545





Epoch 4 Validation Accuracy: 0.8768472906403941, F1-macro: 0.5358090185676393


Epoch 5:   2%|▏         | 1/58 [00:01<01:50,  1.94s/it]

Loss: 0.4820844829082489


Epoch 5:   3%|▎         | 2/58 [00:03<01:48,  1.95s/it]

Loss: 0.06562718749046326


Epoch 5:   5%|▌         | 3/58 [00:05<01:47,  1.95s/it]

Loss: 0.22054734826087952


Epoch 5:   7%|▋         | 4/58 [00:07<01:45,  1.94s/it]

Loss: 0.4165117144584656


Epoch 5:   9%|▊         | 5/58 [00:09<01:43,  1.94s/it]

Loss: 0.48484110832214355


Epoch 5:  10%|█         | 6/58 [00:11<01:41,  1.94s/it]

Loss: 0.27186861634254456


Epoch 5:  12%|█▏        | 7/58 [00:13<01:39,  1.94s/it]

Loss: 0.2413676530122757


Epoch 5:  14%|█▍        | 8/58 [00:15<01:37,  1.94s/it]

Loss: 0.264826238155365


Epoch 5:  16%|█▌        | 9/58 [00:17<01:35,  1.94s/it]

Loss: 0.2740345001220703


Epoch 5:  17%|█▋        | 10/58 [00:19<01:33,  1.94s/it]

Loss: 0.6919326782226562


Epoch 5:  19%|█▉        | 11/58 [00:21<01:31,  1.94s/it]

Loss: 0.48463645577430725


Epoch 5:  21%|██        | 12/58 [00:23<01:29,  1.94s/it]

Loss: 0.25096017122268677


Epoch 5:  22%|██▏       | 13/58 [00:25<01:27,  1.94s/it]

Loss: 0.35448703169822693


Epoch 5:  24%|██▍       | 14/58 [00:27<01:25,  1.95s/it]

Loss: 0.2993593215942383


Epoch 5:  26%|██▌       | 15/58 [00:29<01:23,  1.94s/it]

Loss: 0.20852455496788025


Epoch 5:  28%|██▊       | 16/58 [00:31<01:21,  1.95s/it]

Loss: 0.6733502149581909


Epoch 5:  29%|██▉       | 17/58 [00:33<01:19,  1.94s/it]

Loss: 0.21807697415351868


Epoch 5:  31%|███       | 18/58 [00:35<01:17,  1.95s/it]

Loss: 0.46238404512405396


Epoch 5:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 0.2581290602684021


Epoch 5:  34%|███▍      | 20/58 [00:38<01:13,  1.94s/it]

Loss: 0.3062700927257538


Epoch 5:  36%|███▌      | 21/58 [00:40<01:11,  1.94s/it]

Loss: 0.08920363336801529


Epoch 5:  38%|███▊      | 22/58 [00:42<01:10,  1.95s/it]

Loss: 0.37203359603881836


Epoch 5:  40%|███▉      | 23/58 [00:44<01:08,  1.95s/it]

Loss: 0.4104570746421814


Epoch 5:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.5523136854171753


Epoch 5:  43%|████▎     | 25/58 [00:48<01:04,  1.95s/it]

Loss: 0.3039824366569519


Epoch 5:  45%|████▍     | 26/58 [00:50<01:02,  1.95s/it]

Loss: 0.24728000164031982


Epoch 5:  47%|████▋     | 27/58 [00:52<01:00,  1.95s/it]

Loss: 0.5907233953475952


Epoch 5:  48%|████▊     | 28/58 [00:54<00:58,  1.95s/it]

Loss: 0.21636371314525604


Epoch 5:  50%|█████     | 29/58 [00:56<00:56,  1.95s/it]

Loss: 0.005260602571070194


Epoch 5:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.06418438255786896


Epoch 5:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 0.4687455892562866


Epoch 5:  55%|█████▌    | 32/58 [01:02<00:50,  1.95s/it]

Loss: 0.41295546293258667


Epoch 5:  57%|█████▋    | 33/58 [01:04<00:48,  1.95s/it]

Loss: 0.7534447908401489


Epoch 5:  59%|█████▊    | 34/58 [01:06<00:46,  1.95s/it]

Loss: 0.9947587251663208


Epoch 5:  60%|██████    | 35/58 [01:08<00:44,  1.95s/it]

Loss: 0.03471647948026657


Epoch 5:  62%|██████▏   | 36/58 [01:10<00:42,  1.95s/it]

Loss: 0.5734329223632812


Epoch 5:  64%|██████▍   | 37/58 [01:11<00:40,  1.94s/it]

Loss: 0.33068495988845825


Epoch 5:  66%|██████▌   | 38/58 [01:13<00:38,  1.94s/it]

Loss: 0.7174440622329712


Epoch 5:  67%|██████▋   | 39/58 [01:15<00:36,  1.94s/it]

Loss: 0.5101745128631592


Epoch 5:  69%|██████▉   | 40/58 [01:17<00:35,  1.94s/it]

Loss: 0.2824743390083313


Epoch 5:  71%|███████   | 41/58 [01:19<00:33,  1.95s/it]

Loss: 0.6845731735229492


Epoch 5:  72%|███████▏  | 42/58 [01:21<00:31,  1.94s/it]

Loss: 0.4995472729206085


Epoch 5:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.966660737991333


Epoch 5:  76%|███████▌  | 44/58 [01:25<00:27,  1.94s/it]

Loss: 0.129219189286232


Epoch 5:  78%|███████▊  | 45/58 [01:27<00:25,  1.94s/it]

Loss: 0.8188387155532837


Epoch 5:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 1.0512313842773438


Epoch 5:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.0870533138513565


Epoch 5:  83%|████████▎ | 48/58 [01:33<00:19,  1.94s/it]

Loss: 0.05599789321422577


Epoch 5:  84%|████████▍ | 49/58 [01:35<00:17,  1.94s/it]

Loss: 0.02301665022969246


Epoch 5:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.23355518281459808


Epoch 5:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.8082348108291626


Epoch 5:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.48487305641174316


Epoch 5:  91%|█████████▏| 53/58 [01:43<00:09,  1.94s/it]

Loss: 0.22458720207214355


Epoch 5:  93%|█████████▎| 54/58 [01:45<00:07,  1.94s/it]

Loss: 0.07079998403787613


Epoch 5:  95%|█████████▍| 55/58 [01:46<00:05,  1.94s/it]

Loss: 0.7683740854263306


Epoch 5:  97%|█████████▋| 56/58 [01:48<00:03,  1.94s/it]

Loss: 0.6778827905654907


Epoch 5: 100%|██████████| 58/58 [01:50<00:00,  1.91s/it]

Loss: 0.9715015888214111
Loss: 0.0031051719561219215





Epoch 5 Validation Accuracy: 0.5467980295566502, F1-macro: 0.4913943355119825


Epoch 6:   2%|▏         | 1/58 [00:01<01:50,  1.93s/it]

Loss: 1.293868899345398


Epoch 6:   3%|▎         | 2/58 [00:03<01:48,  1.94s/it]

Loss: 0.2531464099884033


Epoch 6:   5%|▌         | 3/58 [00:05<01:46,  1.94s/it]

Loss: 0.5690792202949524


Epoch 6:   7%|▋         | 4/58 [00:07<01:45,  1.95s/it]

Loss: 0.597630500793457


Epoch 6:   9%|▊         | 5/58 [00:09<01:43,  1.95s/it]

Loss: 0.806471049785614


Epoch 6:  10%|█         | 6/58 [00:11<01:41,  1.94s/it]

Loss: 0.9378979206085205


Epoch 6:  12%|█▏        | 7/58 [00:13<01:39,  1.94s/it]

Loss: 0.3853902816772461


Epoch 6:  14%|█▍        | 8/58 [00:15<01:37,  1.94s/it]

Loss: 0.28014662861824036


Epoch 6:  16%|█▌        | 9/58 [00:17<01:35,  1.94s/it]

Loss: 0.35351383686065674


Epoch 6:  17%|█▋        | 10/58 [00:19<01:33,  1.94s/it]

Loss: 0.31174227595329285


Epoch 6:  19%|█▉        | 11/58 [00:21<01:31,  1.94s/it]

Loss: 0.7488536834716797


Epoch 6:  21%|██        | 12/58 [00:23<01:29,  1.94s/it]

Loss: 0.19506362080574036


Epoch 6:  22%|██▏       | 13/58 [00:25<01:27,  1.94s/it]

Loss: 0.2158387005329132


Epoch 6:  24%|██▍       | 14/58 [00:27<01:25,  1.94s/it]

Loss: 0.09495532512664795


Epoch 6:  26%|██▌       | 15/58 [00:29<01:23,  1.94s/it]

Loss: 0.4450477361679077


Epoch 6:  28%|██▊       | 16/58 [00:31<01:21,  1.94s/it]

Loss: 0.7356858253479004


Epoch 6:  29%|██▉       | 17/58 [00:33<01:19,  1.94s/it]

Loss: 0.22523760795593262


Epoch 6:  31%|███       | 18/58 [00:34<01:17,  1.94s/it]

Loss: 0.6156546473503113


Epoch 6:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 0.19828125834465027


Epoch 6:  34%|███▍      | 20/58 [00:38<01:13,  1.94s/it]

Loss: 0.2682083249092102


Epoch 6:  36%|███▌      | 21/58 [00:40<01:11,  1.94s/it]

Loss: 0.3081657290458679


Epoch 6:  38%|███▊      | 22/58 [00:42<01:10,  1.94s/it]

Loss: 0.4042510390281677


Epoch 6:  40%|███▉      | 23/58 [00:44<01:08,  1.94s/it]

Loss: 0.3873395025730133


Epoch 6:  41%|████▏     | 24/58 [00:46<01:06,  1.95s/it]

Loss: 0.19132472574710846


Epoch 6:  43%|████▎     | 25/58 [00:48<01:04,  1.94s/it]

Loss: 0.5916454195976257


Epoch 6:  45%|████▍     | 26/58 [00:50<01:02,  1.94s/it]

Loss: 0.05381502956151962


Epoch 6:  47%|████▋     | 27/58 [00:52<01:00,  1.94s/it]

Loss: 0.20403984189033508


Epoch 6:  48%|████▊     | 28/58 [00:54<00:58,  1.94s/it]

Loss: 0.3655116558074951


Epoch 6:  50%|█████     | 29/58 [00:56<00:56,  1.95s/it]

Loss: 0.05222836881875992


Epoch 6:  52%|█████▏    | 30/58 [00:58<00:54,  1.95s/it]

Loss: 0.2156141698360443


Epoch 6:  53%|█████▎    | 31/58 [01:00<00:52,  1.95s/it]

Loss: 0.37159401178359985


Epoch 6:  55%|█████▌    | 32/58 [01:02<00:50,  1.95s/it]

Loss: 0.489362895488739


Epoch 6:  57%|█████▋    | 33/58 [01:04<00:48,  1.95s/it]

Loss: 0.3258713185787201


Epoch 6:  59%|█████▊    | 34/58 [01:06<00:46,  1.95s/it]

Loss: 0.29758840799331665


Epoch 6:  60%|██████    | 35/58 [01:08<00:44,  1.95s/it]

Loss: 0.38330501317977905


Epoch 6:  62%|██████▏   | 36/58 [01:10<00:42,  1.95s/it]

Loss: 0.15644867718219757


Epoch 6:  64%|██████▍   | 37/58 [01:11<00:40,  1.95s/it]

Loss: 0.23276899755001068


Epoch 6:  66%|██████▌   | 38/58 [01:13<00:38,  1.95s/it]

Loss: 0.614116370677948


Epoch 6:  67%|██████▋   | 39/58 [01:15<00:36,  1.95s/it]

Loss: 0.19981563091278076


Epoch 6:  69%|██████▉   | 40/58 [01:17<00:35,  1.95s/it]

Loss: 0.39128661155700684


Epoch 6:  71%|███████   | 41/58 [01:19<00:33,  1.95s/it]

Loss: 0.33097144961357117


Epoch 6:  72%|███████▏  | 42/58 [01:21<00:31,  1.95s/it]

Loss: 0.05401770770549774


Epoch 6:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.2607414722442627


Epoch 6:  76%|███████▌  | 44/58 [01:25<00:27,  1.95s/it]

Loss: 0.1335204839706421


Epoch 6:  78%|███████▊  | 45/58 [01:27<00:25,  1.95s/it]

Loss: 0.5245375633239746


Epoch 6:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 0.09713058918714523


Epoch 6:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.1712024211883545


Epoch 6:  83%|████████▎ | 48/58 [01:33<00:19,  1.94s/it]

Loss: 0.44744691252708435


Epoch 6:  84%|████████▍ | 49/58 [01:35<00:17,  1.94s/it]

Loss: 0.31191378831863403


Epoch 6:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.18846052885055542


Epoch 6:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.307758629322052


Epoch 6:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.26768770813941956


Epoch 6:  91%|█████████▏| 53/58 [01:43<00:09,  1.94s/it]

Loss: 0.3594382703304291


Epoch 6:  93%|█████████▎| 54/58 [01:45<00:07,  1.94s/it]

Loss: 0.22869999706745148


Epoch 6:  95%|█████████▍| 55/58 [01:46<00:05,  1.94s/it]

Loss: 0.19961294531822205


Epoch 6:  97%|█████████▋| 56/58 [01:48<00:03,  1.94s/it]

Loss: 0.1522245705127716


Epoch 6: 100%|██████████| 58/58 [01:50<00:00,  1.91s/it]

Loss: 0.20891936123371124
Loss: 0.00012139468162786216





Epoch 6 Validation Accuracy: 0.8817733990147784, F1-macro: 0.5396825396825397


Epoch 7:   2%|▏         | 1/58 [00:01<01:49,  1.93s/it]

Loss: 0.021382944658398628


Epoch 7:   3%|▎         | 2/58 [00:03<01:48,  1.94s/it]

Loss: 0.37832462787628174


Epoch 7:   5%|▌         | 3/58 [00:05<01:46,  1.94s/it]

Loss: 0.2127527892589569


Epoch 7:   7%|▋         | 4/58 [00:07<01:44,  1.94s/it]

Loss: 0.6507871150970459


Epoch 7:   9%|▊         | 5/58 [00:09<01:42,  1.94s/it]

Loss: 0.11602680385112762


Epoch 7:  10%|█         | 6/58 [00:11<01:41,  1.94s/it]

Loss: 1.0016448497772217


Epoch 7:  12%|█▏        | 7/58 [00:13<01:39,  1.94s/it]

Loss: 0.39587104320526123


Epoch 7:  14%|█▍        | 8/58 [00:15<01:37,  1.95s/it]

Loss: 0.117952860891819


Epoch 7:  16%|█▌        | 9/58 [00:17<01:35,  1.95s/it]

Loss: 0.5638911128044128


Epoch 7:  17%|█▋        | 10/58 [00:19<01:33,  1.95s/it]

Loss: 0.33678561449050903


Epoch 7:  19%|█▉        | 11/58 [00:21<01:31,  1.95s/it]

Loss: 0.6685609817504883


Epoch 7:  21%|██        | 12/58 [00:23<01:29,  1.95s/it]

Loss: 0.2855525612831116


Epoch 7:  22%|██▏       | 13/58 [00:25<01:27,  1.95s/it]

Loss: 0.9035701155662537


Epoch 7:  24%|██▍       | 14/58 [00:27<01:25,  1.95s/it]

Loss: 0.44417670369148254


Epoch 7:  26%|██▌       | 15/58 [00:29<01:23,  1.95s/it]

Loss: 0.28588590025901794


Epoch 7:  28%|██▊       | 16/58 [00:31<01:21,  1.95s/it]

Loss: 0.13334347307682037


Epoch 7:  29%|██▉       | 17/58 [00:33<01:19,  1.94s/it]

Loss: 0.15625566244125366


Epoch 7:  31%|███       | 18/58 [00:35<01:17,  1.94s/it]

Loss: 0.2734034061431885


Epoch 7:  33%|███▎      | 19/58 [00:36<01:15,  1.94s/it]

Loss: 0.3625238835811615


Epoch 7:  34%|███▍      | 20/58 [00:38<01:13,  1.94s/it]

Loss: 0.5020101070404053


Epoch 7:  36%|███▌      | 21/58 [00:40<01:11,  1.94s/it]

Loss: 0.0810399204492569


Epoch 7:  38%|███▊      | 22/58 [00:42<01:09,  1.94s/it]

Loss: 0.2929864525794983


Epoch 7:  40%|███▉      | 23/58 [00:44<01:08,  1.94s/it]

Loss: 0.11466898024082184


Epoch 7:  41%|████▏     | 24/58 [00:46<01:06,  1.94s/it]

Loss: 0.31935596466064453


Epoch 7:  43%|████▎     | 25/58 [00:48<01:04,  1.94s/it]

Loss: 0.0384037047624588


Epoch 7:  45%|████▍     | 26/58 [00:50<01:02,  1.94s/it]

Loss: 0.1022777408361435


Epoch 7:  47%|████▋     | 27/58 [00:52<01:00,  1.94s/it]

Loss: 0.40839385986328125


Epoch 7:  48%|████▊     | 28/58 [00:54<00:58,  1.94s/it]

Loss: 0.1579839289188385


Epoch 7:  50%|█████     | 29/58 [00:56<00:56,  1.94s/it]

Loss: 0.345877081155777


Epoch 7:  52%|█████▏    | 30/58 [00:58<00:54,  1.94s/it]

Loss: 0.15556645393371582


Epoch 7:  53%|█████▎    | 31/58 [01:00<00:52,  1.94s/it]

Loss: 0.3234781324863434


Epoch 7:  55%|█████▌    | 32/58 [01:02<00:50,  1.94s/it]

Loss: 0.07744784653186798


Epoch 7:  57%|█████▋    | 33/58 [01:04<00:48,  1.94s/it]

Loss: 0.32354897260665894


Epoch 7:  59%|█████▊    | 34/58 [01:06<00:46,  1.94s/it]

Loss: 0.32054951786994934


Epoch 7:  60%|██████    | 35/58 [01:08<00:44,  1.95s/it]

Loss: 0.260833203792572


Epoch 7:  62%|██████▏   | 36/58 [01:10<00:42,  1.95s/it]

Loss: 0.3020411729812622


Epoch 7:  64%|██████▍   | 37/58 [01:11<00:40,  1.95s/it]

Loss: 0.21380990743637085


Epoch 7:  66%|██████▌   | 38/58 [01:13<00:38,  1.95s/it]

Loss: 0.2543909549713135


Epoch 7:  67%|██████▋   | 39/58 [01:15<00:36,  1.94s/it]

Loss: 0.045531682670116425


Epoch 7:  69%|██████▉   | 40/58 [01:17<00:34,  1.94s/it]

Loss: 0.6364452838897705


Epoch 7:  71%|███████   | 41/58 [01:19<00:33,  1.94s/it]

Loss: 0.28151625394821167


Epoch 7:  72%|███████▏  | 42/58 [01:21<00:31,  1.95s/it]

Loss: 0.5420610904693604


Epoch 7:  74%|███████▍  | 43/58 [01:23<00:29,  1.94s/it]

Loss: 0.7324241995811462


Epoch 7:  76%|███████▌  | 44/58 [01:25<00:27,  1.94s/it]

Loss: 0.4007521867752075


Epoch 7:  78%|███████▊  | 45/58 [01:27<00:25,  1.94s/it]

Loss: 0.10522225499153137


Epoch 7:  79%|███████▉  | 46/58 [01:29<00:23,  1.94s/it]

Loss: 0.3452218770980835


Epoch 7:  81%|████████  | 47/58 [01:31<00:21,  1.94s/it]

Loss: 0.5932996273040771


Epoch 7:  83%|████████▎ | 48/58 [01:33<00:19,  1.94s/it]

Loss: 0.896527886390686


Epoch 7:  84%|████████▍ | 49/58 [01:35<00:17,  1.94s/it]

Loss: 0.942075252532959


Epoch 7:  86%|████████▌ | 50/58 [01:37<00:15,  1.94s/it]

Loss: 0.5237204432487488


Epoch 7:  88%|████████▊ | 51/58 [01:39<00:13,  1.94s/it]

Loss: 0.7793439626693726


Epoch 7:  90%|████████▉ | 52/58 [01:41<00:11,  1.94s/it]

Loss: 0.5073126554489136


Epoch 7:  91%|█████████▏| 53/58 [01:43<00:09,  1.94s/it]

Loss: 0.18412475287914276


Epoch 7:  93%|█████████▎| 54/58 [01:44<00:07,  1.94s/it]

Loss: 0.46703845262527466


Epoch 7:  95%|█████████▍| 55/58 [01:46<00:05,  1.94s/it]

Loss: 0.548534095287323


Epoch 7:  97%|█████████▋| 56/58 [01:48<00:03,  1.94s/it]

Loss: 0.18926510214805603


Epoch 7: 100%|██████████| 58/58 [01:50<00:00,  1.91s/it]

Loss: 0.06566423177719116
Loss: 7.544970139861107e-05





Epoch 7 Validation Accuracy: 0.8866995073891626, F1-macro: 0.6880053458068828


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model4, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8872549019607843, F1-macro: 0.6739628934750885


In [None]:
current_type = 'emotional reasoning'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 5. Fortune Telling

In [None]:
# Add labels
data1_1_labels = list(data1['fortune telling'][data1_1.index])
data2_1_labels = list(data2['fortune telling'][data2_1.index])
data3_1_labels = list(data3['fortune telling'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model5 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model5.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model5.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model5(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model5, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/55 [00:01<01:43,  1.92s/it]

Loss: 9.603532791137695


Epoch 1:   4%|▎         | 2/55 [00:03<01:42,  1.93s/it]

Loss: 2.1768786907196045


Epoch 1:   5%|▌         | 3/55 [00:05<01:40,  1.94s/it]

Loss: 0.8796324729919434


Epoch 1:   7%|▋         | 4/55 [00:07<01:38,  1.94s/it]

Loss: 6.353041172027588


Epoch 1:   9%|▉         | 5/55 [00:09<01:37,  1.94s/it]

Loss: 4.87765645980835


Epoch 1:  11%|█         | 6/55 [00:11<01:35,  1.94s/it]

Loss: 5.893891334533691


Epoch 1:  13%|█▎        | 7/55 [00:13<01:33,  1.94s/it]

Loss: 5.053342819213867


Epoch 1:  15%|█▍        | 8/55 [00:15<01:31,  1.94s/it]

Loss: 5.242246627807617


Epoch 1:  16%|█▋        | 9/55 [00:17<01:29,  1.94s/it]

Loss: 0.623563826084137


Epoch 1:  18%|█▊        | 10/55 [00:19<01:27,  1.94s/it]

Loss: 2.4179940223693848


Epoch 1:  20%|██        | 11/55 [00:21<01:25,  1.94s/it]

Loss: 1.3108011484146118


Epoch 1:  22%|██▏       | 12/55 [00:23<01:23,  1.94s/it]

Loss: 1.115107536315918


Epoch 1:  24%|██▎       | 13/55 [00:25<01:21,  1.95s/it]

Loss: 3.0573782920837402


Epoch 1:  25%|██▌       | 14/55 [00:27<01:19,  1.94s/it]

Loss: 1.8877918720245361


Epoch 1:  27%|██▋       | 15/55 [00:29<01:17,  1.94s/it]

Loss: 0.5093073844909668


Epoch 1:  29%|██▉       | 16/55 [00:31<01:15,  1.94s/it]

Loss: 1.19645094871521


Epoch 1:  31%|███       | 17/55 [00:33<01:13,  1.94s/it]

Loss: 2.3355283737182617


Epoch 1:  33%|███▎      | 18/55 [00:34<01:11,  1.94s/it]

Loss: 1.4099924564361572


Epoch 1:  35%|███▍      | 19/55 [00:36<01:10,  1.94s/it]

Loss: 2.0671255588531494


Epoch 1:  36%|███▋      | 20/55 [00:38<01:08,  1.95s/it]

Loss: 1.5333082675933838


Epoch 1:  38%|███▊      | 21/55 [00:40<01:06,  1.95s/it]

Loss: 1.0408979654312134


Epoch 1:  40%|████      | 22/55 [00:42<01:04,  1.95s/it]

Loss: 0.9689686298370361


Epoch 1:  42%|████▏     | 23/55 [00:44<01:02,  1.94s/it]

Loss: 0.47339776158332825


Epoch 1:  44%|████▎     | 24/55 [00:46<01:00,  1.94s/it]

Loss: 1.9739335775375366


Epoch 1:  45%|████▌     | 25/55 [00:48<00:58,  1.94s/it]

Loss: 0.42217880487442017


Epoch 1:  47%|████▋     | 26/55 [00:50<00:56,  1.94s/it]

Loss: 0.676784098148346


Epoch 1:  49%|████▉     | 27/55 [00:52<00:54,  1.94s/it]

Loss: 0.5275653004646301


Epoch 1:  51%|█████     | 28/55 [00:54<00:52,  1.94s/it]

Loss: 1.1124149560928345


Epoch 1:  53%|█████▎    | 29/55 [00:56<00:50,  1.94s/it]

Loss: 1.1070936918258667


Epoch 1:  55%|█████▍    | 30/55 [00:58<00:48,  1.94s/it]

Loss: 0.4571172595024109


Epoch 1:  56%|█████▋    | 31/55 [01:00<00:46,  1.94s/it]

Loss: 0.29161661863327026


Epoch 1:  58%|█████▊    | 32/55 [01:02<00:44,  1.94s/it]

Loss: 0.4685628414154053


Epoch 1:  60%|██████    | 33/55 [01:04<00:42,  1.94s/it]

Loss: 0.899436891078949


Epoch 1:  62%|██████▏   | 34/55 [01:06<00:40,  1.94s/it]

Loss: 1.3837206363677979


Epoch 1:  64%|██████▎   | 35/55 [01:08<00:38,  1.94s/it]

Loss: 1.2868435382843018


Epoch 1:  65%|██████▌   | 36/55 [01:09<00:36,  1.94s/it]

Loss: 1.725940465927124


Epoch 1:  67%|██████▋   | 37/55 [01:11<00:35,  1.94s/it]

Loss: 0.30563533306121826


Epoch 1:  69%|██████▉   | 38/55 [01:13<00:33,  1.94s/it]

Loss: 0.7021969556808472


Epoch 1:  71%|███████   | 39/55 [01:15<00:31,  1.94s/it]

Loss: 1.0124943256378174


Epoch 1:  73%|███████▎  | 40/55 [01:17<00:29,  1.94s/it]

Loss: 1.0645979642868042


Epoch 1:  75%|███████▍  | 41/55 [01:19<00:27,  1.95s/it]

Loss: 0.9658134579658508


Epoch 1:  76%|███████▋  | 42/55 [01:21<00:25,  1.95s/it]

Loss: 1.1391466856002808


Epoch 1:  78%|███████▊  | 43/55 [01:23<00:23,  1.94s/it]

Loss: 0.8669151067733765


Epoch 1:  80%|████████  | 44/55 [01:25<00:21,  1.94s/it]

Loss: 1.1162891387939453


Epoch 1:  82%|████████▏ | 45/55 [01:27<00:19,  1.94s/it]

Loss: 0.7871640920639038


Epoch 1:  84%|████████▎ | 46/55 [01:29<00:17,  1.95s/it]

Loss: 0.6183763146400452


Epoch 1:  85%|████████▌ | 47/55 [01:31<00:15,  1.95s/it]

Loss: 1.4537122249603271


Epoch 1:  87%|████████▋ | 48/55 [01:33<00:13,  1.95s/it]

Loss: 0.5537384748458862


Epoch 1:  89%|████████▉ | 49/55 [01:35<00:11,  1.95s/it]

Loss: 2.0757696628570557


Epoch 1:  91%|█████████ | 50/55 [01:37<00:09,  1.95s/it]

Loss: 0.5712948441505432


Epoch 1:  93%|█████████▎| 51/55 [01:39<00:07,  1.95s/it]

Loss: 1.6499346494674683


Epoch 1:  95%|█████████▍| 52/55 [01:41<00:05,  1.95s/it]

Loss: 0.760746955871582


Epoch 1:  96%|█████████▋| 53/55 [01:43<00:03,  1.95s/it]

Loss: 1.2012101411819458


Epoch 1:  98%|█████████▊| 54/55 [01:44<00:01,  1.95s/it]

Loss: 1.1565234661102295


Epoch 1: 100%|██████████| 55/55 [01:45<00:00,  1.92s/it]

Loss: 0.8580764532089233





Epoch 1 Validation Accuracy: 0.8177339901477833, F1-macro: 0.44986449864498645


Epoch 2:   2%|▏         | 1/55 [00:01<01:44,  1.93s/it]

Loss: 0.7712557911872864


Epoch 2:   4%|▎         | 2/55 [00:03<01:43,  1.95s/it]

Loss: 1.1322906017303467


Epoch 2:   5%|▌         | 3/55 [00:05<01:41,  1.95s/it]

Loss: 1.0151495933532715


Epoch 2:   7%|▋         | 4/55 [00:07<01:39,  1.95s/it]

Loss: 0.5862223505973816


Epoch 2:   9%|▉         | 5/55 [00:09<01:37,  1.95s/it]

Loss: 0.3882697820663452


Epoch 2:  11%|█         | 6/55 [00:11<01:35,  1.94s/it]

Loss: 0.23641400039196014


Epoch 2:  13%|█▎        | 7/55 [00:13<01:33,  1.94s/it]

Loss: 0.656880259513855


Epoch 2:  15%|█▍        | 8/55 [00:15<01:31,  1.94s/it]

Loss: 0.5094080567359924


Epoch 2:  16%|█▋        | 9/55 [00:17<01:29,  1.94s/it]

Loss: 0.5830252170562744


Epoch 2:  18%|█▊        | 10/55 [00:19<01:27,  1.94s/it]

Loss: 0.6074978113174438


Epoch 2:  20%|██        | 11/55 [00:21<01:25,  1.94s/it]

Loss: 0.39650237560272217


Epoch 2:  22%|██▏       | 12/55 [00:23<01:23,  1.94s/it]

Loss: 0.6533844470977783


Epoch 2:  24%|██▎       | 13/55 [00:25<01:21,  1.94s/it]

Loss: 0.3581882119178772


Epoch 2:  25%|██▌       | 14/55 [00:27<01:19,  1.94s/it]

Loss: 0.5290446281433105


Epoch 2:  27%|██▋       | 15/55 [00:29<01:17,  1.94s/it]

Loss: 0.9091468453407288


Epoch 2:  29%|██▉       | 16/55 [00:31<01:15,  1.94s/it]

Loss: 1.2796568870544434


Epoch 2:  31%|███       | 17/55 [00:33<01:13,  1.94s/it]

Loss: 0.5096369981765747


Epoch 2:  33%|███▎      | 18/55 [00:34<01:11,  1.94s/it]

Loss: 0.2556043267250061


Epoch 2:  35%|███▍      | 19/55 [00:36<01:09,  1.94s/it]

Loss: 0.6359245181083679


Epoch 2:  36%|███▋      | 20/55 [00:38<01:08,  1.94s/it]

Loss: 0.4867285192012787


Epoch 2:  38%|███▊      | 21/55 [00:40<01:06,  1.94s/it]

Loss: 0.43991291522979736


Epoch 2:  40%|████      | 22/55 [00:42<01:04,  1.94s/it]

Loss: 0.6751141548156738


Epoch 2:  42%|████▏     | 23/55 [00:44<01:02,  1.94s/it]

Loss: 0.6472328901290894


Epoch 2:  44%|████▎     | 24/55 [00:46<01:00,  1.95s/it]

Loss: 0.31129103899002075


Epoch 2:  45%|████▌     | 25/55 [00:48<00:58,  1.95s/it]

Loss: 0.5768184661865234


Epoch 2:  47%|████▋     | 26/55 [00:50<00:56,  1.95s/it]

Loss: 0.3630336821079254


Epoch 2:  49%|████▉     | 27/55 [00:52<00:54,  1.95s/it]

Loss: 0.6476148366928101


Epoch 2:  51%|█████     | 28/55 [00:54<00:52,  1.95s/it]

Loss: 0.22292892634868622


Epoch 2:  53%|█████▎    | 29/55 [00:56<00:50,  1.95s/it]

Loss: 1.1134917736053467


Epoch 2:  55%|█████▍    | 30/55 [00:58<00:48,  1.95s/it]

Loss: 0.8550925254821777


Epoch 2:  56%|█████▋    | 31/55 [01:00<00:46,  1.95s/it]

Loss: 0.38388174772262573


Epoch 2:  58%|█████▊    | 32/55 [01:02<00:44,  1.95s/it]

Loss: 0.46896928548812866


Epoch 2:  60%|██████    | 33/55 [01:04<00:42,  1.95s/it]

Loss: 0.6653263568878174


Epoch 2:  62%|██████▏   | 34/55 [01:06<00:40,  1.95s/it]

Loss: 0.45655444264411926


Epoch 2:  64%|██████▎   | 35/55 [01:08<00:38,  1.94s/it]

Loss: 0.3711967170238495


Epoch 2:  65%|██████▌   | 36/55 [01:10<00:36,  1.94s/it]

Loss: 0.353169322013855


Epoch 2:  67%|██████▋   | 37/55 [01:11<00:34,  1.94s/it]

Loss: 0.9153546094894409


Epoch 2:  69%|██████▉   | 38/55 [01:13<00:33,  1.94s/it]

Loss: 0.1786460280418396


Epoch 2:  71%|███████   | 39/55 [01:15<00:31,  1.94s/it]

Loss: 0.6435128450393677


Epoch 2:  73%|███████▎  | 40/55 [01:17<00:29,  1.94s/it]

Loss: 0.6388583779335022


Epoch 2:  75%|███████▍  | 41/55 [01:19<00:27,  1.94s/it]

Loss: 0.8575252890586853


Epoch 2:  76%|███████▋  | 42/55 [01:21<00:25,  1.94s/it]

Loss: 0.8358783721923828


Epoch 2:  78%|███████▊  | 43/55 [01:23<00:23,  1.94s/it]

Loss: 0.499861478805542


Epoch 2:  80%|████████  | 44/55 [01:25<00:21,  1.94s/it]

Loss: 0.8400373458862305


Epoch 2:  82%|████████▏ | 45/55 [01:27<00:19,  1.94s/it]

Loss: 0.61842280626297


Epoch 2:  84%|████████▎ | 46/55 [01:29<00:17,  1.94s/it]

Loss: 0.16059890389442444


Epoch 2:  85%|████████▌ | 47/55 [01:31<00:15,  1.94s/it]

Loss: 0.3855377435684204


Epoch 2:  87%|████████▋ | 48/55 [01:33<00:13,  1.95s/it]

Loss: 0.8694460391998291


Epoch 2:  89%|████████▉ | 49/55 [01:35<00:11,  1.94s/it]

Loss: 0.3363775610923767


Epoch 2:  91%|█████████ | 50/55 [01:37<00:09,  1.94s/it]

Loss: 1.1650567054748535


Epoch 2:  93%|█████████▎| 51/55 [01:39<00:07,  1.94s/it]

Loss: 1.081785798072815


Epoch 2:  95%|█████████▍| 52/55 [01:41<00:05,  1.94s/it]

Loss: 0.30419814586639404


Epoch 2:  96%|█████████▋| 53/55 [01:43<00:03,  1.94s/it]

Loss: 0.4267053008079529


Epoch 2:  98%|█████████▊| 54/55 [01:45<00:01,  1.94s/it]

Loss: 0.38035136461257935


Epoch 2: 100%|██████████| 55/55 [01:45<00:00,  1.92s/it]

Loss: 0.3671230673789978





Epoch 2 Validation Accuracy: 0.7684729064039408, F1-macro: 0.6306662021445438


Epoch 3:   2%|▏         | 1/55 [00:01<01:45,  1.95s/it]

Loss: 0.24376316368579865


Epoch 3:   4%|▎         | 2/55 [00:03<01:43,  1.95s/it]

Loss: 0.556287407875061


Epoch 3:   5%|▌         | 3/55 [00:05<01:41,  1.95s/it]

Loss: 0.9776394367218018


Epoch 3:   7%|▋         | 4/55 [00:07<01:39,  1.95s/it]

Loss: 1.2894597053527832


Epoch 3:   9%|▉         | 5/55 [00:09<01:37,  1.95s/it]

Loss: 0.3764629065990448


Epoch 3:  11%|█         | 6/55 [00:11<01:35,  1.95s/it]

Loss: 0.7358694076538086


Epoch 3:  13%|█▎        | 7/55 [00:13<01:33,  1.95s/it]

Loss: 0.5366089344024658


Epoch 3:  15%|█▍        | 8/55 [00:15<01:31,  1.95s/it]

Loss: 0.14732924103736877


Epoch 3:  16%|█▋        | 9/55 [00:17<01:29,  1.95s/it]

Loss: 0.9793306589126587


Epoch 3:  18%|█▊        | 10/55 [00:19<01:27,  1.95s/it]

Loss: 0.5542278289794922


Epoch 3:  20%|██        | 11/55 [00:21<01:25,  1.95s/it]

Loss: 0.3833335340023041


Epoch 3:  22%|██▏       | 12/55 [00:23<01:23,  1.95s/it]

Loss: 0.33557599782943726


Epoch 3:  24%|██▎       | 13/55 [00:25<01:21,  1.95s/it]

Loss: 0.28046393394470215


Epoch 3:  25%|██▌       | 14/55 [00:27<01:19,  1.94s/it]

Loss: 0.2484775185585022


Epoch 3:  27%|██▋       | 15/55 [00:29<01:17,  1.95s/it]

Loss: 0.6861164569854736


Epoch 3:  29%|██▉       | 16/55 [00:31<01:15,  1.95s/it]

Loss: 0.39956241846084595


Epoch 3:  31%|███       | 17/55 [00:33<01:13,  1.95s/it]

Loss: 0.7597461938858032


Epoch 3:  33%|███▎      | 18/55 [00:35<01:11,  1.95s/it]

Loss: 0.7080533504486084


Epoch 3:  35%|███▍      | 19/55 [00:36<01:10,  1.94s/it]

Loss: 0.2303454875946045


Epoch 3:  36%|███▋      | 20/55 [00:38<01:08,  1.94s/it]

Loss: 0.16545775532722473


Epoch 3:  38%|███▊      | 21/55 [00:40<01:06,  1.94s/it]

Loss: 0.3013116717338562


Epoch 3:  40%|████      | 22/55 [00:42<01:04,  1.94s/it]

Loss: 0.47938767075538635


Epoch 3:  42%|████▏     | 23/55 [00:44<01:02,  1.94s/it]

Loss: 0.3892223834991455


Epoch 3:  44%|████▎     | 24/55 [00:46<01:00,  1.94s/it]

Loss: 0.39426565170288086


Epoch 3:  45%|████▌     | 25/55 [00:48<00:58,  1.94s/it]

Loss: 0.3009188771247864


Epoch 3:  47%|████▋     | 26/55 [00:50<00:56,  1.94s/it]

Loss: 0.36839771270751953


Epoch 3:  49%|████▉     | 27/55 [00:52<00:54,  1.94s/it]

Loss: 0.4584793448448181


Epoch 3:  51%|█████     | 28/55 [00:54<00:52,  1.95s/it]

Loss: 0.10781873762607574


Epoch 3:  53%|█████▎    | 29/55 [00:56<00:50,  1.94s/it]

Loss: 0.5981972217559814


Epoch 3:  55%|█████▍    | 30/55 [00:58<00:48,  1.95s/it]

Loss: 0.4650483727455139


Epoch 3:  56%|█████▋    | 31/55 [01:00<00:46,  1.94s/it]

Loss: 0.18099844455718994


Epoch 3:  58%|█████▊    | 32/55 [01:02<00:44,  1.94s/it]

Loss: 0.22102612257003784


Epoch 3:  60%|██████    | 33/55 [01:04<00:42,  1.94s/it]

Loss: 0.16744087636470795


Epoch 3:  62%|██████▏   | 34/55 [01:06<00:40,  1.94s/it]

Loss: 0.41034334897994995


Epoch 3:  64%|██████▎   | 35/55 [01:08<00:38,  1.94s/it]

Loss: 0.27285581827163696


Epoch 3:  65%|██████▌   | 36/55 [01:10<00:36,  1.95s/it]

Loss: 0.19287700951099396


Epoch 3:  67%|██████▋   | 37/55 [01:11<00:35,  1.95s/it]

Loss: 0.2532603442668915


Epoch 3:  69%|██████▉   | 38/55 [01:13<00:33,  1.95s/it]

Loss: 0.43757563829421997


Epoch 3:  71%|███████   | 39/55 [01:15<00:31,  1.95s/it]

Loss: 0.31947335600852966


Epoch 3:  73%|███████▎  | 40/55 [01:17<00:29,  1.95s/it]

Loss: 0.3539310693740845


Epoch 3:  75%|███████▍  | 41/55 [01:19<00:27,  1.95s/it]

Loss: 0.3174993395805359


Epoch 3:  76%|███████▋  | 42/55 [01:21<00:25,  1.95s/it]

Loss: 0.5776681303977966


Epoch 3:  78%|███████▊  | 43/55 [01:23<00:23,  1.95s/it]

Loss: 0.18862637877464294


Epoch 3:  80%|████████  | 44/55 [01:25<00:21,  1.95s/it]

Loss: 0.4103034734725952


Epoch 3:  82%|████████▏ | 45/55 [01:27<00:19,  1.95s/it]

Loss: 0.5565869212150574


Epoch 3:  84%|████████▎ | 46/55 [01:29<00:17,  1.95s/it]

Loss: 0.5223522186279297


Epoch 3:  85%|████████▌ | 47/55 [01:31<00:15,  1.95s/it]

Loss: 0.4196915626525879


Epoch 3:  87%|████████▋ | 48/55 [01:33<00:13,  1.94s/it]

Loss: 0.4669065475463867


Epoch 3:  89%|████████▉ | 49/55 [01:35<00:11,  1.94s/it]

Loss: 0.3868548572063446


Epoch 3:  91%|█████████ | 50/55 [01:37<00:09,  1.94s/it]

Loss: 0.16297928988933563


Epoch 3:  93%|█████████▎| 51/55 [01:39<00:07,  1.94s/it]

Loss: 0.5627274513244629


Epoch 3:  95%|█████████▍| 52/55 [01:41<00:05,  1.94s/it]

Loss: 0.5675663352012634


Epoch 3:  96%|█████████▋| 53/55 [01:43<00:03,  1.94s/it]

Loss: 0.6185736656188965


Epoch 3:  98%|█████████▊| 54/55 [01:45<00:01,  1.94s/it]

Loss: 1.081851840019226


Epoch 3: 100%|██████████| 55/55 [01:45<00:00,  1.92s/it]

Loss: 0.2661549746990204





Epoch 3 Validation Accuracy: 0.7931034482758621, F1-macro: 0.6953259005145798


Epoch 4:   2%|▏         | 1/55 [00:01<01:44,  1.94s/it]

Loss: 0.3853308856487274


Epoch 4:   4%|▎         | 2/55 [00:03<01:42,  1.94s/it]

Loss: 0.33841434121131897


Epoch 4:   5%|▌         | 3/55 [00:05<01:40,  1.94s/it]

Loss: 0.3127311170101166


Epoch 4:   7%|▋         | 4/55 [00:07<01:39,  1.94s/it]

Loss: 0.7912967801094055


Epoch 4:   9%|▉         | 5/55 [00:09<01:37,  1.95s/it]

Loss: 0.3734217882156372


Epoch 4:  11%|█         | 6/55 [00:11<01:35,  1.95s/it]

Loss: 0.23429110646247864


Epoch 4:  13%|█▎        | 7/55 [00:13<01:33,  1.95s/it]

Loss: 0.4718720316886902


Epoch 4:  15%|█▍        | 8/55 [00:15<01:31,  1.95s/it]

Loss: 0.267255961894989


Epoch 4:  16%|█▋        | 9/55 [00:17<01:29,  1.95s/it]

Loss: 0.7115012407302856


Epoch 4:  18%|█▊        | 10/55 [00:19<01:27,  1.95s/it]

Loss: 0.08302337676286697


Epoch 4:  20%|██        | 11/55 [00:21<01:25,  1.94s/it]

Loss: 0.36258208751678467


Epoch 4:  22%|██▏       | 12/55 [00:23<01:23,  1.94s/it]

Loss: 0.27202194929122925


Epoch 4:  24%|██▎       | 13/55 [00:25<01:21,  1.95s/it]

Loss: 0.12199459224939346


Epoch 4:  25%|██▌       | 14/55 [00:27<01:19,  1.95s/it]

Loss: 0.3745853006839752


Epoch 4:  27%|██▋       | 15/55 [00:29<01:17,  1.95s/it]

Loss: 0.2511255741119385


Epoch 4:  29%|██▉       | 16/55 [00:31<01:15,  1.95s/it]

Loss: 0.1509677767753601


Epoch 4:  31%|███       | 17/55 [00:33<01:13,  1.95s/it]

Loss: 0.45133769512176514


Epoch 4:  33%|███▎      | 18/55 [00:35<01:12,  1.95s/it]

Loss: 0.6422135829925537


Epoch 4:  35%|███▍      | 19/55 [00:36<01:10,  1.95s/it]

Loss: 0.30881527066230774


Epoch 4:  36%|███▋      | 20/55 [00:38<01:08,  1.95s/it]

Loss: 0.5861104130744934


Epoch 4:  38%|███▊      | 21/55 [00:40<01:06,  1.95s/it]

Loss: 0.47634556889533997


Epoch 4:  40%|████      | 22/55 [00:42<01:04,  1.95s/it]

Loss: 0.8441054224967957


Epoch 4:  42%|████▏     | 23/55 [00:44<01:02,  1.95s/it]

Loss: 0.18143808841705322


Epoch 4:  44%|████▎     | 24/55 [00:46<01:00,  1.95s/it]

Loss: 0.7270483374595642


Epoch 4:  45%|████▌     | 25/55 [00:48<00:58,  1.95s/it]

Loss: 0.3781229257583618


Epoch 4:  47%|████▋     | 26/55 [00:50<00:56,  1.94s/it]

Loss: 0.6086850166320801


Epoch 4:  49%|████▉     | 27/55 [00:52<00:54,  1.94s/it]

Loss: 0.1265028715133667


Epoch 4:  51%|█████     | 28/55 [00:54<00:52,  1.94s/it]

Loss: 0.19865001738071442


Epoch 4:  53%|█████▎    | 29/55 [00:56<00:50,  1.95s/it]

Loss: 0.4834885001182556


Epoch 4:  55%|█████▍    | 30/55 [00:58<00:48,  1.95s/it]

Loss: 0.3341825604438782


Epoch 4:  56%|█████▋    | 31/55 [01:00<00:46,  1.94s/it]

Loss: 0.7475098371505737


Epoch 4:  58%|█████▊    | 32/55 [01:02<00:44,  1.94s/it]

Loss: 1.0903899669647217


Epoch 4:  60%|██████    | 33/55 [01:04<00:42,  1.94s/it]

Loss: 0.17623865604400635


Epoch 4:  62%|██████▏   | 34/55 [01:06<00:40,  1.94s/it]

Loss: 1.4448972940444946


Epoch 4:  64%|██████▎   | 35/55 [01:08<00:38,  1.94s/it]

Loss: 0.37871482968330383


Epoch 4:  65%|██████▌   | 36/55 [01:10<00:36,  1.94s/it]

Loss: 0.46119415760040283


Epoch 4:  67%|██████▋   | 37/55 [01:11<00:34,  1.94s/it]

Loss: 0.5402485132217407


Epoch 4:  69%|██████▉   | 38/55 [01:13<00:33,  1.94s/it]

Loss: 1.084728717803955


Epoch 4:  71%|███████   | 39/55 [01:15<00:31,  1.94s/it]

Loss: 0.4089829623699188


Epoch 4:  73%|███████▎  | 40/55 [01:17<00:29,  1.94s/it]

Loss: 1.200613021850586


Epoch 4:  75%|███████▍  | 41/55 [01:19<00:27,  1.94s/it]

Loss: 0.4632529020309448


Epoch 4:  76%|███████▋  | 42/55 [01:21<00:25,  1.94s/it]

Loss: 1.731222152709961


Epoch 4:  78%|███████▊  | 43/55 [01:23<00:23,  1.94s/it]

Loss: 0.3801916241645813


Epoch 4:  80%|████████  | 44/55 [01:25<00:21,  1.94s/it]

Loss: 0.2550056576728821


Epoch 4:  82%|████████▏ | 45/55 [01:27<00:19,  1.94s/it]

Loss: 0.6757807731628418


Epoch 4:  84%|████████▎ | 46/55 [01:29<00:17,  1.94s/it]

Loss: 0.4989874064922333


Epoch 4:  85%|████████▌ | 47/55 [01:31<00:15,  1.94s/it]

Loss: 0.47735846042633057


Epoch 4:  87%|████████▋ | 48/55 [01:33<00:13,  1.95s/it]

Loss: 0.4143703579902649


Epoch 4:  89%|████████▉ | 49/55 [01:35<00:11,  1.95s/it]

Loss: 0.27592918276786804


Epoch 4:  91%|█████████ | 50/55 [01:37<00:09,  1.95s/it]

Loss: 0.34348443150520325


Epoch 4:  93%|█████████▎| 51/55 [01:39<00:07,  1.95s/it]

Loss: 0.25922679901123047


Epoch 4:  95%|█████████▍| 52/55 [01:41<00:05,  1.95s/it]

Loss: 0.4810197651386261


Epoch 4:  96%|█████████▋| 53/55 [01:43<00:03,  1.95s/it]

Loss: 1.0049164295196533


Epoch 4:  98%|█████████▊| 54/55 [01:45<00:01,  1.95s/it]

Loss: 0.34553390741348267


Epoch 4: 100%|██████████| 55/55 [01:45<00:00,  1.92s/it]

Loss: 0.592072606086731





Epoch 4 Validation Accuracy: 0.5024630541871922, F1-macro: 0.48439582547466364


Epoch 5:   2%|▏         | 1/55 [00:01<01:45,  1.95s/it]

Loss: 1.0772392749786377


Epoch 5:   4%|▎         | 2/55 [00:03<01:43,  1.95s/it]

Loss: 0.17856884002685547


Epoch 5:   5%|▌         | 3/55 [00:05<01:41,  1.95s/it]

Loss: 0.3573142886161804


Epoch 5:   7%|▋         | 4/55 [00:07<01:39,  1.95s/it]

Loss: 0.4152928292751312


Epoch 5:   9%|▉         | 5/55 [00:09<01:37,  1.95s/it]

Loss: 1.1681832075119019


Epoch 5:  11%|█         | 6/55 [00:11<01:35,  1.95s/it]

Loss: 0.44398820400238037


Epoch 5:  13%|█▎        | 7/55 [00:13<01:33,  1.95s/it]

Loss: 0.5636944770812988


Epoch 5:  15%|█▍        | 8/55 [00:15<01:31,  1.95s/it]

Loss: 1.3789118528366089


Epoch 5:  16%|█▋        | 9/55 [00:17<01:29,  1.95s/it]

Loss: 0.25806015729904175


Epoch 5:  18%|█▊        | 10/55 [00:19<01:27,  1.95s/it]

Loss: 0.7367281913757324


Epoch 5:  20%|██        | 11/55 [00:21<01:25,  1.95s/it]

Loss: 0.3691375255584717


Epoch 5:  22%|██▏       | 12/55 [00:23<01:23,  1.95s/it]

Loss: 0.9703987836837769


Epoch 5:  24%|██▎       | 13/55 [00:25<01:21,  1.95s/it]

Loss: 0.23915664851665497


Epoch 5:  25%|██▌       | 14/55 [00:27<01:19,  1.95s/it]

Loss: 0.6593707799911499


Epoch 5:  27%|██▋       | 15/55 [00:29<01:17,  1.95s/it]

Loss: 0.4080987572669983


Epoch 5:  29%|██▉       | 16/55 [00:31<01:15,  1.95s/it]

Loss: 0.3769315481185913


Epoch 5:  31%|███       | 17/55 [00:33<01:13,  1.95s/it]

Loss: 0.3093952238559723


Epoch 5:  33%|███▎      | 18/55 [00:35<01:12,  1.95s/it]

Loss: 0.24185818433761597


Epoch 5:  35%|███▍      | 19/55 [00:37<01:10,  1.95s/it]

Loss: 0.565005362033844


Epoch 5:  36%|███▋      | 20/55 [00:38<01:08,  1.95s/it]

Loss: 0.40285593271255493


Epoch 5:  38%|███▊      | 21/55 [00:40<01:06,  1.94s/it]

Loss: 0.7449422478675842


Epoch 5:  40%|████      | 22/55 [00:42<01:04,  1.95s/it]

Loss: 0.6454346179962158


Epoch 5:  42%|████▏     | 23/55 [00:44<01:02,  1.94s/it]

Loss: 0.6596195101737976


Epoch 5:  44%|████▎     | 24/55 [00:46<01:00,  1.95s/it]

Loss: 0.5866026878356934


Epoch 5:  45%|████▌     | 25/55 [00:48<00:58,  1.95s/it]

Loss: 0.5057952404022217


Epoch 5:  47%|████▋     | 26/55 [00:50<00:56,  1.95s/it]

Loss: 0.20004898309707642


Epoch 5:  49%|████▉     | 27/55 [00:52<00:54,  1.95s/it]

Loss: 0.2302703857421875


Epoch 5:  51%|█████     | 28/55 [00:54<00:52,  1.95s/it]

Loss: 0.3463565409183502


Epoch 5:  53%|█████▎    | 29/55 [00:56<00:50,  1.95s/it]

Loss: 0.6409050226211548


Epoch 5:  55%|█████▍    | 30/55 [00:58<00:48,  1.95s/it]

Loss: 0.4681970179080963


Epoch 5:  56%|█████▋    | 31/55 [01:00<00:46,  1.95s/it]

Loss: 0.3897576928138733


Epoch 5:  58%|█████▊    | 32/55 [01:02<00:44,  1.95s/it]

Loss: 0.04118183255195618


Epoch 5:  60%|██████    | 33/55 [01:04<00:42,  1.95s/it]

Loss: 0.6101206541061401


Epoch 5:  62%|██████▏   | 34/55 [01:06<00:40,  1.95s/it]

Loss: 0.5060606002807617


Epoch 5:  64%|██████▎   | 35/55 [01:08<00:38,  1.95s/it]

Loss: 0.36226916313171387


Epoch 5:  65%|██████▌   | 36/55 [01:10<00:36,  1.95s/it]

Loss: 0.4656832218170166


Epoch 5:  67%|██████▋   | 37/55 [01:12<00:35,  1.95s/it]

Loss: 0.39377641677856445


Epoch 5:  69%|██████▉   | 38/55 [01:13<00:33,  1.95s/it]

Loss: 0.3883510231971741


Epoch 5:  71%|███████   | 39/55 [01:15<00:31,  1.94s/it]

Loss: 0.3000129461288452


Epoch 5:  73%|███████▎  | 40/55 [01:17<00:29,  1.94s/it]

Loss: 0.5250066518783569


Epoch 5:  75%|███████▍  | 41/55 [01:19<00:27,  1.94s/it]

Loss: 0.250896155834198


Epoch 5:  76%|███████▋  | 42/55 [01:21<00:25,  1.94s/it]

Loss: 1.1115527153015137


Epoch 5:  78%|███████▊  | 43/55 [01:23<00:23,  1.95s/it]

Loss: 0.2606285810470581


Epoch 5:  80%|████████  | 44/55 [01:25<00:21,  1.94s/it]

Loss: 1.3765654563903809


Epoch 5:  82%|████████▏ | 45/55 [01:27<00:19,  1.94s/it]

Loss: 1.0126609802246094


Epoch 5:  84%|████████▎ | 46/55 [01:29<00:17,  1.94s/it]

Loss: 0.46106570959091187


Epoch 5:  85%|████████▌ | 47/55 [01:31<00:15,  1.94s/it]

Loss: 0.4811401963233948


Epoch 5:  87%|████████▋ | 48/55 [01:33<00:13,  1.94s/it]

Loss: 0.5096823573112488


Epoch 5:  89%|████████▉ | 49/55 [01:35<00:11,  1.94s/it]

Loss: 2.028902530670166


Epoch 5:  91%|█████████ | 50/55 [01:37<00:09,  1.94s/it]

Loss: 0.5956838726997375


Epoch 5:  93%|█████████▎| 51/55 [01:39<00:07,  1.94s/it]

Loss: 1.0325822830200195


Epoch 5:  95%|█████████▍| 52/55 [01:41<00:05,  1.94s/it]

Loss: 0.542503297328949


Epoch 5:  96%|█████████▋| 53/55 [01:43<00:03,  1.94s/it]

Loss: 0.8519130945205688


Epoch 5:  98%|█████████▊| 54/55 [01:45<00:01,  1.94s/it]

Loss: 1.615035057067871


Epoch 5: 100%|██████████| 55/55 [01:45<00:00,  1.92s/it]

Loss: 0.03393686190247536





Epoch 5 Validation Accuracy: 0.8325123152709359, F1-macro: 0.6763878469617405


Epoch 6:   2%|▏         | 1/55 [00:01<01:44,  1.93s/it]

Loss: 0.3972313404083252


Epoch 6:   4%|▎         | 2/55 [00:03<01:42,  1.94s/it]

Loss: 0.8647222518920898


Epoch 6:   5%|▌         | 3/55 [00:05<01:40,  1.94s/it]

Loss: 0.7235463857650757


Epoch 6:   7%|▋         | 4/55 [00:07<01:38,  1.94s/it]

Loss: 0.3609570860862732


Epoch 6:   9%|▉         | 5/55 [00:09<01:36,  1.94s/it]

Loss: 1.04177725315094


Epoch 6:  11%|█         | 6/55 [00:11<01:34,  1.94s/it]

Loss: 1.1147944927215576


Epoch 6:  13%|█▎        | 7/55 [00:13<01:32,  1.94s/it]

Loss: 1.1219205856323242


Epoch 6:  15%|█▍        | 8/55 [00:15<01:31,  1.94s/it]

Loss: 0.4493711292743683


Epoch 6:  16%|█▋        | 9/55 [00:17<01:29,  1.94s/it]

Loss: 0.7534476518630981


Epoch 6:  18%|█▊        | 10/55 [00:19<01:27,  1.94s/it]

Loss: 2.4401352405548096


Epoch 6:  20%|██        | 11/55 [00:21<01:25,  1.94s/it]

Loss: 0.9386676549911499


Epoch 6:  22%|██▏       | 12/55 [00:23<01:23,  1.94s/it]

Loss: 0.2714006006717682


Epoch 6:  24%|██▎       | 13/55 [00:25<01:21,  1.94s/it]

Loss: 0.7058612704277039


Epoch 6:  25%|██▌       | 14/55 [00:27<01:19,  1.94s/it]

Loss: 2.0185818672180176


Epoch 6:  27%|██▋       | 15/55 [00:29<01:17,  1.94s/it]

Loss: 0.39351391792297363


Epoch 6:  29%|██▉       | 16/55 [00:31<01:15,  1.94s/it]

Loss: 2.50895094871521


Epoch 6:  31%|███       | 17/55 [00:33<01:13,  1.94s/it]

Loss: 0.9375314116477966


Epoch 6:  33%|███▎      | 18/55 [00:34<01:11,  1.94s/it]

Loss: 0.15858979523181915


Epoch 6:  35%|███▍      | 19/55 [00:36<01:09,  1.94s/it]

Loss: 0.6212058663368225


Epoch 6:  36%|███▋      | 20/55 [00:38<01:08,  1.94s/it]

Loss: 1.445995807647705


Epoch 6:  38%|███▊      | 21/55 [00:40<01:06,  1.95s/it]

Loss: 0.426318883895874


Epoch 6:  40%|████      | 22/55 [00:42<01:04,  1.94s/it]

Loss: 0.43826788663864136


Epoch 6:  42%|████▏     | 23/55 [00:44<01:02,  1.94s/it]

Loss: 1.4397984743118286


Epoch 6:  44%|████▎     | 24/55 [00:46<01:00,  1.94s/it]

Loss: 1.1479758024215698


Epoch 6:  45%|████▌     | 25/55 [00:48<00:58,  1.94s/it]

Loss: 0.8730035424232483


Epoch 6:  47%|████▋     | 26/55 [00:50<00:56,  1.94s/it]

Loss: 0.7349388599395752


Epoch 6:  49%|████▉     | 27/55 [00:52<00:54,  1.94s/it]

Loss: 0.6356425285339355


Epoch 6:  51%|█████     | 28/55 [00:54<00:52,  1.94s/it]

Loss: 0.3792825937271118


Epoch 6:  53%|█████▎    | 29/55 [00:56<00:50,  1.95s/it]

Loss: 0.3012048602104187


Epoch 6:  55%|█████▍    | 30/55 [00:58<00:48,  1.94s/it]

Loss: 1.1220550537109375


Epoch 6:  56%|█████▋    | 31/55 [01:00<00:46,  1.94s/it]

Loss: 0.011088140308856964


Epoch 6:  58%|█████▊    | 32/55 [01:02<00:44,  1.94s/it]

Loss: 0.7180272340774536


Epoch 6:  60%|██████    | 33/55 [01:04<00:42,  1.94s/it]

Loss: 0.4086363911628723


Epoch 6:  62%|██████▏   | 34/55 [01:06<00:40,  1.94s/it]

Loss: 0.5927761793136597


Epoch 6:  64%|██████▎   | 35/55 [01:08<00:38,  1.94s/it]

Loss: 0.9196023941040039


Epoch 6:  65%|██████▌   | 36/55 [01:09<00:36,  1.94s/it]

Loss: 0.39917466044425964


Epoch 6:  67%|██████▋   | 37/55 [01:11<00:35,  1.94s/it]

Loss: 0.27828842401504517


Epoch 6:  69%|██████▉   | 38/55 [01:13<00:33,  1.95s/it]

Loss: 0.4973439574241638


Epoch 6:  71%|███████   | 39/55 [01:15<00:31,  1.95s/it]

Loss: 0.15893106162548065


Epoch 6:  73%|███████▎  | 40/55 [01:17<00:29,  1.95s/it]

Loss: 0.3712509274482727


Epoch 6:  75%|███████▍  | 41/55 [01:19<00:27,  1.95s/it]

Loss: 0.2229205220937729


Epoch 6:  76%|███████▋  | 42/55 [01:21<00:25,  1.95s/it]

Loss: 0.8954724073410034


Epoch 6:  78%|███████▊  | 43/55 [01:23<00:23,  1.95s/it]

Loss: 0.48932403326034546


Epoch 6:  80%|████████  | 44/55 [01:25<00:21,  1.95s/it]

Loss: 0.2788841128349304


Epoch 6:  82%|████████▏ | 45/55 [01:27<00:19,  1.95s/it]

Loss: 0.7449605464935303


Epoch 6:  84%|████████▎ | 46/55 [01:29<00:17,  1.95s/it]

Loss: 0.1496041864156723


Epoch 6:  85%|████████▌ | 47/55 [01:31<00:15,  1.95s/it]

Loss: 0.9734448790550232


Epoch 6:  87%|████████▋ | 48/55 [01:33<00:13,  1.94s/it]

Loss: 0.18272021412849426


Epoch 6:  89%|████████▉ | 49/55 [01:35<00:11,  1.94s/it]

Loss: 0.35834598541259766


Epoch 6:  91%|█████████ | 50/55 [01:37<00:09,  1.95s/it]

Loss: 0.44028735160827637


Epoch 6:  93%|█████████▎| 51/55 [01:39<00:07,  1.95s/it]

Loss: 0.2582387924194336


Epoch 6:  95%|█████████▍| 52/55 [01:41<00:05,  1.94s/it]

Loss: 0.5706257224082947


Epoch 6:  96%|█████████▋| 53/55 [01:43<00:03,  1.94s/it]

Loss: 0.6687071323394775


Epoch 6:  98%|█████████▊| 54/55 [01:44<00:01,  1.95s/it]

Loss: 0.4338250160217285


Epoch 6: 100%|██████████| 55/55 [01:45<00:00,  1.92s/it]

Loss: 1.7302122116088867





Epoch 6 Validation Accuracy: 0.8374384236453202, F1-macro: 0.7050587769119007


Epoch 7:   2%|▏         | 1/55 [00:01<01:44,  1.94s/it]

Loss: 0.5851547718048096


Epoch 7:   4%|▎         | 2/55 [00:03<01:43,  1.95s/it]

Loss: 2.5243396759033203


Epoch 7:   5%|▌         | 3/55 [00:05<01:41,  1.95s/it]

Loss: 0.5114637017250061


Epoch 7:   7%|▋         | 4/55 [00:07<01:39,  1.94s/it]

Loss: 0.7723526358604431


Epoch 7:   9%|▉         | 5/55 [00:09<01:37,  1.94s/it]

Loss: 1.7246448993682861


Epoch 7:  11%|█         | 6/55 [00:11<01:35,  1.94s/it]

Loss: 1.797196865081787


Epoch 7:  13%|█▎        | 7/55 [00:13<01:33,  1.94s/it]

Loss: 0.9837242364883423


Epoch 7:  15%|█▍        | 8/55 [00:15<01:31,  1.94s/it]

Loss: 0.5999372005462646


Epoch 7:  16%|█▋        | 9/55 [00:17<01:29,  1.94s/it]

Loss: 0.46343153715133667


Epoch 7:  18%|█▊        | 10/55 [00:19<01:27,  1.94s/it]

Loss: 1.7667375802993774


Epoch 7:  20%|██        | 11/55 [00:21<01:25,  1.94s/it]

Loss: 0.6922119855880737


Epoch 7:  22%|██▏       | 12/55 [00:23<01:23,  1.94s/it]

Loss: 1.0296761989593506


Epoch 7:  24%|██▎       | 13/55 [00:25<01:21,  1.94s/it]

Loss: 2.213559150695801


Epoch 7:  25%|██▌       | 14/55 [00:27<01:19,  1.95s/it]

Loss: 2.1266541481018066


Epoch 7:  27%|██▋       | 15/55 [00:29<01:17,  1.95s/it]

Loss: 1.6355419158935547


Epoch 7:  29%|██▉       | 16/55 [00:31<01:15,  1.95s/it]

Loss: 0.7963660359382629


Epoch 7:  31%|███       | 17/55 [00:33<01:13,  1.95s/it]

Loss: 0.20152756571769714


Epoch 7:  33%|███▎      | 18/55 [00:35<01:12,  1.95s/it]

Loss: 0.44184964895248413


Epoch 7:  35%|███▍      | 19/55 [00:36<01:10,  1.95s/it]

Loss: 1.056045413017273


Epoch 7:  36%|███▋      | 20/55 [00:38<01:08,  1.95s/it]

Loss: 0.46776825189590454


Epoch 7:  38%|███▊      | 21/55 [00:40<01:06,  1.95s/it]

Loss: 0.45617496967315674


Epoch 7:  40%|████      | 22/55 [00:42<01:04,  1.95s/it]

Loss: 0.36006203293800354


Epoch 7:  42%|████▏     | 23/55 [00:44<01:02,  1.95s/it]

Loss: 0.09586139768362045


Epoch 7:  44%|████▎     | 24/55 [00:46<01:00,  1.95s/it]

Loss: 0.7294588088989258


Epoch 7:  45%|████▌     | 25/55 [00:48<00:58,  1.95s/it]

Loss: 0.09020400792360306


Epoch 7:  47%|████▋     | 26/55 [00:50<00:56,  1.95s/it]

Loss: 0.3815178871154785


Epoch 7:  49%|████▉     | 27/55 [00:52<00:54,  1.94s/it]

Loss: 0.22471919655799866


Epoch 7:  51%|█████     | 28/55 [00:54<00:52,  1.94s/it]

Loss: 0.6028327941894531


Epoch 7:  53%|█████▎    | 29/55 [00:56<00:50,  1.94s/it]

Loss: 0.0765138790011406


Epoch 7:  55%|█████▍    | 30/55 [00:58<00:48,  1.94s/it]

Loss: 0.3584687113761902


Epoch 7:  56%|█████▋    | 31/55 [01:00<00:46,  1.94s/it]

Loss: 0.7118393182754517


Epoch 7:  58%|█████▊    | 32/55 [01:02<00:44,  1.94s/it]

Loss: 0.15162977576255798


Epoch 7:  60%|██████    | 33/55 [01:04<00:42,  1.94s/it]

Loss: 0.21590353548526764


Epoch 7:  62%|██████▏   | 34/55 [01:06<00:40,  1.94s/it]

Loss: 0.43676653504371643


Epoch 7:  64%|██████▎   | 35/55 [01:08<00:38,  1.94s/it]

Loss: 0.258781373500824


Epoch 7:  65%|██████▌   | 36/55 [01:10<00:36,  1.94s/it]

Loss: 0.26221880316734314


Epoch 7:  67%|██████▋   | 37/55 [01:11<00:34,  1.94s/it]

Loss: 0.15153449773788452


Epoch 7:  69%|██████▉   | 38/55 [01:13<00:33,  1.94s/it]

Loss: 0.012356169521808624


Epoch 7:  71%|███████   | 39/55 [01:15<00:31,  1.94s/it]

Loss: 0.7106233835220337


Epoch 7:  73%|███████▎  | 40/55 [01:17<00:29,  1.94s/it]

Loss: 0.2377225011587143


Epoch 7:  75%|███████▍  | 41/55 [01:19<00:27,  1.94s/it]

Loss: 0.4987248182296753


Epoch 7:  76%|███████▋  | 42/55 [01:21<00:25,  1.94s/it]

Loss: 0.46651381254196167


Epoch 7:  78%|███████▊  | 43/55 [01:23<00:23,  1.95s/it]

Loss: 0.5893684029579163


Epoch 7:  80%|████████  | 44/55 [01:25<00:21,  1.94s/it]

Loss: 0.3437155485153198


Epoch 7:  82%|████████▏ | 45/55 [01:27<00:19,  1.94s/it]

Loss: 0.3016568422317505


Epoch 7:  84%|████████▎ | 46/55 [01:29<00:17,  1.94s/it]

Loss: 0.276727557182312


Epoch 7:  85%|████████▌ | 47/55 [01:31<00:15,  1.94s/it]

Loss: 0.14036154747009277


Epoch 7:  87%|████████▋ | 48/55 [01:33<00:13,  1.95s/it]

Loss: 0.9519059658050537


Epoch 7:  89%|████████▉ | 49/55 [01:35<00:11,  1.95s/it]

Loss: 1.1935317516326904


Epoch 7:  91%|█████████ | 50/55 [01:37<00:09,  1.95s/it]

Loss: 0.42625367641448975


Epoch 7:  93%|█████████▎| 51/55 [01:39<00:07,  1.95s/it]

Loss: 0.5782488584518433


Epoch 7:  95%|█████████▍| 52/55 [01:41<00:05,  1.95s/it]

Loss: 0.3613772988319397


Epoch 7:  96%|█████████▋| 53/55 [01:43<00:03,  1.95s/it]

Loss: 0.15289390087127686


Epoch 7:  98%|█████████▊| 54/55 [01:45<00:01,  1.95s/it]

Loss: 0.8189567923545837


Epoch 7: 100%|██████████| 55/55 [01:45<00:00,  1.92s/it]

Loss: 0.00138014554977417





Epoch 7 Validation Accuracy: 0.8719211822660099, F1-macro: 0.7307142857142856


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model5, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8823529411764706, F1-macro: 0.7166666666666667


In [None]:
current_type = 'fortune telling'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 6. Labeling

In [None]:
# Add labels
data1_1_labels = list(data1['labeling'][data1_1.index])
data2_1_labels = list(data2['labeling'][data2_1.index])
data3_1_labels = list(data3['labeling'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model6 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model6.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model6.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model6(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model6, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/54 [00:01<01:41,  1.92s/it]

Loss: 4.968138694763184


Epoch 1:   4%|▎         | 2/54 [00:03<01:40,  1.93s/it]

Loss: 2.6190247535705566


Epoch 1:   6%|▌         | 3/54 [00:05<01:38,  1.94s/it]

Loss: 2.149172067642212


Epoch 1:   7%|▋         | 4/54 [00:07<01:37,  1.94s/it]

Loss: 2.8206803798675537


Epoch 1:   9%|▉         | 5/54 [00:09<01:35,  1.94s/it]

Loss: 3.2088003158569336


Epoch 1:  11%|█         | 6/54 [00:11<01:33,  1.94s/it]

Loss: 2.24906063079834


Epoch 1:  13%|█▎        | 7/54 [00:13<01:31,  1.94s/it]

Loss: 0.9006586670875549


Epoch 1:  15%|█▍        | 8/54 [00:15<01:29,  1.94s/it]

Loss: 0.6896635293960571


Epoch 1:  17%|█▋        | 9/54 [00:17<01:27,  1.94s/it]

Loss: 3.268418788909912


Epoch 1:  19%|█▊        | 10/54 [00:19<01:25,  1.94s/it]

Loss: 1.5255005359649658


Epoch 1:  20%|██        | 11/54 [00:21<01:23,  1.94s/it]

Loss: 1.6626591682434082


Epoch 1:  22%|██▏       | 12/54 [00:23<01:21,  1.94s/it]

Loss: 1.4160187244415283


Epoch 1:  24%|██▍       | 13/54 [00:25<01:19,  1.94s/it]

Loss: 1.8635518550872803


Epoch 1:  26%|██▌       | 14/54 [00:27<01:17,  1.94s/it]

Loss: 1.9912514686584473


Epoch 1:  28%|██▊       | 15/54 [00:29<01:15,  1.94s/it]

Loss: 1.489635944366455


Epoch 1:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 1.9479906558990479


Epoch 1:  31%|███▏      | 17/54 [00:33<01:11,  1.94s/it]

Loss: 1.0103542804718018


Epoch 1:  33%|███▎      | 18/54 [00:34<01:09,  1.94s/it]

Loss: 2.039121627807617


Epoch 1:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 1.6327234506607056


Epoch 1:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 1.6433418989181519


Epoch 1:  39%|███▉      | 21/54 [00:40<01:04,  1.95s/it]

Loss: 2.435340404510498


Epoch 1:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 3.4518239498138428


Epoch 1:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.7167133092880249


Epoch 1:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 1.3614131212234497


Epoch 1:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 1.648253083229065


Epoch 1:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 3.5307021141052246


Epoch 1:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 1.2948771715164185


Epoch 1:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 1.8186960220336914


Epoch 1:  54%|█████▎    | 29/54 [00:56<00:48,  1.95s/it]

Loss: 0.9594228863716125


Epoch 1:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 2.6749157905578613


Epoch 1:  57%|█████▋    | 31/54 [01:00<00:44,  1.95s/it]

Loss: 1.7125614881515503


Epoch 1:  59%|█████▉    | 32/54 [01:02<00:42,  1.95s/it]

Loss: 0.4848390221595764


Epoch 1:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 1.048205852508545


Epoch 1:  63%|██████▎   | 34/54 [01:06<00:38,  1.95s/it]

Loss: 1.1739318370819092


Epoch 1:  65%|██████▍   | 35/54 [01:08<00:36,  1.94s/it]

Loss: 0.9764727354049683


Epoch 1:  67%|██████▋   | 36/54 [01:10<00:35,  1.94s/it]

Loss: 0.6515074372291565


Epoch 1:  69%|██████▊   | 37/54 [01:11<00:33,  1.94s/it]

Loss: 2.372706413269043


Epoch 1:  70%|███████   | 38/54 [01:13<00:31,  1.94s/it]

Loss: 0.9370478391647339


Epoch 1:  72%|███████▏  | 39/54 [01:15<00:29,  1.94s/it]

Loss: 1.5097668170928955


Epoch 1:  74%|███████▍  | 40/54 [01:17<00:27,  1.94s/it]

Loss: 0.6610038876533508


Epoch 1:  76%|███████▌  | 41/54 [01:19<00:25,  1.94s/it]

Loss: 0.6618665456771851


Epoch 1:  78%|███████▊  | 42/54 [01:21<00:23,  1.94s/it]

Loss: 0.9198046326637268


Epoch 1:  80%|███████▉  | 43/54 [01:23<00:21,  1.94s/it]

Loss: 1.0150729417800903


Epoch 1:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 1.0392498970031738


Epoch 1:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 0.9088526964187622


Epoch 1:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 0.4369950294494629


Epoch 1:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 0.8525949716567993


Epoch 1:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 0.6106762886047363


Epoch 1:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 0.6934980750083923


Epoch 1:  93%|█████████▎| 50/54 [01:37<00:07,  1.95s/it]

Loss: 0.19827181100845337


Epoch 1:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 1.0010355710983276


Epoch 1:  96%|█████████▋| 52/54 [01:41<00:03,  1.94s/it]

Loss: 0.5970242023468018


Epoch 1:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 0.6364238262176514


Epoch 1: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.2124028205871582





Epoch 1 Validation Accuracy: 0.8029556650246306, F1-macro: 0.5977011494252874


Epoch 2:   2%|▏         | 1/54 [00:01<01:43,  1.95s/it]

Loss: 0.6739187240600586


Epoch 2:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.6769562363624573


Epoch 2:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.5067214965820312


Epoch 2:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 0.6967073678970337


Epoch 2:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 0.4836614727973938


Epoch 2:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.8647056818008423


Epoch 2:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.2923196852207184


Epoch 2:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 0.5707568526268005


Epoch 2:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 0.5565588474273682


Epoch 2:  19%|█▊        | 10/54 [00:19<01:25,  1.95s/it]

Loss: 0.33986225724220276


Epoch 2:  20%|██        | 11/54 [00:21<01:23,  1.95s/it]

Loss: 0.6150661706924438


Epoch 2:  22%|██▏       | 12/54 [00:23<01:21,  1.95s/it]

Loss: 0.5080087184906006


Epoch 2:  24%|██▍       | 13/54 [00:25<01:19,  1.95s/it]

Loss: 0.20288343727588654


Epoch 2:  26%|██▌       | 14/54 [00:27<01:17,  1.95s/it]

Loss: 0.6042908430099487


Epoch 2:  28%|██▊       | 15/54 [00:29<01:15,  1.95s/it]

Loss: 0.49858707189559937


Epoch 2:  30%|██▉       | 16/54 [00:31<01:13,  1.95s/it]

Loss: 0.45914196968078613


Epoch 2:  31%|███▏      | 17/54 [00:33<01:11,  1.95s/it]

Loss: 0.3399270176887512


Epoch 2:  33%|███▎      | 18/54 [00:35<01:10,  1.95s/it]

Loss: 0.30811214447021484


Epoch 2:  35%|███▌      | 19/54 [00:36<01:08,  1.95s/it]

Loss: 0.3100414574146271


Epoch 2:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 0.23918451368808746


Epoch 2:  39%|███▉      | 21/54 [00:40<01:04,  1.94s/it]

Loss: 0.3581509590148926


Epoch 2:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 0.40943288803100586


Epoch 2:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.20471468567848206


Epoch 2:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 0.3826276361942291


Epoch 2:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 0.39439666271209717


Epoch 2:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 0.25478485226631165


Epoch 2:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 0.4087403416633606


Epoch 2:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 0.23040375113487244


Epoch 2:  54%|█████▎    | 29/54 [00:56<00:48,  1.95s/it]

Loss: 0.505953311920166


Epoch 2:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 0.5391722321510315


Epoch 2:  57%|█████▋    | 31/54 [01:00<00:44,  1.95s/it]

Loss: 0.33173060417175293


Epoch 2:  59%|█████▉    | 32/54 [01:02<00:42,  1.95s/it]

Loss: 0.48776817321777344


Epoch 2:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 0.28257280588150024


Epoch 2:  63%|██████▎   | 34/54 [01:06<00:38,  1.95s/it]

Loss: 0.0962955504655838


Epoch 2:  65%|██████▍   | 35/54 [01:08<00:37,  1.95s/it]

Loss: 0.7490977048873901


Epoch 2:  67%|██████▋   | 36/54 [01:10<00:35,  1.95s/it]

Loss: 0.6955032348632812


Epoch 2:  69%|██████▊   | 37/54 [01:12<00:33,  1.95s/it]

Loss: 0.6606618762016296


Epoch 2:  70%|███████   | 38/54 [01:13<00:31,  1.95s/it]

Loss: 0.6500204801559448


Epoch 2:  72%|███████▏  | 39/54 [01:15<00:29,  1.95s/it]

Loss: 0.5583505630493164


Epoch 2:  74%|███████▍  | 40/54 [01:17<00:27,  1.95s/it]

Loss: 0.3908732235431671


Epoch 2:  76%|███████▌  | 41/54 [01:19<00:25,  1.95s/it]

Loss: 0.6920629739761353


Epoch 2:  78%|███████▊  | 42/54 [01:21<00:23,  1.95s/it]

Loss: 0.7207856178283691


Epoch 2:  80%|███████▉  | 43/54 [01:23<00:21,  1.95s/it]

Loss: 0.2750478982925415


Epoch 2:  81%|████████▏ | 44/54 [01:25<00:19,  1.95s/it]

Loss: 0.9612364172935486


Epoch 2:  83%|████████▎ | 45/54 [01:27<00:17,  1.95s/it]

Loss: 0.5662598609924316


Epoch 2:  85%|████████▌ | 46/54 [01:29<00:15,  1.95s/it]

Loss: 0.35416463017463684


Epoch 2:  87%|████████▋ | 47/54 [01:31<00:13,  1.95s/it]

Loss: 0.6902412176132202


Epoch 2:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 0.3981456160545349


Epoch 2:  91%|█████████ | 49/54 [01:35<00:09,  1.95s/it]

Loss: 0.5136240124702454


Epoch 2:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 0.9016827940940857


Epoch 2:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 0.5593562126159668


Epoch 2:  96%|█████████▋| 52/54 [01:41<00:03,  1.95s/it]

Loss: 0.5029026865959167


Epoch 2:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 0.46974268555641174


Epoch 2: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.843901515007019





Epoch 2 Validation Accuracy: 0.812807881773399, F1-macro: 0.7196947674418605


Epoch 3:   2%|▏         | 1/54 [00:01<01:43,  1.95s/it]

Loss: 0.3324086666107178


Epoch 3:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.36747127771377563


Epoch 3:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.5119043588638306


Epoch 3:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 0.21196502447128296


Epoch 3:   9%|▉         | 5/54 [00:09<01:35,  1.94s/it]

Loss: 0.6499415636062622


Epoch 3:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.1730639934539795


Epoch 3:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.7238473296165466


Epoch 3:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 0.5299890041351318


Epoch 3:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 0.5038081407546997


Epoch 3:  19%|█▊        | 10/54 [00:19<01:25,  1.94s/it]

Loss: 0.883384108543396


Epoch 3:  20%|██        | 11/54 [00:21<01:23,  1.94s/it]

Loss: 0.433910071849823


Epoch 3:  22%|██▏       | 12/54 [00:23<01:21,  1.94s/it]

Loss: 0.9829641580581665


Epoch 3:  24%|██▍       | 13/54 [00:25<01:19,  1.94s/it]

Loss: 1.2313232421875


Epoch 3:  26%|██▌       | 14/54 [00:27<01:17,  1.95s/it]

Loss: 0.7310385704040527


Epoch 3:  28%|██▊       | 15/54 [00:29<01:15,  1.95s/it]

Loss: 1.015066385269165


Epoch 3:  30%|██▉       | 16/54 [00:31<01:13,  1.95s/it]

Loss: 0.1704668551683426


Epoch 3:  31%|███▏      | 17/54 [00:33<01:12,  1.95s/it]

Loss: 0.36214953660964966


Epoch 3:  33%|███▎      | 18/54 [00:35<01:10,  1.95s/it]

Loss: 1.3283213376998901


Epoch 3:  35%|███▌      | 19/54 [00:36<01:08,  1.95s/it]

Loss: 0.3485885560512543


Epoch 3:  37%|███▋      | 20/54 [00:38<01:06,  1.95s/it]

Loss: 1.1006730794906616


Epoch 3:  39%|███▉      | 21/54 [00:40<01:04,  1.95s/it]

Loss: 1.274863600730896


Epoch 3:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 1.041795253753662


Epoch 3:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.4064624309539795


Epoch 3:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 0.4631385803222656


Epoch 3:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 0.650471568107605


Epoch 3:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 0.31657740473747253


Epoch 3:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 0.4289717674255371


Epoch 3:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 0.2158486247062683


Epoch 3:  54%|█████▎    | 29/54 [00:56<00:48,  1.95s/it]

Loss: 0.7086206674575806


Epoch 3:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 0.5005953907966614


Epoch 3:  57%|█████▋    | 31/54 [01:00<00:44,  1.94s/it]

Loss: 0.8178369998931885


Epoch 3:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 0.5777438879013062


Epoch 3:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 0.5481843948364258


Epoch 3:  63%|██████▎   | 34/54 [01:06<00:38,  1.95s/it]

Loss: 0.6771419048309326


Epoch 3:  65%|██████▍   | 35/54 [01:08<00:36,  1.94s/it]

Loss: 0.20882965624332428


Epoch 3:  67%|██████▋   | 36/54 [01:10<00:35,  1.95s/it]

Loss: 0.4805721640586853


Epoch 3:  69%|██████▊   | 37/54 [01:11<00:33,  1.95s/it]

Loss: 0.6519918441772461


Epoch 3:  70%|███████   | 38/54 [01:13<00:31,  1.95s/it]

Loss: 0.41167524456977844


Epoch 3:  72%|███████▏  | 39/54 [01:15<00:29,  1.95s/it]

Loss: 1.100914716720581


Epoch 3:  74%|███████▍  | 40/54 [01:17<00:27,  1.95s/it]

Loss: 0.16448551416397095


Epoch 3:  76%|███████▌  | 41/54 [01:19<00:25,  1.95s/it]

Loss: 0.369925320148468


Epoch 3:  78%|███████▊  | 42/54 [01:21<00:23,  1.95s/it]

Loss: 0.3305910527706146


Epoch 3:  80%|███████▉  | 43/54 [01:23<00:21,  1.94s/it]

Loss: 0.3243005871772766


Epoch 3:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 0.33204638957977295


Epoch 3:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 0.5562961101531982


Epoch 3:  85%|████████▌ | 46/54 [01:29<00:15,  1.95s/it]

Loss: 1.4291198253631592


Epoch 3:  87%|████████▋ | 47/54 [01:31<00:13,  1.95s/it]

Loss: 0.5273333191871643


Epoch 3:  89%|████████▉ | 48/54 [01:33<00:11,  1.95s/it]

Loss: 2.2043967247009277


Epoch 3:  91%|█████████ | 49/54 [01:35<00:09,  1.95s/it]

Loss: 0.455021470785141


Epoch 3:  93%|█████████▎| 50/54 [01:37<00:07,  1.95s/it]

Loss: 0.22609686851501465


Epoch 3:  94%|█████████▍| 51/54 [01:39<00:05,  1.95s/it]

Loss: 0.943690299987793


Epoch 3:  96%|█████████▋| 52/54 [01:41<00:03,  1.95s/it]

Loss: 0.6593271493911743


Epoch 3:  98%|█████████▊| 53/54 [01:43<00:01,  1.95s/it]

Loss: 3.4436895847320557


Epoch 3: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.8990033864974976





Epoch 3 Validation Accuracy: 0.8078817733990148, F1-macro: 0.5126500461680518


Epoch 4:   2%|▏         | 1/54 [00:01<01:43,  1.95s/it]

Loss: 0.5504113435745239


Epoch 4:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.9745998382568359


Epoch 4:   6%|▌         | 3/54 [00:05<01:39,  1.94s/it]

Loss: 0.46381229162216187


Epoch 4:   7%|▋         | 4/54 [00:07<01:37,  1.94s/it]

Loss: 0.908503532409668


Epoch 4:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 0.9991005659103394


Epoch 4:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.39245885610580444


Epoch 4:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.7779760956764221


Epoch 4:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 1.8001984357833862


Epoch 4:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 1.1620762348175049


Epoch 4:  19%|█▊        | 10/54 [00:19<01:25,  1.95s/it]

Loss: 0.8183298110961914


Epoch 4:  20%|██        | 11/54 [00:21<01:23,  1.95s/it]

Loss: 0.45053526759147644


Epoch 4:  22%|██▏       | 12/54 [00:23<01:21,  1.95s/it]

Loss: 0.271276593208313


Epoch 4:  24%|██▍       | 13/54 [00:25<01:19,  1.95s/it]

Loss: 0.4788280129432678


Epoch 4:  26%|██▌       | 14/54 [00:27<01:17,  1.94s/it]

Loss: 0.8423408269882202


Epoch 4:  28%|██▊       | 15/54 [00:29<01:15,  1.94s/it]

Loss: 0.9990348815917969


Epoch 4:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 0.4792523980140686


Epoch 4:  31%|███▏      | 17/54 [00:33<01:11,  1.94s/it]

Loss: 1.2095205783843994


Epoch 4:  33%|███▎      | 18/54 [00:35<01:10,  1.94s/it]

Loss: 0.9551796913146973


Epoch 4:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 0.43351632356643677


Epoch 4:  37%|███▋      | 20/54 [00:38<01:06,  1.95s/it]

Loss: 0.3353460431098938


Epoch 4:  39%|███▉      | 21/54 [00:40<01:04,  1.95s/it]

Loss: 0.4074995517730713


Epoch 4:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 0.3147990107536316


Epoch 4:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.30078089237213135


Epoch 4:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 0.5361140370368958


Epoch 4:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 0.32348042726516724


Epoch 4:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 0.6196005344390869


Epoch 4:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 0.32328036427497864


Epoch 4:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 0.6971736550331116


Epoch 4:  54%|█████▎    | 29/54 [00:56<00:48,  1.95s/it]

Loss: 0.4726513624191284


Epoch 4:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 1.048935055732727


Epoch 4:  57%|█████▋    | 31/54 [01:00<00:44,  1.95s/it]

Loss: 0.3556552231311798


Epoch 4:  59%|█████▉    | 32/54 [01:02<00:42,  1.95s/it]

Loss: 1.0479985475540161


Epoch 4:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 0.13359346985816956


Epoch 4:  63%|██████▎   | 34/54 [01:06<00:38,  1.95s/it]

Loss: 1.3443011045455933


Epoch 4:  65%|██████▍   | 35/54 [01:08<00:36,  1.95s/it]

Loss: 0.7848737239837646


Epoch 4:  67%|██████▋   | 36/54 [01:10<00:35,  1.95s/it]

Loss: 0.34614861011505127


Epoch 4:  69%|██████▊   | 37/54 [01:12<00:33,  1.95s/it]

Loss: 0.7136027812957764


Epoch 4:  70%|███████   | 38/54 [01:13<00:31,  1.95s/it]

Loss: 0.44804081320762634


Epoch 4:  72%|███████▏  | 39/54 [01:15<00:29,  1.95s/it]

Loss: 0.09142041951417923


Epoch 4:  74%|███████▍  | 40/54 [01:17<00:27,  1.95s/it]

Loss: 0.23411478102207184


Epoch 4:  76%|███████▌  | 41/54 [01:19<00:25,  1.94s/it]

Loss: 0.29289209842681885


Epoch 4:  78%|███████▊  | 42/54 [01:21<00:23,  1.94s/it]

Loss: 0.7200309634208679


Epoch 4:  80%|███████▉  | 43/54 [01:23<00:21,  1.94s/it]

Loss: 0.19561535120010376


Epoch 4:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 0.5951955318450928


Epoch 4:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 0.5249878168106079


Epoch 4:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 0.40270352363586426


Epoch 4:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 0.1903308928012848


Epoch 4:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 0.3855847120285034


Epoch 4:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 0.4700344204902649


Epoch 4:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 0.44653022289276123


Epoch 4:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 0.39896178245544434


Epoch 4:  96%|█████████▋| 52/54 [01:41<00:03,  1.94s/it]

Loss: 0.3651401698589325


Epoch 4:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 0.28401172161102295


Epoch 4: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.4183810353279114





Epoch 4 Validation Accuracy: 0.8325123152709359, F1-macro: 0.7248006379585328


Epoch 5:   2%|▏         | 1/54 [00:01<01:42,  1.94s/it]

Loss: 0.22610795497894287


Epoch 5:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.42383110523223877


Epoch 5:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.2717760503292084


Epoch 5:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 0.16839732229709625


Epoch 5:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 0.1273527294397354


Epoch 5:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.2907770276069641


Epoch 5:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.18000371754169464


Epoch 5:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 0.4461726248264313


Epoch 5:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 0.18395084142684937


Epoch 5:  19%|█▊        | 10/54 [00:19<01:25,  1.95s/it]

Loss: 0.49313366413116455


Epoch 5:  20%|██        | 11/54 [00:21<01:23,  1.95s/it]

Loss: 0.32878798246383667


Epoch 5:  22%|██▏       | 12/54 [00:23<01:21,  1.95s/it]

Loss: 0.2741880714893341


Epoch 5:  24%|██▍       | 13/54 [00:25<01:19,  1.95s/it]

Loss: 0.15810146927833557


Epoch 5:  26%|██▌       | 14/54 [00:27<01:17,  1.95s/it]

Loss: 1.4580824375152588


Epoch 5:  28%|██▊       | 15/54 [00:29<01:15,  1.95s/it]

Loss: 0.7877786159515381


Epoch 5:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 0.8793401718139648


Epoch 5:  31%|███▏      | 17/54 [00:33<01:11,  1.94s/it]

Loss: 1.3205162286758423


Epoch 5:  33%|███▎      | 18/54 [00:35<01:10,  1.94s/it]

Loss: 0.5282649993896484


Epoch 5:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 0.13829578459262848


Epoch 5:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 0.8325160145759583


Epoch 5:  39%|███▉      | 21/54 [00:40<01:04,  1.94s/it]

Loss: 1.6737751960754395


Epoch 5:  41%|████      | 22/54 [00:42<01:02,  1.94s/it]

Loss: 0.7173913717269897


Epoch 5:  43%|████▎     | 23/54 [00:44<01:00,  1.94s/it]

Loss: 1.6304703950881958


Epoch 5:  44%|████▍     | 24/54 [00:46<00:58,  1.94s/it]

Loss: 0.9375636577606201


Epoch 5:  46%|████▋     | 25/54 [00:48<00:56,  1.94s/it]

Loss: 0.8879435062408447


Epoch 5:  48%|████▊     | 26/54 [00:50<00:54,  1.94s/it]

Loss: 0.965502142906189


Epoch 5:  50%|█████     | 27/54 [00:52<00:52,  1.94s/it]

Loss: 3.4391095638275146


Epoch 5:  52%|█████▏    | 28/54 [00:54<00:50,  1.94s/it]

Loss: 0.7230034470558167


Epoch 5:  54%|█████▎    | 29/54 [00:56<00:48,  1.94s/it]

Loss: 0.6755132079124451


Epoch 5:  56%|█████▌    | 30/54 [00:58<00:46,  1.94s/it]

Loss: 1.8424994945526123


Epoch 5:  57%|█████▋    | 31/54 [01:00<00:44,  1.94s/it]

Loss: 0.3806653916835785


Epoch 5:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 4.391875743865967


Epoch 5:  61%|██████    | 33/54 [01:04<00:40,  1.94s/it]

Loss: 2.985238790512085


Epoch 5:  63%|██████▎   | 34/54 [01:06<00:38,  1.94s/it]

Loss: 1.1029503345489502


Epoch 5:  65%|██████▍   | 35/54 [01:08<00:36,  1.94s/it]

Loss: 1.126692533493042


Epoch 5:  67%|██████▋   | 36/54 [01:10<00:34,  1.94s/it]

Loss: 0.6918399333953857


Epoch 5:  69%|██████▊   | 37/54 [01:11<00:32,  1.94s/it]

Loss: 0.8803329467773438


Epoch 5:  70%|███████   | 38/54 [01:13<00:31,  1.94s/it]

Loss: 2.846895456314087


Epoch 5:  72%|███████▏  | 39/54 [01:15<00:29,  1.94s/it]

Loss: 1.0287803411483765


Epoch 5:  74%|███████▍  | 40/54 [01:17<00:27,  1.94s/it]

Loss: 0.6322565078735352


Epoch 5:  76%|███████▌  | 41/54 [01:19<00:25,  1.94s/it]

Loss: 1.728158712387085


Epoch 5:  78%|███████▊  | 42/54 [01:21<00:23,  1.94s/it]

Loss: 2.084975242614746


Epoch 5:  80%|███████▉  | 43/54 [01:23<00:21,  1.94s/it]

Loss: 2.815370798110962


Epoch 5:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 2.1317906379699707


Epoch 5:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 1.6208281517028809


Epoch 5:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 0.7380365133285522


Epoch 5:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 1.190429925918579


Epoch 5:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 1.5305441617965698


Epoch 5:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 0.6761813163757324


Epoch 5:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 0.6049808263778687


Epoch 5:  94%|█████████▍| 51/54 [01:39<00:05,  1.95s/it]

Loss: 1.6898009777069092


Epoch 5:  96%|█████████▋| 52/54 [01:41<00:03,  1.94s/it]

Loss: 0.7961987853050232


Epoch 5:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 0.25265270471572876


Epoch 5: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.30014294385910034





Epoch 5 Validation Accuracy: 0.7783251231527094, F1-macro: 0.6908106278558132


Epoch 6:   2%|▏         | 1/54 [00:01<01:42,  1.94s/it]

Loss: 0.7168608903884888


Epoch 6:   4%|▎         | 2/54 [00:03<01:40,  1.94s/it]

Loss: 0.44955649971961975


Epoch 6:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.5633834600448608


Epoch 6:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 1.0154459476470947


Epoch 6:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 0.04899820685386658


Epoch 6:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.21606114506721497


Epoch 6:  13%|█▎        | 7/54 [00:13<01:31,  1.94s/it]

Loss: 0.5216854214668274


Epoch 6:  15%|█▍        | 8/54 [00:15<01:29,  1.94s/it]

Loss: 0.2318165898323059


Epoch 6:  17%|█▋        | 9/54 [00:17<01:27,  1.94s/it]

Loss: 0.3571743965148926


Epoch 6:  19%|█▊        | 10/54 [00:19<01:25,  1.94s/it]

Loss: 0.5365191102027893


Epoch 6:  20%|██        | 11/54 [00:21<01:23,  1.94s/it]

Loss: 0.6134186387062073


Epoch 6:  22%|██▏       | 12/54 [00:23<01:21,  1.94s/it]

Loss: 0.1202094703912735


Epoch 6:  24%|██▍       | 13/54 [00:25<01:19,  1.94s/it]

Loss: 0.3930920958518982


Epoch 6:  26%|██▌       | 14/54 [00:27<01:17,  1.94s/it]

Loss: 0.19323067367076874


Epoch 6:  28%|██▊       | 15/54 [00:29<01:15,  1.95s/it]

Loss: 0.5473357439041138


Epoch 6:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 0.40549546480178833


Epoch 6:  31%|███▏      | 17/54 [00:33<01:11,  1.95s/it]

Loss: 0.33891797065734863


Epoch 6:  33%|███▎      | 18/54 [00:35<01:10,  1.95s/it]

Loss: 0.6454309225082397


Epoch 6:  35%|███▌      | 19/54 [00:36<01:08,  1.95s/it]

Loss: 0.6625553965568542


Epoch 6:  37%|███▋      | 20/54 [00:38<01:06,  1.95s/it]

Loss: 0.3143730163574219


Epoch 6:  39%|███▉      | 21/54 [00:40<01:04,  1.95s/it]

Loss: 0.32792776823043823


Epoch 6:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 0.19173410534858704


Epoch 6:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.36310333013534546


Epoch 6:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 0.4486837387084961


Epoch 6:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 0.1419866383075714


Epoch 6:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 0.4929548501968384


Epoch 6:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 0.26715371012687683


Epoch 6:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 0.342235803604126


Epoch 6:  54%|█████▎    | 29/54 [00:56<00:48,  1.94s/it]

Loss: 0.06491813063621521


Epoch 6:  56%|█████▌    | 30/54 [00:58<00:46,  1.94s/it]

Loss: 0.591500461101532


Epoch 6:  57%|█████▋    | 31/54 [01:00<00:44,  1.94s/it]

Loss: 0.355909526348114


Epoch 6:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 0.5256137847900391


Epoch 6:  61%|██████    | 33/54 [01:04<00:40,  1.94s/it]

Loss: 0.32906627655029297


Epoch 6:  63%|██████▎   | 34/54 [01:06<00:38,  1.94s/it]

Loss: 0.347579687833786


Epoch 6:  65%|██████▍   | 35/54 [01:08<00:36,  1.94s/it]

Loss: 0.3213866353034973


Epoch 6:  67%|██████▋   | 36/54 [01:10<00:34,  1.94s/it]

Loss: 0.23915547132492065


Epoch 6:  69%|██████▊   | 37/54 [01:11<00:33,  1.94s/it]

Loss: 0.6997865438461304


Epoch 6:  70%|███████   | 38/54 [01:13<00:31,  1.94s/it]

Loss: 0.8665074110031128


Epoch 6:  72%|███████▏  | 39/54 [01:15<00:29,  1.94s/it]

Loss: 0.747635006904602


Epoch 6:  74%|███████▍  | 40/54 [01:17<00:27,  1.94s/it]

Loss: 0.37365594506263733


Epoch 6:  76%|███████▌  | 41/54 [01:19<00:25,  1.94s/it]

Loss: 0.739152193069458


Epoch 6:  78%|███████▊  | 42/54 [01:21<00:23,  1.95s/it]

Loss: 2.2900402545928955


Epoch 6:  80%|███████▉  | 43/54 [01:23<00:21,  1.94s/it]

Loss: 1.3942482471466064


Epoch 6:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 0.6282952427864075


Epoch 6:  83%|████████▎ | 45/54 [01:27<00:17,  1.95s/it]

Loss: 1.3193575143814087


Epoch 6:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 2.2718734741210938


Epoch 6:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 0.9290861487388611


Epoch 6:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 0.8503162860870361


Epoch 6:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 1.2457466125488281


Epoch 6:  93%|█████████▎| 50/54 [01:37<00:07,  1.95s/it]

Loss: 0.7523881196975708


Epoch 6:  94%|█████████▍| 51/54 [01:39<00:05,  1.95s/it]

Loss: 1.4606821537017822


Epoch 6:  96%|█████████▋| 52/54 [01:41<00:03,  1.95s/it]

Loss: 1.062506079673767


Epoch 6:  98%|█████████▊| 53/54 [01:43<00:01,  1.95s/it]

Loss: 0.9761267304420471


Epoch 6: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.42554688453674316





Epoch 6 Validation Accuracy: 0.5714285714285714, F1-macro: 0.5541164886768158


Epoch 7:   2%|▏         | 1/54 [00:01<01:43,  1.95s/it]

Loss: 0.8969721794128418


Epoch 7:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.5890957117080688


Epoch 7:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.4901808798313141


Epoch 7:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 0.665351927280426


Epoch 7:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 1.4106049537658691


Epoch 7:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.656987190246582


Epoch 7:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.397825688123703


Epoch 7:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 0.6099762916564941


Epoch 7:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 0.5746814012527466


Epoch 7:  19%|█▊        | 10/54 [00:19<01:25,  1.95s/it]

Loss: 0.8362974524497986


Epoch 7:  20%|██        | 11/54 [00:21<01:23,  1.95s/it]

Loss: 0.45825082063674927


Epoch 7:  22%|██▏       | 12/54 [00:23<01:21,  1.95s/it]

Loss: 0.49747592210769653


Epoch 7:  24%|██▍       | 13/54 [00:25<01:20,  1.95s/it]

Loss: 1.5062344074249268


Epoch 7:  26%|██▌       | 14/54 [00:27<01:17,  1.95s/it]

Loss: 0.9819414019584656


Epoch 7:  28%|██▊       | 15/54 [00:29<01:15,  1.95s/it]

Loss: 0.44818878173828125


Epoch 7:  30%|██▉       | 16/54 [00:31<01:14,  1.95s/it]

Loss: 0.2894425690174103


Epoch 7:  31%|███▏      | 17/54 [00:33<01:12,  1.95s/it]

Loss: 0.4742649793624878


Epoch 7:  33%|███▎      | 18/54 [00:35<01:10,  1.95s/it]

Loss: 0.5607212781906128


Epoch 7:  35%|███▌      | 19/54 [00:37<01:08,  1.95s/it]

Loss: 0.44073808193206787


Epoch 7:  37%|███▋      | 20/54 [00:38<01:06,  1.95s/it]

Loss: 0.3502166271209717


Epoch 7:  39%|███▉      | 21/54 [00:40<01:04,  1.95s/it]

Loss: 0.7658902406692505


Epoch 7:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 0.6875450015068054


Epoch 7:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.201893612742424


Epoch 7:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 0.4966488182544708


Epoch 7:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 0.23294825851917267


Epoch 7:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 0.3794766068458557


Epoch 7:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 0.4193994104862213


Epoch 7:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 0.3956276774406433


Epoch 7:  54%|█████▎    | 29/54 [00:56<00:48,  1.95s/it]

Loss: 1.5006976127624512


Epoch 7:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 1.0597052574157715


Epoch 7:  57%|█████▋    | 31/54 [01:00<00:44,  1.95s/it]

Loss: 0.6543099284172058


Epoch 7:  59%|█████▉    | 32/54 [01:02<00:42,  1.95s/it]

Loss: 0.6278570890426636


Epoch 7:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 0.4425557553768158


Epoch 7:  63%|██████▎   | 34/54 [01:06<00:38,  1.95s/it]

Loss: 0.2703331708908081


Epoch 7:  65%|██████▍   | 35/54 [01:08<00:37,  1.95s/it]

Loss: 0.33988720178604126


Epoch 7:  67%|██████▋   | 36/54 [01:10<00:35,  1.95s/it]

Loss: 1.0593725442886353


Epoch 7:  69%|██████▊   | 37/54 [01:12<00:33,  1.95s/it]

Loss: 1.0175609588623047


Epoch 7:  70%|███████   | 38/54 [01:14<00:31,  1.95s/it]

Loss: 0.48096394538879395


Epoch 7:  72%|███████▏  | 39/54 [01:15<00:29,  1.95s/it]

Loss: 0.5547370910644531


Epoch 7:  74%|███████▍  | 40/54 [01:17<00:27,  1.95s/it]

Loss: 1.0670102834701538


Epoch 7:  76%|███████▌  | 41/54 [01:19<00:25,  1.95s/it]

Loss: 0.4047877788543701


Epoch 7:  78%|███████▊  | 42/54 [01:21<00:23,  1.95s/it]

Loss: 0.1403338611125946


Epoch 7:  80%|███████▉  | 43/54 [01:23<00:21,  1.95s/it]

Loss: 0.5616320967674255


Epoch 7:  81%|████████▏ | 44/54 [01:25<00:19,  1.95s/it]

Loss: 0.8055210709571838


Epoch 7:  83%|████████▎ | 45/54 [01:27<00:17,  1.95s/it]

Loss: 1.1239351034164429


Epoch 7:  85%|████████▌ | 46/54 [01:29<00:15,  1.95s/it]

Loss: 0.4817538261413574


Epoch 7:  87%|████████▋ | 47/54 [01:31<00:13,  1.95s/it]

Loss: 1.0587190389633179


Epoch 7:  89%|████████▉ | 48/54 [01:33<00:11,  1.95s/it]

Loss: 0.6458524465560913


Epoch 7:  91%|█████████ | 49/54 [01:35<00:09,  1.95s/it]

Loss: 0.48361319303512573


Epoch 7:  93%|█████████▎| 50/54 [01:37<00:07,  1.95s/it]

Loss: 0.6638341546058655


Epoch 7:  94%|█████████▍| 51/54 [01:39<00:05,  1.95s/it]

Loss: 0.9336161017417908


Epoch 7:  96%|█████████▋| 52/54 [01:41<00:03,  1.95s/it]

Loss: 1.0508346557617188


Epoch 7:  98%|█████████▊| 53/54 [01:43<00:01,  1.95s/it]

Loss: 0.6914868354797363


Epoch 7: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.09898076206445694





Epoch 7 Validation Accuracy: 0.8177339901477833, F1-macro: 0.7198328919392741


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model6, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.7990196078431373, F1-macro: 0.6781962987187873


In [None]:
current_type = 'labeling'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 7. Magnification

In [None]:
# Add labels
data1_1_labels = list(data1['labeling'][data1_1.index])
data2_1_labels = list(data2['labeling'][data2_1.index])
data3_1_labels = list(data3['labeling'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model7 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model7.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model7.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model7(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model7, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/54 [00:01<01:42,  1.93s/it]

Loss: 2.9581737518310547


Epoch 1:   4%|▎         | 2/54 [00:03<01:40,  1.94s/it]

Loss: 2.8381199836730957


Epoch 1:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 1.3028273582458496


Epoch 1:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 1.1218528747558594


Epoch 1:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 1.759148359298706


Epoch 1:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 1.3124700784683228


Epoch 1:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.5341469645500183


Epoch 1:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 0.5899149775505066


Epoch 1:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 0.8485533595085144


Epoch 1:  19%|█▊        | 10/54 [00:19<01:25,  1.95s/it]

Loss: 1.104833960533142


Epoch 1:  20%|██        | 11/54 [00:21<01:23,  1.95s/it]

Loss: 1.7882593870162964


Epoch 1:  22%|██▏       | 12/54 [00:23<01:21,  1.95s/it]

Loss: 1.0116374492645264


Epoch 1:  24%|██▍       | 13/54 [00:25<01:19,  1.95s/it]

Loss: 1.733417272567749


Epoch 1:  26%|██▌       | 14/54 [00:27<01:17,  1.95s/it]

Loss: 1.219309687614441


Epoch 1:  28%|██▊       | 15/54 [00:29<01:15,  1.95s/it]

Loss: 0.9456576108932495


Epoch 1:  30%|██▉       | 16/54 [00:31<01:13,  1.95s/it]

Loss: 0.7209353446960449


Epoch 1:  31%|███▏      | 17/54 [00:33<01:11,  1.94s/it]

Loss: 2.4964747428894043


Epoch 1:  33%|███▎      | 18/54 [00:35<01:10,  1.95s/it]

Loss: 1.5334522724151611


Epoch 1:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 1.0257437229156494


Epoch 1:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 1.6632331609725952


Epoch 1:  39%|███▉      | 21/54 [00:40<01:04,  1.94s/it]

Loss: 1.1146681308746338


Epoch 1:  41%|████      | 22/54 [00:42<01:02,  1.94s/it]

Loss: 1.6752296686172485


Epoch 1:  43%|████▎     | 23/54 [00:44<01:00,  1.94s/it]

Loss: 1.455482840538025


Epoch 1:  44%|████▍     | 24/54 [00:46<00:58,  1.94s/it]

Loss: 0.9336014986038208


Epoch 1:  46%|████▋     | 25/54 [00:48<00:56,  1.94s/it]

Loss: 0.3582126200199127


Epoch 1:  48%|████▊     | 26/54 [00:50<00:54,  1.94s/it]

Loss: 1.3796260356903076


Epoch 1:  50%|█████     | 27/54 [00:52<00:52,  1.94s/it]

Loss: 1.2024998664855957


Epoch 1:  52%|█████▏    | 28/54 [00:54<00:50,  1.94s/it]

Loss: 0.6130086779594421


Epoch 1:  54%|█████▎    | 29/54 [00:56<00:48,  1.94s/it]

Loss: 0.34104031324386597


Epoch 1:  56%|█████▌    | 30/54 [00:58<00:46,  1.94s/it]

Loss: 0.5555890202522278


Epoch 1:  57%|█████▋    | 31/54 [01:00<00:44,  1.94s/it]

Loss: 0.9982726573944092


Epoch 1:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 0.7568826675415039


Epoch 1:  61%|██████    | 33/54 [01:04<00:40,  1.94s/it]

Loss: 1.1412954330444336


Epoch 1:  63%|██████▎   | 34/54 [01:06<00:38,  1.94s/it]

Loss: 0.6674009561538696


Epoch 1:  65%|██████▍   | 35/54 [01:08<00:36,  1.94s/it]

Loss: 0.6656533479690552


Epoch 1:  67%|██████▋   | 36/54 [01:10<00:35,  1.94s/it]

Loss: 0.3251005709171295


Epoch 1:  69%|██████▊   | 37/54 [01:11<00:33,  1.94s/it]

Loss: 0.6020065546035767


Epoch 1:  70%|███████   | 38/54 [01:13<00:31,  1.95s/it]

Loss: 0.6236777901649475


Epoch 1:  72%|███████▏  | 39/54 [01:15<00:29,  1.94s/it]

Loss: 0.5160297751426697


Epoch 1:  74%|███████▍  | 40/54 [01:17<00:27,  1.95s/it]

Loss: 0.2180735319852829


Epoch 1:  76%|███████▌  | 41/54 [01:19<00:25,  1.95s/it]

Loss: 0.8208615779876709


Epoch 1:  78%|███████▊  | 42/54 [01:21<00:23,  1.95s/it]

Loss: 1.0986838340759277


Epoch 1:  80%|███████▉  | 43/54 [01:23<00:21,  1.95s/it]

Loss: 0.814424455165863


Epoch 1:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 1.019606590270996


Epoch 1:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 0.3589959740638733


Epoch 1:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 0.6067132949829102


Epoch 1:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 1.7222256660461426


Epoch 1:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 1.511324405670166


Epoch 1:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 1.068016767501831


Epoch 1:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 1.9661741256713867


Epoch 1:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 1.3943982124328613


Epoch 1:  96%|█████████▋| 52/54 [01:41<00:03,  1.94s/it]

Loss: 0.9181615710258484


Epoch 1:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 0.6846941709518433


Epoch 1: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.8705906867980957





Epoch 1 Validation Accuracy: 0.8571428571428571, F1-macro: 0.46153846153846156


Epoch 2:   2%|▏         | 1/54 [00:01<01:42,  1.93s/it]

Loss: 1.3523401021957397


Epoch 2:   4%|▎         | 2/54 [00:03<01:40,  1.94s/it]

Loss: 2.1281471252441406


Epoch 2:   6%|▌         | 3/54 [00:05<01:39,  1.94s/it]

Loss: 2.715977191925049


Epoch 2:   7%|▋         | 4/54 [00:07<01:37,  1.94s/it]

Loss: 1.2024496793746948


Epoch 2:   9%|▉         | 5/54 [00:09<01:35,  1.94s/it]

Loss: 0.5735216736793518


Epoch 2:  11%|█         | 6/54 [00:11<01:33,  1.94s/it]

Loss: 1.0248548984527588


Epoch 2:  13%|█▎        | 7/54 [00:13<01:31,  1.94s/it]

Loss: 1.3059898614883423


Epoch 2:  15%|█▍        | 8/54 [00:15<01:29,  1.94s/it]

Loss: 0.6379618644714355


Epoch 2:  17%|█▋        | 9/54 [00:17<01:27,  1.94s/it]

Loss: 0.7881428003311157


Epoch 2:  19%|█▊        | 10/54 [00:19<01:25,  1.94s/it]

Loss: 0.6188132762908936


Epoch 2:  20%|██        | 11/54 [00:21<01:23,  1.94s/it]

Loss: 0.7264317274093628


Epoch 2:  22%|██▏       | 12/54 [00:23<01:21,  1.94s/it]

Loss: 0.358759343624115


Epoch 2:  24%|██▍       | 13/54 [00:25<01:19,  1.94s/it]

Loss: 0.5646288394927979


Epoch 2:  26%|██▌       | 14/54 [00:27<01:17,  1.94s/it]

Loss: 0.6301057934761047


Epoch 2:  28%|██▊       | 15/54 [00:29<01:15,  1.94s/it]

Loss: 0.6642233729362488


Epoch 2:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 0.5058779120445251


Epoch 2:  31%|███▏      | 17/54 [00:33<01:11,  1.94s/it]

Loss: 0.48148319125175476


Epoch 2:  33%|███▎      | 18/54 [00:34<01:09,  1.94s/it]

Loss: 0.1704866588115692


Epoch 2:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 0.48896655440330505


Epoch 2:  37%|███▋      | 20/54 [00:38<01:06,  1.95s/it]

Loss: 0.7595116496086121


Epoch 2:  39%|███▉      | 21/54 [00:40<01:04,  1.95s/it]

Loss: 0.47883835434913635


Epoch 2:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 0.42447060346603394


Epoch 2:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.5077114701271057


Epoch 2:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 0.5934484004974365


Epoch 2:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 0.6152911186218262


Epoch 2:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 0.46954283118247986


Epoch 2:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 0.8511114120483398


Epoch 2:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 0.34903576970100403


Epoch 2:  54%|█████▎    | 29/54 [00:56<00:48,  1.95s/it]

Loss: 0.2120833694934845


Epoch 2:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 1.240315556526184


Epoch 2:  57%|█████▋    | 31/54 [01:00<00:44,  1.95s/it]

Loss: 0.9748745560646057


Epoch 2:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 0.29967406392097473


Epoch 2:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 0.7128817439079285


Epoch 2:  63%|██████▎   | 34/54 [01:06<00:38,  1.94s/it]

Loss: 0.9721196889877319


Epoch 2:  65%|██████▍   | 35/54 [01:08<00:36,  1.95s/it]

Loss: 1.6437463760375977


Epoch 2:  67%|██████▋   | 36/54 [01:09<00:35,  1.94s/it]

Loss: 0.36249852180480957


Epoch 2:  69%|██████▊   | 37/54 [01:11<00:33,  1.94s/it]

Loss: 0.4120866358280182


Epoch 2:  70%|███████   | 38/54 [01:13<00:31,  1.94s/it]

Loss: 1.560442328453064


Epoch 2:  72%|███████▏  | 39/54 [01:15<00:29,  1.94s/it]

Loss: 0.383984237909317


Epoch 2:  74%|███████▍  | 40/54 [01:17<00:27,  1.94s/it]

Loss: 1.6489503383636475


Epoch 2:  76%|███████▌  | 41/54 [01:19<00:25,  1.94s/it]

Loss: 1.2126551866531372


Epoch 2:  78%|███████▊  | 42/54 [01:21<00:23,  1.94s/it]

Loss: 1.206515908241272


Epoch 2:  80%|███████▉  | 43/54 [01:23<00:21,  1.94s/it]

Loss: 0.5576353073120117


Epoch 2:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 1.2366917133331299


Epoch 2:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 0.8326709270477295


Epoch 2:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 0.5031461715698242


Epoch 2:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 0.41658324003219604


Epoch 2:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 0.4396347403526306


Epoch 2:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 1.8395771980285645


Epoch 2:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 0.5438184142112732


Epoch 2:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 0.6382846832275391


Epoch 2:  96%|█████████▋| 52/54 [01:41<00:03,  1.95s/it]

Loss: 0.9120664596557617


Epoch 2:  98%|█████████▊| 53/54 [01:43<00:01,  1.95s/it]

Loss: 0.48869767785072327


Epoch 2: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 1.1818920373916626





Epoch 2 Validation Accuracy: 0.8620689655172413, F1-macro: 0.6692271880819367


Epoch 3:   2%|▏         | 1/54 [00:01<01:42,  1.94s/it]

Loss: 0.35914164781570435


Epoch 3:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.34957897663116455


Epoch 3:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.19094187021255493


Epoch 3:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 0.18255630135536194


Epoch 3:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 0.24299457669258118


Epoch 3:  11%|█         | 6/54 [00:11<01:33,  1.94s/it]

Loss: 0.5608829259872437


Epoch 3:  13%|█▎        | 7/54 [00:13<01:31,  1.94s/it]

Loss: 0.1967840939760208


Epoch 3:  15%|█▍        | 8/54 [00:15<01:29,  1.94s/it]

Loss: 0.39430755376815796


Epoch 3:  17%|█▋        | 9/54 [00:17<01:27,  1.94s/it]

Loss: 0.6028551459312439


Epoch 3:  19%|█▊        | 10/54 [00:19<01:25,  1.94s/it]

Loss: 0.95527583360672


Epoch 3:  20%|██        | 11/54 [00:21<01:23,  1.94s/it]

Loss: 0.5885430574417114


Epoch 3:  22%|██▏       | 12/54 [00:23<01:21,  1.94s/it]

Loss: 0.7028604745864868


Epoch 3:  24%|██▍       | 13/54 [00:25<01:19,  1.94s/it]

Loss: 0.2688194513320923


Epoch 3:  26%|██▌       | 14/54 [00:27<01:17,  1.94s/it]

Loss: 0.36971476674079895


Epoch 3:  28%|██▊       | 15/54 [00:29<01:15,  1.94s/it]

Loss: 0.3966371417045593


Epoch 3:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 0.5387200117111206


Epoch 3:  31%|███▏      | 17/54 [00:33<01:11,  1.94s/it]

Loss: 0.424344539642334


Epoch 3:  33%|███▎      | 18/54 [00:35<01:09,  1.94s/it]

Loss: 0.4197768568992615


Epoch 3:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 0.4318119287490845


Epoch 3:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 0.5297017097473145


Epoch 3:  39%|███▉      | 21/54 [00:40<01:04,  1.94s/it]

Loss: 0.19458197057247162


Epoch 3:  41%|████      | 22/54 [00:42<01:02,  1.94s/it]

Loss: 0.514952540397644


Epoch 3:  43%|████▎     | 23/54 [00:44<01:00,  1.94s/it]

Loss: 0.7452551126480103


Epoch 3:  44%|████▍     | 24/54 [00:46<00:58,  1.94s/it]

Loss: 0.3522195518016815


Epoch 3:  46%|████▋     | 25/54 [00:48<00:56,  1.94s/it]

Loss: 0.549340546131134


Epoch 3:  48%|████▊     | 26/54 [00:50<00:54,  1.94s/it]

Loss: 0.7376586198806763


Epoch 3:  50%|█████     | 27/54 [00:52<00:52,  1.94s/it]

Loss: 0.47453373670578003


Epoch 3:  52%|█████▏    | 28/54 [00:54<00:50,  1.94s/it]

Loss: 0.6306174993515015


Epoch 3:  54%|█████▎    | 29/54 [00:56<00:48,  1.94s/it]

Loss: 0.7377102375030518


Epoch 3:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 0.49128642678260803


Epoch 3:  57%|█████▋    | 31/54 [01:00<00:44,  1.95s/it]

Loss: 0.5792176723480225


Epoch 3:  59%|█████▉    | 32/54 [01:02<00:42,  1.95s/it]

Loss: 1.0779643058776855


Epoch 3:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 0.4461224675178528


Epoch 3:  63%|██████▎   | 34/54 [01:06<00:38,  1.95s/it]

Loss: 1.2737085819244385


Epoch 3:  65%|██████▍   | 35/54 [01:08<00:36,  1.95s/it]

Loss: 0.34199702739715576


Epoch 3:  67%|██████▋   | 36/54 [01:10<00:35,  1.95s/it]

Loss: 0.8331683874130249


Epoch 3:  69%|██████▊   | 37/54 [01:11<00:33,  1.95s/it]

Loss: 0.7396748065948486


Epoch 3:  70%|███████   | 38/54 [01:13<00:31,  1.95s/it]

Loss: 0.4238469898700714


Epoch 3:  72%|███████▏  | 39/54 [01:15<00:29,  1.95s/it]

Loss: 0.4658581614494324


Epoch 3:  74%|███████▍  | 40/54 [01:17<00:27,  1.95s/it]

Loss: 0.4092610478401184


Epoch 3:  76%|███████▌  | 41/54 [01:19<00:25,  1.95s/it]

Loss: 0.7265193462371826


Epoch 3:  78%|███████▊  | 42/54 [01:21<00:23,  1.94s/it]

Loss: 0.8370306491851807


Epoch 3:  80%|███████▉  | 43/54 [01:23<00:21,  1.95s/it]

Loss: 0.5544949769973755


Epoch 3:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 1.3852945566177368


Epoch 3:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 1.272315502166748


Epoch 3:  85%|████████▌ | 46/54 [01:29<00:15,  1.95s/it]

Loss: 0.6531920433044434


Epoch 3:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 0.8022189140319824


Epoch 3:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 1.0448648929595947


Epoch 3:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 0.11912192404270172


Epoch 3:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 0.7155832648277283


Epoch 3:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 0.6417585611343384


Epoch 3:  96%|█████████▋| 52/54 [01:41<00:03,  1.94s/it]

Loss: 0.5972945690155029


Epoch 3:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 0.6761804819107056


Epoch 3: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.9054646492004395





Epoch 3 Validation Accuracy: 0.8226600985221675, F1-macro: 0.589438202247191


Epoch 4:   2%|▏         | 1/54 [00:01<01:42,  1.94s/it]

Loss: 0.22975678741931915


Epoch 4:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.5247559547424316


Epoch 4:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.39642518758773804


Epoch 4:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 0.36293017864227295


Epoch 4:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 0.4789028763771057


Epoch 4:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.5575743913650513


Epoch 4:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.4568314850330353


Epoch 4:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 0.5609147548675537


Epoch 4:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 0.5178573131561279


Epoch 4:  19%|█▊        | 10/54 [00:19<01:25,  1.95s/it]

Loss: 0.31079643964767456


Epoch 4:  20%|██        | 11/54 [00:21<01:23,  1.95s/it]

Loss: 0.6212645769119263


Epoch 4:  22%|██▏       | 12/54 [00:23<01:21,  1.95s/it]

Loss: 0.818246066570282


Epoch 4:  24%|██▍       | 13/54 [00:25<01:19,  1.95s/it]

Loss: 0.479271799325943


Epoch 4:  26%|██▌       | 14/54 [00:27<01:17,  1.95s/it]

Loss: 0.12090383470058441


Epoch 4:  28%|██▊       | 15/54 [00:29<01:15,  1.95s/it]

Loss: 0.560142457485199


Epoch 4:  30%|██▉       | 16/54 [00:31<01:13,  1.95s/it]

Loss: 0.6141359806060791


Epoch 4:  31%|███▏      | 17/54 [00:33<01:11,  1.95s/it]

Loss: 0.1272718906402588


Epoch 4:  33%|███▎      | 18/54 [00:35<01:10,  1.95s/it]

Loss: 0.4678688943386078


Epoch 4:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 0.38157403469085693


Epoch 4:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 0.43149179220199585


Epoch 4:  39%|███▉      | 21/54 [00:40<01:04,  1.94s/it]

Loss: 0.4761240482330322


Epoch 4:  41%|████      | 22/54 [00:42<01:02,  1.94s/it]

Loss: 0.45333462953567505


Epoch 4:  43%|████▎     | 23/54 [00:44<01:00,  1.94s/it]

Loss: 0.4019954204559326


Epoch 4:  44%|████▍     | 24/54 [00:46<00:58,  1.94s/it]

Loss: 0.6198751926422119


Epoch 4:  46%|████▋     | 25/54 [00:48<00:56,  1.94s/it]

Loss: 0.5136098861694336


Epoch 4:  48%|████▊     | 26/54 [00:50<00:54,  1.94s/it]

Loss: 1.4814788103103638


Epoch 4:  50%|█████     | 27/54 [00:52<00:52,  1.94s/it]

Loss: 1.536087989807129


Epoch 4:  52%|█████▏    | 28/54 [00:54<00:50,  1.94s/it]

Loss: 0.3716070353984833


Epoch 4:  54%|█████▎    | 29/54 [00:56<00:48,  1.94s/it]

Loss: 1.1230626106262207


Epoch 4:  56%|█████▌    | 30/54 [00:58<00:46,  1.94s/it]

Loss: 0.4829424321651459


Epoch 4:  57%|█████▋    | 31/54 [01:00<00:44,  1.94s/it]

Loss: 1.3542760610580444


Epoch 4:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 0.46718376874923706


Epoch 4:  61%|██████    | 33/54 [01:04<00:40,  1.95s/it]

Loss: 0.5177377462387085


Epoch 4:  63%|██████▎   | 34/54 [01:06<00:38,  1.94s/it]

Loss: 0.2757425606250763


Epoch 4:  65%|██████▍   | 35/54 [01:08<00:36,  1.95s/it]

Loss: 0.5092132091522217


Epoch 4:  67%|██████▋   | 36/54 [01:10<00:35,  1.94s/it]

Loss: 1.31131911277771


Epoch 4:  69%|██████▊   | 37/54 [01:11<00:33,  1.94s/it]

Loss: 1.8524296283721924


Epoch 4:  70%|███████   | 38/54 [01:13<00:31,  1.94s/it]

Loss: 1.2859147787094116


Epoch 4:  72%|███████▏  | 39/54 [01:15<00:29,  1.94s/it]

Loss: 0.44840824604034424


Epoch 4:  74%|███████▍  | 40/54 [01:17<00:27,  1.94s/it]

Loss: 1.5158709287643433


Epoch 4:  76%|███████▌  | 41/54 [01:19<00:25,  1.94s/it]

Loss: 1.1979432106018066


Epoch 4:  78%|███████▊  | 42/54 [01:21<00:23,  1.94s/it]

Loss: 0.5927489399909973


Epoch 4:  80%|███████▉  | 43/54 [01:23<00:21,  1.95s/it]

Loss: 1.599513292312622


Epoch 4:  81%|████████▏ | 44/54 [01:25<00:19,  1.95s/it]

Loss: 1.237703800201416


Epoch 4:  83%|████████▎ | 45/54 [01:27<00:17,  1.95s/it]

Loss: 0.4724974036216736


Epoch 4:  85%|████████▌ | 46/54 [01:29<00:15,  1.95s/it]

Loss: 1.012216329574585


Epoch 4:  87%|████████▋ | 47/54 [01:31<00:13,  1.95s/it]

Loss: 0.833186149597168


Epoch 4:  89%|████████▉ | 48/54 [01:33<00:11,  1.95s/it]

Loss: 1.6589076519012451


Epoch 4:  91%|█████████ | 49/54 [01:35<00:09,  1.95s/it]

Loss: 0.3083961009979248


Epoch 4:  93%|█████████▎| 50/54 [01:37<00:07,  1.95s/it]

Loss: 0.8396024703979492


Epoch 4:  94%|█████████▍| 51/54 [01:39<00:05,  1.95s/it]

Loss: 0.26791298389434814


Epoch 4:  96%|█████████▋| 52/54 [01:41<00:03,  1.95s/it]

Loss: 0.326323002576828


Epoch 4:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 0.7494848370552063


Epoch 4: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.19271044433116913





Epoch 4 Validation Accuracy: 0.8522167487684729, F1-macro: 0.6892857142857143


Epoch 5:   2%|▏         | 1/54 [00:01<01:42,  1.93s/it]

Loss: 0.0951300859451294


Epoch 5:   4%|▎         | 2/54 [00:03<01:40,  1.94s/it]

Loss: 0.29092511534690857


Epoch 5:   6%|▌         | 3/54 [00:05<01:38,  1.94s/it]

Loss: 0.5624577403068542


Epoch 5:   7%|▋         | 4/54 [00:07<01:37,  1.94s/it]

Loss: 0.5968696475028992


Epoch 5:   9%|▉         | 5/54 [00:09<01:35,  1.94s/it]

Loss: 0.3314632773399353


Epoch 5:  11%|█         | 6/54 [00:11<01:33,  1.94s/it]

Loss: 0.5506131649017334


Epoch 5:  13%|█▎        | 7/54 [00:13<01:31,  1.94s/it]

Loss: 0.6653591990470886


Epoch 5:  15%|█▍        | 8/54 [00:15<01:29,  1.94s/it]

Loss: 0.3902691602706909


Epoch 5:  17%|█▋        | 9/54 [00:17<01:27,  1.94s/it]

Loss: 0.2645385265350342


Epoch 5:  19%|█▊        | 10/54 [00:19<01:25,  1.94s/it]

Loss: 0.4038640260696411


Epoch 5:  20%|██        | 11/54 [00:21<01:23,  1.94s/it]

Loss: 0.5649248361587524


Epoch 5:  22%|██▏       | 12/54 [00:23<01:21,  1.94s/it]

Loss: 0.5846007466316223


Epoch 5:  24%|██▍       | 13/54 [00:25<01:19,  1.94s/it]

Loss: 0.8093151450157166


Epoch 5:  26%|██▌       | 14/54 [00:27<01:17,  1.94s/it]

Loss: 0.17205044627189636


Epoch 5:  28%|██▊       | 15/54 [00:29<01:15,  1.94s/it]

Loss: 0.4115450382232666


Epoch 5:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 0.20428729057312012


Epoch 5:  31%|███▏      | 17/54 [00:33<01:11,  1.94s/it]

Loss: 0.47751665115356445


Epoch 5:  33%|███▎      | 18/54 [00:34<01:09,  1.94s/it]

Loss: 0.40296369791030884


Epoch 5:  35%|███▌      | 19/54 [00:36<01:08,  1.94s/it]

Loss: 0.9130699634552002


Epoch 5:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 0.4108811318874359


Epoch 5:  39%|███▉      | 21/54 [00:40<01:04,  1.94s/it]

Loss: 0.41561833024024963


Epoch 5:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 0.25951725244522095


Epoch 5:  43%|████▎     | 23/54 [00:44<01:00,  1.94s/it]

Loss: 0.05850401148200035


Epoch 5:  44%|████▍     | 24/54 [00:46<00:58,  1.94s/it]

Loss: 0.4846148192882538


Epoch 5:  46%|████▋     | 25/54 [00:48<00:56,  1.94s/it]

Loss: 1.11790132522583


Epoch 5:  48%|████▊     | 26/54 [00:50<00:54,  1.94s/it]

Loss: 0.34911584854125977


Epoch 5:  50%|█████     | 27/54 [00:52<00:52,  1.94s/it]

Loss: 0.3735559582710266


Epoch 5:  52%|█████▏    | 28/54 [00:54<00:50,  1.94s/it]

Loss: 0.0851043164730072


Epoch 5:  54%|█████▎    | 29/54 [00:56<00:48,  1.94s/it]

Loss: 0.16365063190460205


Epoch 5:  56%|█████▌    | 30/54 [00:58<00:46,  1.94s/it]

Loss: 0.6732932329177856


Epoch 5:  57%|█████▋    | 31/54 [01:00<00:44,  1.94s/it]

Loss: 0.4117882251739502


Epoch 5:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 0.08086861670017242


Epoch 5:  61%|██████    | 33/54 [01:04<00:40,  1.94s/it]

Loss: 0.4587821960449219


Epoch 5:  63%|██████▎   | 34/54 [01:06<00:38,  1.94s/it]

Loss: 0.8478637337684631


Epoch 5:  65%|██████▍   | 35/54 [01:07<00:36,  1.94s/it]

Loss: 0.2747823894023895


Epoch 5:  67%|██████▋   | 36/54 [01:09<00:34,  1.94s/it]

Loss: 0.20239771902561188


Epoch 5:  69%|██████▊   | 37/54 [01:11<00:32,  1.94s/it]

Loss: 0.5402055382728577


Epoch 5:  70%|███████   | 38/54 [01:13<00:31,  1.94s/it]

Loss: 0.6423106789588928


Epoch 5:  72%|███████▏  | 39/54 [01:15<00:29,  1.94s/it]

Loss: 0.4928863048553467


Epoch 5:  74%|███████▍  | 40/54 [01:17<00:27,  1.94s/it]

Loss: 0.40254852175712585


Epoch 5:  76%|███████▌  | 41/54 [01:19<00:25,  1.94s/it]

Loss: 0.41969138383865356


Epoch 5:  78%|███████▊  | 42/54 [01:21<00:23,  1.94s/it]

Loss: 0.5732333660125732


Epoch 5:  80%|███████▉  | 43/54 [01:23<00:21,  1.94s/it]

Loss: 0.5941177606582642


Epoch 5:  81%|████████▏ | 44/54 [01:25<00:19,  1.94s/it]

Loss: 0.28146421909332275


Epoch 5:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 0.26236283779144287


Epoch 5:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 1.324559211730957


Epoch 5:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 0.5614739656448364


Epoch 5:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 0.18504658341407776


Epoch 5:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 0.8379651308059692


Epoch 5:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 0.3206367492675781


Epoch 5:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 0.2777063548564911


Epoch 5:  96%|█████████▋| 52/54 [01:41<00:03,  1.94s/it]

Loss: 0.454608678817749


Epoch 5:  98%|█████████▊| 53/54 [01:42<00:01,  1.94s/it]

Loss: 0.5771341323852539


Epoch 5: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.44194284081459045





Epoch 5 Validation Accuracy: 0.7192118226600985, F1-macro: 0.6192873358997137


Epoch 6:   2%|▏         | 1/54 [00:01<01:42,  1.93s/it]

Loss: 0.8870230317115784


Epoch 6:   4%|▎         | 2/54 [00:03<01:40,  1.94s/it]

Loss: 0.28275877237319946


Epoch 6:   6%|▌         | 3/54 [00:05<01:38,  1.94s/it]

Loss: 0.5122531652450562


Epoch 6:   7%|▋         | 4/54 [00:07<01:36,  1.94s/it]

Loss: 0.24682065844535828


Epoch 6:   9%|▉         | 5/54 [00:09<01:34,  1.94s/it]

Loss: 0.21999254822731018


Epoch 6:  11%|█         | 6/54 [00:11<01:33,  1.94s/it]

Loss: 0.5551844239234924


Epoch 6:  13%|█▎        | 7/54 [00:13<01:31,  1.94s/it]

Loss: 0.35291314125061035


Epoch 6:  15%|█▍        | 8/54 [00:15<01:29,  1.94s/it]

Loss: 0.30330124497413635


Epoch 6:  17%|█▋        | 9/54 [00:17<01:27,  1.94s/it]

Loss: 0.19751819968223572


Epoch 6:  19%|█▊        | 10/54 [00:19<01:25,  1.94s/it]

Loss: 0.1085415780544281


Epoch 6:  20%|██        | 11/54 [00:21<01:23,  1.94s/it]

Loss: 0.16221094131469727


Epoch 6:  22%|██▏       | 12/54 [00:23<01:21,  1.94s/it]

Loss: 0.3226052522659302


Epoch 6:  24%|██▍       | 13/54 [00:25<01:19,  1.94s/it]

Loss: 0.21816813945770264


Epoch 6:  26%|██▌       | 14/54 [00:27<01:17,  1.94s/it]

Loss: 0.3336935043334961


Epoch 6:  28%|██▊       | 15/54 [00:29<01:15,  1.94s/it]

Loss: 0.4125359058380127


Epoch 6:  30%|██▉       | 16/54 [00:31<01:13,  1.94s/it]

Loss: 0.2350960224866867


Epoch 6:  31%|███▏      | 17/54 [00:32<01:11,  1.94s/it]

Loss: 0.19643889367580414


Epoch 6:  33%|███▎      | 18/54 [00:34<01:09,  1.94s/it]

Loss: 0.450106680393219


Epoch 6:  35%|███▌      | 19/54 [00:36<01:07,  1.94s/it]

Loss: 0.6254396438598633


Epoch 6:  37%|███▋      | 20/54 [00:38<01:06,  1.94s/it]

Loss: 0.38168275356292725


Epoch 6:  39%|███▉      | 21/54 [00:40<01:04,  1.94s/it]

Loss: 0.49041157960891724


Epoch 6:  41%|████      | 22/54 [00:42<01:02,  1.94s/it]

Loss: 0.39517372846603394


Epoch 6:  43%|████▎     | 23/54 [00:44<01:00,  1.94s/it]

Loss: 0.6509851217269897


Epoch 6:  44%|████▍     | 24/54 [00:46<00:58,  1.94s/it]

Loss: 0.9294198751449585


Epoch 6:  46%|████▋     | 25/54 [00:48<00:56,  1.94s/it]

Loss: 0.34436583518981934


Epoch 6:  48%|████▊     | 26/54 [00:50<00:54,  1.94s/it]

Loss: 0.5277226567268372


Epoch 6:  50%|█████     | 27/54 [00:52<00:52,  1.94s/it]

Loss: 0.5887857675552368


Epoch 6:  52%|█████▏    | 28/54 [00:54<00:50,  1.94s/it]

Loss: 0.23982615768909454


Epoch 6:  54%|█████▎    | 29/54 [00:56<00:48,  1.94s/it]

Loss: 0.6044721007347107


Epoch 6:  56%|█████▌    | 30/54 [00:58<00:46,  1.94s/it]

Loss: 0.3512619733810425


Epoch 6:  57%|█████▋    | 31/54 [01:00<00:44,  1.94s/it]

Loss: 1.5455889701843262


Epoch 6:  59%|█████▉    | 32/54 [01:02<00:42,  1.94s/it]

Loss: 0.7077702283859253


Epoch 6:  61%|██████    | 33/54 [01:04<00:40,  1.94s/it]

Loss: 0.31809085607528687


Epoch 6:  63%|██████▎   | 34/54 [01:05<00:38,  1.94s/it]

Loss: 0.5336344838142395


Epoch 6:  65%|██████▍   | 35/54 [01:07<00:36,  1.94s/it]

Loss: 0.2875150144100189


Epoch 6:  67%|██████▋   | 36/54 [01:09<00:34,  1.94s/it]

Loss: 0.20148113369941711


Epoch 6:  69%|██████▊   | 37/54 [01:11<00:33,  1.94s/it]

Loss: 0.27778077125549316


Epoch 6:  70%|███████   | 38/54 [01:13<00:31,  1.94s/it]

Loss: 1.832188367843628


Epoch 6:  72%|███████▏  | 39/54 [01:15<00:29,  1.95s/it]

Loss: 0.12075789272785187


Epoch 6:  74%|███████▍  | 40/54 [01:17<00:27,  1.95s/it]

Loss: 0.6103820204734802


Epoch 6:  76%|███████▌  | 41/54 [01:19<00:25,  1.95s/it]

Loss: 0.307040274143219


Epoch 6:  78%|███████▊  | 42/54 [01:21<00:23,  1.95s/it]

Loss: 0.6683582663536072


Epoch 6:  80%|███████▉  | 43/54 [01:23<00:21,  1.95s/it]

Loss: 0.6530942916870117


Epoch 6:  81%|████████▏ | 44/54 [01:25<00:19,  1.95s/it]

Loss: 0.2944961190223694


Epoch 6:  83%|████████▎ | 45/54 [01:27<00:17,  1.94s/it]

Loss: 0.41173070669174194


Epoch 6:  85%|████████▌ | 46/54 [01:29<00:15,  1.94s/it]

Loss: 0.19632180035114288


Epoch 6:  87%|████████▋ | 47/54 [01:31<00:13,  1.94s/it]

Loss: 0.36724135279655457


Epoch 6:  89%|████████▉ | 48/54 [01:33<00:11,  1.94s/it]

Loss: 0.5270246863365173


Epoch 6:  91%|█████████ | 49/54 [01:35<00:09,  1.94s/it]

Loss: 0.20292501151561737


Epoch 6:  93%|█████████▎| 50/54 [01:37<00:07,  1.94s/it]

Loss: 0.8551437854766846


Epoch 6:  94%|█████████▍| 51/54 [01:39<00:05,  1.94s/it]

Loss: 0.19926448166370392


Epoch 6:  96%|█████████▋| 52/54 [01:40<00:03,  1.94s/it]

Loss: 0.16381245851516724


Epoch 6:  98%|█████████▊| 53/54 [01:42<00:01,  1.94s/it]

Loss: 0.1885581910610199


Epoch 6: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.3213236927986145





Epoch 6 Validation Accuracy: 0.8275862068965517, F1-macro: 0.6428391896647061


Epoch 7:   2%|▏         | 1/54 [00:01<01:43,  1.94s/it]

Loss: 0.24201048910617828


Epoch 7:   4%|▎         | 2/54 [00:03<01:41,  1.95s/it]

Loss: 0.1913243532180786


Epoch 7:   6%|▌         | 3/54 [00:05<01:39,  1.95s/it]

Loss: 0.22380147874355316


Epoch 7:   7%|▋         | 4/54 [00:07<01:37,  1.95s/it]

Loss: 0.5499308109283447


Epoch 7:   9%|▉         | 5/54 [00:09<01:35,  1.95s/it]

Loss: 0.1265781819820404


Epoch 7:  11%|█         | 6/54 [00:11<01:33,  1.95s/it]

Loss: 0.7807167768478394


Epoch 7:  13%|█▎        | 7/54 [00:13<01:31,  1.95s/it]

Loss: 0.7285398840904236


Epoch 7:  15%|█▍        | 8/54 [00:15<01:29,  1.95s/it]

Loss: 0.23927858471870422


Epoch 7:  17%|█▋        | 9/54 [00:17<01:27,  1.95s/it]

Loss: 0.3198600113391876


Epoch 7:  19%|█▊        | 10/54 [00:19<01:25,  1.95s/it]

Loss: 0.27177008986473083


Epoch 7:  20%|██        | 11/54 [00:21<01:23,  1.95s/it]

Loss: 0.6248807907104492


Epoch 7:  22%|██▏       | 12/54 [00:23<01:21,  1.95s/it]

Loss: 0.17911764979362488


Epoch 7:  24%|██▍       | 13/54 [00:25<01:19,  1.95s/it]

Loss: 0.27630066871643066


Epoch 7:  26%|██▌       | 14/54 [00:27<01:17,  1.95s/it]

Loss: 0.2721799314022064


Epoch 7:  28%|██▊       | 15/54 [00:29<01:16,  1.95s/it]

Loss: 0.15153613686561584


Epoch 7:  30%|██▉       | 16/54 [00:31<01:14,  1.95s/it]

Loss: 0.34473317861557007


Epoch 7:  31%|███▏      | 17/54 [00:33<01:12,  1.95s/it]

Loss: 0.27388742566108704


Epoch 7:  33%|███▎      | 18/54 [00:35<01:10,  1.95s/it]

Loss: 0.23419442772865295


Epoch 7:  35%|███▌      | 19/54 [00:37<01:08,  1.95s/it]

Loss: 0.30882027745246887


Epoch 7:  37%|███▋      | 20/54 [00:38<01:06,  1.95s/it]

Loss: 0.9516499042510986


Epoch 7:  39%|███▉      | 21/54 [00:40<01:04,  1.95s/it]

Loss: 0.30754196643829346


Epoch 7:  41%|████      | 22/54 [00:42<01:02,  1.95s/it]

Loss: 0.4636307656764984


Epoch 7:  43%|████▎     | 23/54 [00:44<01:00,  1.95s/it]

Loss: 0.056492097675800323


Epoch 7:  44%|████▍     | 24/54 [00:46<00:58,  1.95s/it]

Loss: 0.2419600784778595


Epoch 7:  46%|████▋     | 25/54 [00:48<00:56,  1.95s/it]

Loss: 0.07628551125526428


Epoch 7:  48%|████▊     | 26/54 [00:50<00:54,  1.95s/it]

Loss: 0.4629817008972168


Epoch 7:  50%|█████     | 27/54 [00:52<00:52,  1.95s/it]

Loss: 0.2783905863761902


Epoch 7:  52%|█████▏    | 28/54 [00:54<00:50,  1.95s/it]

Loss: 0.3969423472881317


Epoch 7:  54%|█████▎    | 29/54 [00:56<00:48,  1.95s/it]

Loss: 0.5612339377403259


Epoch 7:  56%|█████▌    | 30/54 [00:58<00:46,  1.95s/it]

Loss: 0.3836442828178406


Epoch 7:  57%|█████▋    | 31/54 [01:00<00:44,  1.95s/it]

Loss: 0.4586353003978729


Epoch 7:  59%|█████▉    | 32/54 [01:02<00:42,  1.95s/it]

Loss: 0.24265627562999725


Epoch 7:  61%|██████    | 33/54 [01:04<00:41,  1.96s/it]

Loss: 0.4449237585067749


Epoch 7:  63%|██████▎   | 34/54 [01:06<00:39,  1.96s/it]

Loss: 0.3028119206428528


Epoch 7:  65%|██████▍   | 35/54 [01:08<00:37,  1.96s/it]

Loss: 0.34620848298072815


Epoch 7:  67%|██████▋   | 36/54 [01:10<00:35,  1.96s/it]

Loss: 0.33114102482795715


Epoch 7:  69%|██████▊   | 37/54 [01:12<00:33,  1.96s/it]

Loss: 0.36929985880851746


Epoch 7:  70%|███████   | 38/54 [01:14<00:31,  1.95s/it]

Loss: 0.3811148405075073


Epoch 7:  72%|███████▏  | 39/54 [01:16<00:29,  1.96s/it]

Loss: 0.052839495241642


Epoch 7:  74%|███████▍  | 40/54 [01:18<00:27,  1.95s/it]

Loss: 0.38064318895339966


Epoch 7:  76%|███████▌  | 41/54 [01:19<00:25,  1.95s/it]

Loss: 0.2743312120437622


Epoch 7:  78%|███████▊  | 42/54 [01:21<00:23,  1.95s/it]

Loss: 0.269942969083786


Epoch 7:  80%|███████▉  | 43/54 [01:23<00:21,  1.95s/it]

Loss: 0.35888195037841797


Epoch 7:  81%|████████▏ | 44/54 [01:25<00:19,  1.95s/it]

Loss: 0.3793320655822754


Epoch 7:  83%|████████▎ | 45/54 [01:27<00:17,  1.95s/it]

Loss: 0.48851844668388367


Epoch 7:  85%|████████▌ | 46/54 [01:29<00:15,  1.95s/it]

Loss: 0.3137332797050476


Epoch 7:  87%|████████▋ | 47/54 [01:31<00:13,  1.95s/it]

Loss: 0.49126023054122925


Epoch 7:  89%|████████▉ | 48/54 [01:33<00:11,  1.95s/it]

Loss: 0.17219902575016022


Epoch 7:  91%|█████████ | 49/54 [01:35<00:09,  1.95s/it]

Loss: 0.2709920406341553


Epoch 7:  93%|█████████▎| 50/54 [01:37<00:07,  1.95s/it]

Loss: 0.5870295763015747


Epoch 7:  94%|█████████▍| 51/54 [01:39<00:05,  1.95s/it]

Loss: 0.37315279245376587


Epoch 7:  96%|█████████▋| 52/54 [01:41<00:03,  1.94s/it]

Loss: 0.29962560534477234


Epoch 7:  98%|█████████▊| 53/54 [01:43<00:01,  1.94s/it]

Loss: 1.0412416458129883


Epoch 7: 100%|██████████| 54/54 [01:44<00:00,  1.93s/it]

Loss: 0.77287757396698





Epoch 7 Validation Accuracy: 0.8571428571428571, F1-macro: 0.6228457940931513


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model7, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8431372549019608, F1-macro: 0.6353072625698324


In [None]:
current_type = 'magnification'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 8. Mind Reading

In [None]:
# Add labels
data1_1_labels = list(data1['mind reading'][data1_1.index])
data2_1_labels = list(data2['mind reading'][data2_1.index])
data3_1_labels = list(data3['mind reading'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model8 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model8.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model8.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model8(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model8, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/52 [00:01<01:39,  1.95s/it]

Loss: 13.617969512939453


Epoch 1:   4%|▍         | 2/52 [00:03<01:38,  1.98s/it]

Loss: 1.2222529649734497


Epoch 1:   6%|▌         | 3/52 [00:05<01:37,  2.00s/it]

Loss: 4.503809452056885


Epoch 1:   8%|▊         | 4/52 [00:08<01:36,  2.01s/it]

Loss: 7.794356822967529


Epoch 1:  10%|▉         | 5/52 [00:10<01:35,  2.03s/it]

Loss: 4.211381912231445


Epoch 1:  12%|█▏        | 6/52 [00:12<01:34,  2.05s/it]

Loss: 12.162936210632324


Epoch 1:  13%|█▎        | 7/52 [00:14<01:32,  2.06s/it]

Loss: 5.768294811248779


Epoch 1:  15%|█▌        | 8/52 [00:16<01:31,  2.08s/it]

Loss: 3.0863819122314453


Epoch 1:  17%|█▋        | 9/52 [00:18<01:29,  2.09s/it]

Loss: 1.5943543910980225


Epoch 1:  19%|█▉        | 10/52 [00:20<01:27,  2.09s/it]

Loss: 1.4043023586273193


Epoch 1:  21%|██        | 11/52 [00:22<01:25,  2.09s/it]

Loss: 3.7163963317871094


Epoch 1:  23%|██▎       | 12/52 [00:24<01:23,  2.08s/it]

Loss: 2.7067084312438965


Epoch 1:  25%|██▌       | 13/52 [00:26<01:20,  2.07s/it]

Loss: 1.013426423072815


Epoch 1:  27%|██▋       | 14/52 [00:28<01:18,  2.05s/it]

Loss: 1.2741615772247314


Epoch 1:  29%|██▉       | 15/52 [00:30<01:15,  2.04s/it]

Loss: 1.5556508302688599


Epoch 1:  31%|███       | 16/52 [00:32<01:12,  2.02s/it]

Loss: 2.2120561599731445


Epoch 1:  33%|███▎      | 17/52 [00:34<01:10,  2.00s/it]

Loss: 1.6078697443008423


Epoch 1:  35%|███▍      | 18/52 [00:36<01:07,  1.99s/it]

Loss: 1.887317419052124


Epoch 1:  37%|███▋      | 19/52 [00:38<01:05,  1.98s/it]

Loss: 2.29382586479187


Epoch 1:  38%|███▊      | 20/52 [00:40<01:02,  1.96s/it]

Loss: 1.7424499988555908


Epoch 1:  40%|████      | 21/52 [00:42<01:00,  1.95s/it]

Loss: 0.7894774675369263


Epoch 1:  42%|████▏     | 22/52 [00:44<00:58,  1.94s/it]

Loss: 2.2802844047546387


Epoch 1:  44%|████▍     | 23/52 [00:46<00:55,  1.93s/it]

Loss: 1.8274509906768799


Epoch 1:  46%|████▌     | 24/52 [00:48<00:53,  1.92s/it]

Loss: 1.057798981666565


Epoch 1:  48%|████▊     | 25/52 [00:50<00:51,  1.91s/it]

Loss: 1.4911068677902222


Epoch 1:  50%|█████     | 26/52 [00:51<00:49,  1.90s/it]

Loss: 1.2607204914093018


Epoch 1:  52%|█████▏    | 27/52 [00:53<00:47,  1.90s/it]

Loss: 0.9185010194778442


Epoch 1:  54%|█████▍    | 28/52 [00:55<00:45,  1.89s/it]

Loss: 1.7425259351730347


Epoch 1:  56%|█████▌    | 29/52 [00:57<00:43,  1.89s/it]

Loss: 1.13216233253479


Epoch 1:  58%|█████▊    | 30/52 [00:59<00:41,  1.89s/it]

Loss: 0.4531126022338867


Epoch 1:  60%|█████▉    | 31/52 [01:01<00:39,  1.89s/it]

Loss: 0.9412263035774231


Epoch 1:  62%|██████▏   | 32/52 [01:03<00:37,  1.89s/it]

Loss: 0.813326358795166


Epoch 1:  63%|██████▎   | 33/52 [01:05<00:35,  1.89s/it]

Loss: 0.795461118221283


Epoch 1:  65%|██████▌   | 34/52 [01:07<00:33,  1.89s/it]

Loss: 1.000733733177185


Epoch 1:  67%|██████▋   | 35/52 [01:08<00:32,  1.89s/it]

Loss: 0.7738784551620483


Epoch 1:  69%|██████▉   | 36/52 [01:10<00:30,  1.89s/it]

Loss: 0.984500527381897


Epoch 1:  71%|███████   | 37/52 [01:12<00:28,  1.90s/it]

Loss: 0.4091149568557739


Epoch 1:  73%|███████▎  | 38/52 [01:14<00:26,  1.90s/it]

Loss: 1.0117297172546387


Epoch 1:  75%|███████▌  | 39/52 [01:16<00:24,  1.91s/it]

Loss: 0.8270672559738159


Epoch 1:  77%|███████▋  | 40/52 [01:18<00:22,  1.91s/it]

Loss: 0.6612032651901245


Epoch 1:  79%|███████▉  | 41/52 [01:20<00:21,  1.92s/it]

Loss: 0.9245905876159668


Epoch 1:  81%|████████  | 42/52 [01:22<00:19,  1.93s/it]

Loss: 0.948735237121582


Epoch 1:  83%|████████▎ | 43/52 [01:24<00:17,  1.93s/it]

Loss: 0.519250750541687


Epoch 1:  85%|████████▍ | 44/52 [01:26<00:15,  1.94s/it]

Loss: 0.9890555143356323


Epoch 1:  87%|████████▋ | 45/52 [01:28<00:13,  1.94s/it]

Loss: 0.566311776638031


Epoch 1:  88%|████████▊ | 46/52 [01:30<00:11,  1.95s/it]

Loss: 0.40166112780570984


Epoch 1:  90%|█████████ | 47/52 [01:32<00:09,  1.96s/it]

Loss: 0.4421583414077759


Epoch 1:  92%|█████████▏| 48/52 [01:34<00:07,  1.96s/it]

Loss: 0.6636829376220703


Epoch 1:  94%|█████████▍| 49/52 [01:36<00:05,  1.97s/it]

Loss: 0.623089075088501


Epoch 1:  96%|█████████▌| 50/52 [01:38<00:03,  1.97s/it]

Loss: 0.5475401878356934


Epoch 1:  98%|█████████▊| 51/52 [01:40<00:01,  1.97s/it]

Loss: 0.30603164434432983


Epoch 1: 100%|██████████| 52/52 [01:40<00:00,  1.93s/it]

Loss: 1.164352536201477





Epoch 1 Validation Accuracy: 0.7635467980295566, F1-macro: 0.6518010291595198


Epoch 2:   2%|▏         | 1/52 [00:01<01:40,  1.97s/it]

Loss: 0.49140337109565735


Epoch 2:   4%|▍         | 2/52 [00:03<01:38,  1.97s/it]

Loss: 0.8737891912460327


Epoch 2:   6%|▌         | 3/52 [00:05<01:36,  1.97s/it]

Loss: 0.7902560830116272


Epoch 2:   8%|▊         | 4/52 [00:07<01:34,  1.97s/it]

Loss: 0.7311127185821533


Epoch 2:  10%|▉         | 5/52 [00:09<01:32,  1.96s/it]

Loss: 1.4857964515686035


Epoch 2:  12%|█▏        | 6/52 [00:11<01:30,  1.96s/it]

Loss: 0.5363197326660156


Epoch 2:  13%|█▎        | 7/52 [00:13<01:27,  1.95s/it]

Loss: 1.010597825050354


Epoch 2:  15%|█▌        | 8/52 [00:15<01:25,  1.95s/it]

Loss: 0.4412560760974884


Epoch 2:  17%|█▋        | 9/52 [00:17<01:23,  1.95s/it]

Loss: 1.076155424118042


Epoch 2:  19%|█▉        | 10/52 [00:19<01:21,  1.95s/it]

Loss: 0.6277296543121338


Epoch 2:  21%|██        | 11/52 [00:21<01:19,  1.95s/it]

Loss: 0.4471645653247833


Epoch 2:  23%|██▎       | 12/52 [00:23<01:17,  1.94s/it]

Loss: 1.1688506603240967


Epoch 2:  25%|██▌       | 13/52 [00:25<01:15,  1.94s/it]

Loss: 0.26790672540664673


Epoch 2:  27%|██▋       | 14/52 [00:27<01:13,  1.94s/it]

Loss: 0.2791900038719177


Epoch 2:  29%|██▉       | 15/52 [00:29<01:11,  1.93s/it]

Loss: 0.6515279412269592


Epoch 2:  31%|███       | 16/52 [00:31<01:09,  1.93s/it]

Loss: 0.7456988096237183


Epoch 2:  33%|███▎      | 17/52 [00:33<01:07,  1.93s/it]

Loss: 0.890978991985321


Epoch 2:  35%|███▍      | 18/52 [00:35<01:05,  1.93s/it]

Loss: 0.41176363825798035


Epoch 2:  37%|███▋      | 19/52 [00:36<01:03,  1.93s/it]

Loss: 0.7749165892601013


Epoch 2:  38%|███▊      | 20/52 [00:38<01:01,  1.93s/it]

Loss: 1.0579211711883545


Epoch 2:  40%|████      | 21/52 [00:40<00:59,  1.93s/it]

Loss: 0.3039994239807129


Epoch 2:  42%|████▏     | 22/52 [00:42<00:57,  1.93s/it]

Loss: 0.76416015625


Epoch 2:  44%|████▍     | 23/52 [00:44<00:55,  1.93s/it]

Loss: 0.7877297401428223


Epoch 2:  46%|████▌     | 24/52 [00:46<00:54,  1.93s/it]

Loss: 0.8524911403656006


Epoch 2:  48%|████▊     | 25/52 [00:48<00:52,  1.93s/it]

Loss: 0.5204318165779114


Epoch 2:  50%|█████     | 26/52 [00:50<00:50,  1.93s/it]

Loss: 0.29070717096328735


Epoch 2:  52%|█████▏    | 27/52 [00:52<00:48,  1.94s/it]

Loss: 0.8980843424797058


Epoch 2:  54%|█████▍    | 28/52 [00:54<00:46,  1.94s/it]

Loss: 0.6911989450454712


Epoch 2:  56%|█████▌    | 29/52 [00:56<00:44,  1.94s/it]

Loss: 0.49758902192115784


Epoch 2:  58%|█████▊    | 30/52 [00:58<00:42,  1.94s/it]

Loss: 0.7469960451126099


Epoch 2:  60%|█████▉    | 31/52 [01:00<00:40,  1.94s/it]

Loss: 0.9255990386009216


Epoch 2:  62%|██████▏   | 32/52 [01:02<00:38,  1.94s/it]

Loss: 0.08590742945671082


Epoch 2:  63%|██████▎   | 33/52 [01:04<00:36,  1.94s/it]

Loss: 0.8622455596923828


Epoch 2:  65%|██████▌   | 34/52 [01:06<00:34,  1.94s/it]

Loss: 0.6693377494812012


Epoch 2:  67%|██████▋   | 35/52 [01:07<00:33,  1.94s/it]

Loss: 0.32675349712371826


Epoch 2:  69%|██████▉   | 36/52 [01:09<00:31,  1.95s/it]

Loss: 1.0717225074768066


Epoch 2:  71%|███████   | 37/52 [01:11<00:29,  1.95s/it]

Loss: 0.7678137421607971


Epoch 2:  73%|███████▎  | 38/52 [01:13<00:27,  1.95s/it]

Loss: 0.48594141006469727


Epoch 2:  75%|███████▌  | 39/52 [01:15<00:25,  1.95s/it]

Loss: 0.5983990430831909


Epoch 2:  77%|███████▋  | 40/52 [01:17<00:23,  1.95s/it]

Loss: 0.5336483716964722


Epoch 2:  79%|███████▉  | 41/52 [01:19<00:21,  1.95s/it]

Loss: 0.26814186573028564


Epoch 2:  81%|████████  | 42/52 [01:21<00:19,  1.95s/it]

Loss: 0.5809496641159058


Epoch 2:  83%|████████▎ | 43/52 [01:23<00:17,  1.95s/it]

Loss: 0.1380012184381485


Epoch 2:  85%|████████▍ | 44/52 [01:25<00:15,  1.95s/it]

Loss: 0.33695968985557556


Epoch 2:  87%|████████▋ | 45/52 [01:27<00:13,  1.95s/it]

Loss: 0.20036964118480682


Epoch 2:  88%|████████▊ | 46/52 [01:29<00:11,  1.95s/it]

Loss: 0.48418402671813965


Epoch 2:  90%|█████████ | 47/52 [01:31<00:09,  1.95s/it]

Loss: 0.44927528500556946


Epoch 2:  92%|█████████▏| 48/52 [01:33<00:07,  1.95s/it]

Loss: 0.49546951055526733


Epoch 2:  94%|█████████▍| 49/52 [01:35<00:05,  1.95s/it]

Loss: 0.29311782121658325


Epoch 2:  96%|█████████▌| 50/52 [01:37<00:03,  1.95s/it]

Loss: 0.5473558902740479


Epoch 2:  98%|█████████▊| 51/52 [01:39<00:01,  1.94s/it]

Loss: 0.36812669038772583


Epoch 2: 100%|██████████| 52/52 [01:39<00:00,  1.92s/it]

Loss: 0.2460414618253708





Epoch 2 Validation Accuracy: 0.8226600985221675, F1-macro: 0.4774027459954233


Epoch 3:   2%|▏         | 1/52 [00:01<01:38,  1.93s/it]

Loss: 0.9048973321914673


Epoch 3:   4%|▍         | 2/52 [00:03<01:36,  1.94s/it]

Loss: 1.0053050518035889


Epoch 3:   6%|▌         | 3/52 [00:05<01:35,  1.94s/it]

Loss: 0.44295892119407654


Epoch 3:   8%|▊         | 4/52 [00:07<01:33,  1.94s/it]

Loss: 0.6201142072677612


Epoch 3:  10%|▉         | 5/52 [00:09<01:31,  1.94s/it]

Loss: 0.6019782423973083


Epoch 3:  12%|█▏        | 6/52 [00:11<01:29,  1.94s/it]

Loss: 0.355165034532547


Epoch 3:  13%|█▎        | 7/52 [00:13<01:27,  1.94s/it]

Loss: 0.16686588525772095


Epoch 3:  15%|█▌        | 8/52 [00:15<01:25,  1.94s/it]

Loss: 0.7834713459014893


Epoch 3:  17%|█▋        | 9/52 [00:17<01:23,  1.94s/it]

Loss: 0.9387195110321045


Epoch 3:  19%|█▉        | 10/52 [00:19<01:21,  1.94s/it]

Loss: 0.761937141418457


Epoch 3:  21%|██        | 11/52 [00:21<01:19,  1.94s/it]

Loss: 1.0892853736877441


Epoch 3:  23%|██▎       | 12/52 [00:23<01:17,  1.94s/it]

Loss: 0.24384713172912598


Epoch 3:  25%|██▌       | 13/52 [00:25<01:15,  1.94s/it]

Loss: 0.5387037396430969


Epoch 3:  27%|██▋       | 14/52 [00:27<01:13,  1.94s/it]

Loss: 1.2133440971374512


Epoch 3:  29%|██▉       | 15/52 [00:29<01:11,  1.94s/it]

Loss: 0.48356467485427856


Epoch 3:  31%|███       | 16/52 [00:31<01:09,  1.94s/it]

Loss: 0.4541712999343872


Epoch 3:  33%|███▎      | 17/52 [00:32<01:07,  1.94s/it]

Loss: 0.8402218818664551


Epoch 3:  35%|███▍      | 18/52 [00:34<01:06,  1.94s/it]

Loss: 0.8003031611442566


Epoch 3:  37%|███▋      | 19/52 [00:36<01:04,  1.94s/it]

Loss: 0.7520748376846313


Epoch 3:  38%|███▊      | 20/52 [00:38<01:02,  1.94s/it]

Loss: 0.9371331334114075


Epoch 3:  40%|████      | 21/52 [00:40<01:00,  1.95s/it]

Loss: 0.8007060289382935


Epoch 3:  42%|████▏     | 22/52 [00:42<00:58,  1.95s/it]

Loss: 1.145139455795288


Epoch 3:  44%|████▍     | 23/52 [00:44<00:56,  1.95s/it]

Loss: 0.42179521918296814


Epoch 3:  46%|████▌     | 24/52 [00:46<00:54,  1.95s/it]

Loss: 1.0071007013320923


Epoch 3:  48%|████▊     | 25/52 [00:48<00:52,  1.95s/it]

Loss: 1.1872501373291016


Epoch 3:  50%|█████     | 26/52 [00:50<00:50,  1.95s/it]

Loss: 0.6763348579406738


Epoch 3:  52%|█████▏    | 27/52 [00:52<00:48,  1.95s/it]

Loss: 0.7406423091888428


Epoch 3:  54%|█████▍    | 28/52 [00:54<00:46,  1.95s/it]

Loss: 0.6999824047088623


Epoch 3:  56%|█████▌    | 29/52 [00:56<00:44,  1.95s/it]

Loss: 1.1531689167022705


Epoch 3:  58%|█████▊    | 30/52 [00:58<00:42,  1.95s/it]

Loss: 0.7970591187477112


Epoch 3:  60%|█████▉    | 31/52 [01:00<00:40,  1.95s/it]

Loss: 0.5583657026290894


Epoch 3:  62%|██████▏   | 32/52 [01:02<00:38,  1.95s/it]

Loss: 0.6211893558502197


Epoch 3:  63%|██████▎   | 33/52 [01:04<00:36,  1.95s/it]

Loss: 1.6218311786651611


Epoch 3:  65%|██████▌   | 34/52 [01:06<00:35,  1.94s/it]

Loss: 0.690615177154541


Epoch 3:  67%|██████▋   | 35/52 [01:07<00:33,  1.95s/it]

Loss: 0.5454742312431335


Epoch 3:  69%|██████▉   | 36/52 [01:09<00:31,  1.95s/it]

Loss: 1.1060612201690674


Epoch 3:  71%|███████   | 37/52 [01:11<00:29,  1.95s/it]

Loss: 1.307541012763977


Epoch 3:  73%|███████▎  | 38/52 [01:13<00:27,  1.94s/it]

Loss: 0.49372828006744385


Epoch 3:  75%|███████▌  | 39/52 [01:15<00:25,  1.94s/it]

Loss: 0.8220500946044922


Epoch 3:  77%|███████▋  | 40/52 [01:17<00:23,  1.94s/it]

Loss: 0.8208376169204712


Epoch 3:  79%|███████▉  | 41/52 [01:19<00:21,  1.94s/it]

Loss: 0.6181893348693848


Epoch 3:  81%|████████  | 42/52 [01:21<00:19,  1.94s/it]

Loss: 0.26655298471450806


Epoch 3:  83%|████████▎ | 43/52 [01:23<00:17,  1.94s/it]

Loss: 1.839095115661621


Epoch 3:  85%|████████▍ | 44/52 [01:25<00:15,  1.94s/it]

Loss: 0.22109392285346985


Epoch 3:  87%|████████▋ | 45/52 [01:27<00:13,  1.94s/it]

Loss: 0.7623448371887207


Epoch 3:  88%|████████▊ | 46/52 [01:29<00:11,  1.94s/it]

Loss: 0.8087955713272095


Epoch 3:  90%|█████████ | 47/52 [01:31<00:09,  1.94s/it]

Loss: 0.27247005701065063


Epoch 3:  92%|█████████▏| 48/52 [01:33<00:07,  1.94s/it]

Loss: 0.4820931553840637


Epoch 3:  94%|█████████▍| 49/52 [01:35<00:05,  1.94s/it]

Loss: 0.6471074819564819


Epoch 3:  96%|█████████▌| 50/52 [01:37<00:03,  1.95s/it]

Loss: 0.5929776430130005


Epoch 3:  98%|█████████▊| 51/52 [01:39<00:01,  1.95s/it]

Loss: 0.2645420730113983


Epoch 3: 100%|██████████| 52/52 [01:39<00:00,  1.92s/it]

Loss: 0.37628674507141113





Epoch 3 Validation Accuracy: 0.8669950738916257, F1-macro: 0.6866746698679472


Epoch 4:   2%|▏         | 1/52 [00:01<01:39,  1.95s/it]

Loss: 0.8362804055213928


Epoch 4:   4%|▍         | 2/52 [00:03<01:37,  1.95s/it]

Loss: 0.28525397181510925


Epoch 4:   6%|▌         | 3/52 [00:05<01:35,  1.95s/it]

Loss: 0.4788188934326172


Epoch 4:   8%|▊         | 4/52 [00:07<01:33,  1.95s/it]

Loss: 0.2623763680458069


Epoch 4:  10%|▉         | 5/52 [00:09<01:31,  1.95s/it]

Loss: 1.1177194118499756


Epoch 4:  12%|█▏        | 6/52 [00:11<01:29,  1.95s/it]

Loss: 0.7444233894348145


Epoch 4:  13%|█▎        | 7/52 [00:13<01:27,  1.95s/it]

Loss: 0.589174211025238


Epoch 4:  15%|█▌        | 8/52 [00:15<01:25,  1.95s/it]

Loss: 1.4777964353561401


Epoch 4:  17%|█▋        | 9/52 [00:17<01:23,  1.95s/it]

Loss: 0.4296989142894745


Epoch 4:  19%|█▉        | 10/52 [00:19<01:21,  1.95s/it]

Loss: 0.6181502342224121


Epoch 4:  21%|██        | 11/52 [00:21<01:19,  1.95s/it]

Loss: 0.8754400014877319


Epoch 4:  23%|██▎       | 12/52 [00:23<01:17,  1.95s/it]

Loss: 0.3224969506263733


Epoch 4:  25%|██▌       | 13/52 [00:25<01:15,  1.95s/it]

Loss: 0.40326225757598877


Epoch 4:  27%|██▋       | 14/52 [00:27<01:13,  1.95s/it]

Loss: 0.7115463018417358


Epoch 4:  29%|██▉       | 15/52 [00:29<01:12,  1.95s/it]

Loss: 0.5491150617599487


Epoch 4:  31%|███       | 16/52 [00:31<01:10,  1.95s/it]

Loss: 0.8009923100471497


Epoch 4:  33%|███▎      | 17/52 [00:33<01:08,  1.95s/it]

Loss: 0.9046390056610107


Epoch 4:  35%|███▍      | 18/52 [00:35<01:06,  1.95s/it]

Loss: 0.4717026352882385


Epoch 4:  37%|███▋      | 19/52 [00:37<01:04,  1.95s/it]

Loss: 0.4108591079711914


Epoch 4:  38%|███▊      | 20/52 [00:38<01:02,  1.95s/it]

Loss: 0.8966715931892395


Epoch 4:  40%|████      | 21/52 [00:40<01:00,  1.95s/it]

Loss: 0.5822535753250122


Epoch 4:  42%|████▏     | 22/52 [00:42<00:58,  1.95s/it]

Loss: 0.18584749102592468


Epoch 4:  44%|████▍     | 23/52 [00:44<00:56,  1.95s/it]

Loss: 0.4733377695083618


Epoch 4:  46%|████▌     | 24/52 [00:46<00:54,  1.95s/it]

Loss: 0.461201548576355


Epoch 4:  48%|████▊     | 25/52 [00:48<00:52,  1.95s/it]

Loss: 0.35063135623931885


Epoch 4:  50%|█████     | 26/52 [00:50<00:50,  1.95s/it]

Loss: 0.49363434314727783


Epoch 4:  52%|█████▏    | 27/52 [00:52<00:48,  1.94s/it]

Loss: 0.6323971748352051


Epoch 4:  54%|█████▍    | 28/52 [00:54<00:46,  1.95s/it]

Loss: 0.48056328296661377


Epoch 4:  56%|█████▌    | 29/52 [00:56<00:44,  1.94s/it]

Loss: 0.09761206805706024


Epoch 4:  58%|█████▊    | 30/52 [00:58<00:42,  1.94s/it]

Loss: 0.7168542742729187


Epoch 4:  60%|█████▉    | 31/52 [01:00<00:40,  1.94s/it]

Loss: 0.3022305369377136


Epoch 4:  62%|██████▏   | 32/52 [01:02<00:38,  1.94s/it]

Loss: 0.43817412853240967


Epoch 4:  63%|██████▎   | 33/52 [01:04<00:36,  1.94s/it]

Loss: 0.6314566731452942


Epoch 4:  65%|██████▌   | 34/52 [01:06<00:35,  1.95s/it]

Loss: 0.5945274829864502


Epoch 4:  67%|██████▋   | 35/52 [01:08<00:33,  1.95s/it]

Loss: 0.41110044717788696


Epoch 4:  69%|██████▉   | 36/52 [01:10<00:31,  1.95s/it]

Loss: 0.795239269733429


Epoch 4:  71%|███████   | 37/52 [01:12<00:29,  1.95s/it]

Loss: 0.2561236619949341


Epoch 4:  73%|███████▎  | 38/52 [01:13<00:27,  1.95s/it]

Loss: 0.8797138333320618


Epoch 4:  75%|███████▌  | 39/52 [01:15<00:25,  1.95s/it]

Loss: 0.6811864376068115


Epoch 4:  77%|███████▋  | 40/52 [01:17<00:23,  1.95s/it]

Loss: 0.8873951435089111


Epoch 4:  79%|███████▉  | 41/52 [01:19<00:21,  1.95s/it]

Loss: 0.806449830532074


Epoch 4:  81%|████████  | 42/52 [01:21<00:19,  1.95s/it]

Loss: 0.8854451179504395


Epoch 4:  83%|████████▎ | 43/52 [01:23<00:17,  1.95s/it]

Loss: 0.46404123306274414


Epoch 4:  85%|████████▍ | 44/52 [01:25<00:15,  1.95s/it]

Loss: 0.30709347128868103


Epoch 4:  87%|████████▋ | 45/52 [01:27<00:13,  1.95s/it]

Loss: 1.1356379985809326


Epoch 4:  88%|████████▊ | 46/52 [01:29<00:11,  1.94s/it]

Loss: 1.261202096939087


Epoch 4:  90%|█████████ | 47/52 [01:31<00:09,  1.95s/it]

Loss: 0.5928906202316284


Epoch 4:  92%|█████████▏| 48/52 [01:33<00:07,  1.94s/it]

Loss: 1.1824911832809448


Epoch 4:  94%|█████████▍| 49/52 [01:35<00:05,  1.94s/it]

Loss: 0.48201751708984375


Epoch 4:  96%|█████████▌| 50/52 [01:37<00:03,  1.94s/it]

Loss: 0.3494080901145935


Epoch 4:  98%|█████████▊| 51/52 [01:39<00:01,  1.94s/it]

Loss: 0.3124213218688965


Epoch 4: 100%|██████████| 52/52 [01:39<00:00,  1.92s/it]

Loss: 0.08318398147821426





Epoch 4 Validation Accuracy: 0.8374384236453202, F1-macro: 0.617046818727491


Epoch 5:   2%|▏         | 1/52 [00:01<01:38,  1.93s/it]

Loss: 1.0675909519195557


Epoch 5:   4%|▍         | 2/52 [00:03<01:36,  1.94s/it]

Loss: 0.14889153838157654


Epoch 5:   6%|▌         | 3/52 [00:05<01:35,  1.94s/it]

Loss: 0.2297784984111786


Epoch 5:   8%|▊         | 4/52 [00:07<01:33,  1.94s/it]

Loss: 0.7856435775756836


Epoch 5:  10%|▉         | 5/52 [00:09<01:31,  1.94s/it]

Loss: 0.44209015369415283


Epoch 5:  12%|█▏        | 6/52 [00:11<01:29,  1.94s/it]

Loss: 0.9705029129981995


Epoch 5:  13%|█▎        | 7/52 [00:13<01:27,  1.94s/it]

Loss: 0.1118730902671814


Epoch 5:  15%|█▌        | 8/52 [00:15<01:25,  1.94s/it]

Loss: 0.6197925806045532


Epoch 5:  17%|█▋        | 9/52 [00:17<01:23,  1.94s/it]

Loss: 0.3232787549495697


Epoch 5:  19%|█▉        | 10/52 [00:19<01:21,  1.94s/it]

Loss: 0.4867827594280243


Epoch 5:  21%|██        | 11/52 [00:21<01:19,  1.94s/it]

Loss: 0.6738503575325012


Epoch 5:  23%|██▎       | 12/52 [00:23<01:17,  1.94s/it]

Loss: 0.4241029620170593


Epoch 5:  25%|██▌       | 13/52 [00:25<01:15,  1.95s/it]

Loss: 0.37696903944015503


Epoch 5:  27%|██▋       | 14/52 [00:27<01:13,  1.95s/it]

Loss: 0.6396459341049194


Epoch 5:  29%|██▉       | 15/52 [00:29<01:11,  1.95s/it]

Loss: 0.27589064836502075


Epoch 5:  31%|███       | 16/52 [00:31<01:10,  1.95s/it]

Loss: 0.37510791420936584


Epoch 5:  33%|███▎      | 17/52 [00:33<01:08,  1.95s/it]

Loss: 0.4887906610965729


Epoch 5:  35%|███▍      | 18/52 [00:34<01:06,  1.95s/it]

Loss: 0.3325379192829132


Epoch 5:  37%|███▋      | 19/52 [00:36<01:04,  1.95s/it]

Loss: 0.3036459982395172


Epoch 5:  38%|███▊      | 20/52 [00:38<01:02,  1.95s/it]

Loss: 0.3901047110557556


Epoch 5:  40%|████      | 21/52 [00:40<01:00,  1.95s/it]

Loss: 1.0097501277923584


Epoch 5:  42%|████▏     | 22/52 [00:42<00:58,  1.95s/it]

Loss: 0.27112242579460144


Epoch 5:  44%|████▍     | 23/52 [00:44<00:56,  1.95s/it]

Loss: 0.7193920612335205


Epoch 5:  46%|████▌     | 24/52 [00:46<00:54,  1.94s/it]

Loss: 0.43789607286453247


Epoch 5:  48%|████▊     | 25/52 [00:48<00:52,  1.94s/it]

Loss: 0.6765501499176025


Epoch 5:  50%|█████     | 26/52 [00:50<00:50,  1.94s/it]

Loss: 0.6719346642494202


Epoch 5:  52%|█████▏    | 27/52 [00:52<00:48,  1.94s/it]

Loss: 0.37875258922576904


Epoch 5:  54%|█████▍    | 28/52 [00:54<00:46,  1.94s/it]

Loss: 0.05068139731884003


Epoch 5:  56%|█████▌    | 29/52 [00:56<00:44,  1.94s/it]

Loss: 0.7015295624732971


Epoch 5:  58%|█████▊    | 30/52 [00:58<00:42,  1.94s/it]

Loss: 0.27703529596328735


Epoch 5:  60%|█████▉    | 31/52 [01:00<00:40,  1.94s/it]

Loss: 0.18935935199260712


Epoch 5:  62%|██████▏   | 32/52 [01:02<00:38,  1.94s/it]

Loss: 0.7197282910346985


Epoch 5:  63%|██████▎   | 33/52 [01:04<00:36,  1.94s/it]

Loss: 0.33264654874801636


Epoch 5:  65%|██████▌   | 34/52 [01:06<00:34,  1.94s/it]

Loss: 0.09942470490932465


Epoch 5:  67%|██████▋   | 35/52 [01:08<00:33,  1.94s/it]

Loss: 1.2061944007873535


Epoch 5:  69%|██████▉   | 36/52 [01:10<00:31,  1.94s/it]

Loss: 0.8987337350845337


Epoch 5:  71%|███████   | 37/52 [01:11<00:29,  1.94s/it]

Loss: 0.39949682354927063


Epoch 5:  73%|███████▎  | 38/52 [01:13<00:27,  1.94s/it]

Loss: 0.6951659917831421


Epoch 5:  75%|███████▌  | 39/52 [01:15<00:25,  1.94s/it]

Loss: 0.5082906484603882


Epoch 5:  77%|███████▋  | 40/52 [01:17<00:23,  1.94s/it]

Loss: 0.5873836874961853


Epoch 5:  79%|███████▉  | 41/52 [01:19<00:21,  1.95s/it]

Loss: 0.822415292263031


Epoch 5:  81%|████████  | 42/52 [01:21<00:19,  1.95s/it]

Loss: 0.5829626321792603


Epoch 5:  83%|████████▎ | 43/52 [01:23<00:17,  1.94s/it]

Loss: 0.5501938462257385


Epoch 5:  85%|████████▍ | 44/52 [01:25<00:15,  1.94s/it]

Loss: 0.9903848171234131


Epoch 5:  87%|████████▋ | 45/52 [01:27<00:13,  1.94s/it]

Loss: 0.7051900029182434


Epoch 5:  88%|████████▊ | 46/52 [01:29<00:11,  1.94s/it]

Loss: 0.6394542455673218


Epoch 5:  90%|█████████ | 47/52 [01:31<00:09,  1.94s/it]

Loss: 0.3156771659851074


Epoch 5:  92%|█████████▏| 48/52 [01:33<00:07,  1.94s/it]

Loss: 0.5499163866043091


Epoch 5:  94%|█████████▍| 49/52 [01:35<00:05,  1.95s/it]

Loss: 0.32546207308769226


Epoch 5:  96%|█████████▌| 50/52 [01:37<00:03,  1.94s/it]

Loss: 0.5449600219726562


Epoch 5:  98%|█████████▊| 51/52 [01:39<00:01,  1.94s/it]

Loss: 0.656557559967041


Epoch 5: 100%|██████████| 52/52 [01:39<00:00,  1.92s/it]

Loss: 0.0914333388209343





Epoch 5 Validation Accuracy: 0.8374384236453202, F1-macro: 0.5523554961577013


Epoch 6:   2%|▏         | 1/52 [00:01<01:38,  1.93s/it]

Loss: 0.8132109642028809


Epoch 6:   4%|▍         | 2/52 [00:03<01:37,  1.94s/it]

Loss: 0.6352262496948242


Epoch 6:   6%|▌         | 3/52 [00:05<01:35,  1.94s/it]

Loss: 1.318056583404541


Epoch 6:   8%|▊         | 4/52 [00:07<01:33,  1.94s/it]

Loss: 0.42989563941955566


Epoch 6:  10%|▉         | 5/52 [00:09<01:31,  1.94s/it]

Loss: 0.902338445186615


Epoch 6:  12%|█▏        | 6/52 [00:11<01:29,  1.94s/it]

Loss: 0.8339760899543762


Epoch 6:  13%|█▎        | 7/52 [00:13<01:27,  1.94s/it]

Loss: 0.6348267793655396


Epoch 6:  15%|█▌        | 8/52 [00:15<01:25,  1.94s/it]

Loss: 1.1108111143112183


Epoch 6:  17%|█▋        | 9/52 [00:17<01:23,  1.94s/it]

Loss: 0.8306418657302856


Epoch 6:  19%|█▉        | 10/52 [00:19<01:21,  1.94s/it]

Loss: 0.5624710917472839


Epoch 6:  21%|██        | 11/52 [00:21<01:19,  1.95s/it]

Loss: 0.49498632550239563


Epoch 6:  23%|██▎       | 12/52 [00:23<01:17,  1.94s/it]

Loss: 0.45065248012542725


Epoch 6:  25%|██▌       | 13/52 [00:25<01:15,  1.94s/it]

Loss: 0.6743779182434082


Epoch 6:  27%|██▋       | 14/52 [00:27<01:13,  1.94s/it]

Loss: 0.4045962393283844


Epoch 6:  29%|██▉       | 15/52 [00:29<01:11,  1.94s/it]

Loss: 0.300620973110199


Epoch 6:  31%|███       | 16/52 [00:31<01:09,  1.94s/it]

Loss: 0.31829625368118286


Epoch 6:  33%|███▎      | 17/52 [00:33<01:08,  1.94s/it]

Loss: 0.22833624482154846


Epoch 6:  35%|███▍      | 18/52 [00:34<01:06,  1.94s/it]

Loss: 0.17780737578868866


Epoch 6:  37%|███▋      | 19/52 [00:36<01:04,  1.94s/it]

Loss: 0.5025780200958252


Epoch 6:  38%|███▊      | 20/52 [00:38<01:02,  1.94s/it]

Loss: 0.24665595591068268


Epoch 6:  40%|████      | 21/52 [00:40<01:00,  1.94s/it]

Loss: 0.3126264214515686


Epoch 6:  42%|████▏     | 22/52 [00:42<00:58,  1.94s/it]

Loss: 0.3755795359611511


Epoch 6:  44%|████▍     | 23/52 [00:44<00:56,  1.94s/it]

Loss: 0.22931155562400818


Epoch 6:  46%|████▌     | 24/52 [00:46<00:54,  1.94s/it]

Loss: 0.42740944027900696


Epoch 6:  48%|████▊     | 25/52 [00:48<00:52,  1.94s/it]

Loss: 0.6303613781929016


Epoch 6:  50%|█████     | 26/52 [00:50<00:50,  1.94s/it]

Loss: 0.3107335567474365


Epoch 6:  52%|█████▏    | 27/52 [00:52<00:48,  1.94s/it]

Loss: 0.18505625426769257


Epoch 6:  54%|█████▍    | 28/52 [00:54<00:46,  1.95s/it]

Loss: 0.24615855515003204


Epoch 6:  56%|█████▌    | 29/52 [00:56<00:44,  1.94s/it]

Loss: 0.3681130111217499


Epoch 6:  58%|█████▊    | 30/52 [00:58<00:42,  1.95s/it]

Loss: 0.3423286974430084


Epoch 6:  60%|█████▉    | 31/52 [01:00<00:40,  1.95s/it]

Loss: 0.26812297105789185


Epoch 6:  62%|██████▏   | 32/52 [01:02<00:38,  1.95s/it]

Loss: 0.3240671455860138


Epoch 6:  63%|██████▎   | 33/52 [01:04<00:36,  1.95s/it]

Loss: 0.19675534963607788


Epoch 6:  65%|██████▌   | 34/52 [01:06<00:35,  1.95s/it]

Loss: 0.41015806794166565


Epoch 6:  67%|██████▋   | 35/52 [01:08<00:33,  1.95s/it]

Loss: 0.5762788653373718


Epoch 6:  69%|██████▉   | 36/52 [01:10<00:31,  1.95s/it]

Loss: 0.26666831970214844


Epoch 6:  71%|███████   | 37/52 [01:11<00:29,  1.95s/it]

Loss: 0.38188663125038147


Epoch 6:  73%|███████▎  | 38/52 [01:13<00:27,  1.95s/it]

Loss: 0.4071752429008484


Epoch 6:  75%|███████▌  | 39/52 [01:15<00:25,  1.95s/it]

Loss: 0.3366554379463196


Epoch 6:  77%|███████▋  | 40/52 [01:17<00:23,  1.95s/it]

Loss: 0.46701860427856445


Epoch 6:  79%|███████▉  | 41/52 [01:19<00:21,  1.95s/it]

Loss: 0.4365101158618927


Epoch 6:  81%|████████  | 42/52 [01:21<00:19,  1.95s/it]

Loss: 0.17943106591701508


Epoch 6:  83%|████████▎ | 43/52 [01:23<00:17,  1.95s/it]

Loss: 0.4026126265525818


Epoch 6:  85%|████████▍ | 44/52 [01:25<00:15,  1.95s/it]

Loss: 0.4351159930229187


Epoch 6:  87%|████████▋ | 45/52 [01:27<00:13,  1.95s/it]

Loss: 0.2508699595928192


Epoch 6:  88%|████████▊ | 46/52 [01:29<00:11,  1.95s/it]

Loss: 0.5410090684890747


Epoch 6:  90%|█████████ | 47/52 [01:31<00:09,  1.95s/it]

Loss: 0.29280877113342285


Epoch 6:  92%|█████████▏| 48/52 [01:33<00:07,  1.95s/it]

Loss: 0.3933085799217224


Epoch 6:  94%|█████████▍| 49/52 [01:35<00:05,  1.95s/it]

Loss: 0.3381924033164978


Epoch 6:  96%|█████████▌| 50/52 [01:37<00:03,  1.95s/it]

Loss: 0.28900399804115295


Epoch 6:  98%|█████████▊| 51/52 [01:39<00:01,  1.95s/it]

Loss: 0.505707859992981


Epoch 6: 100%|██████████| 52/52 [01:39<00:00,  1.92s/it]

Loss: 0.3083687424659729





Epoch 6 Validation Accuracy: 0.8177339901477833, F1-macro: 0.7561759454633987


Epoch 7:   2%|▏         | 1/52 [00:01<01:39,  1.94s/it]

Loss: 0.4665215313434601


Epoch 7:   4%|▍         | 2/52 [00:03<01:37,  1.94s/it]

Loss: 0.3789844512939453


Epoch 7:   6%|▌         | 3/52 [00:05<01:35,  1.94s/it]

Loss: 0.7283488512039185


Epoch 7:   8%|▊         | 4/52 [00:07<01:33,  1.94s/it]

Loss: 0.33735156059265137


Epoch 7:  10%|▉         | 5/52 [00:09<01:31,  1.94s/it]

Loss: 0.444180965423584


Epoch 7:  12%|█▏        | 6/52 [00:11<01:29,  1.94s/it]

Loss: 0.15787164866924286


Epoch 7:  13%|█▎        | 7/52 [00:13<01:27,  1.94s/it]

Loss: 0.3892994225025177


Epoch 7:  15%|█▌        | 8/52 [00:15<01:25,  1.94s/it]

Loss: 0.237339586019516


Epoch 7:  17%|█▋        | 9/52 [00:17<01:23,  1.94s/it]

Loss: 0.7511974573135376


Epoch 7:  19%|█▉        | 10/52 [00:19<01:21,  1.94s/it]

Loss: 0.20017002522945404


Epoch 7:  21%|██        | 11/52 [00:21<01:19,  1.95s/it]

Loss: 0.7181634306907654


Epoch 7:  23%|██▎       | 12/52 [00:23<01:17,  1.95s/it]

Loss: 0.3055815100669861


Epoch 7:  25%|██▌       | 13/52 [00:25<01:15,  1.95s/it]

Loss: 0.546861469745636


Epoch 7:  27%|██▋       | 14/52 [00:27<01:13,  1.95s/it]

Loss: 0.5050177574157715


Epoch 7:  29%|██▉       | 15/52 [00:29<01:12,  1.95s/it]

Loss: 0.5431795120239258


Epoch 7:  31%|███       | 16/52 [00:31<01:10,  1.95s/it]

Loss: 0.3824036419391632


Epoch 7:  33%|███▎      | 17/52 [00:33<01:08,  1.95s/it]

Loss: 0.7172315120697021


Epoch 7:  35%|███▍      | 18/52 [00:35<01:06,  1.95s/it]

Loss: 0.44125843048095703


Epoch 7:  37%|███▋      | 19/52 [00:36<01:04,  1.95s/it]

Loss: 0.52199387550354


Epoch 7:  38%|███▊      | 20/52 [00:38<01:02,  1.95s/it]

Loss: 0.32118508219718933


Epoch 7:  40%|████      | 21/52 [00:40<01:00,  1.94s/it]

Loss: 0.7508596181869507


Epoch 7:  42%|████▏     | 22/52 [00:42<00:58,  1.94s/it]

Loss: 0.5420923233032227


Epoch 7:  44%|████▍     | 23/52 [00:44<00:56,  1.94s/it]

Loss: 0.5113760828971863


Epoch 7:  46%|████▌     | 24/52 [00:46<00:54,  1.95s/it]

Loss: 0.48389023542404175


Epoch 7:  48%|████▊     | 25/52 [00:48<00:52,  1.94s/it]

Loss: 0.4879216253757477


Epoch 7:  50%|█████     | 26/52 [00:50<00:50,  1.95s/it]

Loss: 0.24473613500595093


Epoch 7:  52%|█████▏    | 27/52 [00:52<00:48,  1.94s/it]

Loss: 0.4707897901535034


Epoch 7:  54%|█████▍    | 28/52 [00:54<00:46,  1.94s/it]

Loss: 0.6357547044754028


Epoch 7:  56%|█████▌    | 29/52 [00:56<00:44,  1.94s/it]

Loss: 0.5994035005569458


Epoch 7:  58%|█████▊    | 30/52 [00:58<00:42,  1.94s/it]

Loss: 0.15791167318820953


Epoch 7:  60%|█████▉    | 31/52 [01:00<00:40,  1.94s/it]

Loss: 0.5553436875343323


Epoch 7:  62%|██████▏   | 32/52 [01:02<00:38,  1.94s/it]

Loss: 0.604362964630127


Epoch 7:  63%|██████▎   | 33/52 [01:04<00:36,  1.94s/it]

Loss: 0.44823166728019714


Epoch 7:  65%|██████▌   | 34/52 [01:06<00:34,  1.94s/it]

Loss: 0.28561320900917053


Epoch 7:  67%|██████▋   | 35/52 [01:08<00:33,  1.94s/it]

Loss: 0.4443594813346863


Epoch 7:  69%|██████▉   | 36/52 [01:10<00:31,  1.94s/it]

Loss: 1.3932644128799438


Epoch 7:  71%|███████   | 37/52 [01:11<00:29,  1.94s/it]

Loss: 0.303214430809021


Epoch 7:  73%|███████▎  | 38/52 [01:13<00:27,  1.94s/it]

Loss: 0.8183974027633667


Epoch 7:  75%|███████▌  | 39/52 [01:15<00:25,  1.94s/it]

Loss: 0.5130518674850464


Epoch 7:  77%|███████▋  | 40/52 [01:17<00:23,  1.94s/it]

Loss: 1.0158538818359375


Epoch 7:  79%|███████▉  | 41/52 [01:19<00:21,  1.94s/it]

Loss: 0.7380481958389282


Epoch 7:  81%|████████  | 42/52 [01:21<00:19,  1.94s/it]

Loss: 0.4838845133781433


Epoch 7:  83%|████████▎ | 43/52 [01:23<00:17,  1.94s/it]

Loss: 0.24192635715007782


Epoch 7:  85%|████████▍ | 44/52 [01:25<00:15,  1.94s/it]

Loss: 1.492514729499817


Epoch 7:  87%|████████▋ | 45/52 [01:27<00:13,  1.94s/it]

Loss: 0.15037257969379425


Epoch 7:  88%|████████▊ | 46/52 [01:29<00:11,  1.94s/it]

Loss: 0.48475202918052673


Epoch 7:  90%|█████████ | 47/52 [01:31<00:09,  1.94s/it]

Loss: 2.413121223449707


Epoch 7:  92%|█████████▏| 48/52 [01:33<00:07,  1.94s/it]

Loss: 1.1546530723571777


Epoch 7:  94%|█████████▍| 49/52 [01:35<00:05,  1.94s/it]

Loss: 0.5786861181259155


Epoch 7:  96%|█████████▌| 50/52 [01:37<00:03,  1.94s/it]

Loss: 1.184607982635498


Epoch 7:  98%|█████████▊| 51/52 [01:39<00:01,  1.94s/it]

Loss: 0.508968710899353


Epoch 7: 100%|██████████| 52/52 [01:39<00:00,  1.92s/it]

Loss: 1.1984925270080566





Epoch 7 Validation Accuracy: 0.8768472906403941, F1-macro: 0.71969069317868


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model8, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8431372549019608, F1-macro: 0.6473638720829732


In [None]:
current_type = 'mind reading'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 9. Overgeneralizing

In [None]:
# Add labels
data1_1_labels = list(data1['overgeneralizing'][data1_1.index]) # data1에서는 명칭이 다름.
data2_1_labels = list(data2['overgeneralization'][data2_1.index])
data3_1_labels = list(data3['overgeneralization'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model9 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model9.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model9.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model9(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model9, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/53 [00:01<01:39,  1.91s/it]

Loss: 3.3430519104003906


Epoch 1:   4%|▍         | 2/53 [00:03<01:38,  1.92s/it]

Loss: 6.321318626403809


Epoch 1:   6%|▌         | 3/53 [00:05<01:36,  1.93s/it]

Loss: 3.401599407196045


Epoch 1:   8%|▊         | 4/53 [00:07<01:34,  1.93s/it]

Loss: 7.984327793121338


Epoch 1:   9%|▉         | 5/53 [00:09<01:32,  1.94s/it]

Loss: 3.76102876663208


Epoch 1:  11%|█▏        | 6/53 [00:11<01:31,  1.94s/it]

Loss: 4.535874366760254


Epoch 1:  13%|█▎        | 7/53 [00:13<01:29,  1.94s/it]

Loss: 2.083153247833252


Epoch 1:  15%|█▌        | 8/53 [00:15<01:27,  1.94s/it]

Loss: 1.556077480316162


Epoch 1:  17%|█▋        | 9/53 [00:17<01:25,  1.94s/it]

Loss: 3.684605598449707


Epoch 1:  19%|█▉        | 10/53 [00:19<01:23,  1.94s/it]

Loss: 5.7671966552734375


Epoch 1:  21%|██        | 11/53 [00:21<01:21,  1.94s/it]

Loss: 1.2725586891174316


Epoch 1:  23%|██▎       | 12/53 [00:23<01:19,  1.94s/it]

Loss: 1.0686392784118652


Epoch 1:  25%|██▍       | 13/53 [00:25<01:17,  1.94s/it]

Loss: 1.706366777420044


Epoch 1:  26%|██▋       | 14/53 [00:27<01:15,  1.94s/it]

Loss: 2.288534164428711


Epoch 1:  28%|██▊       | 15/53 [00:29<01:13,  1.95s/it]

Loss: 4.010605812072754


Epoch 1:  30%|███       | 16/53 [00:31<01:11,  1.95s/it]

Loss: 4.130504131317139


Epoch 1:  32%|███▏      | 17/53 [00:33<01:10,  1.95s/it]

Loss: 3.1286370754241943


Epoch 1:  34%|███▍      | 18/53 [00:34<01:08,  1.94s/it]

Loss: 3.035583972930908


Epoch 1:  36%|███▌      | 19/53 [00:36<01:06,  1.95s/it]

Loss: 2.648099184036255


Epoch 1:  38%|███▊      | 20/53 [00:38<01:04,  1.95s/it]

Loss: 1.9269146919250488


Epoch 1:  40%|███▉      | 21/53 [00:40<01:02,  1.95s/it]

Loss: 0.6521969437599182


Epoch 1:  42%|████▏     | 22/53 [00:42<01:00,  1.95s/it]

Loss: 3.2803139686584473


Epoch 1:  43%|████▎     | 23/53 [00:44<00:58,  1.95s/it]

Loss: 3.756528377532959


Epoch 1:  45%|████▌     | 24/53 [00:46<00:56,  1.95s/it]

Loss: 1.5852954387664795


Epoch 1:  47%|████▋     | 25/53 [00:48<00:54,  1.95s/it]

Loss: 1.1301658153533936


Epoch 1:  49%|████▉     | 26/53 [00:50<00:52,  1.95s/it]

Loss: 2.1746699810028076


Epoch 1:  51%|█████     | 27/53 [00:52<00:50,  1.95s/it]

Loss: 0.8935071229934692


Epoch 1:  53%|█████▎    | 28/53 [00:54<00:48,  1.95s/it]

Loss: 0.9745696187019348


Epoch 1:  55%|█████▍    | 29/53 [00:56<00:46,  1.95s/it]

Loss: 1.8905152082443237


Epoch 1:  57%|█████▋    | 30/53 [00:58<00:44,  1.95s/it]

Loss: 2.0076282024383545


Epoch 1:  58%|█████▊    | 31/53 [01:00<00:42,  1.95s/it]

Loss: 0.6818002462387085


Epoch 1:  60%|██████    | 32/53 [01:02<00:40,  1.94s/it]

Loss: 0.926742672920227


Epoch 1:  62%|██████▏   | 33/53 [01:04<00:38,  1.95s/it]

Loss: 0.6051234006881714


Epoch 1:  64%|██████▍   | 34/53 [01:06<00:36,  1.95s/it]

Loss: 1.0785431861877441


Epoch 1:  66%|██████▌   | 35/53 [01:08<00:35,  1.94s/it]

Loss: 1.2040594816207886


Epoch 1:  68%|██████▊   | 36/53 [01:09<00:33,  1.95s/it]

Loss: 1.267329216003418


Epoch 1:  70%|██████▉   | 37/53 [01:11<00:31,  1.94s/it]

Loss: 0.9614041447639465


Epoch 1:  72%|███████▏  | 38/53 [01:13<00:29,  1.95s/it]

Loss: 0.984655499458313


Epoch 1:  74%|███████▎  | 39/53 [01:15<00:27,  1.95s/it]

Loss: 0.8285216093063354


Epoch 1:  75%|███████▌  | 40/53 [01:17<00:25,  1.95s/it]

Loss: 1.6624131202697754


Epoch 1:  77%|███████▋  | 41/53 [01:19<00:23,  1.95s/it]

Loss: 1.2017416954040527


Epoch 1:  79%|███████▉  | 42/53 [01:21<00:21,  1.95s/it]

Loss: 0.670602560043335


Epoch 1:  81%|████████  | 43/53 [01:23<00:19,  1.95s/it]

Loss: 1.0510693788528442


Epoch 1:  83%|████████▎ | 44/53 [01:25<00:17,  1.95s/it]

Loss: 1.0215543508529663


Epoch 1:  85%|████████▍ | 45/53 [01:27<00:15,  1.95s/it]

Loss: 0.9483853578567505


Epoch 1:  87%|████████▋ | 46/53 [01:29<00:13,  1.95s/it]

Loss: 1.747252345085144


Epoch 1:  89%|████████▊ | 47/53 [01:31<00:11,  1.95s/it]

Loss: 0.5723990797996521


Epoch 1:  91%|█████████ | 48/53 [01:33<00:09,  1.95s/it]

Loss: 0.2663899064064026


Epoch 1:  92%|█████████▏| 49/53 [01:35<00:07,  1.95s/it]

Loss: 1.0735195875167847


Epoch 1:  94%|█████████▍| 50/53 [01:37<00:05,  1.95s/it]

Loss: 0.8216530084609985


Epoch 1:  96%|█████████▌| 51/53 [01:39<00:03,  1.95s/it]

Loss: 1.2661662101745605


Epoch 1:  98%|█████████▊| 52/53 [01:41<00:01,  1.95s/it]

Loss: 0.6819205284118652


Epoch 1: 100%|██████████| 53/53 [01:42<00:00,  1.94s/it]

Loss: 1.4041553735733032





Epoch 1 Validation Accuracy: 0.7931034482758621, F1-macro: 0.6200534759358289


Epoch 2:   2%|▏         | 1/53 [00:01<01:41,  1.95s/it]

Loss: 0.30223533511161804


Epoch 2:   4%|▍         | 2/53 [00:03<01:39,  1.95s/it]

Loss: 0.6611200571060181


Epoch 2:   6%|▌         | 3/53 [00:05<01:37,  1.95s/it]

Loss: 1.148531198501587


Epoch 2:   8%|▊         | 4/53 [00:07<01:35,  1.95s/it]

Loss: 0.7252760529518127


Epoch 2:   9%|▉         | 5/53 [00:09<01:33,  1.95s/it]

Loss: 1.0364506244659424


Epoch 2:  11%|█▏        | 6/53 [00:11<01:31,  1.95s/it]

Loss: 1.370017170906067


Epoch 2:  13%|█▎        | 7/53 [00:13<01:29,  1.95s/it]

Loss: 0.536587119102478


Epoch 2:  15%|█▌        | 8/53 [00:15<01:27,  1.95s/it]

Loss: 1.53660249710083


Epoch 2:  17%|█▋        | 9/53 [00:17<01:25,  1.95s/it]

Loss: 0.22739411890506744


Epoch 2:  19%|█▉        | 10/53 [00:19<01:23,  1.95s/it]

Loss: 0.29119980335235596


Epoch 2:  21%|██        | 11/53 [00:21<01:21,  1.94s/it]

Loss: 0.49349138140678406


Epoch 2:  23%|██▎       | 12/53 [00:23<01:19,  1.94s/it]

Loss: 0.6111910939216614


Epoch 2:  25%|██▍       | 13/53 [00:25<01:17,  1.94s/it]

Loss: 0.526952862739563


Epoch 2:  26%|██▋       | 14/53 [00:27<01:15,  1.94s/it]

Loss: 0.6035200953483582


Epoch 2:  28%|██▊       | 15/53 [00:29<01:13,  1.94s/it]

Loss: 1.118587613105774


Epoch 2:  30%|███       | 16/53 [00:31<01:11,  1.94s/it]

Loss: 0.3004453480243683


Epoch 2:  32%|███▏      | 17/53 [00:33<01:09,  1.94s/it]

Loss: 0.5804165601730347


Epoch 2:  34%|███▍      | 18/53 [00:35<01:08,  1.94s/it]

Loss: 0.3264563977718353


Epoch 2:  36%|███▌      | 19/53 [00:36<01:06,  1.94s/it]

Loss: 0.7874266505241394


Epoch 2:  38%|███▊      | 20/53 [00:38<01:04,  1.94s/it]

Loss: 0.7175108194351196


Epoch 2:  40%|███▉      | 21/53 [00:40<01:02,  1.94s/it]

Loss: 0.31747403740882874


Epoch 2:  42%|████▏     | 22/53 [00:42<01:00,  1.94s/it]

Loss: 0.5292078256607056


Epoch 2:  43%|████▎     | 23/53 [00:44<00:58,  1.94s/it]

Loss: 0.8170437812805176


Epoch 2:  45%|████▌     | 24/53 [00:46<00:56,  1.94s/it]

Loss: 0.5807532668113708


Epoch 2:  47%|████▋     | 25/53 [00:48<00:54,  1.94s/it]

Loss: 0.5503995418548584


Epoch 2:  49%|████▉     | 26/53 [00:50<00:52,  1.94s/it]

Loss: 0.9228861927986145


Epoch 2:  51%|█████     | 27/53 [00:52<00:50,  1.94s/it]

Loss: 0.5174089670181274


Epoch 2:  53%|█████▎    | 28/53 [00:54<00:48,  1.94s/it]

Loss: 0.6098700761795044


Epoch 2:  55%|█████▍    | 29/53 [00:56<00:46,  1.94s/it]

Loss: 0.5261983275413513


Epoch 2:  57%|█████▋    | 30/53 [00:58<00:44,  1.94s/it]

Loss: 0.6934105157852173


Epoch 2:  58%|█████▊    | 31/53 [01:00<00:42,  1.94s/it]

Loss: 0.5204009413719177


Epoch 2:  60%|██████    | 32/53 [01:02<00:40,  1.95s/it]

Loss: 0.6805012226104736


Epoch 2:  62%|██████▏   | 33/53 [01:04<00:38,  1.95s/it]

Loss: 0.41519320011138916


Epoch 2:  64%|██████▍   | 34/53 [01:06<00:36,  1.95s/it]

Loss: 0.7291510105133057


Epoch 2:  66%|██████▌   | 35/53 [01:08<00:35,  1.95s/it]

Loss: 0.9097462892532349


Epoch 2:  68%|██████▊   | 36/53 [01:10<00:33,  1.95s/it]

Loss: 0.486516535282135


Epoch 2:  70%|██████▉   | 37/53 [01:11<00:31,  1.95s/it]

Loss: 1.4208621978759766


Epoch 2:  72%|███████▏  | 38/53 [01:13<00:29,  1.95s/it]

Loss: 1.76094388961792


Epoch 2:  74%|███████▎  | 39/53 [01:15<00:27,  1.95s/it]

Loss: 0.5218466520309448


Epoch 2:  75%|███████▌  | 40/53 [01:17<00:25,  1.95s/it]

Loss: 1.8085124492645264


Epoch 2:  77%|███████▋  | 41/53 [01:19<00:23,  1.95s/it]

Loss: 0.6392173171043396


Epoch 2:  79%|███████▉  | 42/53 [01:21<00:21,  1.95s/it]

Loss: 1.4004607200622559


Epoch 2:  81%|████████  | 43/53 [01:23<00:19,  1.95s/it]

Loss: 1.802977442741394


Epoch 2:  83%|████████▎ | 44/53 [01:25<00:17,  1.95s/it]

Loss: 1.2389652729034424


Epoch 2:  85%|████████▍ | 45/53 [01:27<00:15,  1.95s/it]

Loss: 2.0703110694885254


Epoch 2:  87%|████████▋ | 46/53 [01:29<00:13,  1.95s/it]

Loss: 1.0832353830337524


Epoch 2:  89%|████████▊ | 47/53 [01:31<00:11,  1.95s/it]

Loss: 0.6146660447120667


Epoch 2:  91%|█████████ | 48/53 [01:33<00:09,  1.95s/it]

Loss: 0.2925443947315216


Epoch 2:  92%|█████████▏| 49/53 [01:35<00:07,  1.94s/it]

Loss: 0.5622729063034058


Epoch 2:  94%|█████████▍| 50/53 [01:37<00:05,  1.94s/it]

Loss: 1.2181687355041504


Epoch 2:  96%|█████████▌| 51/53 [01:39<00:03,  1.94s/it]

Loss: 0.7571991682052612


Epoch 2:  98%|█████████▊| 52/53 [01:41<00:01,  1.94s/it]

Loss: 0.555982232093811


Epoch 2: 100%|██████████| 53/53 [01:42<00:00,  1.94s/it]

Loss: 0.5019423961639404





Epoch 2 Validation Accuracy: 0.8620689655172413, F1-macro: 0.71


Epoch 3:   2%|▏         | 1/53 [00:01<01:40,  1.93s/it]

Loss: 0.11061159521341324


Epoch 3:   4%|▍         | 2/53 [00:03<01:38,  1.94s/it]

Loss: 0.6670234203338623


Epoch 3:   6%|▌         | 3/53 [00:05<01:37,  1.94s/it]

Loss: 0.2903062701225281


Epoch 3:   8%|▊         | 4/53 [00:07<01:35,  1.94s/it]

Loss: 0.5374547243118286


Epoch 3:   9%|▉         | 5/53 [00:09<01:33,  1.94s/it]

Loss: 0.678351104259491


Epoch 3:  11%|█▏        | 6/53 [00:11<01:31,  1.94s/it]

Loss: 0.6934041380882263


Epoch 3:  13%|█▎        | 7/53 [00:13<01:29,  1.94s/it]

Loss: 0.6118756532669067


Epoch 3:  15%|█▌        | 8/53 [00:15<01:27,  1.94s/it]

Loss: 0.4688050150871277


Epoch 3:  17%|█▋        | 9/53 [00:17<01:25,  1.94s/it]

Loss: 0.3165283203125


Epoch 3:  19%|█▉        | 10/53 [00:19<01:23,  1.94s/it]

Loss: 0.7091320753097534


Epoch 3:  21%|██        | 11/53 [00:21<01:21,  1.94s/it]

Loss: 0.356947660446167


Epoch 3:  23%|██▎       | 12/53 [00:23<01:19,  1.94s/it]

Loss: 0.5778224468231201


Epoch 3:  25%|██▍       | 13/53 [00:25<01:17,  1.95s/it]

Loss: 0.5520329475402832


Epoch 3:  26%|██▋       | 14/53 [00:27<01:15,  1.95s/it]

Loss: 0.5388452410697937


Epoch 3:  28%|██▊       | 15/53 [00:29<01:13,  1.95s/it]

Loss: 0.36450862884521484


Epoch 3:  30%|███       | 16/53 [00:31<01:12,  1.95s/it]

Loss: 0.47091156244277954


Epoch 3:  32%|███▏      | 17/53 [00:33<01:10,  1.95s/it]

Loss: 0.2668417692184448


Epoch 3:  34%|███▍      | 18/53 [00:34<01:08,  1.95s/it]

Loss: 0.5949985980987549


Epoch 3:  36%|███▌      | 19/53 [00:36<01:06,  1.95s/it]

Loss: 0.4036062955856323


Epoch 3:  38%|███▊      | 20/53 [00:38<01:04,  1.95s/it]

Loss: 0.6349580883979797


Epoch 3:  40%|███▉      | 21/53 [00:40<01:02,  1.95s/it]

Loss: 0.5716361999511719


Epoch 3:  42%|████▏     | 22/53 [00:42<01:00,  1.95s/it]

Loss: 0.5220049619674683


Epoch 3:  43%|████▎     | 23/53 [00:44<00:58,  1.95s/it]

Loss: 0.5333953499794006


Epoch 3:  45%|████▌     | 24/53 [00:46<00:56,  1.95s/it]

Loss: 0.39066392183303833


Epoch 3:  47%|████▋     | 25/53 [00:48<00:54,  1.95s/it]

Loss: 1.0719618797302246


Epoch 3:  49%|████▉     | 26/53 [00:50<00:52,  1.94s/it]

Loss: 0.3865888714790344


Epoch 3:  51%|█████     | 27/53 [00:52<00:50,  1.94s/it]

Loss: 0.4265316128730774


Epoch 3:  53%|█████▎    | 28/53 [00:54<00:48,  1.94s/it]

Loss: 0.23581072688102722


Epoch 3:  55%|█████▍    | 29/53 [00:56<00:46,  1.94s/it]

Loss: 0.8682464361190796


Epoch 3:  57%|█████▋    | 30/53 [00:58<00:44,  1.94s/it]

Loss: 0.4713776409626007


Epoch 3:  58%|█████▊    | 31/53 [01:00<00:42,  1.94s/it]

Loss: 0.43592679500579834


Epoch 3:  60%|██████    | 32/53 [01:02<00:40,  1.94s/it]

Loss: 0.3894451856613159


Epoch 3:  62%|██████▏   | 33/53 [01:04<00:38,  1.94s/it]

Loss: 0.7564976215362549


Epoch 3:  64%|██████▍   | 34/53 [01:06<00:36,  1.94s/it]

Loss: 0.36996251344680786


Epoch 3:  66%|██████▌   | 35/53 [01:08<00:34,  1.94s/it]

Loss: 0.7242215871810913


Epoch 3:  68%|██████▊   | 36/53 [01:10<00:33,  1.94s/it]

Loss: 0.32329821586608887


Epoch 3:  70%|██████▉   | 37/53 [01:11<00:31,  1.94s/it]

Loss: 0.46144312620162964


Epoch 3:  72%|███████▏  | 38/53 [01:13<00:29,  1.94s/it]

Loss: 0.5765817165374756


Epoch 3:  74%|███████▎  | 39/53 [01:15<00:27,  1.94s/it]

Loss: 0.601689338684082


Epoch 3:  75%|███████▌  | 40/53 [01:17<00:25,  1.94s/it]

Loss: 0.8334507942199707


Epoch 3:  77%|███████▋  | 41/53 [01:19<00:23,  1.94s/it]

Loss: 0.2740817368030548


Epoch 3:  79%|███████▉  | 42/53 [01:21<00:21,  1.94s/it]

Loss: 0.46479493379592896


Epoch 3:  81%|████████  | 43/53 [01:23<00:19,  1.94s/it]

Loss: 0.4205693006515503


Epoch 3:  83%|████████▎ | 44/53 [01:25<00:17,  1.94s/it]

Loss: 0.394564151763916


Epoch 3:  85%|████████▍ | 45/53 [01:27<00:15,  1.94s/it]

Loss: 0.47213226556777954


Epoch 3:  87%|████████▋ | 46/53 [01:29<00:13,  1.94s/it]

Loss: 0.26390397548675537


Epoch 3:  89%|████████▊ | 47/53 [01:31<00:11,  1.94s/it]

Loss: 0.24864330887794495


Epoch 3:  91%|█████████ | 48/53 [01:33<00:09,  1.94s/it]

Loss: 0.4001489281654358


Epoch 3:  92%|█████████▏| 49/53 [01:35<00:07,  1.94s/it]

Loss: 0.5766547918319702


Epoch 3:  94%|█████████▍| 50/53 [01:37<00:05,  1.94s/it]

Loss: 0.5325489044189453


Epoch 3:  96%|█████████▌| 51/53 [01:39<00:03,  1.94s/it]

Loss: 0.2576374411582947


Epoch 3:  98%|█████████▊| 52/53 [01:41<00:01,  1.94s/it]

Loss: 0.3465653657913208


Epoch 3: 100%|██████████| 53/53 [01:42<00:00,  1.93s/it]

Loss: 0.677928626537323





Epoch 3 Validation Accuracy: 0.8522167487684729, F1-macro: 0.5645022883295194


Epoch 4:   2%|▏         | 1/53 [00:01<01:40,  1.93s/it]

Loss: 0.6756558418273926


Epoch 4:   4%|▍         | 2/53 [00:03<01:38,  1.93s/it]

Loss: 0.4744998812675476


Epoch 4:   6%|▌         | 3/53 [00:05<01:36,  1.94s/it]

Loss: 0.6732824444770813


Epoch 4:   8%|▊         | 4/53 [00:07<01:34,  1.94s/it]

Loss: 0.6192629933357239


Epoch 4:   9%|▉         | 5/53 [00:09<01:33,  1.94s/it]

Loss: 0.1944095641374588


Epoch 4:  11%|█▏        | 6/53 [00:11<01:31,  1.94s/it]

Loss: 0.4254651665687561


Epoch 4:  13%|█▎        | 7/53 [00:13<01:29,  1.94s/it]

Loss: 0.6410691738128662


Epoch 4:  15%|█▌        | 8/53 [00:15<01:27,  1.94s/it]

Loss: 0.5583268404006958


Epoch 4:  17%|█▋        | 9/53 [00:17<01:25,  1.94s/it]

Loss: 0.5471900701522827


Epoch 4:  19%|█▉        | 10/53 [00:19<01:23,  1.94s/it]

Loss: 0.5751073360443115


Epoch 4:  21%|██        | 11/53 [00:21<01:21,  1.94s/it]

Loss: 0.46835994720458984


Epoch 4:  23%|██▎       | 12/53 [00:23<01:19,  1.94s/it]

Loss: 0.7842063307762146


Epoch 4:  25%|██▍       | 13/53 [00:25<01:17,  1.94s/it]

Loss: 0.4520557224750519


Epoch 4:  26%|██▋       | 14/53 [00:27<01:15,  1.94s/it]

Loss: 0.322628378868103


Epoch 4:  28%|██▊       | 15/53 [00:29<01:13,  1.94s/it]

Loss: 0.5143697261810303


Epoch 4:  30%|███       | 16/53 [00:31<01:11,  1.94s/it]

Loss: 0.35760414600372314


Epoch 4:  32%|███▏      | 17/53 [00:33<01:09,  1.94s/it]

Loss: 0.21240803599357605


Epoch 4:  34%|███▍      | 18/53 [00:34<01:08,  1.94s/it]

Loss: 0.557580828666687


Epoch 4:  36%|███▌      | 19/53 [00:36<01:06,  1.94s/it]

Loss: 0.9456014633178711


Epoch 4:  38%|███▊      | 20/53 [00:38<01:04,  1.94s/it]

Loss: 0.46030473709106445


Epoch 4:  40%|███▉      | 21/53 [00:40<01:02,  1.94s/it]

Loss: 0.8514543771743774


Epoch 4:  42%|████▏     | 22/53 [00:42<01:00,  1.94s/it]

Loss: 0.41708141565322876


Epoch 4:  43%|████▎     | 23/53 [00:44<00:58,  1.94s/it]

Loss: 0.6612405180931091


Epoch 4:  45%|████▌     | 24/53 [00:46<00:56,  1.94s/it]

Loss: 0.5547227263450623


Epoch 4:  47%|████▋     | 25/53 [00:48<00:54,  1.94s/it]

Loss: 0.45260074734687805


Epoch 4:  49%|████▉     | 26/53 [00:50<00:52,  1.94s/it]

Loss: 0.41452187299728394


Epoch 4:  51%|█████     | 27/53 [00:52<00:50,  1.95s/it]

Loss: 0.15512359142303467


Epoch 4:  53%|█████▎    | 28/53 [00:54<00:48,  1.95s/it]

Loss: 0.96002197265625


Epoch 4:  55%|█████▍    | 29/53 [00:56<00:46,  1.95s/it]

Loss: 0.5521188974380493


Epoch 4:  57%|█████▋    | 30/53 [00:58<00:44,  1.95s/it]

Loss: 0.8649682998657227


Epoch 4:  58%|█████▊    | 31/53 [01:00<00:42,  1.95s/it]

Loss: 0.34169334173202515


Epoch 4:  60%|██████    | 32/53 [01:02<00:40,  1.95s/it]

Loss: 0.721293568611145


Epoch 4:  62%|██████▏   | 33/53 [01:04<00:38,  1.95s/it]

Loss: 0.3737154006958008


Epoch 4:  64%|██████▍   | 34/53 [01:06<00:36,  1.95s/it]

Loss: 0.8386316895484924


Epoch 4:  66%|██████▌   | 35/53 [01:08<00:35,  1.95s/it]

Loss: 0.34698837995529175


Epoch 4:  68%|██████▊   | 36/53 [01:09<00:33,  1.95s/it]

Loss: 0.2253051996231079


Epoch 4:  70%|██████▉   | 37/53 [01:11<00:31,  1.95s/it]

Loss: 1.5706367492675781


Epoch 4:  72%|███████▏  | 38/53 [01:13<00:29,  1.95s/it]

Loss: 1.3592311143875122


Epoch 4:  74%|███████▎  | 39/53 [01:15<00:27,  1.95s/it]

Loss: 0.8339100480079651


Epoch 4:  75%|███████▌  | 40/53 [01:17<00:25,  1.94s/it]

Loss: 0.32960203289985657


Epoch 4:  77%|███████▋  | 41/53 [01:19<00:23,  1.94s/it]

Loss: 1.5872313976287842


Epoch 4:  79%|███████▉  | 42/53 [01:21<00:21,  1.94s/it]

Loss: 1.2975870370864868


Epoch 4:  81%|████████  | 43/53 [01:23<00:19,  1.94s/it]

Loss: 0.4613449275493622


Epoch 4:  83%|████████▎ | 44/53 [01:25<00:17,  1.94s/it]

Loss: 1.724965214729309


Epoch 4:  85%|████████▍ | 45/53 [01:27<00:15,  1.94s/it]

Loss: 1.8903868198394775


Epoch 4:  87%|████████▋ | 46/53 [01:29<00:13,  1.94s/it]

Loss: 2.9957475662231445


Epoch 4:  89%|████████▊ | 47/53 [01:31<00:11,  1.94s/it]

Loss: 2.445826292037964


Epoch 4:  91%|█████████ | 48/53 [01:33<00:09,  1.94s/it]

Loss: 3.1676888465881348


Epoch 4:  92%|█████████▏| 49/53 [01:35<00:07,  1.94s/it]

Loss: 2.3206429481506348


Epoch 4:  94%|█████████▍| 50/53 [01:37<00:05,  1.94s/it]

Loss: 0.5899895429611206


Epoch 4:  96%|█████████▌| 51/53 [01:39<00:03,  1.95s/it]

Loss: 1.6216622591018677


Epoch 4:  98%|█████████▊| 52/53 [01:41<00:01,  1.95s/it]

Loss: 2.613618850708008


Epoch 4: 100%|██████████| 53/53 [01:42<00:00,  1.93s/it]

Loss: 0.6625992059707642





Epoch 4 Validation Accuracy: 0.8472906403940886, F1-macro: 0.45866666666666667


Epoch 5:   2%|▏         | 1/53 [00:01<01:41,  1.95s/it]

Loss: 0.9272546172142029


Epoch 5:   4%|▍         | 2/53 [00:03<01:39,  1.95s/it]

Loss: 1.187955379486084


Epoch 5:   6%|▌         | 3/53 [00:05<01:37,  1.94s/it]

Loss: 1.172562837600708


Epoch 5:   8%|▊         | 4/53 [00:07<01:35,  1.95s/it]

Loss: 2.4508461952209473


Epoch 5:   9%|▉         | 5/53 [00:09<01:33,  1.95s/it]

Loss: 1.1145477294921875


Epoch 5:  11%|█▏        | 6/53 [00:11<01:31,  1.95s/it]

Loss: 1.9407408237457275


Epoch 5:  13%|█▎        | 7/53 [00:13<01:29,  1.95s/it]

Loss: 1.6778247356414795


Epoch 5:  15%|█▌        | 8/53 [00:15<01:27,  1.95s/it]

Loss: 1.2870224714279175


Epoch 5:  17%|█▋        | 9/53 [00:17<01:25,  1.95s/it]

Loss: 1.8072879314422607


Epoch 5:  19%|█▉        | 10/53 [00:19<01:23,  1.95s/it]

Loss: 0.4248250722885132


Epoch 5:  21%|██        | 11/53 [00:21<01:21,  1.95s/it]

Loss: 1.0559382438659668


Epoch 5:  23%|██▎       | 12/53 [00:23<01:19,  1.95s/it]

Loss: 2.3175625801086426


Epoch 5:  25%|██▍       | 13/53 [00:25<01:17,  1.95s/it]

Loss: 0.7747032642364502


Epoch 5:  26%|██▋       | 14/53 [00:27<01:16,  1.95s/it]

Loss: 0.49490851163864136


Epoch 5:  28%|██▊       | 15/53 [00:29<01:14,  1.95s/it]

Loss: 1.8117384910583496


Epoch 5:  30%|███       | 16/53 [00:31<01:12,  1.95s/it]

Loss: 3.1901895999908447


Epoch 5:  32%|███▏      | 17/53 [00:33<01:10,  1.95s/it]

Loss: 1.8582763671875


Epoch 5:  34%|███▍      | 18/53 [00:35<01:08,  1.95s/it]

Loss: 0.4557574987411499


Epoch 5:  36%|███▌      | 19/53 [00:36<01:06,  1.95s/it]

Loss: 0.5387909412384033


Epoch 5:  38%|███▊      | 20/53 [00:38<01:04,  1.95s/it]

Loss: 1.2089462280273438


Epoch 5:  40%|███▉      | 21/53 [00:40<01:02,  1.95s/it]

Loss: 2.0024967193603516


Epoch 5:  42%|████▏     | 22/53 [00:42<01:00,  1.95s/it]

Loss: 0.7280158996582031


Epoch 5:  43%|████▎     | 23/53 [00:44<00:58,  1.94s/it]

Loss: 1.1326113939285278


Epoch 5:  45%|████▌     | 24/53 [00:46<00:56,  1.94s/it]

Loss: 0.4884161651134491


Epoch 5:  47%|████▋     | 25/53 [00:48<00:54,  1.94s/it]

Loss: 0.8874096274375916


Epoch 5:  49%|████▉     | 26/53 [00:50<00:52,  1.94s/it]

Loss: 0.32849007844924927


Epoch 5:  51%|█████     | 27/53 [00:52<00:50,  1.94s/it]

Loss: 0.4484093487262726


Epoch 5:  53%|█████▎    | 28/53 [00:54<00:48,  1.95s/it]

Loss: 0.345869779586792


Epoch 5:  55%|█████▍    | 29/53 [00:56<00:46,  1.94s/it]

Loss: 0.6080652475357056


Epoch 5:  57%|█████▋    | 30/53 [00:58<00:44,  1.94s/it]

Loss: 0.692500114440918


Epoch 5:  58%|█████▊    | 31/53 [01:00<00:42,  1.94s/it]

Loss: 0.4866640567779541


Epoch 5:  60%|██████    | 32/53 [01:02<00:40,  1.94s/it]

Loss: 0.4072389304637909


Epoch 5:  62%|██████▏   | 33/53 [01:04<00:38,  1.94s/it]

Loss: 0.3822396397590637


Epoch 5:  64%|██████▍   | 34/53 [01:06<00:36,  1.94s/it]

Loss: 0.5406292676925659


Epoch 5:  66%|██████▌   | 35/53 [01:08<00:34,  1.94s/it]

Loss: 0.2914027273654938


Epoch 5:  68%|██████▊   | 36/53 [01:10<00:33,  1.94s/it]

Loss: 0.5752647519111633


Epoch 5:  70%|██████▉   | 37/53 [01:11<00:31,  1.94s/it]

Loss: 0.6365193128585815


Epoch 5:  72%|███████▏  | 38/53 [01:13<00:29,  1.94s/it]

Loss: 0.6787089705467224


Epoch 5:  74%|███████▎  | 39/53 [01:15<00:27,  1.94s/it]

Loss: 0.414513498544693


Epoch 5:  75%|███████▌  | 40/53 [01:17<00:25,  1.95s/it]

Loss: 0.4363047778606415


Epoch 5:  77%|███████▋  | 41/53 [01:19<00:23,  1.95s/it]

Loss: 0.8004828691482544


Epoch 5:  79%|███████▉  | 42/53 [01:21<00:21,  1.95s/it]

Loss: 0.7206594347953796


Epoch 5:  81%|████████  | 43/53 [01:23<00:19,  1.95s/it]

Loss: 0.4522044062614441


Epoch 5:  83%|████████▎ | 44/53 [01:25<00:17,  1.95s/it]

Loss: 0.1499912440776825


Epoch 5:  85%|████████▍ | 45/53 [01:27<00:15,  1.95s/it]

Loss: 0.9455636739730835


Epoch 5:  87%|████████▋ | 46/53 [01:29<00:13,  1.95s/it]

Loss: 0.4441559314727783


Epoch 5:  89%|████████▊ | 47/53 [01:31<00:11,  1.95s/it]

Loss: 0.48853904008865356


Epoch 5:  91%|█████████ | 48/53 [01:33<00:09,  1.95s/it]

Loss: 0.6960932016372681


Epoch 5:  92%|█████████▏| 49/53 [01:35<00:07,  1.95s/it]

Loss: 0.8125278949737549


Epoch 5:  94%|█████████▍| 50/53 [01:37<00:05,  1.95s/it]

Loss: 0.1267203837633133


Epoch 5:  96%|█████████▌| 51/53 [01:39<00:03,  1.94s/it]

Loss: 0.41377079486846924


Epoch 5:  98%|█████████▊| 52/53 [01:41<00:01,  1.95s/it]

Loss: 0.5375315546989441


Epoch 5: 100%|██████████| 53/53 [01:42<00:00,  1.94s/it]

Loss: 0.4041426181793213





Epoch 5 Validation Accuracy: 0.8423645320197044, F1-macro: 0.635056179775281


Epoch 6:   2%|▏         | 1/53 [00:01<01:40,  1.93s/it]

Loss: 0.2450515329837799


Epoch 6:   4%|▍         | 2/53 [00:03<01:38,  1.94s/it]

Loss: 0.26483771204948425


Epoch 6:   6%|▌         | 3/53 [00:05<01:36,  1.94s/it]

Loss: 0.6091230511665344


Epoch 6:   8%|▊         | 4/53 [00:07<01:35,  1.94s/it]

Loss: 0.3842184841632843


Epoch 6:   9%|▉         | 5/53 [00:09<01:33,  1.94s/it]

Loss: 0.08161388337612152


Epoch 6:  11%|█▏        | 6/53 [00:11<01:31,  1.94s/it]

Loss: 0.5465922951698303


Epoch 6:  13%|█▎        | 7/53 [00:13<01:29,  1.94s/it]

Loss: 0.44401654601097107


Epoch 6:  15%|█▌        | 8/53 [00:15<01:27,  1.94s/it]

Loss: 0.6898338198661804


Epoch 6:  17%|█▋        | 9/53 [00:17<01:25,  1.94s/it]

Loss: 0.26547810435295105


Epoch 6:  19%|█▉        | 10/53 [00:19<01:23,  1.94s/it]

Loss: 0.4372798204421997


Epoch 6:  21%|██        | 11/53 [00:21<01:21,  1.94s/it]

Loss: 1.027515172958374


Epoch 6:  23%|██▎       | 12/53 [00:23<01:19,  1.94s/it]

Loss: 0.8471168279647827


Epoch 6:  25%|██▍       | 13/53 [00:25<01:17,  1.94s/it]

Loss: 0.39819931983947754


Epoch 6:  26%|██▋       | 14/53 [00:27<01:15,  1.94s/it]

Loss: 0.6799513101577759


Epoch 6:  28%|██▊       | 15/53 [00:29<01:13,  1.94s/it]

Loss: 0.37424129247665405


Epoch 6:  30%|███       | 16/53 [00:31<01:11,  1.94s/it]

Loss: 0.5968586206436157


Epoch 6:  32%|███▏      | 17/53 [00:33<01:10,  1.94s/it]

Loss: 0.47189587354660034


Epoch 6:  34%|███▍      | 18/53 [00:34<01:08,  1.94s/it]

Loss: 0.2637573182582855


Epoch 6:  36%|███▌      | 19/53 [00:36<01:06,  1.94s/it]

Loss: 0.5059250593185425


Epoch 6:  38%|███▊      | 20/53 [00:38<01:04,  1.95s/it]

Loss: 0.4040524959564209


Epoch 6:  40%|███▉      | 21/53 [00:40<01:02,  1.95s/it]

Loss: 0.15485559403896332


Epoch 6:  42%|████▏     | 22/53 [00:42<01:00,  1.95s/it]

Loss: 0.2088593989610672


Epoch 6:  43%|████▎     | 23/53 [00:44<00:58,  1.95s/it]

Loss: 0.32007884979248047


Epoch 6:  45%|████▌     | 24/53 [00:46<00:56,  1.95s/it]

Loss: 0.3000057339668274


Epoch 6:  47%|████▋     | 25/53 [00:48<00:54,  1.95s/it]

Loss: 0.38419193029403687


Epoch 6:  49%|████▉     | 26/53 [00:50<00:52,  1.95s/it]

Loss: 0.38985562324523926


Epoch 6:  51%|█████     | 27/53 [00:52<00:50,  1.95s/it]

Loss: 0.2051706165075302


Epoch 6:  53%|█████▎    | 28/53 [00:54<00:48,  1.95s/it]

Loss: 0.3206445574760437


Epoch 6:  55%|█████▍    | 29/53 [00:56<00:46,  1.95s/it]

Loss: 0.09703919291496277


Epoch 6:  57%|█████▋    | 30/53 [00:58<00:44,  1.95s/it]

Loss: 0.5159028768539429


Epoch 6:  58%|█████▊    | 31/53 [01:00<00:42,  1.94s/it]

Loss: 0.7345021963119507


Epoch 6:  60%|██████    | 32/53 [01:02<00:40,  1.94s/it]

Loss: 0.4984329640865326


Epoch 6:  62%|██████▏   | 33/53 [01:04<00:38,  1.94s/it]

Loss: 0.7382375597953796


Epoch 6:  64%|██████▍   | 34/53 [01:06<00:36,  1.94s/it]

Loss: 0.34926897287368774


Epoch 6:  66%|██████▌   | 35/53 [01:08<00:34,  1.94s/it]

Loss: 0.8904365301132202


Epoch 6:  68%|██████▊   | 36/53 [01:10<00:33,  1.94s/it]

Loss: 0.6594096422195435


Epoch 6:  70%|██████▉   | 37/53 [01:11<00:31,  1.94s/it]

Loss: 0.30460917949676514


Epoch 6:  72%|███████▏  | 38/53 [01:13<00:29,  1.94s/it]

Loss: 0.4467218220233917


Epoch 6:  74%|███████▎  | 39/53 [01:15<00:27,  1.94s/it]

Loss: 0.527725338935852


Epoch 6:  75%|███████▌  | 40/53 [01:17<00:25,  1.94s/it]

Loss: 1.0470927953720093


Epoch 6:  77%|███████▋  | 41/53 [01:19<00:23,  1.94s/it]

Loss: 0.3011932373046875


Epoch 6:  79%|███████▉  | 42/53 [01:21<00:21,  1.94s/it]

Loss: 0.19365572929382324


Epoch 6:  81%|████████  | 43/53 [01:23<00:19,  1.94s/it]

Loss: 1.1234099864959717


Epoch 6:  83%|████████▎ | 44/53 [01:25<00:17,  1.94s/it]

Loss: 0.35224390029907227


Epoch 6:  85%|████████▍ | 45/53 [01:27<00:15,  1.94s/it]

Loss: 0.40817874670028687


Epoch 6:  87%|████████▋ | 46/53 [01:29<00:13,  1.95s/it]

Loss: 0.7440739274024963


Epoch 6:  89%|████████▊ | 47/53 [01:31<00:11,  1.94s/it]

Loss: 0.5619679093360901


Epoch 6:  91%|█████████ | 48/53 [01:33<00:09,  1.94s/it]

Loss: 0.4423750936985016


Epoch 6:  92%|█████████▏| 49/53 [01:35<00:07,  1.95s/it]

Loss: 0.7337962985038757


Epoch 6:  94%|█████████▍| 50/53 [01:37<00:05,  1.95s/it]

Loss: 0.5585544109344482


Epoch 6:  96%|█████████▌| 51/53 [01:39<00:03,  1.95s/it]

Loss: 0.4649938941001892


Epoch 6:  98%|█████████▊| 52/53 [01:41<00:01,  1.95s/it]

Loss: 0.11467303335666656


Epoch 6: 100%|██████████| 53/53 [01:42<00:00,  1.94s/it]

Loss: 0.6723465919494629





Epoch 6 Validation Accuracy: 0.7881773399014779, F1-macro: 0.7001271084544298


Epoch 7:   2%|▏         | 1/53 [00:01<01:41,  1.94s/it]

Loss: 0.25503087043762207


Epoch 7:   4%|▍         | 2/53 [00:03<01:39,  1.95s/it]

Loss: 0.44640225172042847


Epoch 7:   6%|▌         | 3/53 [00:05<01:37,  1.95s/it]

Loss: 0.13394978642463684


Epoch 7:   8%|▊         | 4/53 [00:07<01:35,  1.95s/it]

Loss: 0.0715661272406578


Epoch 7:   9%|▉         | 5/53 [00:09<01:33,  1.95s/it]

Loss: 0.40368586778640747


Epoch 7:  11%|█▏        | 6/53 [00:11<01:31,  1.95s/it]

Loss: 0.46295303106307983


Epoch 7:  13%|█▎        | 7/53 [00:13<01:29,  1.95s/it]

Loss: 0.38855570554733276


Epoch 7:  15%|█▌        | 8/53 [00:15<01:27,  1.95s/it]

Loss: 0.23217517137527466


Epoch 7:  17%|█▋        | 9/53 [00:17<01:25,  1.95s/it]

Loss: 1.1697903871536255


Epoch 7:  19%|█▉        | 10/53 [00:19<01:23,  1.95s/it]

Loss: 0.6137369871139526


Epoch 7:  21%|██        | 11/53 [00:21<01:21,  1.95s/it]

Loss: 1.7441233396530151


Epoch 7:  23%|██▎       | 12/53 [00:23<01:19,  1.95s/it]

Loss: 0.7355106472969055


Epoch 7:  25%|██▍       | 13/53 [00:25<01:18,  1.95s/it]

Loss: 0.7127772569656372


Epoch 7:  26%|██▋       | 14/53 [00:27<01:16,  1.95s/it]

Loss: 0.48117488622665405


Epoch 7:  28%|██▊       | 15/53 [00:29<01:14,  1.95s/it]

Loss: 0.9196212291717529


Epoch 7:  30%|███       | 16/53 [00:31<01:12,  1.95s/it]

Loss: 0.3533572852611542


Epoch 7:  32%|███▏      | 17/53 [00:33<01:10,  1.95s/it]

Loss: 0.7575163841247559


Epoch 7:  34%|███▍      | 18/53 [00:35<01:08,  1.96s/it]

Loss: 0.48542487621307373


Epoch 7:  36%|███▌      | 19/53 [00:37<01:06,  1.95s/it]

Loss: 0.5572633147239685


Epoch 7:  38%|███▊      | 20/53 [00:39<01:04,  1.95s/it]

Loss: 0.18473118543624878


Epoch 7:  40%|███▉      | 21/53 [00:40<01:02,  1.95s/it]

Loss: 0.730634868144989


Epoch 7:  42%|████▏     | 22/53 [00:42<01:00,  1.95s/it]

Loss: 0.695181131362915


Epoch 7:  43%|████▎     | 23/53 [00:44<00:58,  1.95s/it]

Loss: 0.5501725077629089


Epoch 7:  45%|████▌     | 24/53 [00:46<00:56,  1.95s/it]

Loss: 0.6796125173568726


Epoch 7:  47%|████▋     | 25/53 [00:48<00:54,  1.95s/it]

Loss: 0.6149588227272034


Epoch 7:  49%|████▉     | 26/53 [00:50<00:52,  1.95s/it]

Loss: 0.5645414590835571


Epoch 7:  51%|█████     | 27/53 [00:52<00:50,  1.95s/it]

Loss: 0.3831474184989929


Epoch 7:  53%|█████▎    | 28/53 [00:54<00:48,  1.95s/it]

Loss: 0.4175863265991211


Epoch 7:  55%|█████▍    | 29/53 [00:56<00:46,  1.95s/it]

Loss: 0.617238461971283


Epoch 7:  57%|█████▋    | 30/53 [00:58<00:44,  1.95s/it]

Loss: 0.31580325961112976


Epoch 7:  58%|█████▊    | 31/53 [01:00<00:43,  1.95s/it]

Loss: 0.7139705419540405


Epoch 7:  60%|██████    | 32/53 [01:02<00:41,  1.95s/it]

Loss: 0.21330440044403076


Epoch 7:  62%|██████▏   | 33/53 [01:04<00:39,  1.95s/it]

Loss: 0.3722612261772156


Epoch 7:  64%|██████▍   | 34/53 [01:06<00:37,  1.95s/it]

Loss: 0.5431424379348755


Epoch 7:  66%|██████▌   | 35/53 [01:08<00:35,  1.95s/it]

Loss: 0.6966017484664917


Epoch 7:  68%|██████▊   | 36/53 [01:10<00:33,  1.95s/it]

Loss: 0.3749393820762634


Epoch 7:  70%|██████▉   | 37/53 [01:12<00:31,  1.95s/it]

Loss: 0.4072568416595459


Epoch 7:  72%|███████▏  | 38/53 [01:14<00:29,  1.95s/it]

Loss: 0.5551541447639465


Epoch 7:  74%|███████▎  | 39/53 [01:16<00:27,  1.95s/it]

Loss: 0.5181354284286499


Epoch 7:  75%|███████▌  | 40/53 [01:18<00:25,  1.95s/it]

Loss: 0.33902060985565186


Epoch 7:  77%|███████▋  | 41/53 [01:20<00:23,  1.95s/it]

Loss: 0.15001703798770905


Epoch 7:  79%|███████▉  | 42/53 [01:22<00:21,  1.95s/it]

Loss: 0.6914613842964172


Epoch 7:  81%|████████  | 43/53 [01:23<00:19,  1.96s/it]

Loss: 0.7128506898880005


Epoch 7:  83%|████████▎ | 44/53 [01:25<00:17,  1.95s/it]

Loss: 0.38370972871780396


Epoch 7:  85%|████████▍ | 45/53 [01:27<00:15,  1.96s/it]

Loss: 0.5416319370269775


Epoch 7:  87%|████████▋ | 46/53 [01:29<00:13,  1.95s/it]

Loss: 0.3899966776371002


Epoch 7:  89%|████████▊ | 47/53 [01:31<00:11,  1.96s/it]

Loss: 0.6444442868232727


Epoch 7:  91%|█████████ | 48/53 [01:33<00:09,  1.96s/it]

Loss: 0.2782423198223114


Epoch 7:  92%|█████████▏| 49/53 [01:35<00:07,  1.96s/it]

Loss: 0.2343081831932068


Epoch 7:  94%|█████████▍| 50/53 [01:37<00:05,  1.96s/it]

Loss: 0.3889617621898651


Epoch 7:  96%|█████████▌| 51/53 [01:39<00:03,  1.95s/it]

Loss: 0.3204595446586609


Epoch 7:  98%|█████████▊| 52/53 [01:41<00:01,  1.95s/it]

Loss: 0.1818055659532547


Epoch 7: 100%|██████████| 53/53 [01:42<00:00,  1.94s/it]

Loss: 0.25177276134490967





Epoch 7 Validation Accuracy: 0.8669950738916257, F1-macro: 0.5778975741239892


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model9, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.7941176470588235, F1-macro: 0.48546721114580826


In [None]:
current_type = 'overgeneralizing'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 10. Should Statements

In [None]:
# Add labels
data1_1_labels = list(data1['should statements'][data1_1.index])
data2_1_labels = list(data2['should statements'][data2_1.index])
data3_1_labels = list(data3['should statements'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model10 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model10.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model10.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model10(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model10, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/59 [00:01<01:52,  1.93s/it]

Loss: 15.077157974243164


Epoch 1:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 1.3542100191116333


Epoch 1:   5%|▌         | 3/59 [00:05<01:49,  1.95s/it]

Loss: 2.3644967079162598


Epoch 1:   7%|▋         | 4/59 [00:07<01:47,  1.95s/it]

Loss: 3.2541565895080566


Epoch 1:   8%|▊         | 5/59 [00:09<01:45,  1.95s/it]

Loss: 1.412664532661438


Epoch 1:  10%|█         | 6/59 [00:11<01:43,  1.95s/it]

Loss: 1.8309557437896729


Epoch 1:  12%|█▏        | 7/59 [00:13<01:41,  1.95s/it]

Loss: 3.178715229034424


Epoch 1:  14%|█▎        | 8/59 [00:15<01:39,  1.96s/it]

Loss: 1.5599370002746582


Epoch 1:  15%|█▌        | 9/59 [00:17<01:37,  1.95s/it]

Loss: 2.6483659744262695


Epoch 1:  17%|█▋        | 10/59 [00:19<01:35,  1.95s/it]

Loss: 1.6734009981155396


Epoch 1:  19%|█▊        | 11/59 [00:21<01:33,  1.95s/it]

Loss: 0.767415463924408


Epoch 1:  20%|██        | 12/59 [00:23<01:31,  1.96s/it]

Loss: 0.6960598826408386


Epoch 1:  22%|██▏       | 13/59 [00:25<01:29,  1.95s/it]

Loss: 1.4053858518600464


Epoch 1:  24%|██▎       | 14/59 [00:27<01:27,  1.95s/it]

Loss: 2.4309189319610596


Epoch 1:  25%|██▌       | 15/59 [00:29<01:26,  1.96s/it]

Loss: 0.7555765509605408


Epoch 1:  27%|██▋       | 16/59 [00:31<01:23,  1.95s/it]

Loss: 1.2987923622131348


Epoch 1:  29%|██▉       | 17/59 [00:33<01:22,  1.95s/it]

Loss: 0.46113479137420654


Epoch 1:  31%|███       | 18/59 [00:35<01:20,  1.95s/it]

Loss: 1.0162138938903809


Epoch 1:  32%|███▏      | 19/59 [00:37<01:18,  1.95s/it]

Loss: 1.197083830833435


Epoch 1:  34%|███▍      | 20/59 [00:39<01:16,  1.95s/it]

Loss: 1.0344048738479614


Epoch 1:  36%|███▌      | 21/59 [00:41<01:14,  1.95s/it]

Loss: 1.4857330322265625


Epoch 1:  37%|███▋      | 22/59 [00:42<01:12,  1.95s/it]

Loss: 1.2526130676269531


Epoch 1:  39%|███▉      | 23/59 [00:44<01:10,  1.95s/it]

Loss: 2.1485180854797363


Epoch 1:  41%|████      | 24/59 [00:46<01:08,  1.95s/it]

Loss: 1.3944158554077148


Epoch 1:  42%|████▏     | 25/59 [00:48<01:06,  1.95s/it]

Loss: 0.4608602225780487


Epoch 1:  44%|████▍     | 26/59 [00:50<01:04,  1.95s/it]

Loss: 0.9398626089096069


Epoch 1:  46%|████▌     | 27/59 [00:52<01:02,  1.95s/it]

Loss: 0.5671420097351074


Epoch 1:  47%|████▋     | 28/59 [00:54<01:00,  1.96s/it]

Loss: 0.47753438353538513


Epoch 1:  49%|████▉     | 29/59 [00:56<00:58,  1.95s/it]

Loss: 1.206534504890442


Epoch 1:  51%|█████     | 30/59 [00:58<00:56,  1.96s/it]

Loss: 0.9203774929046631


Epoch 1:  53%|█████▎    | 31/59 [01:00<00:54,  1.95s/it]

Loss: 1.1911916732788086


Epoch 1:  54%|█████▍    | 32/59 [01:02<00:52,  1.96s/it]

Loss: 0.46041715145111084


Epoch 1:  56%|█████▌    | 33/59 [01:04<00:50,  1.96s/it]

Loss: 0.20826727151870728


Epoch 1:  58%|█████▊    | 34/59 [01:06<00:48,  1.95s/it]

Loss: 1.201711654663086


Epoch 1:  59%|█████▉    | 35/59 [01:08<00:46,  1.95s/it]

Loss: 1.1847867965698242


Epoch 1:  61%|██████    | 36/59 [01:10<00:44,  1.95s/it]

Loss: 0.997319757938385


Epoch 1:  63%|██████▎   | 37/59 [01:12<00:42,  1.95s/it]

Loss: 0.5807517766952515


Epoch 1:  64%|██████▍   | 38/59 [01:14<00:41,  1.96s/it]

Loss: 0.3121638298034668


Epoch 1:  66%|██████▌   | 39/59 [01:16<00:39,  1.96s/it]

Loss: 1.096268653869629


Epoch 1:  68%|██████▊   | 40/59 [01:18<00:37,  1.96s/it]

Loss: 0.8384091854095459


Epoch 1:  69%|██████▉   | 41/59 [01:20<00:35,  1.96s/it]

Loss: 0.46676844358444214


Epoch 1:  71%|███████   | 42/59 [01:22<00:33,  1.95s/it]

Loss: 0.5188405513763428


Epoch 1:  73%|███████▎  | 43/59 [01:24<00:31,  1.96s/it]

Loss: 0.14515244960784912


Epoch 1:  75%|███████▍  | 44/59 [01:25<00:29,  1.96s/it]

Loss: 0.23417668044567108


Epoch 1:  76%|███████▋  | 45/59 [01:27<00:27,  1.96s/it]

Loss: 0.2906077802181244


Epoch 1:  78%|███████▊  | 46/59 [01:29<00:25,  1.95s/it]

Loss: 0.45033735036849976


Epoch 1:  80%|███████▉  | 47/59 [01:31<00:23,  1.96s/it]

Loss: 1.2083275318145752


Epoch 1:  81%|████████▏ | 48/59 [01:33<00:21,  1.96s/it]

Loss: 0.4345712661743164


Epoch 1:  83%|████████▎ | 49/59 [01:35<00:19,  1.96s/it]

Loss: 1.3678178787231445


Epoch 1:  85%|████████▍ | 50/59 [01:37<00:17,  1.96s/it]

Loss: 0.3841387629508972


Epoch 1:  86%|████████▋ | 51/59 [01:39<00:15,  1.96s/it]

Loss: 0.6337598562240601


Epoch 1:  88%|████████▊ | 52/59 [01:41<00:13,  1.96s/it]

Loss: 0.5405066609382629


Epoch 1:  90%|████████▉ | 53/59 [01:43<00:11,  1.96s/it]

Loss: 0.3995380401611328


Epoch 1:  92%|█████████▏| 54/59 [01:45<00:09,  1.96s/it]

Loss: 0.7293453812599182


Epoch 1:  93%|█████████▎| 55/59 [01:47<00:07,  1.95s/it]

Loss: 0.33452725410461426


Epoch 1:  95%|█████████▍| 56/59 [01:49<00:05,  1.95s/it]

Loss: 0.20076335966587067


Epoch 1:  97%|█████████▋| 57/59 [01:51<00:03,  1.96s/it]

Loss: 0.16224700212478638


Epoch 1:  98%|█████████▊| 58/59 [01:53<00:01,  1.96s/it]

Loss: 0.09756917506456375


Epoch 1: 100%|██████████| 59/59 [01:53<00:00,  1.93s/it]

Loss: 0.006539992988109589





Epoch 1 Validation Accuracy: 0.916256157635468, F1-macro: 0.5306677546579628


Epoch 2:   2%|▏         | 1/59 [00:01<01:53,  1.95s/it]

Loss: 0.3105177879333496


Epoch 2:   3%|▎         | 2/59 [00:03<01:51,  1.95s/it]

Loss: 0.44659483432769775


Epoch 2:   5%|▌         | 3/59 [00:05<01:49,  1.96s/it]

Loss: 0.14019393920898438


Epoch 2:   7%|▋         | 4/59 [00:07<01:47,  1.96s/it]

Loss: 0.4371439218521118


Epoch 2:   8%|▊         | 5/59 [00:09<01:45,  1.96s/it]

Loss: 0.4597073197364807


Epoch 2:  10%|█         | 6/59 [00:11<01:43,  1.96s/it]

Loss: 0.7783339023590088


Epoch 2:  12%|█▏        | 7/59 [00:13<01:41,  1.96s/it]

Loss: 0.20828378200531006


Epoch 2:  14%|█▎        | 8/59 [00:15<01:39,  1.96s/it]

Loss: 0.2433129847049713


Epoch 2:  15%|█▌        | 9/59 [00:17<01:37,  1.96s/it]

Loss: 0.24065491557121277


Epoch 2:  17%|█▋        | 10/59 [00:19<01:35,  1.96s/it]

Loss: 0.442287415266037


Epoch 2:  19%|█▊        | 11/59 [00:21<01:33,  1.96s/it]

Loss: 0.17638537287712097


Epoch 2:  20%|██        | 12/59 [00:23<01:32,  1.96s/it]

Loss: 0.45479047298431396


Epoch 2:  22%|██▏       | 13/59 [00:25<01:30,  1.96s/it]

Loss: 0.15938839316368103


Epoch 2:  24%|██▎       | 14/59 [00:27<01:28,  1.96s/it]

Loss: 0.14735081791877747


Epoch 2:  25%|██▌       | 15/59 [00:29<01:26,  1.96s/it]

Loss: 0.15663085877895355


Epoch 2:  27%|██▋       | 16/59 [00:31<01:24,  1.96s/it]

Loss: 1.0573102235794067


Epoch 2:  29%|██▉       | 17/59 [00:33<01:22,  1.96s/it]

Loss: 0.23315836489200592


Epoch 2:  31%|███       | 18/59 [00:35<01:20,  1.96s/it]

Loss: 0.4109156131744385


Epoch 2:  32%|███▏      | 19/59 [00:37<01:18,  1.96s/it]

Loss: 0.05282101035118103


Epoch 2:  34%|███▍      | 20/59 [00:39<01:16,  1.96s/it]

Loss: 0.18894758820533752


Epoch 2:  36%|███▌      | 21/59 [00:41<01:14,  1.95s/it]

Loss: 0.42032742500305176


Epoch 2:  37%|███▋      | 22/59 [00:43<01:12,  1.95s/it]

Loss: 0.23858851194381714


Epoch 2:  39%|███▉      | 23/59 [00:44<01:10,  1.95s/it]

Loss: 0.1672155261039734


Epoch 2:  41%|████      | 24/59 [00:46<01:08,  1.95s/it]

Loss: 0.45345795154571533


Epoch 2:  42%|████▏     | 25/59 [00:48<01:06,  1.95s/it]

Loss: 0.9508146047592163


Epoch 2:  44%|████▍     | 26/59 [00:50<01:04,  1.95s/it]

Loss: 0.20856720209121704


Epoch 2:  46%|████▌     | 27/59 [00:52<01:02,  1.95s/it]

Loss: 0.5703271627426147


Epoch 2:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.07110921293497086


Epoch 2:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 0.4516144394874573


Epoch 2:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.5342456102371216


Epoch 2:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.2845393419265747


Epoch 2:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.7146341800689697


Epoch 2:  56%|█████▌    | 33/59 [01:04<00:50,  1.94s/it]

Loss: 0.487490713596344


Epoch 2:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.8844263553619385


Epoch 2:  59%|█████▉    | 35/59 [01:08<00:46,  1.94s/it]

Loss: 2.7226647944189608e-05


Epoch 2:  61%|██████    | 36/59 [01:10<00:44,  1.94s/it]

Loss: 0.28155550360679626


Epoch 2:  63%|██████▎   | 37/59 [01:12<00:42,  1.94s/it]

Loss: 0.9467551708221436


Epoch 2:  64%|██████▍   | 38/59 [01:14<00:40,  1.94s/it]

Loss: 0.3175065517425537


Epoch 2:  66%|██████▌   | 39/59 [01:16<00:38,  1.94s/it]

Loss: 1.336836576461792


Epoch 2:  68%|██████▊   | 40/59 [01:18<00:36,  1.94s/it]

Loss: 0.08952106535434723


Epoch 2:  69%|██████▉   | 41/59 [01:19<00:34,  1.94s/it]

Loss: 0.24584369361400604


Epoch 2:  71%|███████   | 42/59 [01:21<00:32,  1.94s/it]

Loss: 1.3270174264907837


Epoch 2:  73%|███████▎  | 43/59 [01:23<00:30,  1.94s/it]

Loss: 0.8185262680053711


Epoch 2:  75%|███████▍  | 44/59 [01:25<00:29,  1.94s/it]

Loss: 0.6579375267028809


Epoch 2:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 1.8394192457199097


Epoch 2:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 1.086275577545166


Epoch 2:  80%|███████▉  | 47/59 [01:31<00:23,  1.94s/it]

Loss: 0.3273686468601227


Epoch 2:  81%|████████▏ | 48/59 [01:33<00:21,  1.94s/it]

Loss: 1.1371660232543945


Epoch 2:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 0.3910086154937744


Epoch 2:  85%|████████▍ | 50/59 [01:37<00:17,  1.94s/it]

Loss: 0.5456113815307617


Epoch 2:  86%|████████▋ | 51/59 [01:39<00:15,  1.94s/it]

Loss: 0.27840256690979004


Epoch 2:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.00033296545734629035


Epoch 2:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 1.3869130611419678


Epoch 2:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.6319311857223511


Epoch 2:  93%|█████████▎| 55/59 [01:47<00:07,  1.94s/it]

Loss: 1.833927869796753


Epoch 2:  95%|█████████▍| 56/59 [01:49<00:05,  1.94s/it]

Loss: 1.790597677230835


Epoch 2:  97%|█████████▋| 57/59 [01:50<00:03,  1.94s/it]

Loss: 0.34280726313591003


Epoch 2:  98%|█████████▊| 58/59 [01:52<00:01,  1.94s/it]

Loss: 0.9282210469245911


Epoch 2: 100%|██████████| 59/59 [01:53<00:00,  1.92s/it]

Loss: 1.0766315460205078





Epoch 2 Validation Accuracy: 0.9211822660098522, F1-macro: 0.4794871794871795


Epoch 3:   2%|▏         | 1/59 [00:01<01:52,  1.93s/it]

Loss: 0.3977600932121277


Epoch 3:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 0.964133083820343


Epoch 3:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 1.286025047302246


Epoch 3:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 0.4287242591381073


Epoch 3:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 0.9571217894554138


Epoch 3:  10%|█         | 6/59 [00:11<01:43,  1.94s/it]

Loss: 1.3951849937438965


Epoch 3:  12%|█▏        | 7/59 [00:13<01:41,  1.94s/it]

Loss: 0.7817893028259277


Epoch 3:  14%|█▎        | 8/59 [00:15<01:39,  1.94s/it]

Loss: 1.873643159866333


Epoch 3:  15%|█▌        | 9/59 [00:17<01:37,  1.94s/it]

Loss: 0.8389343619346619


Epoch 3:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 0.5706836581230164


Epoch 3:  19%|█▊        | 11/59 [00:21<01:33,  1.94s/it]

Loss: 0.9771358966827393


Epoch 3:  20%|██        | 12/59 [00:23<01:31,  1.94s/it]

Loss: 0.5106932520866394


Epoch 3:  22%|██▏       | 13/59 [00:25<01:29,  1.94s/it]

Loss: 0.0683894157409668


Epoch 3:  24%|██▎       | 14/59 [00:27<01:27,  1.94s/it]

Loss: 0.8335192799568176


Epoch 3:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 0.7002936601638794


Epoch 3:  27%|██▋       | 16/59 [00:31<01:23,  1.94s/it]

Loss: 0.27205759286880493


Epoch 3:  29%|██▉       | 17/59 [00:33<01:21,  1.94s/it]

Loss: 1.080094575881958


Epoch 3:  31%|███       | 18/59 [00:34<01:19,  1.94s/it]

Loss: 1.5104974508285522


Epoch 3:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 0.32827335596084595


Epoch 3:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 0.4760565161705017


Epoch 3:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 0.11888635903596878


Epoch 3:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 0.581527829170227


Epoch 3:  39%|███▉      | 23/59 [00:44<01:09,  1.94s/it]

Loss: 0.15967577695846558


Epoch 3:  41%|████      | 24/59 [00:46<01:08,  1.95s/it]

Loss: 0.5214354991912842


Epoch 3:  42%|████▏     | 25/59 [00:48<01:06,  1.95s/it]

Loss: 0.359253853559494


Epoch 3:  44%|████▍     | 26/59 [00:50<01:04,  1.95s/it]

Loss: 0.43458056449890137


Epoch 3:  46%|████▌     | 27/59 [00:52<01:02,  1.95s/it]

Loss: 0.8769062161445618


Epoch 3:  47%|████▋     | 28/59 [00:54<01:00,  1.95s/it]

Loss: 1.1534457206726074


Epoch 3:  49%|████▉     | 29/59 [00:56<00:58,  1.95s/it]

Loss: 0.48993611335754395


Epoch 3:  51%|█████     | 30/59 [00:58<00:56,  1.95s/it]

Loss: 0.4563257396221161


Epoch 3:  53%|█████▎    | 31/59 [01:00<00:54,  1.95s/it]

Loss: 0.36862024664878845


Epoch 3:  54%|█████▍    | 32/59 [01:02<00:52,  1.95s/it]

Loss: 0.00887826643884182


Epoch 3:  56%|█████▌    | 33/59 [01:04<00:50,  1.95s/it]

Loss: 0.10862401872873306


Epoch 3:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.25623971223831177


Epoch 3:  59%|█████▉    | 35/59 [01:08<00:46,  1.95s/it]

Loss: 0.8417578339576721


Epoch 3:  61%|██████    | 36/59 [01:10<00:44,  1.95s/it]

Loss: 0.5536448359489441


Epoch 3:  63%|██████▎   | 37/59 [01:11<00:42,  1.95s/it]

Loss: 0.6106124520301819


Epoch 3:  64%|██████▍   | 38/59 [01:13<00:40,  1.95s/it]

Loss: 0.23334833979606628


Epoch 3:  66%|██████▌   | 39/59 [01:15<00:38,  1.94s/it]

Loss: 0.3646959066390991


Epoch 3:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.28643184900283813


Epoch 3:  69%|██████▉   | 41/59 [01:19<00:34,  1.94s/it]

Loss: 0.33535972237586975


Epoch 3:  71%|███████   | 42/59 [01:21<00:33,  1.94s/it]

Loss: 0.21762478351593018


Epoch 3:  73%|███████▎  | 43/59 [01:23<00:31,  1.94s/it]

Loss: 1.1906861066818237


Epoch 3:  75%|███████▍  | 44/59 [01:25<00:29,  1.94s/it]

Loss: 0.7033023238182068


Epoch 3:  76%|███████▋  | 45/59 [01:27<00:27,  1.95s/it]

Loss: 0.8789637684822083


Epoch 3:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.68091881275177


Epoch 3:  80%|███████▉  | 47/59 [01:31<00:23,  1.95s/it]

Loss: 1.3864994049072266


Epoch 3:  81%|████████▏ | 48/59 [01:33<00:21,  1.95s/it]

Loss: 0.38231706619262695


Epoch 3:  83%|████████▎ | 49/59 [01:35<00:19,  1.95s/it]

Loss: 0.6452957987785339


Epoch 3:  85%|████████▍ | 50/59 [01:37<00:17,  1.95s/it]

Loss: 1.5370872020721436


Epoch 3:  86%|████████▋ | 51/59 [01:39<00:15,  1.95s/it]

Loss: 0.41123369336128235


Epoch 3:  88%|████████▊ | 52/59 [01:41<00:13,  1.95s/it]

Loss: 0.5440179109573364


Epoch 3:  90%|████████▉ | 53/59 [01:43<00:11,  1.95s/it]

Loss: 0.5499943494796753


Epoch 3:  92%|█████████▏| 54/59 [01:45<00:09,  1.95s/it]

Loss: 0.4426204264163971


Epoch 3:  93%|█████████▎| 55/59 [01:47<00:07,  1.95s/it]

Loss: 0.43840041756629944


Epoch 3:  95%|█████████▍| 56/59 [01:48<00:05,  1.95s/it]

Loss: 0.6831470727920532


Epoch 3:  97%|█████████▋| 57/59 [01:50<00:03,  1.95s/it]

Loss: 0.42379331588745117


Epoch 3:  98%|█████████▊| 58/59 [01:52<00:01,  1.95s/it]

Loss: 0.5567232370376587


Epoch 3: 100%|██████████| 59/59 [01:53<00:00,  1.92s/it]

Loss: 6.304336420726031e-05





Epoch 3 Validation Accuracy: 0.9310344827586207, F1-macro: 0.48214285714285715


Epoch 4:   2%|▏         | 1/59 [00:01<01:53,  1.95s/it]

Loss: 0.2881499230861664


Epoch 4:   3%|▎         | 2/59 [00:03<01:51,  1.95s/it]

Loss: 0.59441077709198


Epoch 4:   5%|▌         | 3/59 [00:05<01:49,  1.95s/it]

Loss: 0.8258185386657715


Epoch 4:   7%|▋         | 4/59 [00:07<01:47,  1.95s/it]

Loss: 1.313844919204712


Epoch 4:   8%|▊         | 5/59 [00:09<01:45,  1.95s/it]

Loss: 1.698453426361084


Epoch 4:  10%|█         | 6/59 [00:11<01:43,  1.95s/it]

Loss: 1.2271130084991455


Epoch 4:  12%|█▏        | 7/59 [00:13<01:41,  1.95s/it]

Loss: 0.4150539040565491


Epoch 4:  14%|█▎        | 8/59 [00:15<01:39,  1.95s/it]

Loss: 0.455310583114624


Epoch 4:  15%|█▌        | 9/59 [00:17<01:37,  1.95s/it]

Loss: 1.0911786556243896


Epoch 4:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 0.36830779910087585


Epoch 4:  19%|█▊        | 11/59 [00:21<01:33,  1.95s/it]

Loss: 0.5200117230415344


Epoch 4:  20%|██        | 12/59 [00:23<01:31,  1.95s/it]

Loss: 0.1652005910873413


Epoch 4:  22%|██▏       | 13/59 [00:25<01:29,  1.94s/it]

Loss: 1.248160719871521


Epoch 4:  24%|██▎       | 14/59 [00:27<01:27,  1.95s/it]

Loss: 0.5018278956413269


Epoch 4:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 0.569265604019165


Epoch 4:  27%|██▋       | 16/59 [00:31<01:23,  1.94s/it]

Loss: 0.34564951062202454


Epoch 4:  29%|██▉       | 17/59 [00:33<01:21,  1.94s/it]

Loss: 0.4546801447868347


Epoch 4:  31%|███       | 18/59 [00:35<01:19,  1.94s/it]

Loss: 0.5166032314300537


Epoch 4:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 0.3328089118003845


Epoch 4:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 0.2499425709247589


Epoch 4:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 0.37913909554481506


Epoch 4:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 1.0764586925506592


Epoch 4:  39%|███▉      | 23/59 [00:44<01:09,  1.94s/it]

Loss: 0.2924318313598633


Epoch 4:  41%|████      | 24/59 [00:46<01:08,  1.94s/it]

Loss: 0.2904283106327057


Epoch 4:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 0.5250231027603149


Epoch 4:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 0.10832276940345764


Epoch 4:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 0.2907918095588684


Epoch 4:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.6540954113006592


Epoch 4:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 0.11792580038309097


Epoch 4:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.10971944034099579


Epoch 4:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.2461806833744049


Epoch 4:  54%|█████▍    | 32/59 [01:02<00:52,  1.95s/it]

Loss: 0.4078877866268158


Epoch 4:  56%|█████▌    | 33/59 [01:04<00:50,  1.95s/it]

Loss: 0.19264420866966248


Epoch 4:  58%|█████▊    | 34/59 [01:06<00:48,  1.95s/it]

Loss: 0.32569897174835205


Epoch 4:  59%|█████▉    | 35/59 [01:08<00:46,  1.95s/it]

Loss: 0.05652964115142822


Epoch 4:  61%|██████    | 36/59 [01:10<00:44,  1.95s/it]

Loss: 0.44785088300704956


Epoch 4:  63%|██████▎   | 37/59 [01:11<00:42,  1.95s/it]

Loss: 0.08415040373802185


Epoch 4:  64%|██████▍   | 38/59 [01:13<00:40,  1.95s/it]

Loss: 0.48810404539108276


Epoch 4:  66%|██████▌   | 39/59 [01:15<00:38,  1.95s/it]

Loss: 0.3685600757598877


Epoch 4:  68%|██████▊   | 40/59 [01:17<00:36,  1.95s/it]

Loss: 0.3391992151737213


Epoch 4:  69%|██████▉   | 41/59 [01:19<00:35,  1.95s/it]

Loss: 0.009196067228913307


Epoch 4:  71%|███████   | 42/59 [01:21<00:33,  1.95s/it]

Loss: 0.38138043880462646


Epoch 4:  73%|███████▎  | 43/59 [01:23<00:31,  1.95s/it]

Loss: 0.382290780544281


Epoch 4:  75%|███████▍  | 44/59 [01:25<00:29,  1.95s/it]

Loss: 0.4116332232952118


Epoch 4:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 0.05945998430252075


Epoch 4:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.14094381034374237


Epoch 4:  80%|███████▉  | 47/59 [01:31<00:23,  1.94s/it]

Loss: 1.4043352603912354


Epoch 4:  81%|████████▏ | 48/59 [01:33<00:21,  1.94s/it]

Loss: 4.5820800664841954e-07


Epoch 4:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 1.4086642265319824


Epoch 4:  85%|████████▍ | 50/59 [01:37<00:17,  1.94s/it]

Loss: 0.9906014204025269


Epoch 4:  86%|████████▋ | 51/59 [01:39<00:15,  1.94s/it]

Loss: 1.3895443677902222


Epoch 4:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.284635990858078


Epoch 4:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.392341673374176


Epoch 4:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 1.4736988544464111


Epoch 4:  93%|█████████▎| 55/59 [01:46<00:07,  1.94s/it]

Loss: 0.5029432773590088


Epoch 4:  95%|█████████▍| 56/59 [01:48<00:05,  1.94s/it]

Loss: 0.0006999967154115438


Epoch 4:  97%|█████████▋| 57/59 [01:50<00:03,  1.94s/it]

Loss: 0.00021170209220144898


Epoch 4:  98%|█████████▊| 58/59 [01:52<00:01,  1.94s/it]

Loss: 1.3006205558776855


Epoch 4: 100%|██████████| 59/59 [01:53<00:00,  1.92s/it]

Loss: 2.0265558475784928e-07





Epoch 4 Validation Accuracy: 0.9359605911330049, F1-macro: 0.48346055979643765


Epoch 5:   2%|▏         | 1/59 [00:01<01:52,  1.93s/it]

Loss: 1.113073706626892


Epoch 5:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 0.30798712372779846


Epoch 5:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 1.0126080513000488


Epoch 5:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 0.8925387859344482


Epoch 5:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 0.4054145812988281


Epoch 5:  10%|█         | 6/59 [00:11<01:43,  1.94s/it]

Loss: 0.05168045684695244


Epoch 5:  12%|█▏        | 7/59 [00:13<01:41,  1.95s/it]

Loss: 0.43357378244400024


Epoch 5:  14%|█▎        | 8/59 [00:15<01:39,  1.95s/it]

Loss: 1.30438232421875


Epoch 5:  15%|█▌        | 9/59 [00:17<01:37,  1.95s/it]

Loss: 0.14000338315963745


Epoch 5:  17%|█▋        | 10/59 [00:19<01:35,  1.95s/it]

Loss: 0.48570889234542847


Epoch 5:  19%|█▊        | 11/59 [00:21<01:33,  1.95s/it]

Loss: 0.38696420192718506


Epoch 5:  20%|██        | 12/59 [00:23<01:31,  1.95s/it]

Loss: 0.483796626329422


Epoch 5:  22%|██▏       | 13/59 [00:25<01:29,  1.95s/it]

Loss: 0.3051341772079468


Epoch 5:  24%|██▎       | 14/59 [00:27<01:27,  1.95s/it]

Loss: 0.5818697214126587


Epoch 5:  25%|██▌       | 15/59 [00:29<01:25,  1.95s/it]

Loss: 0.08968908339738846


Epoch 5:  27%|██▋       | 16/59 [00:31<01:23,  1.95s/it]

Loss: 0.3340712785720825


Epoch 5:  29%|██▉       | 17/59 [00:33<01:21,  1.95s/it]

Loss: 0.1698959767818451


Epoch 5:  31%|███       | 18/59 [00:35<01:19,  1.95s/it]

Loss: 0.04843290150165558


Epoch 5:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 0.21023862063884735


Epoch 5:  34%|███▍      | 20/59 [00:38<01:15,  1.95s/it]

Loss: 0.300226092338562


Epoch 5:  36%|███▌      | 21/59 [00:40<01:13,  1.95s/it]

Loss: 0.004195378627628088


Epoch 5:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 0.3945316672325134


Epoch 5:  39%|███▉      | 23/59 [00:44<01:10,  1.95s/it]

Loss: 0.14570610225200653


Epoch 5:  41%|████      | 24/59 [00:46<01:08,  1.94s/it]

Loss: 0.1705215722322464


Epoch 5:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 0.23580659925937653


Epoch 5:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 0.19957956671714783


Epoch 5:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 0.3044270873069763


Epoch 5:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.5101717710494995


Epoch 5:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 0.33941221237182617


Epoch 5:  51%|█████     | 30/59 [00:58<00:56,  1.95s/it]

Loss: 0.36978745460510254


Epoch 5:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.15993443131446838


Epoch 5:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.016717450693249702


Epoch 5:  56%|█████▌    | 33/59 [01:04<00:50,  1.94s/it]

Loss: 0.1744939088821411


Epoch 5:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.19663147628307343


Epoch 5:  59%|█████▉    | 35/59 [01:08<00:46,  1.94s/it]

Loss: 0.3009468913078308


Epoch 5:  61%|██████    | 36/59 [01:10<00:44,  1.94s/it]

Loss: 0.32536575198173523


Epoch 5:  63%|██████▎   | 37/59 [01:11<00:42,  1.94s/it]

Loss: 0.015612047165632248


Epoch 5:  64%|██████▍   | 38/59 [01:13<00:40,  1.94s/it]

Loss: 0.39676183462142944


Epoch 5:  66%|██████▌   | 39/59 [01:15<00:38,  1.95s/it]

Loss: 0.17771610617637634


Epoch 5:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.6775806546211243


Epoch 5:  69%|██████▉   | 41/59 [01:19<00:35,  1.94s/it]

Loss: 0.13693749904632568


Epoch 5:  71%|███████   | 42/59 [01:21<00:33,  1.95s/it]

Loss: 0.4260551929473877


Epoch 5:  73%|███████▎  | 43/59 [01:23<00:31,  1.95s/it]

Loss: 0.21843871474266052


Epoch 5:  75%|███████▍  | 44/59 [01:25<00:29,  1.95s/it]

Loss: 0.8680476546287537


Epoch 5:  76%|███████▋  | 45/59 [01:27<00:27,  1.95s/it]

Loss: 0.3374379873275757


Epoch 5:  78%|███████▊  | 46/59 [01:29<00:25,  1.95s/it]

Loss: 0.3564090430736542


Epoch 5:  80%|███████▉  | 47/59 [01:31<00:23,  1.95s/it]

Loss: 0.19645501673221588


Epoch 5:  81%|████████▏ | 48/59 [01:33<00:21,  1.95s/it]

Loss: 0.29437169432640076


Epoch 5:  83%|████████▎ | 49/59 [01:35<00:19,  1.95s/it]

Loss: 0.1414114534854889


Epoch 5:  85%|████████▍ | 50/59 [01:37<00:17,  1.95s/it]

Loss: 0.30906984210014343


Epoch 5:  86%|████████▋ | 51/59 [01:39<00:15,  1.94s/it]

Loss: 0.09338211268186569


Epoch 5:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.033809252083301544


Epoch 5:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.3550184369087219


Epoch 5:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.2156546711921692


Epoch 5:  93%|█████████▎| 55/59 [01:46<00:07,  1.94s/it]

Loss: 0.2652851343154907


Epoch 5:  95%|█████████▍| 56/59 [01:48<00:05,  1.94s/it]

Loss: 0.26856422424316406


Epoch 5:  97%|█████████▋| 57/59 [01:50<00:03,  1.94s/it]

Loss: 0.2786792814731598


Epoch 5:  98%|█████████▊| 58/59 [01:52<00:01,  1.94s/it]

Loss: 0.673712968826294


Epoch 5: 100%|██████████| 59/59 [01:53<00:00,  1.92s/it]

Loss: 0.0002128803462255746





Epoch 5 Validation Accuracy: 0.9064039408866995, F1-macro: 0.6682150537634408


Epoch 6:   2%|▏         | 1/59 [00:01<01:52,  1.93s/it]

Loss: 0.13482727110385895


Epoch 6:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 0.10632288455963135


Epoch 6:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 0.11422260105609894


Epoch 6:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 0.1734873503446579


Epoch 6:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 0.04533042386174202


Epoch 6:  10%|█         | 6/59 [00:11<01:43,  1.94s/it]

Loss: 0.21576952934265137


Epoch 6:  12%|█▏        | 7/59 [00:13<01:41,  1.94s/it]

Loss: 0.14577651023864746


Epoch 6:  14%|█▎        | 8/59 [00:15<01:39,  1.94s/it]

Loss: 0.20263315737247467


Epoch 6:  15%|█▌        | 9/59 [00:17<01:37,  1.94s/it]

Loss: 0.18078017234802246


Epoch 6:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 0.1321253776550293


Epoch 6:  19%|█▊        | 11/59 [00:21<01:33,  1.95s/it]

Loss: 0.22994935512542725


Epoch 6:  20%|██        | 12/59 [00:23<01:31,  1.95s/it]

Loss: 0.4421980381011963


Epoch 6:  22%|██▏       | 13/59 [00:25<01:29,  1.95s/it]

Loss: 0.011992101557552814


Epoch 6:  24%|██▎       | 14/59 [00:27<01:27,  1.95s/it]

Loss: 0.38087934255599976


Epoch 6:  25%|██▌       | 15/59 [00:29<01:25,  1.95s/it]

Loss: 0.5454033017158508


Epoch 6:  27%|██▋       | 16/59 [00:31<01:23,  1.95s/it]

Loss: 0.4857906699180603


Epoch 6:  29%|██▉       | 17/59 [00:33<01:21,  1.95s/it]

Loss: 0.46073657274246216


Epoch 6:  31%|███       | 18/59 [00:35<01:19,  1.95s/it]

Loss: 0.3292718529701233


Epoch 6:  32%|███▏      | 19/59 [00:36<01:17,  1.95s/it]

Loss: 0.0056542884558439255


Epoch 6:  34%|███▍      | 20/59 [00:38<01:15,  1.95s/it]

Loss: 0.04633177071809769


Epoch 6:  36%|███▌      | 21/59 [00:40<01:13,  1.95s/it]

Loss: 0.22887566685676575


Epoch 6:  37%|███▋      | 22/59 [00:42<01:12,  1.95s/it]

Loss: 0.1694452464580536


Epoch 6:  39%|███▉      | 23/59 [00:44<01:10,  1.95s/it]

Loss: 0.15048837661743164


Epoch 6:  41%|████      | 24/59 [00:46<01:08,  1.94s/it]

Loss: 0.2523564100265503


Epoch 6:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 0.2453581839799881


Epoch 6:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 0.21600818634033203


Epoch 6:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 0.2266073226928711


Epoch 6:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.47464197874069214


Epoch 6:  49%|████▉     | 29/59 [00:56<00:58,  1.95s/it]

Loss: 0.39397910237312317


Epoch 6:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.25873103737831116


Epoch 6:  53%|█████▎    | 31/59 [01:00<00:54,  1.95s/it]

Loss: 0.20447297394275665


Epoch 6:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.21934446692466736


Epoch 6:  56%|█████▌    | 33/59 [01:04<00:50,  1.94s/it]

Loss: 0.06170669198036194


Epoch 6:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.22929342091083527


Epoch 6:  59%|█████▉    | 35/59 [01:08<00:46,  1.94s/it]

Loss: 0.12517701089382172


Epoch 6:  61%|██████    | 36/59 [01:10<00:44,  1.94s/it]

Loss: 0.15395480394363403


Epoch 6:  63%|██████▎   | 37/59 [01:11<00:42,  1.94s/it]

Loss: 0.14589127898216248


Epoch 6:  64%|██████▍   | 38/59 [01:13<00:40,  1.94s/it]

Loss: 0.23114489018917084


Epoch 6:  66%|██████▌   | 39/59 [01:15<00:38,  1.94s/it]

Loss: 0.1656162589788437


Epoch 6:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.24023422598838806


Epoch 6:  69%|██████▉   | 41/59 [01:19<00:34,  1.94s/it]

Loss: 0.361759752035141


Epoch 6:  71%|███████   | 42/59 [01:21<00:33,  1.94s/it]

Loss: 0.744986355304718


Epoch 6:  73%|███████▎  | 43/59 [01:23<00:31,  1.94s/it]

Loss: 0.1635119616985321


Epoch 6:  75%|███████▍  | 44/59 [01:25<00:29,  1.94s/it]

Loss: 0.11229691654443741


Epoch 6:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 0.22600044310092926


Epoch 6:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.26751187443733215


Epoch 6:  80%|███████▉  | 47/59 [01:31<00:23,  1.94s/it]

Loss: 0.20033244788646698


Epoch 6:  81%|████████▏ | 48/59 [01:33<00:21,  1.95s/it]

Loss: 0.20528237521648407


Epoch 6:  83%|████████▎ | 49/59 [01:35<00:19,  1.95s/it]

Loss: 0.05478030443191528


Epoch 6:  85%|████████▍ | 50/59 [01:37<00:17,  1.95s/it]

Loss: 0.3748505711555481


Epoch 6:  86%|████████▋ | 51/59 [01:39<00:15,  1.95s/it]

Loss: 0.09483657777309418


Epoch 6:  88%|████████▊ | 52/59 [01:41<00:13,  1.95s/it]

Loss: 0.18780675530433655


Epoch 6:  90%|████████▉ | 53/59 [01:43<00:11,  1.95s/it]

Loss: 0.1651676595211029


Epoch 6:  92%|█████████▏| 54/59 [01:45<00:09,  1.95s/it]

Loss: 0.13111786544322968


Epoch 6:  93%|█████████▎| 55/59 [01:46<00:07,  1.95s/it]

Loss: 0.6003957390785217


Epoch 6:  95%|█████████▍| 56/59 [01:48<00:05,  1.95s/it]

Loss: 0.6124197840690613


Epoch 6:  97%|█████████▋| 57/59 [01:50<00:03,  1.95s/it]

Loss: 0.45758768916130066


Epoch 6:  98%|█████████▊| 58/59 [01:52<00:01,  1.95s/it]

Loss: 0.4643424153327942


Epoch 6: 100%|██████████| 59/59 [01:53<00:00,  1.92s/it]

Loss: 0.3327786326408386





Epoch 6 Validation Accuracy: 0.9310344827586207, F1-macro: 0.6318652849740932


Epoch 7:   2%|▏         | 1/59 [00:01<01:51,  1.93s/it]

Loss: 0.2302354872226715


Epoch 7:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 0.4735119640827179


Epoch 7:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 0.18581220507621765


Epoch 7:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 0.5223615169525146


Epoch 7:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 0.15437473356723785


Epoch 7:  10%|█         | 6/59 [00:11<01:42,  1.94s/it]

Loss: 0.5214560031890869


Epoch 7:  12%|█▏        | 7/59 [00:13<01:41,  1.94s/it]

Loss: 0.2542475163936615


Epoch 7:  14%|█▎        | 8/59 [00:15<01:39,  1.94s/it]

Loss: 0.2274843007326126


Epoch 7:  15%|█▌        | 9/59 [00:17<01:37,  1.94s/it]

Loss: 0.16784755885601044


Epoch 7:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 0.6517511606216431


Epoch 7:  19%|█▊        | 11/59 [00:21<01:33,  1.94s/it]

Loss: 0.631962239742279


Epoch 7:  20%|██        | 12/59 [00:23<01:31,  1.94s/it]

Loss: 0.2734907865524292


Epoch 7:  22%|██▏       | 13/59 [00:25<01:29,  1.94s/it]

Loss: 0.13429579138755798


Epoch 7:  24%|██▎       | 14/59 [00:27<01:27,  1.94s/it]

Loss: 0.1156461089849472


Epoch 7:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 0.5774555802345276


Epoch 7:  27%|██▋       | 16/59 [00:31<01:23,  1.94s/it]

Loss: 0.11711174249649048


Epoch 7:  29%|██▉       | 17/59 [00:33<01:21,  1.94s/it]

Loss: 0.13738781213760376


Epoch 7:  31%|███       | 18/59 [00:34<01:19,  1.94s/it]

Loss: 0.28447145223617554


Epoch 7:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 0.5917090177536011


Epoch 7:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 0.9075313806533813


Epoch 7:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 0.10982967913150787


Epoch 7:  37%|███▋      | 22/59 [00:42<01:11,  1.95s/it]

Loss: 0.5730013847351074


Epoch 7:  39%|███▉      | 23/59 [00:44<01:10,  1.95s/it]

Loss: 0.5125678777694702


Epoch 7:  41%|████      | 24/59 [00:46<01:08,  1.95s/it]

Loss: 0.5545644760131836


Epoch 7:  42%|████▏     | 25/59 [00:48<01:06,  1.95s/it]

Loss: 0.15778601169586182


Epoch 7:  44%|████▍     | 26/59 [00:50<01:04,  1.95s/it]

Loss: 0.26254406571388245


Epoch 7:  46%|████▌     | 27/59 [00:52<01:02,  1.95s/it]

Loss: 0.21209010481834412


Epoch 7:  47%|████▋     | 28/59 [00:54<01:00,  1.95s/it]

Loss: 0.11509199440479279


Epoch 7:  49%|████▉     | 29/59 [00:56<00:58,  1.95s/it]

Loss: 0.47872868180274963


Epoch 7:  51%|█████     | 30/59 [00:58<00:56,  1.95s/it]

Loss: 0.5628600120544434


Epoch 7:  53%|█████▎    | 31/59 [01:00<00:54,  1.95s/it]

Loss: 0.1840953528881073


Epoch 7:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.4411015212535858


Epoch 7:  56%|█████▌    | 33/59 [01:04<00:50,  1.95s/it]

Loss: 0.2679794430732727


Epoch 7:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.4560684263706207


Epoch 7:  59%|█████▉    | 35/59 [01:08<00:46,  1.94s/it]

Loss: 0.2453637719154358


Epoch 7:  61%|██████    | 36/59 [01:10<00:44,  1.94s/it]

Loss: 0.07677533477544785


Epoch 7:  63%|██████▎   | 37/59 [01:11<00:42,  1.94s/it]

Loss: 0.5698248147964478


Epoch 7:  64%|██████▍   | 38/59 [01:13<00:40,  1.94s/it]

Loss: 0.19676408171653748


Epoch 7:  66%|██████▌   | 39/59 [01:15<00:38,  1.94s/it]

Loss: 0.6285462975502014


Epoch 7:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.2528902292251587


Epoch 7:  69%|██████▉   | 41/59 [01:19<00:35,  1.94s/it]

Loss: 0.4726535379886627


Epoch 7:  71%|███████   | 42/59 [01:21<00:33,  1.95s/it]

Loss: 0.31845349073410034


Epoch 7:  73%|███████▎  | 43/59 [01:23<00:31,  1.94s/it]

Loss: 0.0013633898925036192


Epoch 7:  75%|███████▍  | 44/59 [01:25<00:29,  1.94s/it]

Loss: 0.5384628772735596


Epoch 7:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 0.908598780632019


Epoch 7:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.7288422584533691


Epoch 7:  80%|███████▉  | 47/59 [01:31<00:23,  1.94s/it]

Loss: 0.2749953269958496


Epoch 7:  81%|████████▏ | 48/59 [01:33<00:21,  1.94s/it]

Loss: 0.10604531317949295


Epoch 7:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 0.7530714273452759


Epoch 7:  85%|████████▍ | 50/59 [01:37<00:17,  1.94s/it]

Loss: 0.35903891921043396


Epoch 7:  86%|████████▋ | 51/59 [01:39<00:15,  1.94s/it]

Loss: 0.36840569972991943


Epoch 7:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.837547779083252


Epoch 7:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.9525700211524963


Epoch 7:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.6239595413208008


Epoch 7:  93%|█████████▎| 55/59 [01:46<00:07,  1.95s/it]

Loss: 0.2358693778514862


Epoch 7:  95%|█████████▍| 56/59 [01:48<00:05,  1.95s/it]

Loss: 0.5443882346153259


Epoch 7:  97%|█████████▋| 57/59 [01:50<00:03,  1.95s/it]

Loss: 0.1326400339603424


Epoch 7:  98%|█████████▊| 58/59 [01:52<00:01,  1.95s/it]

Loss: 0.19606003165245056


Epoch 7: 100%|██████████| 59/59 [01:53<00:00,  1.92s/it]

Loss: 0.011419190093874931





Epoch 7 Validation Accuracy: 0.9261083743842364, F1-macro: 0.5858833129334966


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model10, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.9313725490196079, F1-macro: 0.6319587628865979


In [None]:
current_type = 'should statements'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 11. Mental Filter

In [None]:
# Add labels
data1_1_labels = list(data1['mental filter'][data1_1.index])
data2_1_labels = list(data2['mental filter'][data2_1.index])
data3_1_labels = list(data3['mental filter'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model11 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model11.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model11.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model11(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model11, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   2%|▏         | 1/59 [00:01<01:50,  1.91s/it]

Loss: 3.844313859939575


Epoch 1:   3%|▎         | 2/59 [00:03<01:49,  1.93s/it]

Loss: 1.3824689388275146


Epoch 1:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 3.3161845207214355


Epoch 1:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 4.849904537200928


Epoch 1:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 0.5749248266220093


Epoch 1:  10%|█         | 6/59 [00:11<01:42,  1.94s/it]

Loss: 5.3599348068237305


Epoch 1:  12%|█▏        | 7/59 [00:13<01:41,  1.94s/it]

Loss: 1.2581868171691895


Epoch 1:  14%|█▎        | 8/59 [00:15<01:39,  1.94s/it]

Loss: 5.215400022962058e-08


Epoch 1:  15%|█▌        | 9/59 [00:17<01:37,  1.94s/it]

Loss: 1.814711332321167


Epoch 1:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 5.347845077514648


Epoch 1:  19%|█▊        | 11/59 [00:21<01:33,  1.94s/it]

Loss: 4.979660987854004


Epoch 1:  20%|██        | 12/59 [00:23<01:31,  1.94s/it]

Loss: 2.7071969509124756


Epoch 1:  22%|██▏       | 13/59 [00:25<01:29,  1.94s/it]

Loss: 1.3535370826721191


Epoch 1:  24%|██▎       | 14/59 [00:27<01:27,  1.94s/it]

Loss: 0.476940393447876


Epoch 1:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 1.000394582748413


Epoch 1:  27%|██▋       | 16/59 [00:31<01:23,  1.94s/it]

Loss: 1.7350536584854126


Epoch 1:  29%|██▉       | 17/59 [00:33<01:21,  1.94s/it]

Loss: 2.085695743560791


Epoch 1:  31%|███       | 18/59 [00:34<01:19,  1.95s/it]

Loss: 1.7072515487670898


Epoch 1:  32%|███▏      | 19/59 [00:36<01:17,  1.95s/it]

Loss: 0.8197923898696899


Epoch 1:  34%|███▍      | 20/59 [00:38<01:15,  1.95s/it]

Loss: 0.945354700088501


Epoch 1:  36%|███▌      | 21/59 [00:40<01:13,  1.95s/it]

Loss: 2.0200414657592773


Epoch 1:  37%|███▋      | 22/59 [00:42<01:12,  1.95s/it]

Loss: 1.847456932067871


Epoch 1:  39%|███▉      | 23/59 [00:44<01:10,  1.95s/it]

Loss: 1.2212473154067993


Epoch 1:  41%|████      | 24/59 [00:46<01:08,  1.95s/it]

Loss: 0.5372313857078552


Epoch 1:  42%|████▏     | 25/59 [00:48<01:06,  1.95s/it]

Loss: 1.5358836650848389


Epoch 1:  44%|████▍     | 26/59 [00:50<01:04,  1.95s/it]

Loss: 1.4018127918243408


Epoch 1:  46%|████▌     | 27/59 [00:52<01:02,  1.95s/it]

Loss: 0.07245815545320511


Epoch 1:  47%|████▋     | 28/59 [00:54<01:00,  1.95s/it]

Loss: 1.156989574432373


Epoch 1:  49%|████▉     | 29/59 [00:56<00:58,  1.95s/it]

Loss: 1.5199227333068848


Epoch 1:  51%|█████     | 30/59 [00:58<00:56,  1.95s/it]

Loss: 0.416370153427124


Epoch 1:  53%|█████▎    | 31/59 [01:00<00:54,  1.95s/it]

Loss: 0.10191048681735992


Epoch 1:  54%|█████▍    | 32/59 [01:02<00:52,  1.95s/it]

Loss: 0.6638753414154053


Epoch 1:  56%|█████▌    | 33/59 [01:04<00:50,  1.95s/it]

Loss: 0.9422229528427124


Epoch 1:  58%|█████▊    | 34/59 [01:06<00:48,  1.95s/it]

Loss: 0.3377228379249573


Epoch 1:  59%|█████▉    | 35/59 [01:08<00:46,  1.95s/it]

Loss: 0.954345166683197


Epoch 1:  61%|██████    | 36/59 [01:10<00:44,  1.95s/it]

Loss: 0.6439030170440674


Epoch 1:  63%|██████▎   | 37/59 [01:11<00:42,  1.95s/it]

Loss: 0.22031787037849426


Epoch 1:  64%|██████▍   | 38/59 [01:13<00:41,  1.95s/it]

Loss: 0.39757463335990906


Epoch 1:  66%|██████▌   | 39/59 [01:15<00:39,  1.95s/it]

Loss: 0.34808140993118286


Epoch 1:  68%|██████▊   | 40/59 [01:17<00:37,  1.95s/it]

Loss: 0.7674084305763245


Epoch 1:  69%|██████▉   | 41/59 [01:19<00:35,  1.95s/it]

Loss: 0.6228353977203369


Epoch 1:  71%|███████   | 42/59 [01:21<00:33,  1.95s/it]

Loss: 0.6126949787139893


Epoch 1:  73%|███████▎  | 43/59 [01:23<00:31,  1.95s/it]

Loss: 0.4815964698791504


Epoch 1:  75%|███████▍  | 44/59 [01:25<00:29,  1.95s/it]

Loss: 0.794887900352478


Epoch 1:  76%|███████▋  | 45/59 [01:27<00:27,  1.95s/it]

Loss: 1.1243317127227783


Epoch 1:  78%|███████▊  | 46/59 [01:29<00:25,  1.95s/it]

Loss: 0.23403605818748474


Epoch 1:  80%|███████▉  | 47/59 [01:31<00:23,  1.95s/it]

Loss: 0.4243316054344177


Epoch 1:  81%|████████▏ | 48/59 [01:33<00:21,  1.94s/it]

Loss: 0.8622473478317261


Epoch 1:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 0.5310439467430115


Epoch 1:  85%|████████▍ | 50/59 [01:37<00:17,  1.94s/it]

Loss: 0.4818387031555176


Epoch 1:  86%|████████▋ | 51/59 [01:39<00:15,  1.94s/it]

Loss: 1.4421234130859375


Epoch 1:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.4074392020702362


Epoch 1:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.24233831465244293


Epoch 1:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.1002233624458313


Epoch 1:  93%|█████████▎| 55/59 [01:47<00:07,  1.94s/it]

Loss: 0.8022488951683044


Epoch 1:  95%|█████████▍| 56/59 [01:48<00:05,  1.94s/it]

Loss: 0.22216397523880005


Epoch 1:  97%|█████████▋| 57/59 [01:50<00:03,  1.94s/it]

Loss: 0.23015610873699188


Epoch 1: 100%|██████████| 59/59 [01:52<00:00,  1.91s/it]

Loss: 0.7094781398773193
Loss: 0.010556332767009735





Epoch 1 Validation Accuracy: 0.5862068965517241, F1-macro: 0.48214285714285715


Epoch 2:   2%|▏         | 1/59 [00:01<01:51,  1.93s/it]

Loss: 0.7982749938964844


Epoch 2:   3%|▎         | 2/59 [00:03<01:50,  1.93s/it]

Loss: 0.5164329409599304


Epoch 2:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 0.765476405620575


Epoch 2:   7%|▋         | 4/59 [00:07<01:46,  1.93s/it]

Loss: 1.1129887104034424


Epoch 2:   8%|▊         | 5/59 [00:09<01:44,  1.93s/it]

Loss: 0.8924276828765869


Epoch 2:  10%|█         | 6/59 [00:11<01:42,  1.94s/it]

Loss: 0.6278345584869385


Epoch 2:  12%|█▏        | 7/59 [00:13<01:40,  1.93s/it]

Loss: 0.5874854922294617


Epoch 2:  14%|█▎        | 8/59 [00:15<01:38,  1.94s/it]

Loss: 0.9299474954605103


Epoch 2:  15%|█▌        | 9/59 [00:17<01:36,  1.93s/it]

Loss: 0.5120036005973816


Epoch 2:  17%|█▋        | 10/59 [00:19<01:34,  1.93s/it]

Loss: 1.3288092613220215


Epoch 2:  19%|█▊        | 11/59 [00:21<01:32,  1.93s/it]

Loss: 0.3102114498615265


Epoch 2:  20%|██        | 12/59 [00:23<01:30,  1.94s/it]

Loss: 0.4678932726383209


Epoch 2:  22%|██▏       | 13/59 [00:25<01:29,  1.94s/it]

Loss: 0.7201919555664062


Epoch 2:  24%|██▎       | 14/59 [00:27<01:27,  1.94s/it]

Loss: 0.6515856981277466


Epoch 2:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 0.4791266918182373


Epoch 2:  27%|██▋       | 16/59 [00:30<01:23,  1.94s/it]

Loss: 0.3467656672000885


Epoch 2:  29%|██▉       | 17/59 [00:32<01:21,  1.94s/it]

Loss: 0.5219783782958984


Epoch 2:  31%|███       | 18/59 [00:34<01:19,  1.94s/it]

Loss: 0.20904111862182617


Epoch 2:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 0.5246599912643433


Epoch 2:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 0.5144715309143066


Epoch 2:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 0.13763204216957092


Epoch 2:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 0.47511225938796997


Epoch 2:  39%|███▉      | 23/59 [00:44<01:09,  1.94s/it]

Loss: 0.22132839262485504


Epoch 2:  41%|████      | 24/59 [00:46<01:08,  1.94s/it]

Loss: 0.6947842836380005


Epoch 2:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 0.2995190918445587


Epoch 2:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 0.34791290760040283


Epoch 2:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 0.0663420781493187


Epoch 2:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.33135083317756653


Epoch 2:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 0.1216353103518486


Epoch 2:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.3016658425331116


Epoch 2:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.20151008665561676


Epoch 2:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.05108688771724701


Epoch 2:  56%|█████▌    | 33/59 [01:04<00:50,  1.94s/it]

Loss: 0.5867825150489807


Epoch 2:  58%|█████▊    | 34/59 [01:05<00:48,  1.94s/it]

Loss: 0.4228111207485199


Epoch 2:  59%|█████▉    | 35/59 [01:07<00:46,  1.94s/it]

Loss: 0.4779753088951111


Epoch 2:  61%|██████    | 36/59 [01:09<00:44,  1.95s/it]

Loss: 0.1520707905292511


Epoch 2:  63%|██████▎   | 37/59 [01:11<00:42,  1.95s/it]

Loss: 0.5437526702880859


Epoch 2:  64%|██████▍   | 38/59 [01:13<00:40,  1.95s/it]

Loss: 0.19695928692817688


Epoch 2:  66%|██████▌   | 39/59 [01:15<00:38,  1.95s/it]

Loss: 0.18988218903541565


Epoch 2:  68%|██████▊   | 40/59 [01:17<00:36,  1.95s/it]

Loss: 0.130279541015625


Epoch 2:  69%|██████▉   | 41/59 [01:19<00:35,  1.95s/it]

Loss: 0.12028973549604416


Epoch 2:  71%|███████   | 42/59 [01:21<00:33,  1.95s/it]

Loss: 0.20939214527606964


Epoch 2:  73%|███████▎  | 43/59 [01:23<00:31,  1.95s/it]

Loss: 0.20857559144496918


Epoch 2:  75%|███████▍  | 44/59 [01:25<00:29,  1.95s/it]

Loss: 0.13633698225021362


Epoch 2:  76%|███████▋  | 45/59 [01:27<00:27,  1.95s/it]

Loss: 0.3208909034729004


Epoch 2:  78%|███████▊  | 46/59 [01:29<00:25,  1.95s/it]

Loss: 0.3836216926574707


Epoch 2:  80%|███████▉  | 47/59 [01:31<00:23,  1.95s/it]

Loss: 0.3447641134262085


Epoch 2:  81%|████████▏ | 48/59 [01:33<00:21,  1.95s/it]

Loss: 0.2169111669063568


Epoch 2:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 0.2744925320148468


Epoch 2:  85%|████████▍ | 50/59 [01:37<00:17,  1.95s/it]

Loss: 0.24663153290748596


Epoch 2:  86%|████████▋ | 51/59 [01:39<00:15,  1.95s/it]

Loss: 0.05099596083164215


Epoch 2:  88%|████████▊ | 52/59 [01:41<00:13,  1.95s/it]

Loss: 0.4815746247768402


Epoch 2:  90%|████████▉ | 53/59 [01:42<00:11,  1.95s/it]

Loss: 0.5973976850509644


Epoch 2:  92%|█████████▏| 54/59 [01:44<00:09,  1.95s/it]

Loss: 0.18829184770584106


Epoch 2:  93%|█████████▎| 55/59 [01:46<00:07,  1.95s/it]

Loss: 1.19537353515625


Epoch 2:  95%|█████████▍| 56/59 [01:48<00:05,  1.95s/it]

Loss: 0.5218092203140259


Epoch 2:  97%|█████████▋| 57/59 [01:50<00:03,  1.95s/it]

Loss: 0.3116410970687866


Epoch 2: 100%|██████████| 59/59 [01:52<00:00,  1.91s/it]

Loss: 0.38583797216415405
Loss: 2.185218334197998





Epoch 2 Validation Accuracy: 0.9113300492610837, F1-macro: 0.47680412371134023


Epoch 3:   2%|▏         | 1/59 [00:01<01:52,  1.95s/it]

Loss: 0.42857247591018677


Epoch 3:   3%|▎         | 2/59 [00:03<01:50,  1.95s/it]

Loss: 1.1765022277832031


Epoch 3:   5%|▌         | 3/59 [00:05<01:48,  1.95s/it]

Loss: 0.9031074047088623


Epoch 3:   7%|▋         | 4/59 [00:07<01:46,  1.95s/it]

Loss: 2.0719003677368164


Epoch 3:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 1.1440327167510986


Epoch 3:  10%|█         | 6/59 [00:11<01:43,  1.94s/it]

Loss: 0.22724635899066925


Epoch 3:  12%|█▏        | 7/59 [00:13<01:41,  1.94s/it]

Loss: 0.2943043112754822


Epoch 3:  14%|█▎        | 8/59 [00:15<01:39,  1.95s/it]

Loss: 0.4483063220977783


Epoch 3:  15%|█▌        | 9/59 [00:17<01:37,  1.95s/it]

Loss: 1.7518360614776611


Epoch 3:  17%|█▋        | 10/59 [00:19<01:35,  1.95s/it]

Loss: 1.099657416343689


Epoch 3:  19%|█▊        | 11/59 [00:21<01:33,  1.95s/it]

Loss: 1.2699202299118042


Epoch 3:  20%|██        | 12/59 [00:23<01:31,  1.95s/it]

Loss: 1.0861188173294067


Epoch 3:  22%|██▏       | 13/59 [00:25<01:29,  1.95s/it]

Loss: 0.511814296245575


Epoch 3:  24%|██▎       | 14/59 [00:27<01:27,  1.95s/it]

Loss: 2.399074219283648e-05


Epoch 3:  25%|██▌       | 15/59 [00:29<01:25,  1.95s/it]

Loss: 1.3955585956573486


Epoch 3:  27%|██▋       | 16/59 [00:31<01:23,  1.95s/it]

Loss: 1.3922460079193115


Epoch 3:  29%|██▉       | 17/59 [00:33<01:21,  1.95s/it]

Loss: 1.915372371673584


Epoch 3:  31%|███       | 18/59 [00:35<01:19,  1.95s/it]

Loss: 0.824799120426178


Epoch 3:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 0.4829259216785431


Epoch 3:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 1.4318623542785645


Epoch 3:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 1.2163565158843994


Epoch 3:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 0.3642803430557251


Epoch 3:  39%|███▉      | 23/59 [00:44<01:09,  1.94s/it]

Loss: 0.5227693319320679


Epoch 3:  41%|████      | 24/59 [00:46<01:08,  1.94s/it]

Loss: 0.3070303201675415


Epoch 3:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 1.0647722482681274


Epoch 3:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 0.7567086815834045


Epoch 3:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 0.6149530410766602


Epoch 3:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.65263831615448


Epoch 3:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 0.09730581939220428


Epoch 3:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.37651583552360535


Epoch 3:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.5384077429771423


Epoch 3:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.31019917130470276


Epoch 3:  56%|█████▌    | 33/59 [01:04<00:50,  1.94s/it]

Loss: 0.4155117869377136


Epoch 3:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.6387288570404053


Epoch 3:  59%|█████▉    | 35/59 [01:08<00:46,  1.94s/it]

Loss: 1.0652813911437988


Epoch 3:  61%|██████    | 36/59 [01:10<00:44,  1.94s/it]

Loss: 0.23239541053771973


Epoch 3:  63%|██████▎   | 37/59 [01:11<00:42,  1.94s/it]

Loss: 0.3696274757385254


Epoch 3:  64%|██████▍   | 38/59 [01:13<00:40,  1.94s/it]

Loss: 0.23659372329711914


Epoch 3:  66%|██████▌   | 39/59 [01:15<00:38,  1.94s/it]

Loss: 0.8120132088661194


Epoch 3:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.5830224752426147


Epoch 3:  69%|██████▉   | 41/59 [01:19<00:35,  1.94s/it]

Loss: 0.6377514004707336


Epoch 3:  71%|███████   | 42/59 [01:21<00:33,  1.94s/it]

Loss: 0.37181320786476135


Epoch 3:  73%|███████▎  | 43/59 [01:23<00:31,  1.95s/it]

Loss: 0.12949544191360474


Epoch 3:  75%|███████▍  | 44/59 [01:25<00:29,  1.95s/it]

Loss: 0.3660934865474701


Epoch 3:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 0.9465268850326538


Epoch 3:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.248117133975029


Epoch 3:  80%|███████▉  | 47/59 [01:31<00:23,  1.95s/it]

Loss: 0.5992376804351807


Epoch 3:  81%|████████▏ | 48/59 [01:33<00:21,  1.95s/it]

Loss: 1.3011150360107422


Epoch 3:  83%|████████▎ | 49/59 [01:35<00:19,  1.95s/it]

Loss: 0.6989749670028687


Epoch 3:  85%|████████▍ | 50/59 [01:37<00:17,  1.95s/it]

Loss: 1.2760581970214844


Epoch 3:  86%|████████▋ | 51/59 [01:39<00:15,  1.95s/it]

Loss: 0.6811465620994568


Epoch 3:  88%|████████▊ | 52/59 [01:41<00:13,  1.95s/it]

Loss: 1.0973584651947021


Epoch 3:  90%|████████▉ | 53/59 [01:43<00:11,  1.95s/it]

Loss: 0.1056610494852066


Epoch 3:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.35483697056770325


Epoch 3:  93%|█████████▎| 55/59 [01:46<00:07,  1.94s/it]

Loss: 0.5081303119659424


Epoch 3:  95%|█████████▍| 56/59 [01:48<00:05,  1.94s/it]

Loss: 0.20127514004707336


Epoch 3:  97%|█████████▋| 57/59 [01:50<00:03,  1.95s/it]

Loss: 0.8014360070228577


Epoch 3: 100%|██████████| 59/59 [01:52<00:00,  1.91s/it]

Loss: 0.45821094512939453
Loss: 0.0004870789125561714





Epoch 3 Validation Accuracy: 0.9211822660098522, F1-macro: 0.615530303030303


Epoch 4:   2%|▏         | 1/59 [00:01<01:51,  1.93s/it]

Loss: 0.399586021900177


Epoch 4:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 0.35512033104896545


Epoch 4:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 0.4392905831336975


Epoch 4:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 0.12708118557929993


Epoch 4:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 0.35331544280052185


Epoch 4:  10%|█         | 6/59 [00:11<01:43,  1.94s/it]

Loss: 0.44172537326812744


Epoch 4:  12%|█▏        | 7/59 [00:13<01:41,  1.94s/it]

Loss: 0.09911207109689713


Epoch 4:  14%|█▎        | 8/59 [00:15<01:39,  1.94s/it]

Loss: 0.4931345582008362


Epoch 4:  15%|█▌        | 9/59 [00:17<01:37,  1.94s/it]

Loss: 0.05305665358901024


Epoch 4:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 0.21806049346923828


Epoch 4:  19%|█▊        | 11/59 [00:21<01:33,  1.94s/it]

Loss: 0.2621617913246155


Epoch 4:  20%|██        | 12/59 [00:23<01:31,  1.94s/it]

Loss: 0.8710517883300781


Epoch 4:  22%|██▏       | 13/59 [00:25<01:29,  1.94s/it]

Loss: 0.7102906703948975


Epoch 4:  24%|██▎       | 14/59 [00:27<01:27,  1.94s/it]

Loss: 0.11253421753644943


Epoch 4:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 0.4232189357280731


Epoch 4:  27%|██▋       | 16/59 [00:31<01:23,  1.95s/it]

Loss: 0.4030526876449585


Epoch 4:  29%|██▉       | 17/59 [00:33<01:21,  1.95s/it]

Loss: 0.7245128750801086


Epoch 4:  31%|███       | 18/59 [00:34<01:19,  1.95s/it]

Loss: 0.27263450622558594


Epoch 4:  32%|███▏      | 19/59 [00:36<01:17,  1.95s/it]

Loss: 0.4762440621852875


Epoch 4:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 0.20457874238491058


Epoch 4:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 0.840288519859314


Epoch 4:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 0.6568537950515747


Epoch 4:  39%|███▉      | 23/59 [00:44<01:09,  1.94s/it]

Loss: 0.46049773693084717


Epoch 4:  41%|████      | 24/59 [00:46<01:07,  1.94s/it]

Loss: 0.7392423152923584


Epoch 4:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 0.1404399573802948


Epoch 4:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 1.1550415754318237


Epoch 4:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 1.3286491632461548


Epoch 4:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 1.1006300449371338


Epoch 4:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 1.9213365316390991


Epoch 4:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.5559866428375244


Epoch 4:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.3980799615383148


Epoch 4:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.47203493118286133


Epoch 4:  56%|█████▌    | 33/59 [01:04<00:50,  1.94s/it]

Loss: 1.140243649482727


Epoch 4:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.545653760433197


Epoch 4:  59%|█████▉    | 35/59 [01:08<00:46,  1.94s/it]

Loss: 0.5173066854476929


Epoch 4:  61%|██████    | 36/59 [01:09<00:44,  1.94s/it]

Loss: 0.8858671188354492


Epoch 4:  63%|██████▎   | 37/59 [01:11<00:42,  1.94s/it]

Loss: 0.1973526030778885


Epoch 4:  64%|██████▍   | 38/59 [01:13<00:40,  1.94s/it]

Loss: 0.43113234639167786


Epoch 4:  66%|██████▌   | 39/59 [01:15<00:38,  1.94s/it]

Loss: 1.0365121364593506


Epoch 4:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.15480190515518188


Epoch 4:  69%|██████▉   | 41/59 [01:19<00:35,  1.94s/it]

Loss: 0.1120825856924057


Epoch 4:  71%|███████   | 42/59 [01:21<00:33,  1.94s/it]

Loss: 0.4749297797679901


Epoch 4:  73%|███████▎  | 43/59 [01:23<00:31,  1.94s/it]

Loss: 0.7651844024658203


Epoch 4:  75%|███████▍  | 44/59 [01:25<00:29,  1.94s/it]

Loss: 0.2302255630493164


Epoch 4:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 0.07761164009571075


Epoch 4:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.2139863222837448


Epoch 4:  80%|███████▉  | 47/59 [01:31<00:23,  1.95s/it]

Loss: 1.21207594871521


Epoch 4:  81%|████████▏ | 48/59 [01:33<00:21,  1.94s/it]

Loss: 0.38224950432777405


Epoch 4:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 0.9447265267372131


Epoch 4:  85%|████████▍ | 50/59 [01:37<00:17,  1.94s/it]

Loss: 0.647728443145752


Epoch 4:  86%|████████▋ | 51/59 [01:39<00:15,  1.95s/it]

Loss: 0.5118680000305176


Epoch 4:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.4084559679031372


Epoch 4:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.38089293241500854


Epoch 4:  92%|█████████▏| 54/59 [01:44<00:09,  1.95s/it]

Loss: 0.8745127320289612


Epoch 4:  93%|█████████▎| 55/59 [01:46<00:07,  1.95s/it]

Loss: 0.18294657766819


Epoch 4:  95%|█████████▍| 56/59 [01:48<00:05,  1.95s/it]

Loss: 0.17414475977420807


Epoch 4:  97%|█████████▋| 57/59 [01:50<00:03,  1.95s/it]

Loss: 0.416978120803833


Epoch 4: 100%|██████████| 59/59 [01:52<00:00,  1.91s/it]

Loss: 0.25120458006858826
Loss: 0.2404792606830597





Epoch 4 Validation Accuracy: 0.9113300492610837, F1-macro: 0.47680412371134023


Epoch 5:   2%|▏         | 1/59 [00:01<01:51,  1.93s/it]

Loss: 0.2729767858982086


Epoch 5:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 0.705920398235321


Epoch 5:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 0.9028099179267883


Epoch 5:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 1.3321752548217773


Epoch 5:   8%|▊         | 5/59 [00:09<01:44,  1.94s/it]

Loss: 0.5625031590461731


Epoch 5:  10%|█         | 6/59 [00:11<01:42,  1.94s/it]

Loss: 0.28627216815948486


Epoch 5:  12%|█▏        | 7/59 [00:13<01:41,  1.94s/it]

Loss: 1.050990104675293


Epoch 5:  14%|█▎        | 8/59 [00:15<01:39,  1.94s/it]

Loss: 0.6566342115402222


Epoch 5:  15%|█▌        | 9/59 [00:17<01:37,  1.94s/it]

Loss: 0.4607877731323242


Epoch 5:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 0.5077313184738159


Epoch 5:  19%|█▊        | 11/59 [00:21<01:33,  1.94s/it]

Loss: 0.3850666284561157


Epoch 5:  20%|██        | 12/59 [00:23<01:31,  1.94s/it]

Loss: 0.6403471827507019


Epoch 5:  22%|██▏       | 13/59 [00:25<01:29,  1.94s/it]

Loss: 0.17886728048324585


Epoch 5:  24%|██▎       | 14/59 [00:27<01:27,  1.94s/it]

Loss: 0.5929156541824341


Epoch 5:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 0.451763391494751


Epoch 5:  27%|██▋       | 16/59 [00:31<01:23,  1.94s/it]

Loss: 0.5730876922607422


Epoch 5:  29%|██▉       | 17/59 [00:33<01:21,  1.94s/it]

Loss: 0.5535061955451965


Epoch 5:  31%|███       | 18/59 [00:34<01:19,  1.94s/it]

Loss: 0.8951019048690796


Epoch 5:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 1.050670862197876


Epoch 5:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 0.6662925481796265


Epoch 5:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 0.46668335795402527


Epoch 5:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 0.11270197480916977


Epoch 5:  39%|███▉      | 23/59 [00:44<01:10,  1.94s/it]

Loss: 0.6644558310508728


Epoch 5:  41%|████      | 24/59 [00:46<01:08,  1.95s/it]

Loss: 0.17387518286705017


Epoch 5:  42%|████▏     | 25/59 [00:48<01:06,  1.95s/it]

Loss: 0.5061209201812744


Epoch 5:  44%|████▍     | 26/59 [00:50<01:04,  1.95s/it]

Loss: 0.5017539858818054


Epoch 5:  46%|████▌     | 27/59 [00:52<01:02,  1.95s/it]

Loss: 0.23360390961170197


Epoch 5:  47%|████▋     | 28/59 [00:54<01:00,  1.95s/it]

Loss: 0.7012379169464111


Epoch 5:  49%|████▉     | 29/59 [00:56<00:58,  1.95s/it]

Loss: 0.13305915892124176


Epoch 5:  51%|█████     | 30/59 [00:58<00:56,  1.95s/it]

Loss: 0.31761497259140015


Epoch 5:  53%|█████▎    | 31/59 [01:00<00:54,  1.95s/it]

Loss: 0.5104542970657349


Epoch 5:  54%|█████▍    | 32/59 [01:02<00:52,  1.95s/it]

Loss: 0.5898494720458984


Epoch 5:  56%|█████▌    | 33/59 [01:04<00:50,  1.95s/it]

Loss: 0.683899998664856


Epoch 5:  58%|█████▊    | 34/59 [01:06<00:48,  1.95s/it]

Loss: 0.7973830699920654


Epoch 5:  59%|█████▉    | 35/59 [01:08<00:46,  1.95s/it]

Loss: 1.0408462285995483


Epoch 5:  61%|██████    | 36/59 [01:10<00:44,  1.94s/it]

Loss: 0.32307836413383484


Epoch 5:  63%|██████▎   | 37/59 [01:11<00:42,  1.94s/it]

Loss: 0.1882306933403015


Epoch 5:  64%|██████▍   | 38/59 [01:13<00:40,  1.94s/it]

Loss: 0.42474231123924255


Epoch 5:  66%|██████▌   | 39/59 [01:15<00:38,  1.94s/it]

Loss: 0.716934323310852


Epoch 5:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.5599002838134766


Epoch 5:  69%|██████▉   | 41/59 [01:19<00:35,  1.94s/it]

Loss: 0.8495256900787354


Epoch 5:  71%|███████   | 42/59 [01:21<00:33,  1.94s/it]

Loss: 0.4953065514564514


Epoch 5:  73%|███████▎  | 43/59 [01:23<00:31,  1.94s/it]

Loss: 0.2093251496553421


Epoch 5:  75%|███████▍  | 44/59 [01:25<00:29,  1.94s/it]

Loss: 0.74296635389328


Epoch 5:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 0.18136577308177948


Epoch 5:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.341116726398468


Epoch 5:  80%|███████▉  | 47/59 [01:31<00:23,  1.94s/it]

Loss: 0.2920476198196411


Epoch 5:  81%|████████▏ | 48/59 [01:33<00:21,  1.94s/it]

Loss: 0.48443296551704407


Epoch 5:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 0.4345252513885498


Epoch 5:  85%|████████▍ | 50/59 [01:37<00:17,  1.94s/it]

Loss: 0.40029531717300415


Epoch 5:  86%|████████▋ | 51/59 [01:39<00:15,  1.94s/it]

Loss: 0.5077462196350098


Epoch 5:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.3341512084007263


Epoch 5:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.4916689991950989


Epoch 5:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.5519655346870422


Epoch 5:  93%|█████████▎| 55/59 [01:46<00:07,  1.94s/it]

Loss: 0.7705011367797852


Epoch 5:  95%|█████████▍| 56/59 [01:48<00:05,  1.94s/it]

Loss: 0.19976504147052765


Epoch 5:  97%|█████████▋| 57/59 [01:50<00:03,  1.94s/it]

Loss: 0.34169572591781616


Epoch 5: 100%|██████████| 59/59 [01:52<00:00,  1.91s/it]

Loss: 0.267441987991333
Loss: 1.5864454507827759





Epoch 5 Validation Accuracy: 0.37438423645320196, F1-macro: 0.3435439105746951


Epoch 6:   2%|▏         | 1/59 [00:01<01:52,  1.93s/it]

Loss: 0.923936665058136


Epoch 6:   3%|▎         | 2/59 [00:03<01:50,  1.94s/it]

Loss: 1.1720353364944458


Epoch 6:   5%|▌         | 3/59 [00:05<01:48,  1.94s/it]

Loss: 0.4986383318901062


Epoch 6:   7%|▋         | 4/59 [00:07<01:46,  1.94s/it]

Loss: 0.199370875954628


Epoch 6:   8%|▊         | 5/59 [00:09<01:45,  1.95s/it]

Loss: 1.4781020879745483


Epoch 6:  10%|█         | 6/59 [00:11<01:43,  1.95s/it]

Loss: 0.5510052442550659


Epoch 6:  12%|█▏        | 7/59 [00:13<01:41,  1.95s/it]

Loss: 1.4479721784591675


Epoch 6:  14%|█▎        | 8/59 [00:15<01:39,  1.94s/it]

Loss: 1.5941067934036255


Epoch 6:  15%|█▌        | 9/59 [00:17<01:37,  1.94s/it]

Loss: 1.8389976024627686


Epoch 6:  17%|█▋        | 10/59 [00:19<01:35,  1.94s/it]

Loss: 1.2530131340026855


Epoch 6:  19%|█▊        | 11/59 [00:21<01:33,  1.95s/it]

Loss: 0.5465686917304993


Epoch 6:  20%|██        | 12/59 [00:23<01:31,  1.94s/it]

Loss: 1.62810218334198


Epoch 6:  22%|██▏       | 13/59 [00:25<01:29,  1.95s/it]

Loss: 0.8853925466537476


Epoch 6:  24%|██▎       | 14/59 [00:27<01:27,  1.94s/it]

Loss: 1.2505385875701904


Epoch 6:  25%|██▌       | 15/59 [00:29<01:25,  1.94s/it]

Loss: 1.641676902770996


Epoch 6:  27%|██▋       | 16/59 [00:31<01:23,  1.95s/it]

Loss: 2.6942076683044434


Epoch 6:  29%|██▉       | 17/59 [00:33<01:21,  1.95s/it]

Loss: 0.5322080850601196


Epoch 6:  31%|███       | 18/59 [00:35<01:19,  1.95s/it]

Loss: 0.7584151029586792


Epoch 6:  32%|███▏      | 19/59 [00:36<01:17,  1.94s/it]

Loss: 1.3471958637237549


Epoch 6:  34%|███▍      | 20/59 [00:38<01:15,  1.94s/it]

Loss: 0.0003191365103702992


Epoch 6:  36%|███▌      | 21/59 [00:40<01:13,  1.94s/it]

Loss: 0.7363303899765015


Epoch 6:  37%|███▋      | 22/59 [00:42<01:11,  1.94s/it]

Loss: 0.22448617219924927


Epoch 6:  39%|███▉      | 23/59 [00:44<01:09,  1.94s/it]

Loss: 1.0852855443954468


Epoch 6:  41%|████      | 24/59 [00:46<01:08,  1.94s/it]

Loss: 0.8315441608428955


Epoch 6:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 0.45306527614593506


Epoch 6:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 0.8831719160079956


Epoch 6:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 0.4236658811569214


Epoch 6:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.3660748600959778


Epoch 6:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 0.41833099722862244


Epoch 6:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.6173322796821594


Epoch 6:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.2404058426618576


Epoch 6:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.01869492046535015


Epoch 6:  56%|█████▌    | 33/59 [01:04<00:50,  1.95s/it]

Loss: 0.2863329350948334


Epoch 6:  58%|█████▊    | 34/59 [01:06<00:48,  1.95s/it]

Loss: 0.548003613948822


Epoch 6:  59%|█████▉    | 35/59 [01:08<00:46,  1.95s/it]

Loss: 0.3072357475757599


Epoch 6:  61%|██████    | 36/59 [01:10<00:44,  1.95s/it]

Loss: 0.2922150790691376


Epoch 6:  63%|██████▎   | 37/59 [01:11<00:42,  1.95s/it]

Loss: 0.18379312753677368


Epoch 6:  64%|██████▍   | 38/59 [01:13<00:40,  1.95s/it]

Loss: 0.2774859666824341


Epoch 6:  66%|██████▌   | 39/59 [01:15<00:38,  1.95s/it]

Loss: 0.37863895297050476


Epoch 6:  68%|██████▊   | 40/59 [01:17<00:37,  1.95s/it]

Loss: 0.47813746333122253


Epoch 6:  69%|██████▉   | 41/59 [01:19<00:35,  1.95s/it]

Loss: 0.41104960441589355


Epoch 6:  71%|███████   | 42/59 [01:21<00:33,  1.95s/it]

Loss: 0.13558447360992432


Epoch 6:  73%|███████▎  | 43/59 [01:23<00:31,  1.95s/it]

Loss: 0.01423755194991827


Epoch 6:  75%|███████▍  | 44/59 [01:25<00:29,  1.95s/it]

Loss: 1.0438950061798096


Epoch 6:  76%|███████▋  | 45/59 [01:27<00:27,  1.94s/it]

Loss: 1.924388197949156e-05


Epoch 6:  78%|███████▊  | 46/59 [01:29<00:25,  1.94s/it]

Loss: 0.5774123668670654


Epoch 6:  80%|███████▉  | 47/59 [01:31<00:23,  1.94s/it]

Loss: 0.4916823208332062


Epoch 6:  81%|████████▏ | 48/59 [01:33<00:21,  1.94s/it]

Loss: 0.34648048877716064


Epoch 6:  83%|████████▎ | 49/59 [01:35<00:19,  1.94s/it]

Loss: 0.013991400599479675


Epoch 6:  85%|████████▍ | 50/59 [01:37<00:17,  1.94s/it]

Loss: 0.32612013816833496


Epoch 6:  86%|████████▋ | 51/59 [01:39<00:15,  1.94s/it]

Loss: 0.7298940420150757


Epoch 6:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.11917449533939362


Epoch 6:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.1830943375825882


Epoch 6:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.5672485828399658


Epoch 6:  93%|█████████▎| 55/59 [01:46<00:07,  1.94s/it]

Loss: 0.6538280248641968


Epoch 6:  95%|█████████▍| 56/59 [01:48<00:05,  1.94s/it]

Loss: 0.8159271478652954


Epoch 6:  97%|█████████▋| 57/59 [01:50<00:03,  1.94s/it]

Loss: 0.01126077026128769


Epoch 6: 100%|██████████| 59/59 [01:52<00:00,  1.91s/it]

Loss: 0.9639294147491455
Loss: 0.04405852407217026





Epoch 6 Validation Accuracy: 0.6600985221674877, F1-macro: 0.5326794114703233


Epoch 7:   2%|▏         | 1/59 [00:01<01:52,  1.94s/it]

Loss: 0.30642133951187134


Epoch 7:   3%|▎         | 2/59 [00:03<01:50,  1.95s/it]

Loss: 0.9154235124588013


Epoch 7:   5%|▌         | 3/59 [00:05<01:48,  1.95s/it]

Loss: 0.444840669631958


Epoch 7:   7%|▋         | 4/59 [00:07<01:47,  1.95s/it]

Loss: 0.6291126012802124


Epoch 7:   8%|▊         | 5/59 [00:09<01:45,  1.95s/it]

Loss: 0.68336421251297


Epoch 7:  10%|█         | 6/59 [00:11<01:43,  1.95s/it]

Loss: 0.19966742396354675


Epoch 7:  12%|█▏        | 7/59 [00:13<01:41,  1.95s/it]

Loss: 0.7764373421669006


Epoch 7:  14%|█▎        | 8/59 [00:15<01:39,  1.95s/it]

Loss: 0.01308753713965416


Epoch 7:  15%|█▌        | 9/59 [00:17<01:37,  1.95s/it]

Loss: 0.24230290949344635


Epoch 7:  17%|█▋        | 10/59 [00:19<01:35,  1.95s/it]

Loss: 0.5396811962127686


Epoch 7:  19%|█▊        | 11/59 [00:21<01:33,  1.95s/it]

Loss: 0.4896252751350403


Epoch 7:  20%|██        | 12/59 [00:23<01:31,  1.95s/it]

Loss: 0.26128077507019043


Epoch 7:  22%|██▏       | 13/59 [00:25<01:29,  1.95s/it]

Loss: 0.7292418479919434


Epoch 7:  24%|██▎       | 14/59 [00:27<01:27,  1.95s/it]

Loss: 0.6473487615585327


Epoch 7:  25%|██▌       | 15/59 [00:29<01:25,  1.95s/it]

Loss: 0.07282998412847519


Epoch 7:  27%|██▋       | 16/59 [00:31<01:23,  1.95s/it]

Loss: 0.6076146960258484


Epoch 7:  29%|██▉       | 17/59 [00:33<01:21,  1.94s/it]

Loss: 0.20558200776576996


Epoch 7:  31%|███       | 18/59 [00:35<01:19,  1.95s/it]

Loss: 0.16116230189800262


Epoch 7:  32%|███▏      | 19/59 [00:36<01:17,  1.95s/it]

Loss: 0.2651943266391754


Epoch 7:  34%|███▍      | 20/59 [00:38<01:15,  1.95s/it]

Loss: 0.23243601620197296


Epoch 7:  36%|███▌      | 21/59 [00:40<01:13,  1.95s/it]

Loss: 0.3407917618751526


Epoch 7:  37%|███▋      | 22/59 [00:42<01:11,  1.95s/it]

Loss: 0.10912779718637466


Epoch 7:  39%|███▉      | 23/59 [00:44<01:10,  1.95s/it]

Loss: 0.661831259727478


Epoch 7:  41%|████      | 24/59 [00:46<01:08,  1.94s/it]

Loss: 0.054812654852867126


Epoch 7:  42%|████▏     | 25/59 [00:48<01:06,  1.94s/it]

Loss: 0.15837158262729645


Epoch 7:  44%|████▍     | 26/59 [00:50<01:04,  1.94s/it]

Loss: 0.46523821353912354


Epoch 7:  46%|████▌     | 27/59 [00:52<01:02,  1.94s/it]

Loss: 0.4262081980705261


Epoch 7:  47%|████▋     | 28/59 [00:54<01:00,  1.94s/it]

Loss: 0.7026607394218445


Epoch 7:  49%|████▉     | 29/59 [00:56<00:58,  1.94s/it]

Loss: 0.5308095216751099


Epoch 7:  51%|█████     | 30/59 [00:58<00:56,  1.94s/it]

Loss: 0.3464794158935547


Epoch 7:  53%|█████▎    | 31/59 [01:00<00:54,  1.94s/it]

Loss: 0.4854680895805359


Epoch 7:  54%|█████▍    | 32/59 [01:02<00:52,  1.94s/it]

Loss: 0.6691405177116394


Epoch 7:  56%|█████▌    | 33/59 [01:04<00:50,  1.94s/it]

Loss: 0.2997414767742157


Epoch 7:  58%|█████▊    | 34/59 [01:06<00:48,  1.94s/it]

Loss: 0.5936816930770874


Epoch 7:  59%|█████▉    | 35/59 [01:08<00:46,  1.94s/it]

Loss: 0.19212065637111664


Epoch 7:  61%|██████    | 36/59 [01:10<00:44,  1.94s/it]

Loss: 0.6070382595062256


Epoch 7:  63%|██████▎   | 37/59 [01:11<00:42,  1.94s/it]

Loss: 0.19252482056617737


Epoch 7:  64%|██████▍   | 38/59 [01:13<00:40,  1.94s/it]

Loss: 0.18448232114315033


Epoch 7:  66%|██████▌   | 39/59 [01:15<00:38,  1.94s/it]

Loss: 0.2625691294670105


Epoch 7:  68%|██████▊   | 40/59 [01:17<00:36,  1.94s/it]

Loss: 0.9740912914276123


Epoch 7:  69%|██████▉   | 41/59 [01:19<00:35,  1.95s/it]

Loss: 0.4748140573501587


Epoch 7:  71%|███████   | 42/59 [01:21<00:33,  1.95s/it]

Loss: 0.0008778877090662718


Epoch 7:  73%|███████▎  | 43/59 [01:23<00:31,  1.95s/it]

Loss: 0.32400697469711304


Epoch 7:  75%|███████▍  | 44/59 [01:25<00:29,  1.95s/it]

Loss: 0.15403135120868683


Epoch 7:  76%|███████▋  | 45/59 [01:27<00:27,  1.95s/it]

Loss: 0.07376908510923386


Epoch 7:  78%|███████▊  | 46/59 [01:29<00:25,  1.95s/it]

Loss: 0.47043150663375854


Epoch 7:  80%|███████▉  | 47/59 [01:31<00:23,  1.95s/it]

Loss: 0.8167802691459656


Epoch 7:  81%|████████▏ | 48/59 [01:33<00:21,  1.95s/it]

Loss: 0.1665925830602646


Epoch 7:  83%|████████▎ | 49/59 [01:35<00:19,  1.95s/it]

Loss: 0.03287206590175629


Epoch 7:  85%|████████▍ | 50/59 [01:37<00:17,  1.95s/it]

Loss: 0.7218868136405945


Epoch 7:  86%|████████▋ | 51/59 [01:39<00:15,  1.95s/it]

Loss: 0.11198055744171143


Epoch 7:  88%|████████▊ | 52/59 [01:41<00:13,  1.94s/it]

Loss: 0.9970483779907227


Epoch 7:  90%|████████▉ | 53/59 [01:43<00:11,  1.94s/it]

Loss: 0.2967126667499542


Epoch 7:  92%|█████████▏| 54/59 [01:45<00:09,  1.94s/it]

Loss: 0.2849569320678711


Epoch 7:  93%|█████████▎| 55/59 [01:46<00:07,  1.94s/it]

Loss: 0.09543262422084808


Epoch 7:  95%|█████████▍| 56/59 [01:48<00:05,  1.94s/it]

Loss: 0.4370007812976837


Epoch 7:  97%|█████████▋| 57/59 [01:50<00:03,  1.94s/it]

Loss: 0.26722967624664307


Epoch 7: 100%|██████████| 59/59 [01:52<00:00,  1.91s/it]

Loss: 0.3195909261703491
Loss: 6.556505240951083e-07





Epoch 7 Validation Accuracy: 0.9113300492610837, F1-macro: 0.47680412371134023


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model11, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.9313725490196079, F1-macro: 0.5446428571428572


In [None]:
current_type = 'mental filter'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 12. Personalization and Blaming

In [None]:
# Add labels
data1_1_labels = list(data1['personalization and blaming'][data1_1.index])
data3_1_labels = list(data3['personalization and blaming'][data3_1.index]) # data2에는 없음.

# Merging Data
data_encoded = data1_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

In [None]:
dataset_encoded = dataset_with_labels.data
dataset_labels = dataset_with_labels.labels

# Convert list of BatchEncoding objects to a 2D numpy array for SMOTE features
feature_vectors = []
max_len = 512 # Assuming this from tokenize_and_pad function
for encoded_dict in dataset_encoded:
    input_ids_np = encoded_dict['input_ids'].squeeze().cpu().numpy()
    attention_mask_np = encoded_dict['attention_mask'].squeeze().cpu().numpy()
    # Concatenate input_ids and attention_mask to create a single feature vector per sample
    feature_vectors.append(np.concatenate([input_ids_np, attention_mask_np]))

# 'temp1' becomes the feature matrix, 'temp2' becomes the labels vector
temp1 = np.array(feature_vectors)
temp2 = np.array(dataset_labels)

# Apply SMOTE
temp1_resampled, temp2_resampled = smote.fit_resample(temp1, temp2)

# Reconstruct data_encoded and data_labels from the resampled data
reconstructed_data_encoded = []
for resampled_features in temp1_resampled:
    # Split the combined feature vector back into input_ids and attention_mask parts
    reconstructed_input_ids = resampled_features[:max_len]
    reconstructed_attention_mask = resampled_features[max_len:]

    # Convert back to torch tensors and a dictionary format compatible with CustomDatasetWithLabels
    input_ids_tensor = torch.tensor(reconstructed_input_ids, dtype=torch.long).unsqueeze(0)
    attention_mask_tensor = torch.tensor(reconstructed_attention_mask, dtype=torch.long).unsqueeze(0)

    reconstructed_data_encoded.append({
        'input_ids': input_ids_tensor,
        'attention_mask': attention_mask_tensor
    })

# Update data_encoded and data_labels with the resampled data
dataset_encoded = reconstructed_data_encoded
dataset_labels = temp2_resampled.tolist()

train_dataset = CustomDatasetWithLabels(dataset_encoded, dataset_labels)

In [None]:
# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model12 = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model12.parameters(), lr=2e-4)

In [None]:
EPOCHS = 7

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model12.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model12(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model12, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   9%|▉         | 1/11 [00:01<00:19,  1.91s/it]

Loss: 5.8179216384887695


Epoch 1:  18%|█▊        | 2/11 [00:03<00:17,  1.93s/it]

Loss: 1.3913652896881104


Epoch 1:  27%|██▋       | 3/11 [00:05<00:15,  1.94s/it]

Loss: 3.5879926681518555


Epoch 1:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Loss: 1.5363249778747559


Epoch 1:  45%|████▌     | 5/11 [00:09<00:11,  1.94s/it]

Loss: 2.613809108734131


Epoch 1:  55%|█████▍    | 6/11 [00:11<00:09,  1.94s/it]

Loss: 1.1591774225234985


Epoch 1:  64%|██████▎   | 7/11 [00:13<00:07,  1.94s/it]

Loss: 2.605848789215088


Epoch 1:  73%|███████▎  | 8/11 [00:15<00:05,  1.94s/it]

Loss: 1.7183722257614136


Epoch 1:  82%|████████▏ | 9/11 [00:17<00:03,  1.94s/it]

Loss: 1.7972359657287598


Epoch 1:  91%|█████████ | 10/11 [00:19<00:01,  1.94s/it]

Loss: 1.221584439277649


Epoch 1: 100%|██████████| 11/11 [00:21<00:00,  1.91s/it]

Loss: 2.7589609622955322





Epoch 1 Validation Accuracy: 0.7441860465116279, F1-macro: 0.5015806111696522


Epoch 2:   9%|▉         | 1/11 [00:01<00:19,  1.93s/it]

Loss: 1.5927467346191406


Epoch 2:  18%|█▊        | 2/11 [00:03<00:17,  1.94s/it]

Loss: 0.8831673860549927


Epoch 2:  27%|██▋       | 3/11 [00:05<00:15,  1.94s/it]

Loss: 1.7355281114578247


Epoch 2:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Loss: 0.7904282808303833


Epoch 2:  45%|████▌     | 5/11 [00:09<00:11,  1.94s/it]

Loss: 1.021864414215088


Epoch 2:  55%|█████▍    | 6/11 [00:11<00:09,  1.94s/it]

Loss: 0.7564519047737122


Epoch 2:  64%|██████▎   | 7/11 [00:13<00:07,  1.94s/it]

Loss: 0.909855842590332


Epoch 2:  73%|███████▎  | 8/11 [00:15<00:05,  1.94s/it]

Loss: 1.908266305923462


Epoch 2:  82%|████████▏ | 9/11 [00:17<00:03,  1.94s/it]

Loss: 0.8786427974700928


Epoch 2:  91%|█████████ | 10/11 [00:19<00:01,  1.94s/it]

Loss: 1.2478625774383545


Epoch 2: 100%|██████████| 11/11 [00:21<00:00,  1.91s/it]

Loss: 0.8830359578132629





Epoch 2 Validation Accuracy: 0.7906976744186046, F1-macro: 0.44155844155844154


Epoch 3:   9%|▉         | 1/11 [00:01<00:19,  1.93s/it]

Loss: 1.2978553771972656


Epoch 3:  18%|█▊        | 2/11 [00:03<00:17,  1.94s/it]

Loss: 0.9733742475509644


Epoch 3:  27%|██▋       | 3/11 [00:05<00:15,  1.94s/it]

Loss: 0.38740283250808716


Epoch 3:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Loss: 1.253243088722229


Epoch 3:  45%|████▌     | 5/11 [00:09<00:11,  1.94s/it]

Loss: 0.623660683631897


Epoch 3:  55%|█████▍    | 6/11 [00:11<00:09,  1.94s/it]

Loss: 0.5902438759803772


Epoch 3:  64%|██████▎   | 7/11 [00:13<00:07,  1.94s/it]

Loss: 0.804965615272522


Epoch 3:  73%|███████▎  | 8/11 [00:15<00:05,  1.94s/it]

Loss: 0.8475781679153442


Epoch 3:  82%|████████▏ | 9/11 [00:17<00:03,  1.94s/it]

Loss: 0.6503203511238098


Epoch 3:  91%|█████████ | 10/11 [00:19<00:01,  1.94s/it]

Loss: 0.6979886889457703


Epoch 3: 100%|██████████| 11/11 [00:20<00:00,  1.91s/it]

Loss: 0.816682755947113





Epoch 3 Validation Accuracy: 0.7674418604651163, F1-macro: 0.6160714285714286


Epoch 4:   9%|▉         | 1/11 [00:01<00:19,  1.92s/it]

Loss: 0.4949750304222107


Epoch 4:  18%|█▊        | 2/11 [00:03<00:17,  1.93s/it]

Loss: 0.296988844871521


Epoch 4:  27%|██▋       | 3/11 [00:05<00:15,  1.93s/it]

Loss: 0.8189530968666077


Epoch 4:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Loss: 1.168360948562622


Epoch 4:  45%|████▌     | 5/11 [00:09<00:11,  1.93s/it]

Loss: 0.36656713485717773


Epoch 4:  55%|█████▍    | 6/11 [00:11<00:09,  1.93s/it]

Loss: 0.7062708139419556


Epoch 4:  64%|██████▎   | 7/11 [00:13<00:07,  1.93s/it]

Loss: 0.43459558486938477


Epoch 4:  73%|███████▎  | 8/11 [00:15<00:05,  1.93s/it]

Loss: 0.4632538855075836


Epoch 4:  82%|████████▏ | 9/11 [00:17<00:03,  1.94s/it]

Loss: 0.3732840418815613


Epoch 4:  91%|█████████ | 10/11 [00:19<00:01,  1.94s/it]

Loss: 0.18211200833320618


Epoch 4: 100%|██████████| 11/11 [00:20<00:00,  1.90s/it]

Loss: 0.24107906222343445





Epoch 4 Validation Accuracy: 0.7674418604651163, F1-macro: 0.6160714285714286


Epoch 5:   9%|▉         | 1/11 [00:01<00:19,  1.93s/it]

Loss: 0.5681658387184143


Epoch 5:  18%|█▊        | 2/11 [00:03<00:17,  1.94s/it]

Loss: 0.49522051215171814


Epoch 5:  27%|██▋       | 3/11 [00:05<00:15,  1.94s/it]

Loss: 0.19257867336273193


Epoch 5:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Loss: 0.18773972988128662


Epoch 5:  45%|████▌     | 5/11 [00:09<00:11,  1.94s/it]

Loss: 0.4481649696826935


Epoch 5:  55%|█████▍    | 6/11 [00:11<00:09,  1.94s/it]

Loss: 0.27545878291130066


Epoch 5:  64%|██████▎   | 7/11 [00:13<00:07,  1.94s/it]

Loss: 0.3098902404308319


Epoch 5:  73%|███████▎  | 8/11 [00:15<00:05,  1.94s/it]

Loss: 0.3561723828315735


Epoch 5:  82%|████████▏ | 9/11 [00:17<00:03,  1.94s/it]

Loss: 0.5099319815635681


Epoch 5:  91%|█████████ | 10/11 [00:19<00:01,  1.94s/it]

Loss: 0.2757761478424072


Epoch 5: 100%|██████████| 11/11 [00:21<00:00,  1.91s/it]

Loss: 0.08605227619409561





Epoch 5 Validation Accuracy: 0.8604651162790697, F1-macro: 0.5865384615384616


Epoch 6:   9%|▉         | 1/11 [00:01<00:19,  1.93s/it]

Loss: 0.2175627052783966


Epoch 6:  18%|█▊        | 2/11 [00:03<00:17,  1.94s/it]

Loss: 0.5893776416778564


Epoch 6:  27%|██▋       | 3/11 [00:05<00:15,  1.94s/it]

Loss: 0.5233111381530762


Epoch 6:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Loss: 0.2769496440887451


Epoch 6:  45%|████▌     | 5/11 [00:09<00:11,  1.94s/it]

Loss: 0.20226310193538666


Epoch 6:  55%|█████▍    | 6/11 [00:11<00:09,  1.94s/it]

Loss: 0.4792367219924927


Epoch 6:  64%|██████▎   | 7/11 [00:13<00:07,  1.94s/it]

Loss: 0.6520830392837524


Epoch 6:  73%|███████▎  | 8/11 [00:15<00:05,  1.94s/it]

Loss: 0.19100430607795715


Epoch 6:  82%|████████▏ | 9/11 [00:17<00:03,  1.94s/it]

Loss: 0.23042672872543335


Epoch 6:  91%|█████████ | 10/11 [00:19<00:01,  1.94s/it]

Loss: 0.47300389409065247


Epoch 6: 100%|██████████| 11/11 [00:21<00:00,  1.91s/it]

Loss: 0.3535333275794983





Epoch 6 Validation Accuracy: 0.8372093023255814, F1-macro: 0.45569620253164556


Epoch 7:   9%|▉         | 1/11 [00:01<00:19,  1.93s/it]

Loss: 0.713222861289978


Epoch 7:  18%|█▊        | 2/11 [00:03<00:17,  1.94s/it]

Loss: 0.5055133700370789


Epoch 7:  27%|██▋       | 3/11 [00:05<00:15,  1.94s/it]

Loss: 0.47697651386260986


Epoch 7:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Loss: 0.40517085790634155


Epoch 7:  45%|████▌     | 5/11 [00:09<00:11,  1.94s/it]

Loss: 0.27085983753204346


Epoch 7:  55%|█████▍    | 6/11 [00:11<00:09,  1.94s/it]

Loss: 0.21386951208114624


Epoch 7:  64%|██████▎   | 7/11 [00:13<00:07,  1.94s/it]

Loss: 0.3817305266857147


Epoch 7:  73%|███████▎  | 8/11 [00:15<00:05,  1.94s/it]

Loss: 0.0949937254190445


Epoch 7:  82%|████████▏ | 9/11 [00:17<00:03,  1.94s/it]

Loss: 0.1364101767539978


Epoch 7:  91%|█████████ | 10/11 [00:19<00:01,  1.94s/it]

Loss: 0.3498877286911011


Epoch 7: 100%|██████████| 11/11 [00:21<00:00,  1.91s/it]

Loss: 0.34534475207328796





Epoch 7 Validation Accuracy: 0.8837209302325582, F1-macro: 0.7734457323498419


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model12, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8888888888888888, F1-macro: 0.8


In [None]:
current_type = 'personalization and blaming'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# Accuracy

In [None]:
results_df

Unnamed: 0,distortion_type,test_accuracy,f1_macro
0,all-or-nothing thinking,0.872549,0.722536
1,comparing and despairing,0.92,0.479167
2,disqualifying the positive,0.990196,0.497537
3,emotional reasoning,0.887255,0.673963
4,fortune telling,0.882353,0.716667
5,labeling,0.79902,0.678196
6,magnification,0.843137,0.635307
7,mind reading,0.843137,0.647364
8,overgeneralizing,0.794118,0.485467
9,should statements,0.931373,0.631959


# Savings

In [None]:
import torch
import os

output_dir = "./saved_smote_model"
os.makedirs(output_dir, exist_ok=True)

# Save the entire Binary Classifier model's state_dict
torch.save(model1.state_dict(), os.path.join(output_dir, "model1_weights.pth"))
torch.save(model2.state_dict(), os.path.join(output_dir, "model2_weights.pth"))
torch.save(model3.state_dict(), os.path.join(output_dir, "model3_weights.pth"))
torch.save(model4.state_dict(), os.path.join(output_dir, "model4_weights.pth"))
torch.save(model5.state_dict(), os.path.join(output_dir, "model5_weights.pth"))
torch.save(model6.state_dict(), os.path.join(output_dir, "model6_weights.pth"))
torch.save(model7.state_dict(), os.path.join(output_dir, "model7_weights.pth"))
torch.save(model8.state_dict(), os.path.join(output_dir, "model8_weights.pth"))
torch.save(model9.state_dict(), os.path.join(output_dir, "model9_weights.pth"))
torch.save(model10.state_dict(), os.path.join(output_dir, "model10_weights.pth"))
torch.save(model11.state_dict(), os.path.join(output_dir, "model11_weights.pth"))
torch.save(model12.state_dict(), os.path.join(output_dir, "model12_weights.pth"))

print(f"Model saved to {output_dir}")

Model saved to ./saved_smote_model
