In [1]:
from utils import d
from transformer import CTransformer

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import IMDB
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import numpy as np
import random, tqdm, sys, math, gzip

device = d()


# Parameters

Define parameters for later usage.

In [2]:
batch_size = 128
embedding_size = 128
max_pool = True
num_heads = 8
depth = 6
NUM_CLS = 2
lr = 0.0001
lr_warmup = 10_000
num_epochs = 80
gradient_clipping = 1.0
max_length = 512

# Import data

We will use IMDB reviews dataset for sentiment classification. Two labels:

* Positive: 2

* Negative: 1

In [3]:
train_iter, test_iter = IMDB(split=('train', 'test'))
train_list = list(train_iter)
test_list = list(test_iter)

In [4]:
type(train_iter)

torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe

Let's check how the data looks like:

In [5]:
train_list[500]

(1,
 'When I ordered this from Blockbuster\'s website I had no idea that it would be as terrible as it was. Who knows? Maybe I\'d forgotten to take my ADD meds that day. I do know that from the moment the cast drove up in their station wagon, donned in their late 70\'s-style wide collars, bell-bottoms and feathered hair, I knew that this misplaced gem of the disco era was glory bound for the dumpster.<br /><br />The first foretelling of just how bad things were to be was the narration at the beginning, trying to explain what cosmic forces were at play to wreak havoc upon the universe, forcing polyester and porno-quality music on the would-be viewer. From the opening scene with the poorly-done effects to the "monsters" from another world and then the house which jumps from universe to universe was as achingly painful as watching an elementary school production of \'The Vagina Monologues\'.<br /><br />Throughout the film, the sure sign something was about to happen was when a small ship 

# Preprocessing

Now we create our vocabulary: <span style="background-color: yellow;">**[Code Pointer 1]**</span>

In [6]:
# Word tokenizer!
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text.lower())

# Build vocabulary
vocab = build_vocab_from_iterator(yield_tokens(train_list), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])

# Ensure the vocab size fits the argument
vocab_size = len(vocab)
vocab_size

100684

As we can see, every word has an associated index

In [7]:
# String to index
S2I = vocab.get_stoi()

# Index to string
I2S = vocab.get_itos()

In [8]:
I2S[100]

'made'

Now let us preprocess our labels

In [9]:
# Define pipelines for text and label processing
text_pipeline = lambda x: [vocab[token.lower()] for token in tokenizer(x)]  # Lowercase conversion here
label_pipeline = lambda x: 1 if x == 2 else 0

Finally, we create dataloaders for training and testing

In [10]:
def collate_batch(batch, vocab, text_pipeline, label_pipeline):
    label_list, text_list = [], []
    for label, text in batch:
        label_list.append(label_pipeline(label))
        processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64).to(device)
        text_list.append(processed_text)
    label_list = torch.tensor(label_list, dtype=torch.int64).to(device) 
    text_list = pad_sequence(text_list, padding_value=vocab['<pad>'])
    # Transpose the text_list to have batch size as the first dimension
    text_list = text_list.transpose(0, 1)  # Switching from [seq_len, batch_size] to [batch_size, seq_len]
    return label_list, text_list

# Create data loaders
train_dataloader = DataLoader(train_list, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_batch(x, vocab, text_pipeline, label_pipeline))
test_dataloader = DataLoader(test_list, batch_size=batch_size, shuffle=False, collate_fn=lambda x: collate_batch(x, vocab, text_pipeline, label_pipeline))

Let's take a look to a batch of sequences

In [11]:
for element in train_dataloader:
    sample_label = element[0]
    sample_text  = element[1]
    break 

In [12]:
sample_text

tensor([[ 14,  21,  10,  ...,   1,   1,   1],
        [ 13,   9, 152,  ...,   1,   1,   1],
        [ 13,  97,   9,  ...,   1,   1,   1],
        ...,
        [ 72,   4,  39,  ...,   1,   1,   1],
        [ 14,  21,  10,  ...,   1,   1,   1],
        [  2,  70, 160,  ...,   1,   1,   1]], device='cuda:0')

Let's compute the max sequence lenght

In [13]:
print(f'- nr. of training examples {len(train_list)}')
print(f'- nr. of test examples {len(test_list)}')

if max_length < 0.0:
    max_seq_length = 0
    for _, text_batch in train_dataloader:
        # text_batch is a tuple of (label_list, text_list)
        # Here, we're interested in text_list, which is the second item in the tuple
        current_max = text_batch.shape[1]  # Get the sequence length dimension
        if current_max > max_seq_length:
            max_seq_length = current_max
        
    max_seq_length_doubled = max_seq_length * 2
else:
    max_seq_length = max_length

print(f'- maximum sequence length: {max_seq_length}')

- nr. of training examples 25000
- nr. of test examples 25000
- maximum sequence length: 512


# Model and optimizer

We create model and define optimizer.
<span style="background-color: yellow;">**[Code Pointer 6]**</span>

In [14]:
# create the model
model = CTransformer(emb=embedding_size, heads=num_heads, depth=depth, seq_length=max_seq_length, num_tokens=vocab_size, num_classes=NUM_CLS, max_pool=max_pool)
if torch.cuda.is_available():
    model.to(device)

opt = torch.optim.Adam(lr=lr, params=model.parameters())
sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / (lr_warmup / batch_size), 1.0))

# Training

In [15]:
# training loop
seen = 0
for e in range(num_epochs):

    print(f'\n epoch {e}')
    model.train(True)

    for batch in tqdm.tqdm(train_dataloader):

        opt.zero_grad()

        input = batch[1]
        label = batch[0]

        # Limit text lenght
        if input.shape[1] > max_seq_length:
            input = input[:, :max_seq_length]
        out = model(input)

        # Loss function (negative log likelihood)
        loss = F.nll_loss(out, label)

        loss.backward()

        # clip gradients
        # - If the total gradient vector has a length > 1, we clip it back down to 1.
        if gradient_clipping > 0.0:
            nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)

        opt.step()
        sch.step()

        seen += input.size(0)
        # tbw.add_scalar('classification/train-loss', float(loss.item()), seen)

    with torch.no_grad():

        model.train(False)
        tot, cor= 0.0, 0.0

        for batch in tqdm.tqdm(test_dataloader):

            input = batch[1]
            label = batch[0]

            if input.shape[1] > max_seq_length:
                input = input[:, :max_seq_length]

            out = model(input).argmax(dim=1)

            tot += float(input.shape[0])
            cor += float((label == out).sum().item())

        acc = cor / tot
        print(f'-- test accuracy {acc:.3}')
        # tbw.add_scalar('classification/test-loss', float(loss.item()), e)



 epoch 0


100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [02:31<00:00,  1.30it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:57<00:00,  3.43it/s]


-- test accuracy 0.665

 epoch 1


 54%|█████████████████████████████████████████████████████                                              | 105/196 [01:23<01:12,  1.26it/s]


KeyboardInterrupt: 

# Prediction

In [None]:
model_new = torch.load("sent-model-27-epochs.pt")

In [115]:
text = "This is the bed reading group I have ever attended!"
processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64).to(device)
processed_text = torch.unsqueeze(processed_text, 0)
if processed_text.shape[1] > max_seq_length:
    processed_text = processed_text[:, :max_seq_length]

# evaluate model:
model_new.eval()
with torch.no_grad():
    probs = torch.exp(model_new(processed_text))

print(probs)

tensor([[1.1776e-07, 1.0000e+00]], device='cuda:0')
