## Import modules

In [None]:
import nlp
from transformers import AlbertTokenizer, AlbertConfig, AlbertForSequenceClassification
import torch
from tqdm.notebook import trange, tqdm
import random
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score, classification_report

device = "cuda" if torch.cuda.is_available() else "cpu"

## Load data

In [None]:
scicite = nlp.load_dataset("scicite")

train = scicite["train"]
val = scicite["validation"]
test = scicite["test"]

## Initialize model configuration

In [None]:
config = AlbertConfig.from_pretrained('albert-base-v2')
config.num_labels = 3
config.use_bfloat16 = True
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2', config=config)
model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', config=config)

In [None]:
# print(summary(model, torch.zeros((BATCH_SIZE, MAX_LEN), dtype=torch.long), show_input=True))

## Format data into a DataLoader object

In [None]:
def encode(example):
    return tokenizer(example["string"], truncation=True, max_length=100, padding="max_length")

def format_data(train, val, test, batch_size=16):
    # Here we tokenize all instances in the data 
    train_tokens = train.map(encode, batched=True)
    val_tokens = val.map(encode, batched=True)
    test_tokens = test.map(encode, batched=True)
    
    # Here we assign a dataset specific assignment of relevant column names (these get assigned by model specific tokenizers, so different models may have different columns)
    train_tokens.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
    val_tokens.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
    test_tokens.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
    
    # Here we create DataLoader objects, which we use to loop over batches  of data
    train_dataloader = torch.utils.data.DataLoader(train_tokens, batch_size=batch_size)
    val_dataloader = torch.utils.data.DataLoader(val_tokens, batch_size=batch_size)
    test_dataloader = torch.utils.data.DataLoader(test_tokens, batch_size=batch_size)
    
    return train_dataloader, val_dataloader, test_dataloader

# Convert to DataLoader
train_dataloader, val_dataloader, test_dataloader = format_data(train, val, test, batch_size=16)

## Train model

In [None]:
model.to(device)  # Move model to gpu if available

optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)  # Define optimizer with an adjustible learning rate

for epoch in trange(5, desc="Epoch 1"):
    model.train()  # Set the model to train mode. Model layers have different behavior depending on train or eval mode (e.g. dropout is removed during eval)

    train_true = []
    train_pred = []
    losses = []

    for i, batch in enumerate(tqdm(train_dataloader)):
        batch["labels"] = batch.pop("label")  # Rename keyword
        batch = {k: v.to(device) for k, v in batch.items()}  # Move instances in batch to device for calculation 
        outputs = model(**batch)  # Calculate model output on instances in batch
        loss = outputs[0]  # First element of the output is the loss
        loss.backward()  # Calculate the gradients
        losses.append(loss.item())  # Keep track of the loss
        optimizer.step()  # Compute the backpropagation step
        optimizer.zero_grad()  # Clear computed gradients for next iteration
        
        train_true.append(batch["labels"].tolist())  # Keep track of gold truth labels
        train_pred.append(torch.argmax(outputs[1], dim=1).tolist())  # Keep track of predicted labels

    train_y_true = [v for l in train_true for v in l]  # Flatten layered list
    train_y_pred = [v for l in train_pred for v in l]  # Flatten layered list
    train_acc = accuracy_score(train_y_true, train_y_pred)  # Compute accuracy
    train_f1 = f1_score(train_y_true, train_y_pred, average="macro")  # Compute F1-score
    # print(f"mean train loss: {losses/(i+1)}\t", f"train acc: {train_acc}", f"train f1: {train_f1}")
    
    model.eval()  # Set eval mode for testing on validation data
    with torch.no_grad():  # Disable gradient calculation, which speeds up computation and reduces memory
        val_true = []
        val_pred = []
        for i, val_batch in enumerate(tqdm(val_dataloader)):
            val_batch["labels"] = val_batch.pop("label")
            val_batch = {k: v.to(device) for k, v in val_batch.items()}
            val_outputs = model(**val_batch)
            val_true.append(val_batch["labels"].detach().cpu().tolist())
            val_pred.append(torch.argmax(val_outputs[1].detach().cpu(), dim=1).tolist())
        val_y_true = [v for l in val_true for v in l]
        val_y_pred = [v for l in val_pred for v in l]
        val_acc = accuracy_score(val_y_true, val_y_pred)
        val_f1 = f1_score(val_y_true, val_y_pred, average="macro")
        print(f"val acc: {val_acc}", f"val f1: \t{val_f1}")
        
        print(classification_report(val_y_true, val_y_pred, labels=[0,1,2]))

## Save model

In [None]:
os.mkdir("./models")
torch.save(model.state_dict(), "models/model.pt")

## Load trained model

In [None]:
# In order to use a saved model, you have to initialize the configurator, tokenizer, and classifier

config = AlbertConfig.from_pretrained('albert-base-v2')
config.num_labels = 3
config.use_bfloat16 = True
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2', config=config)
model = AlbertForSequenceClassification(config=config)

model.load_state_dict(torch.load('model.pt'))  # Load model parameters

## Test model

In [None]:
model.to(device)
model.eval()
with torch.no_grad():
    test_true = []
    test_pred = []
    for i, test_batch in enumerate(tqdm(test_dataloader)):
        test_batch["labels"] = test_batch.pop("label")
        test_batch = {k: v.to(device) for k, v in test_batch.items()}
        test_outputs = model(**test_batch)
        test_true.append(test_batch["labels"].detach().cpu().tolist())
        test_pred.append(torch.argmax(test_outputs[1].detach().cpu(), dim=1).tolist())
    test_y_true = [v for l in test_true for v in l]
    test_y_pred = [v for l in test_pred for v in l] 
    
    print(classification_report(test_y_true, test_y_pred, labels=[0,1,2]))