In [None]:
% load_ext autoreload
% autoreload 2

import torch
from config import conf
from torch.utils.data import DataLoader
from torch.optim import Adam
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
from tqdm import tqdm
from dataset import SquadDataset
from torch.nn import Module
from pathlib import Path
from datetime import datetime

torch.cuda.is_available()

In [None]:
TRAIN_MODEL = True
SAVE_MODEL = True
MODELS_FOLDER = "./models"
MODEL_LOAD_NAME = "model_0125.pt"

In [None]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
dataset = SquadDataset.from_json(conf['DATASET_FILE'], tokenizer)
train_dataset, val_dataset = dataset.train_val_split(conf['TRAIN_RATIO'])

model: Module = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')

In [None]:
def compute_accuracy(pred: torch.Tensor, true: torch.Tensor) -> float:
    assert len(pred) == len(true)
    #TODO: check the sum
    return ((start_pred == start_true).sum() / len(start_pred)).item()

In [None]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

if TRAIN_MODEL:
    opt = Adam(model.parameters(), lr=5e-5)
    train_loader = DataLoader(train_dataset, batch_size=conf['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=conf['BATCH_SIZE'])

    for epoch in range(conf['N_EPOCHS']):
        # ================================ TRAINING ================================
        model.train()

        train_losses = []
        train_accs = []

        train_iter = tqdm(train_loader)
        train_iter.set_description(f'Epoch {epoch}')

        for train_batch in train_iter:
            input_ids = train_batch['input_ids'].to(device)
            attention_mask = train_batch['attention_mask'].to(device)
            start_true = train_batch['start_positions'].to(device)
            end_true = train_batch['end_positions'].to(device)

            outputs = model(input_ids,
                            attention_mask=attention_mask,
                            start_positions=start_true,
                            end_positions=end_true)
            start_pred = torch.argmax(outputs['start_logits'], dim=1)
            end_pred = torch.argmax(outputs['end_logits'], dim=1)

            loss = outputs['loss']
            start_acc = compute_accuracy(start_pred, start_true)
            end_acc = compute_accuracy(end_pred, end_true)
            train_iter.set_postfix(loss=loss.item(),
                                   acc=(start_acc + end_acc) / 2)

            train_losses.append(loss.item())
            train_accs.append(start_acc)
            train_accs.append(end_acc)

            loss.backward()

            opt.step()
            opt.zero_grad()

        # =============================== VALIDATION ===============================
        model.eval()

        val_losses = [0]
        val_accs = []

        with torch.no_grad():
            for val_batch in val_loader:
                input_ids = val_batch['input_ids'].to(device)
                attention_mask = val_batch['attention_mask'].to(device)
                start_true = val_batch['start_positions'].to(device)
                end_true = val_batch['end_positions'].to(device)

                outputs = model(input_ids, attention_mask=attention_mask)
                start_pred = torch.argmax(outputs['start_logits'], dim=1)
                end_pred = torch.argmax(outputs['end_logits'], dim=1)

                # val_losses.append(outputs['loss']) #TODO: check missing key
                val_accs.append(compute_accuracy(start_pred, start_true))
                val_accs.append(compute_accuracy(end_pred, end_true))

        train_iter.set_postfix(loss=sum(train_losses) / len(train_losses),
                               acc=sum(train_accs) / len(train_accs),
                               val_loss=sum(val_losses) / len(val_losses),
                               val_acc=sum(val_accs) / len(val_accs))
    # SAVING
    if SAVE_MODEL:
        Path(MODELS_FOLDER).mkdir(parents=True, exist_ok=True)
        filepath = f"{MODELS_FOLDER}/model_{datetime.today().strftime('%m%d')}.pt"
        torch.save(model.state_dict(), filepath)
        print(f"Model saved in {filepath}")

In [None]:
# =============================== TESTING ===============================
test_dataset = val_dataset  # we don't actually have the testing ds yet
test_loader = DataLoader(test_dataset, batch_size=conf['BATCH_SIZE'])

if not TRAIN_MODEL:
    filepath = MODELS_FOLDER + '/' + MODEL_LOAD_NAME
    model.load_state_dict(torch.load(filepath))
    print(f"Loaded model at {filepath}")

model.eval()
test_losses = []
test_accs = []

with torch.no_grad():
    for test_batch in test_loader:
        input_ids = test_batch['input_ids'].to(device)
        attention_mask = test_batch['attention_mask'].to(device)
        start_true = test_batch['start_positions'].to(device)
        end_true = test_batch['end_positions'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)

        # test_losses.append(outputs['loss']) #TODO: check missing key (same as validation)
        test_accs.append(compute_accuracy(start_pred, start_true))
        test_accs.append(compute_accuracy(end_pred, end_true))

print(f"Average test accuracy: {sum(test_accs) / len(test_accs)}")

In [None]:
from torchinfo import summary
summary(model, verbose=1)