In [22]:
import os
import pickle

from collections import Counter

# import pandas as pd
from sklearn.metrics import classification_report

import numpy as np
import torch
import torch.nn as nn

import transformers
from transformers import Trainer
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorWithPadding

import datasets
from datasets import Dataset
from datasets import ClassLabel
from datasets import load_metric

import evaluate # type: ignore

## Global variables

In [29]:
# DATA_FOLDER = '/notebooks/Data/bert_sequence_classification'
DATA_FILE = 'emotion_analysis_comics/bert/datasets/comics_dataset_complete.pt'
RESULTS_FOLDER = 'emotion_analysis_comics/bert/outputs'

In [30]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [31]:
device

device(type='cuda')

## Load data

In [32]:
dataset = torch.load(DATA_FILE)

  dataset = torch.load(DATA_FILE)


In [33]:
dataset

DatasetDict({
    train: Dataset({
        features: ['file_name', 'page_nr', 'panel_nr', 'balloon_nr', 'utterance', 'raw_annotation', 'raw_emotion', 'raw_speaker_id', 'emotion', 'speaker_id', 'split', 'utterance_emotion', 'unique_emotion'],
        num_rows: 5075
    })
    test: Dataset({
        features: ['file_name', 'page_nr', 'panel_nr', 'balloon_nr', 'utterance', 'raw_annotation', 'raw_emotion', 'raw_speaker_id', 'emotion', 'speaker_id', 'split', 'utterance_emotion', 'unique_emotion'],
        num_rows: 1097
    })
    validation: Dataset({
        features: ['file_name', 'page_nr', 'panel_nr', 'balloon_nr', 'utterance', 'raw_annotation', 'raw_emotion', 'raw_speaker_id', 'emotion', 'speaker_id', 'split', 'utterance_emotion', 'unique_emotion'],
        num_rows: 564
    })
})

In [34]:
l = dataset['test']['unique_emotion']

In [35]:
len(l)

1097

In [36]:
dataset['train']['utterance'][230]

"I… didn't expect this. How should I play it?"

In [37]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")



In [38]:
label_names = list(set(dataset['train']['unique_emotion']))
label_nb = len(label_names)
labels = ClassLabel(num_classes=label_nb, names=label_names)

In [39]:
labels

ClassLabel(names=['anger', 'fear', 'neutral', 'surprise', 'joy', 'disgust', 'sadness'], id=None)

In [70]:
list_labels = dataset['test']['unique_emotion']

In [71]:
len(list_labels)

1097

In [72]:
counter = Counter(list_labels)

In [73]:
counter

Counter({'anger': 321,
         'fear': 212,
         'joy': 195,
         'surprise': 177,
         'sadness': 129,
         'neutral': 42,
         'disgust': 21})

In [74]:
class_samples = []

for cl in labels.names:
    class_samples.append(counter[cl])

In [75]:
class_samples

[321, 212, 42, 177, 195, 21, 129]

In [40]:
# labels.num_classes

In [41]:
def tokenize(batch):
    tokens = tokenizer(batch['utterance'], truncation=True, padding=True, max_length=512)
    tokens['labels'] = labels.str2int(batch['unique_emotion'])
    return tokens

# this is just the text. if the results are nice, check transfer with text + topic 

In [42]:
dataset = dataset.map(tokenize, batched=True)

Map:   0%|          | 0/5075 [00:00<?, ? examples/s]

Map:   0%|          | 0/1097 [00:00<?, ? examples/s]

Map:   0%|          | 0/564 [00:00<?, ? examples/s]

In [43]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

In [44]:
dataset

