In [86]:
import os
import pandas as pd
import torch
from torch import nn
import numpy as np
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import bert

In [21]:
data_path = os.path.expanduser("~/data/bbc-text.csv")

In [45]:
df = pd.read_csv(data_path)
df.head()

Unnamed: 0,category,text
0,tech,tv future in the hands of viewers with home th...
1,business,worldcom boss left books alone former worldc...
2,sport,tigers wary of farrell gamble leicester say ...
3,sport,yeading face newcastle in fa cup premiership s...
4,entertainment,ocean s twelve raids box office ocean s twelve...


In [75]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
labels = {'business':0,
          'entertainment':1,
          'sport':2,
          'tech':3,
          'politics':4
          }

In [36]:
example_text = 'I will watch Memento tonight'
bert_input = tokenizer(example_text,padding='max_length', max_length=64, 
                       truncation=True, return_tensors="pt")

In [95]:
bert_input

{'input_ids': tensor([[  101,   146,  1209,  2824,  2508, 26173,  3568,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [47]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):

        self.labels = [labels[label] for label in df['category']]
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for text in df['text']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [69]:
class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 5)
        self.relu = nn.ReLU()
    
    def forward(self, input_ids, mask):
        _, x = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False)
        x = self.linear(self.dropout(x))
        return self.relu(x)

In [70]:
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), 
                                     [int(.8*len(df)), int(.9*len(df))])


In [81]:
def train(model, train_data, val_data, learning_rate, epochs):
    train, val = Dataset(train_data), Dataset(val_data)

    train_dataloader = torch.utils.data.DataLoader(train, batch_size=4, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=4)
    
    device = 'mps'
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    model.to(device)
    
    for epoch_num in range(epochs):
        total_acc_train = 0
        total_loss_train = 0
        
        for train_input, train_label in tqdm(train_dataloader):
            train_label = train_label.to(device)
            mask = train_input["attention_mask"].to(device)
            input_ids = train_input["input_ids"].squeeze(1).to(device)
            
            output = model(input_ids, mask)
            
            batch_loss = criterion(output, train_label.long())
            total_loss_train += batch_loss.item()
            
            acc = (output.argmax(dim=1) == train_label).sum().item()
            total_acc_train += acc
            
            model.zero_grad()
            batch_loss.backward()
            optimizer.step()

        total_acc_val = 0
        total_loss_val = 0
        with torch.no_grad():
            for val_input, val_label in val_dataloader:
                val_label = val_label.to(device)
                mask = val_input['attention_mask'].to(device)
                input_ids = val_input['input_ids'].squeeze(1).to(device)

                output = model(input_ids, mask)

                batch_loss = criterion(output, val_label.long())
                total_loss_val += batch_loss.item()

                acc = (output.argmax(dim=1) == val_label).sum().item()
                total_acc_val += acc

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} \
            | Train Accuracy: {total_acc_train / len(train_data): .3f} \
            | Val Loss: {total_loss_val / len(val_data): .3f} \
            | Val Accuracy: {total_acc_val / len(val_data): .3f}'
        )

EPOCHS = 5
model = BertClassifier()
LR = 1e-6

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [82]:
train(model, df_train, df_val, LR, EPOCHS)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 445/445 [04:20<00:00,  1.71it/s]


Epochs: 1 | Train Loss:  0.404             | Train Accuracy:  0.247             | Val Loss:  0.395             | Val Accuracy:  0.306


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 445/445 [03:19<00:00,  2.23it/s]


Epochs: 2 | Train Loss:  0.362             | Train Accuracy:  0.430             | Val Loss:  0.300             | Val Accuracy:  0.712


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 445/445 [03:17<00:00,  2.25it/s]


Epochs: 3 | Train Loss:  0.206             | Train Accuracy:  0.866             | Val Loss:  0.132             | Val Accuracy:  0.977


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 445/445 [03:16<00:00,  2.26it/s]


Epochs: 4 | Train Loss:  0.109             | Train Accuracy:  0.960             | Val Loss:  0.076             | Val Accuracy:  0.986


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 445/445 [03:17<00:00,  2.26it/s]


Epochs: 5 | Train Loss:  0.066             | Train Accuracy:  0.980             | Val Loss:  0.052             | Val Accuracy:  0.977


In [85]:
def evaluate(model, test_data):

    test = Dataset(test_data)

    test_dataloader = torch.utils.data.DataLoader(test, batch_size=4)

    device = "mps"

    total_acc_test = 0
    
    with torch.no_grad():
        for test_input, test_label in test_dataloader:
            test_label = test_label.to(device)
            mask = test_input['attention_mask'].to(device)
            input_id = test_input['input_ids'].squeeze(1).to(device)

            output = model(input_id, mask)

            acc = (output.argmax(dim=1) == test_label).sum().item()
            total_acc_test += acc
    
    print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')

evaluate(model, df_test)

Test Accuracy:  0.987


In [96]:
num_hiddens, ffn_num_hiddens, num_heads = 512, 32, 4
norm_shape, ffn_num_input, num_layers, dropout = [512], 512, 4, 0.15
my_bert = bert.BERTModel(
        len(tokenizer.vocab),
        num_hiddens,
        norm_shape,
        ffn_num_input,
        ffn_num_hiddens,
        num_heads,
        num_layers,
        dropout,
        max_len=64,
        key_size=512,
        query_size=512,
        value_size=512,
        hid_in_features=512,
        mlm_in_features=512,
        nsp_in_features=512,
)

In [113]:
class MyBertClassifier(nn.Module):
    def __init__(self, my_bert, dropout=0.5):
        super().__init__()
        self.bert = my_bert
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(512, 5)
        self.relu = nn.ReLU()
    
    def forward(self, input_ids, mask):
        _, x = self.bert(input_ids, torch.zeros(input_ids.shape), mask.sum(-1))
        x = self.linear(self.dropout(x))
        return self.relu(x)