# Aim
In this notebook, I'll use a TransformerEncoderLayer to perform text classification task. For simplicity, I'll directly use the pre-trained tokenizer since this doesn't affect the classification very much.

In [None]:
import torch
from torch import nn, Tensor
from torch.nn.modules import TransformerEncoderLayer, TransformerEncoder
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import BertTokenizer
import time
import math
from tempfile import TemporaryDirectory
import os

# Load the dataset

In [None]:
train = load_dataset('dair-ai/emotion', 'split', split='train')
valid = load_dataset('dair-ai/emotion', 'split', split='validation')
test = load_dataset('dair-ai/emotion', 'split', split='test')
print('size of train: {}, validation: {}, test: {}'.format(len(train), len(valid), len(test)))

# Pre-process data with pre-trained model

In [None]:
train[0]

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
def tokenizing(record):
    return tokenizer(record['text'], truncation=True, max_length=300)

train_tokenized = train.map(tokenizing, batched=True)
valid_tokenized = valid.map(tokenizing, batched=True)
test_tokenized = test.map(tokenizing, batched=True)

In [None]:
def func(l):
    t = torch.Tensor(7)
    t[l] = 1
    return t
print(train_tokenized['label'][:2])
[func(l) for l in train_tokenized['label'][:2]]

In [None]:
train_tokenized[0]

In [None]:
id2label = {
    0: 'sadness',
    1: 'joy',
    2: 'love',
    3: 'anger',
    4: 'fear',
    5: 'surprise'
}
label2id = {v: k for k, v in id2label.items()}

# Define metrics

In [None]:
accuracy = evaluate.load('accuracy')
def metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

# Define model

In [None]:
class TransformerModel(nn.Module):

    def __init__(self, d_model: int, nhead: int, d_hid: int, nlayers: int, out_features: int,
                 d_between: int = 128, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer text classificator'
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.linear1 = nn.modules.Linear(d_model, d_between)
        self.output = nn.modules.Linear(d_between, out_features)

    def forward(self, src: Tensor, mask: Tensor) -> Tensor:
        x = self.transformer_encoder(src, mask=mask)
        x = F.relu(self.linear1(x))
        output = self.output(x)
        return output

# training arguments

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
emsize = 300  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
out_features = 6 # num of categories
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
BATCH_SIZE = 256
model = TransformerModel(emsize, nhead, d_hid, nlayers, out_features, dropout=dropout).to(device)
model

In [None]:
def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [None]:
train_tokenized.shape

In [None]:
criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def convert_label(l):
    t = torch.zeros(7, dtype=torch.float32)
    t[l] = 1
    return t

def train(model: nn.Module, epoch: int) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 10
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(BATCH_SIZE).to(device)

    num_batches = len(train_tokenized) // BATCH_SIZE
    for batch, i in enumerate(range(0, len(train_tokenized), BATCH_SIZE)):
        end = i + BATCH_SIZE
        end = end if end < len(train_tokenized) else -1
        input_ids = train_tokenized['input_ids'][i: end]
        labels = train_tokenized['label'][i: end]

        # padding
        input_ids = pad_sequence([torch.tensor(l, dtype=torch.float) for l in input_ids], batch_first=True)
        m = nn.ZeroPad2d((0, emsize - input_ids.shape[1]))
        input_ids = m(input_ids)
        labels = torch.tensor(labels).to(device)

        seq_len = input_ids.shape[0]
        if seq_len != BATCH_SIZE:  # only on last batch
            src_mask = src_mask[:seq_len, :seq_len]

        input_ids = input_ids.to(device)
        output = model(input_ids, mask=src_mask)
        loss = criterion(output, labels)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(BATCH_SIZE).to(device)
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(eval_data), BATCH_SIZE)):
            end = i + BATCH_SIZE
            end = end if end < len(eval_data) else -1
            input_ids = eval_data['input_ids'][i: end]
            labels = eval_data['label'][i: end]

            # padding
            input_ids = pad_sequence([torch.tensor(l, dtype=torch.float) for l in input_ids], batch_first=True)
            m = nn.ZeroPad2d((0, emsize - input_ids.shape[1]))
            input_ids = m(input_ids)
            labels = torch.tensor(labels).to(device)

            seq_len = input_ids.shape[0]
            if seq_len != BATCH_SIZE:  # only on last batch
                src_mask = src_mask[:seq_len, :seq_len]

            input_ids = input_ids.to(device)
            output = model(input_ids, mask=src_mask)
            total_loss += criterion(output, labels).item()
    return total_loss / (len(valid_tokenized) - 1)

# Train

In [None]:
best_val_loss = float('inf')
epochs = 3

with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model, epoch)
        val_loss = evaluate(model, valid_tokenized)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print('-' * 89)
        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
              f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()
    model.load_state_dict(torch.load(best_model_params_path)) # load best model states