DatasetDict({
    train: Dataset({
        features: ['file_name', 'page_nr', 'panel_nr', 'balloon_nr', 'utterance', 'raw_annotation', 'raw_emotion', 'raw_speaker_id', 'emotion', 'speaker_id', 'split', 'utterance_emotion', 'unique_emotion', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5075
    })
    test: Dataset({
        features: ['file_name', 'page_nr', 'panel_nr', 'balloon_nr', 'utterance', 'raw_annotation', 'raw_emotion', 'raw_speaker_id', 'emotion', 'speaker_id', 'split', 'utterance_emotion', 'unique_emotion', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1097
    })
    validation: Dataset({
        features: ['file_name', 'page_nr', 'panel_nr', 'balloon_nr', 'utterance', 'raw_annotation', 'raw_emotion', 'raw_speaker_id', 'emotion', 'speaker_id', 'split', 'utterance_emotion', 'unique_emotion', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 564
    })
})

In [45]:
train_dataset = dataset['train'].shuffle(seed=42)
test_dataset = dataset['test'].shuffle(seed=42)
val_dataset = dataset['validation'].shuffle(seed=42)

In [46]:
dataset_d = {}
dataset_d['train'] = train_dataset
dataset_d['test'] = test_dataset
dataset_d['val'] = val_dataset

In [47]:
test_dataset

Dataset({
    features: ['file_name', 'page_nr', 'panel_nr', 'balloon_nr', 'utterance', 'raw_annotation', 'raw_emotion', 'raw_speaker_id', 'emotion', 'speaker_id', 'split', 'utterance_emotion', 'unique_emotion', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 1097
})

In [48]:
tokenizer.decode(dataset['train'][1945]['input_ids'])

"[CLS] @ shit … are we getting interference on the mics? better not be screwing with the feed. i don't want to miss anything. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

In [49]:
# sanity check
set(dataset_d['train']['split'])

{'TRAIN'}

In [50]:
# sanity check
set(dataset_d['val']['split'])

{'TRAIN'}

In [51]:
# sanity check
set(dataset_d['test']['split'])

{'TEST'}

In [52]:
# global variables
NUM_LABELS = label_nb
BATCH_SIZE = 256
NB_EPOCHS = 100

In [53]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=NUM_LABELS, device_map='cuda')
#model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
model.device

device(type='cuda', index=0)

In [55]:
# https://huggingface.co/transformers/main_classes/trainer.html
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get('logits')
        loss_fct = nn.CrossEntropyLoss()#(weight=class_weights)
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [56]:
#metric = load_metric('f1', trust_remote_code=True)
metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    return metric.compute(predictions=predictions, references=labels, average='macro')

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

In [57]:
training_args = TrainingArguments(
    
    # output
    output_dir=RESULTS_FOLDER,          
    
    # params
    num_train_epochs=NB_EPOCHS,               # nb of epochs
    per_device_train_batch_size=BATCH_SIZE,   # batch size per device during training
    per_device_eval_batch_size=BATCH_SIZE,    # cf. paper Sun et al.
    learning_rate=1e-5,#2e-5,                 # cf. paper Sun et al.
#     warmup_steps=500,                         # number of warmup steps for learning rate scheduler
    warmup_ratio=0.1,                         # cf. paper Sun et al.
    weight_decay=0.01,                        # strength of weight decay
    
    # eval
    eval_strategy="steps",              # cf. paper Sun et al.
    eval_steps=20,                            # cf. paper Sun et al.
    
    # log
    logging_dir="emotion_analysis_comics/bert/logs",  
    logging_strategy='steps',
    logging_steps=20,
    
    # save
    save_strategy='steps',
    save_total_limit=1,
    # save_steps=20, # default 500
    load_best_model_at_end=True,              # cf. paper Sun et al.
    # metric_for_best_model='eval_loss' 
    metric_for_best_model='f1'
)

In [58]:
trainer = CustomTrainer( # Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

In [59]:
trainer.train()

Step,Training Loss,Validation Loss,F1
20,1.9347,1.932102,0.077932
40,1.8985,1.880744,0.084087
60,1.8584,1.839619,0.146252
80,1.8144,1.787726,0.174386
100,1.7582,1.72898,0.163216
120,1.6944,1.673566,0.181267
140,1.638,1.625176,0.208123
160,1.5755,1.574609,0.312273
180,1.5102,1.52749,0.330426
200,1.4318,1.492725,0.361466


TrainOutput(global_step=2000, training_loss=0.3995676355063915, metrics={'train_runtime': 3731.3453, 'train_samples_per_second': 136.01, 'train_steps_per_second': 0.536, 'total_flos': 5.581339646924986e+16, 'train_loss': 0.3995676355063915, 'epoch': 100.0})

In [60]:
# save best model
#trainer.save_model(os.path.join("/notebooks/cascade_bert/saved_models", 'best-model-with-real-prev-probs'))

In [61]:
#model_file = os.path.join("/notebooks/cascade_bert/saved_models", 'best-model-with-real-prev-probs')

#model = BertForSequenceClassification.from_pretrained(model_file, num_labels=NUM_LABELS)
#model.to(device)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [62]:
test_trainer = Trainer(model, data_collator=DataCollatorWithPadding(tokenizer))
test_raw_preds, test_labels, _ = test_trainer.predict(test_dataset)
test_preds = np.argmax(test_raw_preds, axis=1)

In [63]:
len(test_preds)

1097

In [64]:
test_labels

array([0, 3, 0, ..., 0, 0, 1])

In [65]:
test_preds

array([1, 3, 6, ..., 0, 1, 1])

In [66]:
# labels=['fear', 'anger', 'disgust', 'joy', 'sadness', 'surprise', 'neutral']

In [67]:
target_name = labels.int2str([0,1,2,3,4,5,6])
print(classification_report(test_labels, test_preds, target_names=target_name, digits=3)) # type: ignore

              precision    recall  f1-score   support

       anger      0.499     0.592     0.541       321
        fear      0.461     0.363     0.406       212
     neutral      0.056     0.048     0.051        42
    surprise      0.441     0.463     0.452       177
         joy      0.421     0.421     0.421       195
     disgust      0.000     0.000     0.000        21
     sadness      0.348     0.357     0.352       129

    accuracy                          0.437      1097
   macro avg      0.318     0.320     0.318      1097
weighted avg      0.424     0.437     0.428      1097



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
