Note: ChatGPT was used throughout the creation of this notebook to better understand libraries and to find ways to optimize code.

In [None]:
! pip install transformers datasets
! pip install evaluate

In [131]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from google.colab import drive
import re

In [132]:
from datasets import (Dataset,
                      DatasetDict)

from transformers import (DataCollatorWithPadding,
                          AutoTokenizer,
                          AutoModel,
                          AutoConfig,
                          get_scheduler)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW

from tqdm.auto import tqdm

from evaluate import load

from collections import Counter

In [133]:
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [134]:
%cd /content/drive/MyDrive/fruits

/content/drive/.shortcut-targets-by-id/1-AdTPjFiouJTFk4KTDPVEUDC8lwsbPDJ/fruits


In [135]:
%pwd

'/content/drive/.shortcut-targets-by-id/1-AdTPjFiouJTFk4KTDPVEUDC8lwsbPDJ/fruits'

In [136]:
checkpoint_dir = '/content/drive/.shortcut-targets-by-id/1-AdTPjFiouJTFk4KTDPVEUDC8lwsbPDJ/fruits'
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
model_name = 'vinai/bertweet-base'

# BERTweet base model
config = AutoConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
bertweet_model = AutoModel.from_pretrained(model_name, config=config)

# BERTweet tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

In [138]:
raw_df = pd.read_csv('/content/one_post_per_row.csv')

print(f'Total number of samples before preprocessing: {raw_df.shape[0]}\n')
raw_df.head(5)

Total number of samples before preprocessing: 422845



Unnamed: 0,type,post,type_index
0,INFJ,'http://www.youtube.com/watch?v=qsXHcwe3krw,1
1,INFJ,http://41.media.tumblr.com/tumblr_lfouy03PMA1q...,1
2,INFJ,enfp and intj moments https://www.youtube.com...,1
3,INFJ,What has been the most life-changing experienc...,1
4,INFJ,http://www.youtube.com/watch?v=vXZeYwwRDw8 h...,1


In [139]:
# mbti_encoding = {
#     'INFP': 0, 'INFJ': 1, 'INTP': 2, 'INTJ': 3,
#     'ISFP': 4, 'ISFJ': 5, 'ISTP': 6, 'ISTJ': 7,
#     'ENFP': 8, 'ENFJ': 9, 'ENTP': 10, 'ENTJ': 11,
#     'ESFP': 12, 'ESFJ': 13, 'ESTP': 14, 'ESTJ': 15
# }

In [140]:
# Rename columns and drop any samples with missing values
df = raw_df.rename(columns={'post': 'sentence', 'type_index': 'label_int', 'type': 'label_str'})
# df = df[['sentence', 'label_int', 'label_str']]
df = df.dropna(subset=['sentence'])

# Normalize sentence (e.g. replace links with special token)
df['sentence'] = df['sentence'].apply(tokenizer.normalizeTweet)

# Discard samples where sentence is less than 4 words
df = df[df['sentence'].str.split().str.len() > 3].reset_index(drop=True)

df.head(5)

Unnamed: 0,label_str,sentence,label_int
0,INFJ,enfp and intj moments HTTPURL sportscenter not...,1
1,INFJ,What has been the most life-changing experienc...,1
2,INFJ,HTTPURL HTTPURL On repeat for most of today .,1
3,INFJ,May the PerC Experience immerse you .,1
4,INFJ,The last thing my INFJ friend posted on his fa...,1


In [141]:
# Check if the MBTI type is present in the given text
def contains_mbti(mbti, text):
    re_pattern = r'\b' + re.escape(mbti) + r's?\b'
    return bool(re.search(re_pattern, text, flags=re.IGNORECASE))

# Remove samples that contain their MBTI type in the sentence
mask = df.apply(lambda x: contains_mbti(x['label_str'], x['sentence']), axis=1)
df = df[~mask]

print(f'Total number of samples after preprocessing: {df.shape[0]}')
df.head(5)

Total number of samples after preprocessing: 358230


Unnamed: 0,label_str,sentence,label_int
0,INFJ,enfp and intj moments HTTPURL sportscenter not...,1
1,INFJ,What has been the most life-changing experienc...,1
2,INFJ,HTTPURL HTTPURL On repeat for most of today .,1
3,INFJ,May the PerC Experience immerse you .,1
5,INFJ,Hello ENFJ 7 . Sorry to hear of your distress ...,1


