# Encoder-only transformer model for AG News classification

In this notebook, I train a encoder-only transformer to do text classification on the AG_NEWS dataset.
Text classification seems to be a pretty simple task, and using transformer is probably overkill. But this is my first time implementing the transformer structure from scratch (including the self-attention module), and it was fun :-)

In [None]:
# some commands in th is notebook require torchtext 0.12.0
# !pip install  torchtext --upgrade --quiet
# !pip install torchdata --quiet
# !pip install torchinfo --quiet

In [None]:
import collections
import math
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchtext
import torchdata
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchinfo
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


## Data processing

In [16]:
# One can easily modify the data processing part of this code to accommodate for   other datasets for text classification listed in https://pytorch.org/text/stable/datasets.html#text-classification
from torchtext.datasets import AG_NEWS
train_iter, test_iter = AG_NEWS()
num_classes = len(set([label for (label, text) in train_iter]))
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [17]:
# see an example of the dateset
next(iter(train_iter))

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

In [None]:
# convert the labels to be in range(0, num_classes)
y_train = torch.tensor([label-1 for (label, text) in train_iter])
y_test  = torch.tensor([label-1 for (label, text) in test_iter])

# There are many "\\" in the texts in the AG_news dataset, we get rid of them.
train_iter = ((label, text.replace("\\", " ")) for label, text in train_iter)
test_iter  = ((label, text.replace("\\", " ")) for label, text in test_iter)

# tokenize the texts, and truncate the number of words in each text to max_seq_len
max_seq_len = 100
x_train_texts = [tokenizer(text.lower())[0:max_seq_len]
                 for (label, text) in train_iter]
x_test_texts  = [tokenizer(text.lower())[0:max_seq_len]
                 for (label, text) in test_iter]

In [None]:
# build the vocabulary and word-to-integer map
counter = collections.Counter()
for text in x_train_texts:
    counter.update(text)

vocab_size = 15000
most_common_words = np.array(counter.most_common(vocab_size - 2))
vocab = most_common_words[:,0]

# indexes for the padding token, and unknown tokens
PAD = 0
UNK = 1
word_to_id = {vocab[i]: i + 2 for i in range(len(vocab))}

In [None]:
# map the words in the training and test texts to integers
x_train = [torch.tensor([word_to_id.get(word, UNK) for word in text])
           for text in x_train_texts]
x_test  = [torch.tensor([word_to_id.get(word, UNK) for word in text])
          for text in x_test_texts]
x_test = torch.nn.utils.rnn.pad_sequence(x_test,
                                batch_first=True, padding_value = PAD)

In [None]:
# constructing the dataset in order to be compatible with torch.utils.data.Dataloader
class AGNewsDataset:
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        return self.features[item], self.labels[item]


train_dataset = AGNewsDataset(x_train, y_train)
test_dataset  = AGNewsDataset(x_test, y_test)

In [None]:
# collate_fn to be used in torch.utils.data.DataLoader().
# It pads the texts in each batch such that they have the same sequence length.
def pad_sequence(batch):
    texts  = [text for text, label in batch]
    labels = torch.tensor([label for text, label in batch])
    texts_padded = torch.nn.utils.rnn.pad_sequence(texts,
                                batch_first=True, padding_value = PAD)
    return texts_padded, labels

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True,
                        collate_fn = pad_sequence)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True,
                        collate_fn = pad_sequence)

## Building the encoder-only transformer model for text classification

