## Model training

In [1]:
from transformers import TrainingArguments, Trainer
import pickle
from transformers import (
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
    AutoConfig,
    BertModel,
)
from transformers import AutoTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch
import torch.nn as nn
from transformers.modeling_outputs import SequenceClassifierOutput

  from .autonotebook import tqdm as notebook_tqdm
2023-07-26 03:13:09.437773: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


### Loading the data


In [2]:
with open('train_dataset_tokenized.pkl', 'rb') as file:
    train_dataset = pickle.load(file)

with open('val_data_tokenized.pkl', 'rb') as file:
    val_dataset = pickle.load(file)

with open('test_data_tokenized.pkl', 'rb') as file:
    test_dataset = pickle.load(file)

with open('train_dataset_full_tokenized.pkl', 'rb') as file:
    train_dataset_full = [pickle.load(file)]

with open('augmented_train_dataset_tokenized.pkl', 'rb') as file:
    train_dataset_augmented = pickle.load(file)

### Setting up the training arguments

In [8]:
learning_rates = [5e-5, 4e-5, 3e-5, 2e-5]

pre_trained_BERTmodel='bert-large-uncased'
BERT_tokenizer=AutoTokenizer.from_pretrained(pre_trained_BERTmodel)

### Modifying Bert for our classification Task