In [142]:
df_1 = df[df['label_str'].str[0] == 'I']
df_2 = df[df['label_str'].str[1] == 'N']
df_3 = df[df['label_str'].str[2] == 'F']
df_4 = df[df['label_str'].str[3] == 'P']

prop_I = len(df_1)/len(df)
prop_N = len(df_2)/len(df)
prop_F = len(df_3)/len(df)
prop_P = len(df_4)/len(df)

prop_E = 1 - prop_I
prop_S = 1 - prop_N
prop_T = 1 - prop_F
prop_J = 1 - prop_P

print(f'Proportion I: {round(prop_I, 2)}')
print(f'Proportion N: {round(prop_N, 2)}')
print(f'Proportion F: {round(prop_F, 2)}')
print(f'Proportion P: {round(prop_P, 2)}')

Proportion I: 0.77
Proportion N: 0.86
Proportion F: 0.54
Proportion P: 0.61


In [143]:
dataset = Dataset.from_pandas(df)

# 80% train
train_testvalid = dataset.train_test_split(test_size=0.2, seed=15, shuffle=True)

# 10% validation and 10% test
test_valid = train_testvalid['test'].train_test_split(test_size=0.5,seed=15, shuffle=True)

dataset = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})

dataset

DatasetDict({
    train: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__'],
        num_rows: 286584
    })
    test: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__'],
        num_rows: 35823
    })
    valid: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__'],
        num_rows: 35823
    })
})

In [144]:
tokenizer

BertweetTokenizer(name_or_path='vinai/bertweet-base', vocab_size=64000, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	64000: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [145]:
def tokenize(batch):
    # Tokenize posts so that they can be inputs for BERTweet
    return tokenizer(batch['sentence'], padding=True, truncation=True, max_length=128, return_tensors="pt")

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset

Map:   0%|          | 0/286584 [00:00<?, ? examples/s]

Map:   0%|          | 0/35823 [00:00<?, ? examples/s]

Map:   0%|          | 0/35823 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 286584
    })
    test: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 35823
    })
    valid: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 35823
    })
})

In [146]:
def encode_mbti_label(label):
    encoding = {
        'E': 0, 'I': 1,
        'S': 0, 'N': 1,
        'T': 0, 'F': 1,
        'J': 0, 'P': 1
    }
    return [encoding[c] for c in label]

# Create 4-dimensional vector representation of MBTI type
tokenized_dataset = tokenized_dataset.map(lambda x: {'label_vec': encode_mbti_label(x['label_str'])})

Map:   0%|          | 0/286584 [00:00<?, ? examples/s]

Map:   0%|          | 0/35823 [00:00<?, ? examples/s]

Map:   0%|          | 0/35823 [00:00<?, ? examples/s]

In [147]:
print('Example of data sample:')
print(f"  {tokenized_dataset['train'][7]['sentence']}")
print(f"  {tokenized_dataset['train'][7]['label_str']}")
print(f"  {tokenized_dataset['train'][7]['label_int']}")
print(f"  {tokenized_dataset['train'][7]['label_vec']}")

Example of data sample:
  My avatar is a picture of Red Rocks , which is a concert amphitheater , tourist venue with a few hiking trails scattered around it . It 's located in Morrison , Colorado which is around 45 minutes away ...
  INTJ
  3
  [1, 1, 0, 0]


In [148]:
# Convert specified columns into tensors
tokenized_dataset.set_format('torch',columns=['input_ids', 'attention_mask', 'label_int', 'label_vec'])

# Create data collator object used for batching / data loading
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [149]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask', 'label_vec'],
        num_rows: 286584
    })
    test: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask', 'label_vec'],
        num_rows: 35823
    })
    valid: Dataset({
        features: ['label_str', 'sentence', 'label_int', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask', 'label_vec'],
        num_rows: 35823
    })
})

In [150]:
data_collator

