#### Sentiment classification with BERT

In [1]:
### Import and prepare dataset

from torchtext.datasets import IMDB
train_ds = IMDB('./data/imdb/train', split='train')
train_ds = list(train_ds)

test_ds = IMDB('./data/imdb/test', split='test')
test_dataset = list(test_ds)

In [2]:
from torch.utils.data.dataset import random_split
train_dataset, valid_dataset = random_split(train_ds, [20000, 5000])

In [3]:
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertForSequenceClassification
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier

In [4]:
import torch

def label_pipeline(label):
    return 1 if label == 'pos' else 0

def collate_batch(batch):
    labels, texts, texts_lenghts = [], [], []
    for label, text in batch:
        labels.append(label_pipeline(label))
        texts.append(text)
    labels = torch.tensor(labels, dtype=torch.long)
    texts = tokenizer(texts, return_tensors='pt', truncation=True, padding=True, max_length=512)   
    return texts, labels   

In [5]:
from torch.utils.data import DataLoader
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=20)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch, num_workers=20)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch, num_workers=20)



In [6]:
from torch.functional import F


def train(model, dataloader, optimizer, device, progress_bar):
    model.train()
    epoch_loss, epoch_acc = 0., 0.
    num_samples = len(dataloader.dataset)
    for input_batch, labels_batch in dataloader:

        text_batch = input_batch['input_ids'].to(device)
        attn_batch = input_batch['attention_mask'].to(device)
        labels_batch = labels_batch.to(device)

        model_output = model(text_batch, attention_mask=attn_batch, labels=labels_batch)
        loss, logits = model_output['loss'], model_output['logits']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        progress_bar.update(1)
        epoch_loss += loss.item() * labels_batch.size(0)
        epoch_acc += (torch.argmax(logits, 1) == labels_batch).float().sum().item()
    return epoch_acc/num_samples, epoch_loss/num_samples    
        
        
def evaluate(model, dataloader, device):
    model.eval()
    epoch_loss, epoch_acc = 0., 0.
    num_samples = len(dataloader.dataset)
    with torch.no_grad():
        for input_batch, labels_batch in dataloader:

            text_batch = input_batch['input_ids'].to(device)
            attn_batch = input_batch['attention_mask'].to(device)
            labels_batch = labels_batch.to(device)

            model_output = model(text_batch, attention_mask=attn_batch, labels=labels_batch)
            loss, logits = model_output['loss'], model_output['logits']

            epoch_loss += loss.item() * labels_batch.size(0)
            epoch_acc += (torch.argmax(logits, 1) == labels_batch).float().sum().item()
    return epoch_acc/num_samples, epoch_loss/num_samples   

In [7]:
from tqdm.auto import tqdm
from torch import nn
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
device = torch.device('cuda')
model.to(device)
epochs = 5
progress_bar = tqdm(range(epochs*len(train_dataloader)))

for epoch in range(epochs):
    train_acc, train_loss = train(model, train_dataloader, optimizer, device, progress_bar)
    valid_acc, valid_loss = evaluate(model, valid_dataloader, device)
    print(f'Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}')

  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch: 0, Train Loss: 0.2698, Train Acc: 0.8872, Valid Loss: 0.2257, Valid Acc: 0.9078
Epoch: 1, Train Loss: 0.1391, Train Acc: 0.9496, Valid Loss: 0.2228, Valid Acc: 0.9148
Epoch: 2, Train Loss: 0.0715, Train Acc: 0.9758, Valid Loss: 0.2652, Valid Acc: 0.9220
Epoch: 3, Train Loss: 0.0408, Train Acc: 0.9869, Valid Loss: 0.2626, Valid Acc: 0.9214
Epoch: 4, Train Loss: 0.0334, Train Acc: 0.9889, Valid Loss: 0.2994, Valid Acc: 0.9134