In [None]:
# from transformer_blocks import Encoder
# One can also import Encoder from transformer_blocks.py in my Github repository.
# I copied the code here so that this notebook is self-contained.  


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_embed, dropout=0.0):
        super(MultiHeadedAttention, self).__init__()
        assert d_embed % h == 0 # check the h number
        self.d_k = d_embed//h
        self.d_embed = d_embed
        self.h = h
        self.WQ = nn.Linear(d_embed, d_embed)
        self.WK = nn.Linear(d_embed, d_embed)
        self.WV = nn.Linear(d_embed, d_embed)
        self.linear = nn.Linear(d_embed, d_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_query, x_key, x_value, mask=None):
        nbatch = x_query.size(0) # get batch size
        # 1) Linear projections to get the multi-head query, key and value tensors
        # x_query, x_key, x_value dimension: nbatch * seq_len * d_embed
        # LHS query, key, value dimensions: nbatch * h * seq_len * d_k
        query = self.WQ(x_query).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        key   = self.WK(x_key).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        value = self.WV(x_value).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        # 2) Attention
        # scores has dimensions: nbatch * h * seq_len * seq_len
        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_embed)
        # 3) Mask out padding tokens and future tokens
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        # p_atten dimensions: nbatch * h * seq_len * seq_len
        p_atten = torch.nn.functional.softmax(scores, dim=-1)
        p_atten = self.dropout(p_atten)
        # x dimensions: nbatch * h * seq_len * d_k
        x = torch.matmul(p_atten, value)
        # x now has dimensions:nbtach * seq_len * d_embed
        x = x.transpose(1, 2).contiguous().view(nbatch, -1, self.d_embed)
        return self.linear(x) # final linear layer


class ResidualConnection(nn.Module):
  '''residual connection: x + dropout(sublayer(layernorm(x))) '''
  def __init__(self, dim, dropout):
      super().__init__()
      self.drop = nn.Dropout(dropout)
      self.norm = nn.LayerNorm(dim)

  def forward(self, x, sublayer):
      return x + self.drop(sublayer(self.norm(x)))

# I simply let the model learn the positional embeddings in this notebook, since this
# almost produces identital results as using sin/cosin functions embeddings, as claimed
# in the original transformer paper. Note also that in the original paper, they multiplied
# the token embeddings by a factor of sqrt(d_embed), which I do not do here.

class Encoder(nn.Module):
    '''Encoder = token embedding + positional embedding -> a stack of N EncoderBlock -> layer norm'''
    def __init__(self, config):
        super().__init__()
        self.d_embed = config.d_embed
        self.tok_embed = nn.Embedding(config.encoder_vocab_size, config.d_embed)
        self.pos_embed = nn.Parameter(torch.zeros(1, config.max_seq_len, config.d_embed))
        self.encoder_blocks = nn.ModuleList([EncoderBlock(config) for _ in range(config.N_encoder)])
        self.dropout = nn.Dropout(config.dropout)
        self.norm = nn.LayerNorm(config.d_embed)

    def forward(self, input, mask=None):
        x = self.tok_embed(input)
        x_pos = self.pos_embed[:, :x.size(1), :]
        x = self.dropout(x + x_pos)
        for layer in self.encoder_blocks:
            x = layer(x, mask)
        return self.norm(x)


class EncoderBlock(nn.Module):
    '''EncoderBlock: self-attention -> position-wise fully connected feed-forward layer'''
    def __init__(self, config):
        super(EncoderBlock, self).__init__()
        self.atten = MultiHeadedAttention(config.h, config.d_embed, config.dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_embed, config.d_ff),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_ff, config.d_embed)
        )
        self.residual1 = ResidualConnection(config.d_embed, config.dropout)
        self.residual2 = ResidualConnection(config.d_embed, config.dropout)

    def forward(self, x, mask=None):
        # self-attention
        x = self.residual1(x, lambda x: self.atten(x, x, x, mask=mask))
        # position-wise fully connected feed-forward layer
        return self.residual2(x, self.feed_forward)


