## Model training

In [32]:
from transformers import TrainingArguments, Trainer
import pickle
from transformers import (
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
    AutoConfig,
    BertModel,
)
from transformers import AutoTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch
import torch.nn as nn
from transformers.modeling_outputs import SequenceClassifierOutput

### Loading the data


In [33]:
with open('train_dataset_tokenized.pkl', 'rb') as file:
    train_dataset = pickle.load(file)

with open('val_data_tokenized.pkl', 'rb') as file:
    val_dataset = pickle.load(file)

with open('test_data_tokenized.pkl', 'rb') as file:
    test_dataset = pickle.load(file)

### Setting up the training arguments

In [34]:
args = TrainingArguments(
        output_dir="./output",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=3e-5,
        per_device_train_batch_size=8 ,
        per_device_eval_batch_size=8 ,
        num_train_epochs=20,
        warmup_ratio= 0.1,
        weight_decay= 0.001,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        save_total_limit=1,
            )

pre_trained_BERTmodel='bert-large-uncased'
BERT_tokenizer=AutoTokenizer.from_pretrained(pre_trained_BERTmodel)

### Modifying Bert for our classification Task

In [36]:
class BertModelWithCustomLossFunction(nn.Module):
    def __init__(self):
        super(BertModelWithCustomLossFunction, self).__init__()
        self.num_labels = 64
        self.bert = BertModel.from_pretrained(
            pre_trained_BERTmodel, num_labels=self.num_labels
        )
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(1024, self.num_labels)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        output = self.dropout(outputs.pooler_output)
        logits = self.classifier(output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

### Setting up metrics for accuracy, precision, recall and f1

In [37]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

### Training the model

In [39]:
best_accuracy = 0
for train_data in train_dataset:
  BERT_model = BertModelWithCustomLossFunction()
  trainer = Trainer(
        model = BERT_model,
        args = args,
        train_dataset=train_data,
        eval_dataset=val_dataset,
        tokenizer=BERT_tokenizer,
        compute_metrics=compute_metrics,)
  trainer.train()
  evaluation_metrics = trainer.predict(test_dataset)
  accuracy = evaluation_metrics.metrics['test_accuracy']
  best_accuracy = max(accuracy, best_accuracy)
  print(f"Best Test Accuracy for this training dataset: {accuracy}")
  torch.cuda.empty_cache()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.198321,0.020141,0.00628,0.020141,0.007179
2,No log,4.044622,0.06143,0.032681,0.06143,0.033894
3,No log,3.822736,0.131923,0.154717,0.131923,0.106068
4,No log,3.543237,0.317221,0.384951,0.317221,0.283106
5,No log,3.198873,0.476334,0.593818,0.476334,0.45599
6,No log,2.778737,0.602216,0.643916,0.602216,0.581602
7,No log,2.380214,0.654582,0.699252,0.654582,0.637272
8,No log,2.050138,0.715005,0.751127,0.715005,0.697447
9,No log,1.733713,0.7714,0.776218,0.7714,0.755475
10,No log,1.510513,0.796576,0.800479,0.796576,0.780965


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Best Test Accuracy for this training dataset: 0.8102325581395349




Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.161445,0.036254,0.006665,0.036254,0.008279
2,No log,4.017666,0.042296,0.032001,0.042296,0.020807
3,No log,3.812759,0.15005,0.179093,0.15005,0.130481
4,No log,3.503791,0.317221,0.411577,0.317221,0.298836
5,No log,3.158438,0.492447,0.566952,0.492447,0.475426
6,No log,2.751917,0.628399,0.66292,0.628399,0.613481
7,No log,2.386314,0.708963,0.739096,0.708963,0.701924
8,No log,2.036235,0.743202,0.76735,0.743202,0.73509
9,No log,1.730175,0.761329,0.771508,0.761329,0.753302
10,No log,1.516768,0.774421,0.784374,0.774421,0.766411


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Best Test Accuracy for this training dataset: 0.8111627906976744




Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.158862,0.038268,0.011958,0.038268,0.010013
2,No log,4.054665,0.045317,0.054383,0.045317,0.026426
3,No log,3.755263,0.160121,0.194702,0.160121,0.12763
4,No log,3.387835,0.364552,0.44708,0.364552,0.340851
5,No log,3.011528,0.560926,0.627555,0.560926,0.545099
6,No log,2.583475,0.65861,0.711271,0.65861,0.645447
7,No log,2.224372,0.728097,0.749868,0.728097,0.717012
8,No log,1.901034,0.758308,0.768195,0.758308,0.747038
9,No log,1.675509,0.757301,0.771879,0.757301,0.747796
10,No log,1.468249,0.786506,0.794668,0.786506,0.77758


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Best Test Accuracy for this training dataset: 0.8427906976744186


In [40]:
best_accuracy

0.8427906976744186