# Question Classification with DistilBert

Different Classes
1.   HUM for questions about humans
2.   ENTY for questions about entities
3.   DESC for questions asking you for a description
4.   NUM for questions where the answer is numerical
5.   LOC for questions where the answer is a location
6.   ABBR for questions asking about abbreviations

# 1. Import Libraries

In [4]:
# Install Transformers Package
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/fd/1a/41c644c963249fd7f3836d926afa1e3f1cc234a1c40d80c5f03ad8f6f1b2/transformers-4.8.2-py3-none-any.whl (2.5MB)
[K     |████████████████████████████████| 2.5MB 7.5MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 46.3MB/s 
Collecting huggingface-hub==0.0.12
  Downloading https://files.pythonhosted.org/packages/2f/ee/97e253668fda9b17e968b3f97b2f8e53aa0127e8807d24a547687423fe0b/huggingface_hub-0.0.12-py3-none-any.whl
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/d4/e2/df3543e8ffdab68f5acc73f613de9c2b155ac47f162e725dcac87c521c11/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB)
[K     |█████

In [5]:
# Libraries

import matplotlib.pyplot as plt
import pandas as pd
import torch

import numpy as np
from torchtext.legacy import data
from torchtext.legacy import datasets
import torch.nn.functional as F

# Preliminaries

from torchtext.legacy.data import Field, LabelField, TabularDataset, BucketIterator, Iterator

# Models

import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel

# Training

import torch.optim as optim

# Evaluation

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns

# 2. Preprocessing--Build Fileds, Create Vocab and Dataloaders

In [6]:
# Load Tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




In [7]:
# Decalre max sequence length
MAX_SEQ_LEN = 64

# Create fields
text_field = Field(sequential=True, use_vocab=False, tokenize=tokenizer.encode, batch_first=True,
                   fix_length = MAX_SEQ_LEN, pad_token=tokenizer.pad_token_id, unk_token=tokenizer.unk_token_id)
label_field = LabelField(sequential=False, batch_first=True)

# Split dataset with fields
train_data, test_data = datasets.TREC.splits(text_field, label_field)
train_data, valid_data = train_data.split([0.8, 0.2])

downloading train_5500.label


train_5500.label: 100%|██████████| 336k/336k [00:00<00:00, 1.32MB/s]


downloading TREC_10.label


TREC_10.label: 100%|██████████| 23.4k/23.4k [00:00<00:00, 362kB/s]


In [8]:
# Create vocab for label_field
label_field.build_vocab(train_data)
#vars(label_field.vocab)

In [9]:
# Create dataloaders with datasets
BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data), 
                                                                            batch_size = BATCH_SIZE, 
                                                                            device = device)

# 3. Model Implementation

In [10]:
class Transformer(nn.Module):
    def __init__(self, output_dim, freeze=True):
        super().__init__()
        self.transformer = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.hidden_dim = int(768)
        self.fc = nn.Linear(self.hidden_dim, 256)
        self.linear = nn.Linear(256, output_dim)
        
        if freeze:
            for param in self.transformer.parameters():
                param.requires_grad = False
        
    def forward(self, ids):
        # ids = [batch size, seq len]
        output = self.transformer(ids, output_attentions=True)
        hidden = output.last_hidden_state
        # hidden = [batch size, seq len, hidden dim]
        attention = output.attentions[-1]
        # attention = [batch size, n heads, seq len, seq len]
        cls_hidden = hidden[:,0,:]
        fced = F.relu(self.fc(torch.tanh(cls_hidden)))
        # prediction = [batch size, output dim]
        return self.linear(fced)

In [12]:
# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 4. Define Training and Evaluating Functions

In [13]:
# Define Train and Evaluate Function

def train(model, device, train_loader, valid_loader, optimizer, criterion, epoch, log_interval=1, dry_run=False, save=False):
    model.to(device)
    criterion.to(device)
    model.train()
    best_loss = np.inf
    for ep in range(1, epoch+1):
        current = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            current += len(data)
            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    ep, current, len(train_loader.dataset),
                    100. * current / len(train_loader.dataset), loss.item()))
                if dry_run:
                    break
        loss, _, _, _ = test(model, device, valid_loader, nn.CrossEntropyLoss(reduction='sum'), validation=True)
        if save:
            if loss < best_loss:
                best_loss = loss
                torch.save(model.state_dict(), 'checkpoint.pt')
                print('***Model Updated***')

def test(model, device, test_loader, criterion, validation=False):
    model.to(device)
    criterion.to(device)
    model.eval()
    test_loss = 0
    correct = 0
    wrongx = []
    wrongy = []
    truey = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # test_loss += criterion(output, target, reduction='sum').item()  # sum up batch loss
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            selector = ~pred.eq(target.view_as(pred)).squeeze(1)
            # print(selector.shape)
            wrongx.append(data[selector])
            wrongy.append(pred[selector])
            truey.append(target[selector])
            #print(selector.shape)
            #print(data[selector].shape)

    test_loss /= len(test_loader.dataset)
    if validation:
        print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    else:
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    
    wrongx = torch.cat( wrongx, dim=0 )
    wrongy = torch.cat( wrongy, dim=0 )
    truey = torch.cat( truey, dim=0 )
    return test_loss, wrongx, wrongy, truey

# 5. Train, Test, and Analysis

In [14]:
# Create Model
model = Transformer(output_dim=6, freeze=False)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
print(f'The model has {count_parameters(model):,} trainable parameters')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


The model has 66,561,286 trainable parameters


In [20]:
# -----------Run only to load model-----------
# LOAD_PATH = 'checkpoint.pt'
# model.load_state_dict(torch.load(LOAD_PATH))
# --------------------------------------------

<All keys matched successfully>

In [15]:
# Train
EPOCH=5
train(model, device, train_iterator, valid_iterator, optimizer, criterion, EPOCH, save=True)
a = test(model, device, test_iterator, nn.CrossEntropyLoss(reduction='sum'))


Validation set: Average loss: 1.3091, Accuracy: 717/1090 (65.78%)

***Model Updated***

Validation set: Average loss: 0.5653, Accuracy: 962/1090 (88.26%)

***Model Updated***

Validation set: Average loss: 0.3448, Accuracy: 1008/1090 (92.48%)

***Model Updated***

Validation set: Average loss: 0.2717, Accuracy: 1025/1090 (94.04%)

***Model Updated***

Validation set: Average loss: 0.2474, Accuracy: 1021/1090 (93.67%)

***Model Updated***

Test set: Average loss: 0.1745, Accuracy: 479/500 (95.80%)



In [16]:
def predict_class(model, sentence, min_len = 4):
    model.eval()
    tokenized = tokenizer(sentence)['input_ids']
    if len(tokenized) < min_len:
        tokenized += [tokenizer.pad_token_id] * (128 - len(tokenized))
    tensor = torch.LongTensor(tokenized).to(device)
    tensor = tensor.unsqueeze(0)
    preds = model(tensor)
    max_preds = preds.argmax(dim = 1)
    return max_preds.item()

In [17]:
# Test custom questions here
sentence = "What does WYSIWYG stand for?"
label_field.vocab.itos[predict_class(model, sentence)]

'ABBR'

In [None]:
'''
Configurations:
    Epoch = 5
    lr = 2e-5
    Accuracy on testset: 95.80 %
'''