## Train

* Load training set
* Load bert
* Finetune bert on training set
* Export finetuned model state
* Evaluate finetuned model on test set

In [14]:
import os
import time

import torch
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import DataLoader
from transformers import (
    BertForSequenceClassification,
    get_linear_schedule_with_warmup,
)

from shared import BERT_MODEL, tokenized_train_dataset_path, model_path, WORKDIR, MODELDIR

#TODO from google.colab import drive

In [15]:
#TODO drive.mount("/content/drive")
if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    raise ValueError("Cuda or Apple Metal required but neither is available.")
print(f"Using device: {device}")

Using device: mps


In [16]:
# Define hyperparameters
TRAIN_EPOCHS = 10
TRAIN_BATCH_SIZE = 32
NUM_WORKERS = 2

In [17]:
# Load datasets
train_dataset = torch.load(tokenized_train_dataset_path)

In [18]:
model = BertForSequenceClassification.from_pretrained(
    BERT_MODEL, num_labels=2, output_attentions=False, output_hidden_states=False
)
model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [19]:
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-6, eps=1e-8)

In [20]:
train_loader = DataLoader(dataset=train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [21]:
total_steps = len(train_loader) * TRAIN_EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [22]:
if not os.path.exists(MODELDIR):
        os.makedirs(MODELDIR)
    
model.train()
    
total_step = len(train_loader)
for epoch in range(TRAIN_EPOCHS):
    start_time = time.time()
    total_loss = 0
    for i, (ids, masks, labels) in enumerate(train_loader):
        ids = ids.to(device)
        masks = masks.to(device)
        labels = labels.to(device)

        loss = model(ids, token_type_ids=None, attention_mask=masks, labels=labels)[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    end_time = time.time()
    epoch_time = end_time - start_time
    loss_per_step = total_loss / total_step
    print(f"{epoch + 1} epoch with loss: {loss_per_step:.4f}, duration: {epoch_time:.2f} seconds")
    previous_checkpoint_path = MODELDIR / f"BERT_finetuned_epoch_{epoch}.ckpt"
    current_checkpoint_path = MODELDIR / f"BERT_finetuned_epoch_{epoch+1}.ckpt"
    if epoch == TRAIN_EPOCHS - 1:
        torch.save(model.state_dict(), model_path)
    else:  
        torch.save(model.state_dict(), current_checkpoint_path)
        
    if os.path.exists(previous_checkpoint_path):
       os.remove(previous_checkpoint_path)


1 epoch with loss: 0.0951, duration: 1095.82 seconds
2 epoch with loss: 0.0320, duration: 2925.69 seconds
3 epoch with loss: 0.0152, duration: 744.40 seconds
4 epoch with loss: 0.0076, duration: 2747.85 seconds
5 epoch with loss: 0.0036, duration: 1586.92 seconds
6 epoch with loss: 0.0025, duration: 3290.63 seconds
7 epoch with loss: 0.0018, duration: 677.84 seconds
8 epoch with loss: 0.0009, duration: 675.29 seconds
9 epoch with loss: 0.0008, duration: 678.67 seconds
10 epoch with loss: 0.0004, duration: 696.01 seconds