class Transformer(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        self.encoder = Encoder(config)
        self.linear = nn.Linear(config.d_embed, num_classes)

    def forward(self, x, pad_mask=None):
        x = self.encoder(x, pad_mask)
        return  self.linear(torch.mean(x,-2))

In [None]:
@dataclass
class ModelConfig:
    encoder_vocab_size: int
    d_embed: int
    # d_ff is the dimension of the fully-connected  feed-forward layer
    d_ff: int
    # h is the number of attention head
    h: int
    N_encoder: int
    max_seq_len: int
    dropout: float

def make_model(config):
    model = Transformer(config, num_classes).to(DEVICE)
    # initialize model parameters
    # it seems that this initialization is very important!
    for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    return model

## Train the model

In [None]:
def train_epoch(model, dataloader):
    model.train()
    losses, acc, count = [], 0, 0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for idx, (x, y)  in  pbar:
        optimizer.zero_grad()
        features= x.to(DEVICE)
        labels  = y.to(DEVICE)
        pad_mask = (features == PAD).view(features.size(0), 1, 1, features.size(-1))
        pred = model(features, pad_mask)

        loss = loss_fn(pred, labels).to(DEVICE)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        acc += (pred.argmax(1) == labels).sum().item()
        count += len(labels)
        # report progress
        if idx>0 and idx%50 == 0:
            pbar.set_description(f'train loss={loss.item():.4f}, train_acc={acc/count:.4f}')
    return np.mean(losses), acc/count

def train(model, train_loader, test_loader, epochs):
    for ep in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader)
        val_loss, val_acc = evaluate(model, test_loader)
        print(f'ep {ep}: val_loss={val_loss:.4f}, val_acc={val_acc:.4f}')
def evaluate(model, dataloader):
    model.eval()
    losses = []
    with torch.no_grad():
        for x, y in dataloader:
            features = x_test.to(DEVICE)
            labels  = y_test.to(DEVICE)
            pad_mask = (features == PAD).view(features.size(0), 1, 1, features.size(-1))
            pred = model(features, pad_mask)
            loss = loss_fn(pred,labels).to(DEVICE)
            losses.append(loss.item())
            acc = (pred.argmax(1) == labels).sum().item()
            count = len(labels)
    return np.mean(losses), acc/count

In [None]:
config = ModelConfig(encoder_vocab_size = vocab_size,
                     d_embed = 32,
                     d_ff = 4*32,
                     h = 1,
                     N_encoder = 1,
                     max_seq_len = max_seq_len,
                     dropout = 0.1
                     )
model = make_model(config)
print(torchinfo.summary(model))
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

Layer (type:depth-idx)                             Param #
Transformer                                        --
├─Encoder: 1-1                                     --
│    └─Embedding: 2-1                              480,000
│    └─ModuleList: 2-2                             --
│    │    └─EncoderBlock: 3-1                      12,704
│    └─Dropout: 2-3                                --
│    └─LayerNorm: 2-4                              64
├─Linear: 1-2                                      132
Total params: 492,900
Trainable params: 492,900
Non-trainable params: 0


In [None]:
train(model, train_loader, test_loader, epochs=3)

train loss=0.2858, train_acc=0.8969: 100%|██████████| 938/938 [00:08<00:00, 116.30it/s]


ep 0: val_loss=0.2391, val_acc=0.9201


train loss=0.2265, train_acc=0.9404: 100%|██████████| 938/938 [00:08<00:00, 110.53it/s]


ep 1: val_loss=0.2436, val_acc=0.9200


train loss=0.1170, train_acc=0.9524: 100%|██████████| 938/938 [00:08<00:00, 112.94it/s]


ep 2: val_loss=0.2472, val_acc=0.9216


## News classification example

In [None]:
ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

def classify_news(news):
    x_text = tokenizer(news.lower())[0:max_seq_len]
    x_int = torch.tensor([[word_to_id.get(word, UNK) for word in x_text]]).to(DEVICE)

    model.eval()
    with torch.no_grad():
        pred = model(x_int).argmax(1).item() + 1
    print(f"This is a {ag_news_label[pred]} news")

# The model correctly classifies a theoretical physics news as Sci/Tec news, :-)
news = """The conformal bootstrapDavid Poland1,2and David Simmons-Duﬃn2*The conformal bootstrap was
proposed in the 1970s as a strategy for calculating the properties of second-order phasetransitions.
After spectacular success elucidating two-dimensional systems, little progress was made on systems in
 higher dimensions until a recent renaissance beginning in 2008. We report on some of the main results and
  ideas from thisrenaissance, focusing on new determinations of critical exponents and correlation
  functions in the three-dimensional Ising and O(N) models.
"""
classify_news(news)

This is a Sci/Tec news