DataCollatorWithPadding(tokenizer=BertweetTokenizer(name_or_path='vinai/bertweet-base', vocab_size=64000, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	64000: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}, padding=True, max_length=None,

In [151]:
# Calculate positive (1) and negative (0) class counts for each MBTI dimension
pos_counts = tokenized_dataset['train']['label_vec'].sum(dim=0)
total_counts = len(tokenized_dataset['train'])

# Calculate weights for loss function
weights = []
for pos_count in pos_counts:
    neg_count = total_counts - pos_count
    weight = neg_count / pos_count
    weights.append(weight.item())

# Convert list to tensor
weights = torch.tensor(weights)

weights

tensor([0.2998, 0.1578, 0.8487, 0.6507])

In [152]:
train_dataloader = DataLoader(
    tokenized_dataset['train'], shuffle=True, batch_size=32, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_dataset['valid'], batch_size=32, collate_fn=data_collator
)
test_dataloader = DataLoader(
    tokenized_dataset['test'], batch_size=32, collate_fn=data_collator
)

# 16-Way Classification Using 2 Linear Layers With ReLU In-Between

In [153]:
class BERTweetClassifier(nn.Module):
    def __init__(self, bertweet_model, num_classes=16):
        super(BERTweetClassifier, self).__init__()
        self.bertweet_model = bertweet_model

        self.num_classes = num_classes

        # Hidden size of the BERTweet model
        self.hidden_size = bertweet_model.config.hidden_size

        # Regularization
        self.dropout = nn.Dropout(0.1)

        # Linear layers
        self.linear1 = nn.Linear(self.hidden_size, self.hidden_size//2)
        self.linear2 = nn.Linear(self.hidden_size//2, num_classes)

        # Freeze BERT model parameters
        for param in self.bertweet_model.parameters():
          param.requires_grad = False

        # Unfreeze the last 2 transformer layers
        for param in self.bertweet_model.encoder.layer[-2:].parameters():
            param.requires_grad = True


    def forward(self, input_ids, attention_mask=None):
        # Get the outputs from the BERTweet model
        outputs = self.bertweet_model(input_ids=input_ids, attention_mask=attention_mask)

        last_hidden_state = outputs[0]

        attentions = outputs[1]

        hidden_states = outputs[2]



        # Take the [CLS] token representation (first token) as the pooled output
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)

        # Pass it through the classification layer
        out = F.relu(self.linear1(pooled_output))
        logits = self.linear2(out)

        return logits, attentions, hidden_states

In [154]:
model = BERTweetClassifier(bertweet_model)
optimizer = AdamW(model.parameters(), lr=5e-5)

In [156]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
device

device(type='cuda')

In [157]:
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [None]:
model.train()

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label_int']

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        logits, attentions, hidden_states = outputs[0], outputs[1], outputs[2]

        # Compute the loss using cross-entropy
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(logits, labels)

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_extra_layer_{epoch}.pt")
    torch.save(model.state_dict(), checkpoint_path)
    torch.save(optimizer.state_dict(), checkpoint_path + "optim")
    print(f"Checkpoint saved at: {checkpoint_path}")


In [158]:
# # Restore the best validation checkpoint
# checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_2.pt")
# model.load_state_dict(torch.load(checkpoint_path))
# optimizer.load_state_dict(torch.load(checkpoint_path + "optim"))

In [None]:
accuracy_metric = load("accuracy")
precision_metric = load("precision")
recall_metric = load("recall")
f1_metric = load("f1")

model.eval()

progress_bar = tqdm(range(len(eval_dataloader)))

for batch in eval_dataloader:
  batch = {k: v.to(device) for k, v in batch.items()}
  input_ids = batch['input_ids']
  attention_mask = batch['attention_mask']
  labels = batch['label_int']

  with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits, _, _ = outputs

  predictions = torch.argmax(logits, dim=-1)

  accuracy_metric.add_batch(predictions=predictions, references=labels)
  precision_metric.add_batch(predictions=predictions, references=labels)
  recall_metric.add_batch(predictions=predictions, references=labels)
  f1_metric.add_batch(predictions=predictions, references=labels)

  progress_bar.update(1)

accuracy = accuracy_metric.compute()
precision = precision_metric.compute(average='weighted')
recall = recall_metric.compute(average='weighted')
f1 = f1_metric.compute(average='weighted')

print("Accuracy:\n", accuracy['accuracy'])
print("Precision:\n", precision['precision'])
print("Recall:\n", recall['recall'])
print("F1 Score:\n", f1['f1'])

  0%|          | 0/1318 [00:00<?, ?it/s]

In [155]:
# mbti_encoding = {
#     'INFP': 0, 'INFJ': 1, 'INTP': 2, 'INTJ': 3,
#     'ISFP': 4, 'ISFJ': 5, 'ISTP': 6, 'ISTJ': 7,
#     'ENFP': 8, 'ENFJ': 9, 'ENTP': 10, 'ENTJ': 11,
#     'ESFP': 12, 'ESFJ': 13, 'ESTP': 14, 'ESTJ': 15
# }

In [None]:
# Find the distribution of the 16 MBTI types in the validation data
labels = [sample['label_int'].item() for sample in tokenized_dataset['valid']]
label_counts = Counter(labels)
label_proportions = {label: round(count / len(labels), 2) for label, count in label_counts.items()}

# Printing label counts in order
print("Label Counts:")
for num, count in sorted(label_counts.items()):
    print(f"{num} -> {count}")

# Printing label proportions in order
print("\nLabel Proportions:")
for num, proportion in sorted(label_proportions.items()):
    print(f"{num} -> {proportion}")

Label Counts:
0 -> 8993
1 -> 7356
2 -> 6259
3 -> 5275
4 -> 1233
5 -> 832
6 -> 1630
7 -> 983
8 -> 3216
9 -> 939
10 -> 3316
11 -> 1065
12 -> 228
13 -> 210
14 -> 446
15 -> 194

Label Proportions:
0 -> 0.21
1 -> 0.17
2 -> 0.15
3 -> 0.13
4 -> 0.03
5 -> 0.02
6 -> 0.04
7 -> 0.02
8 -> 0.08
9 -> 0.02
10 -> 0.08
11 -> 0.03
12 -> 0.01
13 -> 0.0
14 -> 0.01
15 -> 0.0


# 4-Head Binary Classification

In [159]:
class BERTweetBinaryClassifier(nn.Module):
    def __init__(self, bertweet_model):
        super(BERTweetBinaryClassifier, self).__init__()

        self.bertweet_model = bertweet_model

        # Hidden size of the BERTweet model
        self.hidden_size = bertweet_model.config.hidden_size

        # Regularization
        self.dropout = nn.Dropout(0.1)

        # First linear layer
        self.linear1 = nn.Linear(self.hidden_size, 24)

        # Linear classification heads
        self.head1 = nn.Linear(24, 1)
        self.head2 = nn.Linear(24, 1)
        self.head3 = nn.Linear(24, 1)
        self.head4 = nn.Linear(24, 1)

        # Freeze all BERTweet parameters
        for param in self.bertweet_model.parameters():
            param.requires_grad = False

        # # Unfreeze only biases and normalization layers in the last BERTweet layers
        # last_layers = self.bertweet_model.encoder.layer[-1:]
        # for layer in last_layers:
        #     for name, param in layer.named_parameters():
        #         if "LayerNorm" in name or "bias" in name:
        #             param.requires_grad = True

        # # Optional: Verify which parameters are trainable
        # for name, param in self.bertweet_model.named_parameters():
        #     print(f"{name}: requires_grad = {param.requires_grad}")

    def forward(self, input_ids, attention_mask=None):
        # Get the outputs from the BERTweet model
        outputs = self.bertweet_model(input_ids=input_ids, attention_mask=attention_mask)

        last_hidden_state = outputs[0]
        attentions = outputs[1]
        hidden_states = outputs[2]

        # Take the [CLS] token representation (first token) as the pooled output
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)

        # Pass it through the classification layer
        hidden_layer = F.relu(self.linear1(pooled_output))
        logit1 = self.head1(hidden_layer).view(-1)
        logit2 = self.head2(hidden_layer).view(-1)
        logit3 = self.head3(hidden_layer).view(-1)
        logit4 = self.head4(hidden_layer).view(-1)

        # Return a tuple of logits
        return (logit1, logit2, logit3, logit4, attentions, hidden_states)

In [160]:
model_bin = BERTweetBinaryClassifier(bertweet_model)
optimizer_bin = AdamW(model_bin.parameters(), lr=3e-4)

In [161]:
# # Restore the best validation checkpoint
# checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_4head_weighted_loss_0.1Dropout_7.pt")
# model_bin.load_state_dict(torch.load(checkpoint_path))
# optimizer_bin.load_state_dict(torch.load(checkpoint_path + ".optim"))

In [162]:
# Count number of parameters in overall model
def count_trainable_parameters(model):
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

num_trainable_params = count_trainable_parameters(model_bin)
print(f"Number of trainable parameters: {num_trainable_params}")

Number of trainable parameters: 18556


In [163]:
model_bin

BERTweetBinaryClassifier(
  (bertweet_model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(64001, 768, padding_idx=1)
      (position_embeddings): Embedding(130, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              

In [164]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_bin.to(device)
device

device(type='cuda')

In [165]:
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)

lr_scheduler_bin = get_scheduler(
    name="linear", optimizer=optimizer_bin,
    num_warmup_steps=500, num_training_steps=num_training_steps
)

In [None]:
model_bin.train()

progress_bar = tqdm(range(num_training_steps))

loss_head1 = nn.BCEWithLogitsLoss(pos_weight=weights[0])
loss_head2 = nn.BCEWithLogitsLoss(pos_weight=weights[1])
loss_head3 = nn.BCEWithLogitsLoss(pos_weight=weights[2])
loss_head4 = nn.BCEWithLogitsLoss(pos_weight=weights[3])

for epoch in range(num_epochs):
    total_loss = 0

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label_vec']

        # Run forward pass
        outputs = model_bin(input_ids=input_ids, attention_mask=attention_mask)

        # Extract logits
        logit1, logit2, logit3, logit4 = outputs[0], outputs[1], outputs[2], outputs[3]

        # Compute individual binary cross-entropy losses
        loss1 = loss_head1(logit1, labels[:,0].float())
        loss2 = loss_head2(logit2, labels[:,1].float())
        loss3 = loss_head3(logit3, labels[:,2].float())
        loss4 = loss_head4(logit4, labels[:,3].float())

        # Total loss as the sum of individual losses
        loss = loss1 + loss2 + loss3 + loss4
        total_loss += loss.item()

        # Backpropagation
        optimizer_bin.zero_grad()
        loss.backward()
        optimizer_bin.step()

        lr_scheduler_bin.step()
        progress_bar.update(1)

    # Logging and checkpointing
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}")

    # checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_4head_weighted_loss_0.1Dropout_no_bias_{epoch}.pt")
    # torch.save(model_bin.state_dict(), checkpoint_path)
    # torch.save(optimizer_bin.state_dict(), checkpoint_path + ".optim")
    # print(f"Checkpoint saved at: {checkpoint_path}")

In [167]:
# # Restore the best validation checkpoint
# checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_4head_weighted_loss_0.1Dropout_5.pt")
# model_bin.load_state_dict(torch.load(checkpoint_path))
# optimizer_bin.load_state_dict(torch.load(checkpoint_path + ".optim"))

### Calculate evaluation metrics for each component

In [None]:
# Initialize metrics for each task
task_metrics = {
    'task1': {'accuracy': load('accuracy'), 'precision': load('precision'), 'recall': load('recall'), 'f1': load('f1')},
    'task2': {'accuracy': load('accuracy'), 'precision': load('precision'), 'recall': load('recall'), 'f1': load('f1')},
    'task3': {'accuracy': load('accuracy'), 'precision': load('precision'), 'recall': load('recall'), 'f1': load('f1')},
    'task4': {'accuracy': load('accuracy'), 'precision': load('precision'), 'recall': load('recall'), 'f1': load('f1')}
}

model_bin.eval()
progress_bar = tqdm(range(len(eval_dataloader)))

for batch in eval_dataloader:
  batch = {k: v.to(device) for k, v in batch.items()}
  input_ids = batch['input_ids']
  attention_mask = batch['attention_mask']
  labels = batch['label_vec']

  with torch.no_grad():
    outputs = model_bin(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs[:4]

  # Convert logits into binary predictions for each task
  # Decision threshold of 0.5
  predictions = [torch.sigmoid(logit) >= 0.5 for logit in logits]

  # Update metrics for each task
  for i, pred in enumerate(predictions):
      task_key = f'task{i+1}'
      pred_np = pred.int().cpu().numpy()
      labels_np = labels[:, i].cpu().numpy()

      task_metrics[task_key]['accuracy'].add_batch(predictions=pred_np, references=labels_np)
      task_metrics[task_key]['precision'].add_batch(predictions=pred_np, references=labels_np)
      task_metrics[task_key]['recall'].add_batch(predictions=pred_np, references=labels_np)
      task_metrics[task_key]['f1'].add_batch(predictions=pred_np, references=labels_np)

  progress_bar.update(1)


# Compute and print each metric for each task
for i in range(4):
  task_key = f'task{i+1}'
  print(f"Metrics for {task_key}:")
  accuracy = task_metrics[task_key]['accuracy'].compute()
  precision = task_metrics[task_key]['precision'].compute()
  recall = task_metrics[task_key]['recall'].compute()
  f1 = task_metrics[task_key]['f1'].compute()
  print(f"  Accuracy: {accuracy['accuracy'] * 100:.2f}%")
  print(f"  Precision: {precision['precision'] * 100:.2f}%")
  print(f"  Recall: {recall['recall'] * 100:.2f}%")
  print(f"  F1 Score: {f1['f1'] * 100:.2f}%")

### Store results of predictions (letters) into dataframes

In [168]:
def decode_mbti_label(binary_vector):
  mbti_map = {
      0: ('E', 'I'),
      1: ('S', 'N'),
      2: ('T', 'F'),
      3: ('J', 'P')
  }

  # Decode each element in the binary vector to the corresponding MBTI letter
  mbti_type = ''.join([mbti_map[i][bit] for i, bit in enumerate(binary_vector)])
  return mbti_type

In [None]:
# Initialize lists to store results
predicted_types_1 = []
labeled_types_1 = []
predicted_types_2 = []
labeled_types_2 = []
predicted_types_3 = []
labeled_types_3 = []
predicted_types_4 = []
labeled_types_4 = []

model_bin.eval()
progress_bar = tqdm(range(len(eval_dataloader)))

for batch in eval_dataloader:
  batch = {k: v.to(device) for k, v in batch.items()}
  input_ids = batch['input_ids']
  attention_mask = batch['attention_mask']
  labels = batch['label_vec']

  with torch.no_grad():
    outputs = model_bin(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs[:4]

  # Convert logits into binary predictions for each task
  predictions = [torch.sigmoid(logit) >= 0.5 for logit in logits]

  # Convert predictions and labels to MBTI types
  for j in range(labels.shape[0]):
      pred_vector = [p[j].item() for p in predictions]
      label_vector = labels[j].tolist()

      predicted_type = decode_mbti_label(pred_vector)
      labeled_type = decode_mbti_label(label_vector)

      predicted_types_1.append(predicted_type[0])
      labeled_types_1.append(labeled_type[0])
      predicted_types_2.append(predicted_type[1])
      labeled_types_2.append(labeled_type[1])
      predicted_types_3.append(predicted_type[2])
      labeled_types_3.append(labeled_type[2])
      predicted_types_4.append(predicted_type[3])
      labeled_types_4.append(labeled_type[3])

  progress_bar.update(1)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_1,
    'Labeled Type': labeled_types_1
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_1.csv", index=False)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_2,
    'Labeled Type': labeled_types_2
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_2.csv", index=False)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_3,
    'Labeled Type': labeled_types_3
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_3.csv", index=False)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_4,
    'Labeled Type': labeled_types_4
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_4.csv", index=False)

### Store results of predictions (ints) into dataframes

In [104]:
# Initialize lists to store results
predicted_types_1 = []
labeled_types_1 = []
predicted_types_2 = []
labeled_types_2 = []
predicted_types_3 = []
labeled_types_3 = []
predicted_types_4 = []
labeled_types_4 = []

model_bin.eval()
progress_bar = tqdm(range(len(eval_dataloader)))

for batch in eval_dataloader:
  batch = {k: v.to(device) for k, v in batch.items()}
  input_ids = batch['input_ids']
  attention_mask = batch['attention_mask']
  labels = batch['label_vec']

  with torch.no_grad():
    outputs = model_bin(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs[:4]

  # Convert logits into probabilities for each task
  predictions = [torch.sigmoid(logit) for logit in logits]

  for j in range(labels.shape[0]):
      pred_vector = [p[j].item() for p in predictions]
      label_vector = labels[j].tolist()

      predicted_types_1.append(pred_vector[0])
      labeled_types_1.append(label_vector[0])
      predicted_types_2.append(pred_vector[1])
      labeled_types_2.append(label_vector[1])
      predicted_types_3.append(pred_vector[2])
      labeled_types_3.append(label_vector[2])
      predicted_types_4.append(pred_vector[3])
      labeled_types_4.append(label_vector[3])

  progress_bar.update(1)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_1,
    'Labeled Type': labeled_types_1
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_1.csv", index=False)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_2,
    'Labeled Type': labeled_types_2
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_2.csv", index=False)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_3,
    'Labeled Type': labeled_types_3
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_3.csv", index=False)


str_type_df = pd.DataFrame({
    'Predicted Type': predicted_types_4,
    'Labeled Type': labeled_types_4
})
str_type_df.to_csv(f"{checkpoint_dir}/prob_df_4.csv", index=False)

  0%|          | 0/1120 [00:00<?, ?it/s]