In [4]:
class BertModelWithCustomLossFunction(nn.Module):
    def __init__(self):
        super(BertModelWithCustomLossFunction, self).__init__()
        self.num_labels = 64
        self.bert = BertModel.from_pretrained(
            pre_trained_BERTmodel, num_labels=self.num_labels
        )
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(1024, self.num_labels)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        output = self.dropout(outputs.pooler_output)
        logits = self.classifier(output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

### Setting up metrics for accuracy, precision, recall and f1

In [5]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

### Training the model

In [6]:
import warnings
warnings.filterwarnings("ignore")


In [10]:
for i, train in enumerate([train_dataset_full, train_dataset, train_dataset_augmented]):
  best_lr = learning_rates[0]
  for lr in learning_rates:
    if i != 0:
      learning_rate = best_lr
    else:
      learning_rate = lr
    best_accuracy = 0
    for train_data in train:
      BERT_model = BertModelWithCustomLossFunction()
      trainer = Trainer(
            model = BERT_model,
            args = TrainingArguments(
          output_dir="./output",
          evaluation_strategy="epoch",
          save_strategy="epoch",
          learning_rate=learning_rate,
          per_device_train_batch_size=8 ,
          per_device_eval_batch_size=8 ,
          num_train_epochs=3,
          warmup_ratio= 0.1,
          weight_decay= 0.001,
          load_best_model_at_end=True,
          metric_for_best_model="accuracy",
          save_total_limit=1,
              ),
            train_dataset=train_data,
            eval_dataset=val_dataset,
            tokenizer=BERT_tokenizer,
            compute_metrics=compute_metrics,)
      trainer.train()
      evaluation_metrics = trainer.predict(test_dataset)
      accuracy = evaluation_metrics.metrics['test_accuracy']
      if accuracy > best_accuracy:
        best_accuracy = accuracy
        if i == 0:
          best_lr = learning_rate
      best_accuracy = max(accuracy, best_accuracy)
      torch.cuda.empty_cache()
    print(f"Best Test Accuracy for this training dataset: {accuracy}")

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.584568,0.875126,0.885359,0.875126,0.869749
2,1.415400,0.356534,0.915408,0.92041,0.915408,0.914232
3,1.415400,0.30087,0.934542,0.938047,0.934542,0.93377


Best Test Accuracy for this training dataset: 0.9144186046511628


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.654746,0.868077,0.869433,0.868077,0.857108
2,1.588600,0.37692,0.903323,0.902579,0.903323,0.899305
3,1.588600,0.318915,0.926485,0.928335,0.926485,0.924946


Best Test Accuracy for this training dataset: 0.9162790697674419


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.876044,0.850957,0.850631,0.850957,0.833813
2,1.887000,0.387274,0.910373,0.910434,0.910373,0.906959
3,1.887000,0.324552,0.924471,0.928111,0.924471,0.92333


Best Test Accuracy for this training dataset: 0.9097674418604651


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,1.231819,0.820745,0.792514,0.820745,0.796396
2,2.146300,0.504549,0.898288,0.900764,0.898288,0.892497
3,2.146300,0.409723,0.913394,0.914449,0.913394,0.910041


Best Test Accuracy for this training dataset: 0.9032558139534884


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.215529,0.021148,0.001559,0.021148,0.002771
2,No log,3.979859,0.056395,0.046541,0.056395,0.037328
3,No log,3.768556,0.175227,0.224955,0.175227,0.158207


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.02825,0.06143,0.062097,0.06143,0.048562
2,No log,3.730567,0.186304,0.242597,0.186304,0.143178
3,No log,3.574939,0.278953,0.347509,0.278953,0.236599


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.228258,0.018127,0.002274,0.018127,0.003375
2,No log,4.227074,0.028197,0.002149,0.028197,0.003061
3,No log,4.203496,0.02719,0.001659,0.02719,0.002347


Best Test Accuracy for this training dataset: 0.018604651162790697


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.215529,0.021148,0.001559,0.021148,0.002771
2,No log,3.979859,0.056395,0.046541,0.056395,0.037328
3,No log,3.768556,0.175227,0.224955,0.175227,0.158207


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.02825,0.06143,0.062097,0.06143,0.048562
2,No log,3.730567,0.186304,0.242597,0.186304,0.143178
3,No log,3.574939,0.278953,0.347509,0.278953,0.236599


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.228258,0.018127,0.002274,0.018127,0.003375
2,No log,4.227074,0.028197,0.002149,0.028197,0.003061
3,No log,4.203496,0.02719,0.001659,0.02719,0.002347


Best Test Accuracy for this training dataset: 0.018604651162790697


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.215529,0.021148,0.001559,0.021148,0.002771
2,No log,3.979859,0.056395,0.046541,0.056395,0.037328
3,No log,3.768556,0.175227,0.224955,0.175227,0.158207


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.02825,0.06143,0.062097,0.06143,0.048562
2,No log,3.730567,0.186304,0.242597,0.186304,0.143178
3,No log,3.574939,0.278953,0.347509,0.278953,0.236599


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.228258,0.018127,0.002274,0.018127,0.003375
2,No log,4.227074,0.028197,0.002149,0.028197,0.003061
3,No log,4.203496,0.02719,0.001659,0.02719,0.002347


Best Test Accuracy for this training dataset: 0.018604651162790697


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.215529,0.021148,0.001559,0.021148,0.002771
2,No log,3.979859,0.056395,0.046541,0.056395,0.037328
3,No log,3.768556,0.175227,0.224955,0.175227,0.158207


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.02825,0.06143,0.062097,0.06143,0.048562
2,No log,3.730567,0.186304,0.242597,0.186304,0.143178
3,No log,3.574939,0.278953,0.347509,0.278953,0.236599


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.228258,0.018127,0.002274,0.018127,0.003375
2,No log,4.227074,0.028197,0.002149,0.028197,0.003061
3,No log,4.203496,0.02719,0.001659,0.02719,0.002347


Best Test Accuracy for this training dataset: 0.018604651162790697


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.217168,0.01712,0.007164,0.01712,0.0059
2,No log,4.142601,0.025176,0.026947,0.025176,0.012222
3,No log,4.077608,0.057402,0.036015,0.057402,0.030639


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,3.953987,0.07855,0.059735,0.07855,0.044614
2,No log,3.44874,0.342397,0.41746,0.342397,0.300834
3,No log,3.078752,0.526687,0.551175,0.526687,0.493253


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.232747,0.015106,0.000228,0.015106,0.00045
2,No log,4.180817,0.015106,0.022078,0.015106,0.003956
3,No log,4.147583,0.039275,0.01949,0.039275,0.0175


Best Test Accuracy for this training dataset: 0.03255813953488372


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.217168,0.01712,0.007164,0.01712,0.0059
2,No log,4.142601,0.025176,0.026947,0.025176,0.012222
3,No log,4.077608,0.057402,0.036015,0.057402,0.030639


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,3.953987,0.07855,0.059735,0.07855,0.044614
2,No log,3.44874,0.342397,0.41746,0.342397,0.300834
3,No log,3.078752,0.526687,0.551175,0.526687,0.493253


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.232747,0.015106,0.000228,0.015106,0.00045
2,No log,4.180817,0.015106,0.022078,0.015106,0.003956
3,No log,4.147583,0.039275,0.01949,0.039275,0.0175


Best Test Accuracy for this training dataset: 0.03255813953488372


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.217168,0.01712,0.007164,0.01712,0.0059
2,No log,4.142601,0.025176,0.026947,0.025176,0.012222
3,No log,4.077608,0.057402,0.036015,0.057402,0.030639


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,3.953987,0.07855,0.059735,0.07855,0.044614
2,No log,3.44874,0.342397,0.41746,0.342397,0.300834
3,No log,3.078752,0.526687,0.551175,0.526687,0.493253


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.232747,0.015106,0.000228,0.015106,0.00045
2,No log,4.180817,0.015106,0.022078,0.015106,0.003956
3,No log,4.147583,0.039275,0.01949,0.039275,0.0175


Best Test Accuracy for this training dataset: 0.03255813953488372


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.217168,0.01712,0.007164,0.01712,0.0059
2,No log,4.142601,0.025176,0.026947,0.025176,0.012222
3,No log,4.077608,0.057402,0.036015,0.057402,0.030639


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,3.953987,0.07855,0.059735,0.07855,0.044614
2,No log,3.44874,0.342397,0.41746,0.342397,0.300834
3,No log,3.078752,0.526687,0.551175,0.526687,0.493253


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.232747,0.015106,0.000228,0.015106,0.00045
2,No log,4.180817,0.015106,0.022078,0.015106,0.003956
3,No log,4.147583,0.039275,0.01949,0.039275,0.0175


Best Test Accuracy for this training dataset: 0.03255813953488372
