# Training an classifier with a transformer encoder 
In this notebook, we'll train the transformer encoder with the imdb dataset. Imdb is a movie review dataset for binary sentiment classification in which we have to classify whether a text reivew is negative (0) or positive (1)



In [1]:
import datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
imdb = datasets.load_dataset('imdb')
imdb

Reusing dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
100%|██████████| 3/3 [00:00<00:00, 1130.95it/s]


DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [13]:
# Below are a summary and several examples of the imdb dataset
from pprint import pprint

print(imdb['train'].features)
print(imdb['train'][0], '\n')
print(imdb['train'][734], '\n')
print(imdb['train'][19375], '\n')
print(imdb['train'][20075], '\n')


{'text': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['neg', 'pos'], id=None)}
{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW 

## Data preparation

We will tokenizer the entire dataset during preprocessing, then batch them and create the labels on the fly in the collate function of the DataLoader

In [3]:
# First we'll download a pretrained tokenizer from huggingface
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer

PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [4]:
# Tokenization with the `batch_encode_plus` function
imdb['train'] = imdb['train'].map(
    lambda x: tokenizer.batch_encode_plus(
        x['text'], padding=False, return_attention_mask=False, truncation=True), 
    batched=True)
imdb['test'] = imdb['test'].map(
    lambda x: tokenizer.batch_encode_plus(
        x['text'], padding=False, return_attention_mask=False, truncation=True), 
    batched=True)

100%|██████████| 25/25 [00:02<00:00, 11.77ba/s]
100%|██████████| 25/25 [00:02<00:00, 12.33ba/s]


In [55]:
print(imdb['train'][0])

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [62]:
dummy_batch = imdb['train'][:4]

In [57]:
dummy_batch['label']

[0, 0, 0, 0]

In [69]:
# We use this padding function to pad the input_ids to the same length
output = tokenizer.pad(dummy_batch)

In [5]:
# The collate_fn is run on each data batch, padding and 
# transforming every data batch to pytorch tensors
import torch
def collate_fn(batch):
    batch = tokenizer.pad(batch)
    return {
        'input_ids': torch.tensor(batch['input_ids']),
        'attn_mask': torch.tensor(batch['attention_mask']),
        'labels': torch.tensor(batch['label'])
    }

In [6]:
# Then we create dataloaders for the train and test splits
# of the dataset for us to iterate over during training

from torch.utils.data import DataLoader
train_dataloader = DataLoader(
    imdb['train'],
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn
)

In [7]:
it = iter(train_dataloader)
next(it)

{'input_ids': tensor([[  101,  2065,  2017,  ...,     0,     0,     0],
         [  101,  2074, 14395,  ...,     0,     0,     0],
         [  101,  2026, 15003,  ...,  3272,  2077,   102],
         [  101,  2023,  2265,  ...,     0,     0,     0]]),
 'attn_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'labels': tensor([1, 0, 0, 0])}

In [8]:
# Let's now try to feed the batched data into our model

from vit.encoder import TransformerEncoderClassifer

classifier = TransformerEncoderClassifer(
    vocab_size=tokenizer.vocab_size,
    d_model=512,
    num_layer=6,
    num_head=8,
    d_k=8,
    dropout_rate=0.1,
    num_class=2
)

In [13]:
# Pick whichever device you want!
device='cuda:0'
# device='cpu'

train_batch = next(it)
classifier.to(device)
output = classifier.forward(
    x=train_batch['input_ids'].to(device),
    attn_mask=train_batch['attn_mask'].to(device)
)

In [14]:
output

tensor([[ 0.2218, -0.1533],
        [ 0.1998, -0.1332],
        [ 0.0444, -0.1459],
        [-0.0451, -0.1249]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [19]:
# Now calculate the Cross Entropy Loss between the output and labels
loss = torch.nn.CrossEntropyLoss()(output, train_batch['labels'].to(device))

In [20]:
# Now make sure the gradient flows through the entire classifier
loss.backward()
for name, param in classifier.named_parameters():
    assert param.grad is not None

In [5]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from vit.encoder import TransformerEncoderClassifer

from transformers import AutoTokenizer
import datasets 
from tqdm import tqdm # tqdm is a python progress bar library

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Training configurations
num_epoch = 100
learning_rate = 1e-5
optimizer = 'adam'
device='cuda:0'

# Data configurations
batch_size=100

# Model configurations
vocab_size=tokenizer.vocab_size
d_model=256
num_layer=2
num_head=8
d_k=32
dropout_rate=0.1
num_class=2

imdb = datasets.load_dataset('imdb')
# imdb

# Tokenization with the `batch_encode_plus` function
imdb['train'] = imdb['train'].map(
    lambda x: tokenizer.batch_encode_plus(
        x['text'], padding=False, return_attention_mask=False, truncation=True), 
    batched=True)
imdb['test'] = imdb['test'].map(
    lambda x: tokenizer.batch_encode_plus(
        x['text'], padding=False, return_attention_mask=False, truncation=True), 
    batched=True)

def collate_fn(batch):
    batch = tokenizer.pad(batch)
    return {
        'input_ids': torch.tensor(batch['input_ids']),
        'attn_mask': torch.tensor(batch['attention_mask']),
        'labels': torch.tensor(batch['label'])
    }

train_dataloader = DataLoader(
    imdb['train'],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers = 24
)

test_dataloader = DataLoader(
    imdb['test'],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers = 24
)

classifier = TransformerEncoderClassifer(
    vocab_size=vocab_size,
    d_model=d_model,
    num_layer=num_layer,
    num_head=num_head,
    d_k=d_k,
    dropout_rate=dropout_rate,
    num_class=num_class
).to(device)

optimizer = Adam(classifier.parameters(), lr=learning_rate)




Reusing dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
100%|██████████| 3/3 [00:00<00:00, 970.98it/s]
Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-e62e61132651b481.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-a39490da13cfcc47.arrow


In [6]:
for i in range(num_epoch):
    # Train
    with tqdm(train_dataloader) as train_epoch:
        for batch_id, batch in enumerate(train_epoch):
            input_ids = batch['input_ids'].to(device)
            attn_mask = batch['attn_mask'].to(device)
            labels = batch['labels'].to(device)
            # if batch_id > 3:
            #     break

            outputs = classifier(
                x=input_ids,
                attn_mask=attn_mask
            )
            loss = torch.nn.CrossEntropyLoss()(outputs, labels)

            predictions = outputs.argmax(dim=1, keepdim=True).squeeze()
            correct = (predictions == labels).sum().item()
            accuracy = correct / batch_size

            loss.backward()
            optimizer.step()

            # print(loss)
            # print(accuracy)
            train_epoch.set_description(f"Training Epoch {i}")
            train_epoch.set_postfix({
                'Loss': loss.item(), 
                'Accuracy': accuracy
            })
    # Validate
    with tqdm(test_dataloader) as test_epoch:
        for batch_id, batch in enumerate(test_epoch):
            # if batch_id > 3:
            #     break
            input_ids = batch['input_ids'].to(device)
            attn_mask = batch['attn_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = classifier(
                x=input_ids,
                attn_mask=attn_mask
            )
            loss = torch.nn.CrossEntropyLoss()(outputs, labels)

            predictions = outputs.argmax(dim=1, keepdim=True).squeeze()
            correct = (predictions == labels).sum().item()
            accuracy = correct / batch_size

            test_epoch.set_description(f"Test Epoch {i}")
            test_epoch.set_postfix({
                'Loss': loss.item(), 
                'Accuracy': accuracy
            })



Training Epoch 0:  33%|███▎      | 82/250 [03:49<07:49,  2.80s/it, Loss=0.694, Accuracy=0.54] 
