In this notebook, I train a encoder-only transformer to do text classification on the AG_NEWS dataset.
Text classification seems to be a pretty simple task, and using transformer is probably overkill. But this is my first time implementing the transformer structure from scratch (including the self-attention layers), and it was fun :-)

In [None]:
# some commands in th   is notebook require torchtext 0.12.0
!pip install --upgrade torchtext

In [1]:
import torch
import torchtext
import collections
import numpy as np
import torch.nn as nn
import math
import copy
import torch.nn.functional as functional
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchdata

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cpu


# Data processing

In [32]:
# One can easily modify the data processing part of this code to accommodate for   other datasets for text classification listed in https://pytorch.org/text/stable/datasets.html#text-classification
from torchtext.datasets import AG_NEWS
train_iter, test_iter = AG_NEWS()
num_classes = len(set([label for (label, text) in train_iter]))
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [33]:
# see an example of the dateset
next(iter(train_iter))

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

In [6]:
# convert the labels to be in range(0, num_classes)
y_train = torch.tensor([label-1 for (label, text) in train_iter])
y_test  = torch.tensor([label-1 for (label, text) in test_iter])

# There are many "\\" in the texts in the AG_news dataset, we get rid of them.
train_iter = ((label, text.replace("\\", " ")) for label, text in train_iter)
test_iter  = ((label, text.replace("\\", " ")) for label, text in test_iter)

# tokenize the texts, and truncate the number of words in each text to max_seq_len
max_seq_len = 100
x_train_texts = [tokenizer(text.lower())[0:max_seq_len]
                 for (label, text) in train_iter]
x_test_texts  = [tokenizer(text.lower())[0:max_seq_len]
                 for (label, text) in test_iter]

In [7]:
# build the vocabulary and word-to-integer map
counter = collections.Counter()
for text in x_train_texts:
    counter.update(text)

vocab_size = 15000
most_common_words = np.array(counter.most_common(vocab_size - 2))
vocab = most_common_words[:,0]

# indexes for the padding token, and unknown tokens
PAD = 0
UNK = 1
word_to_id = {vocab[i]: i + 2 for i in range(len(vocab))}

In [8]:
# map the words in the training and test texts to integers
x_train = [torch.tensor([word_to_id.get(word, UNK) for word in text])
           for text in x_train_texts]
x_test  = [torch.tensor([word_to_id.get(word, UNK) for word in text])
          for text in x_test_texts]
x_test = torch.nn.utils.rnn.pad_sequence(x_test,
                                batch_first=True, padding_value = PAD)

In [12]:
# constructing the dataset in order to be compatible with torch.utils.data.Dataloader
class AGNewsDataset:
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        return self.features[item], self.labels[item]


train_dataset = AGNewsDataset(x_train, y_train)
test_dataset  = AGNewsDataset(x_test, y_test)

In [17]:
# collate_fn to be used in torch.utils.data.DataLoader().
# It pads the texts in each batch such that they have the same sequence length.
def pad_sequence(batch):
    texts  = [text for text, label in batch]
    labels = torch.tensor([label for text, label in batch])
    texts_padded = torch.nn.utils.rnn.pad_sequence(texts,
                                batch_first=True, padding_value = PAD)
    return texts_padded, labels

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True,
                        collate_fn = pad_sequence)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True,
                        collate_fn = pad_sequence)

# Building the encoder-only transformer model for text classification

In [25]:
# One can also replace the MultiHeadedAttention class here with
# torch.nn.MultiheadAttention provided by pytorch.
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0 # check the h number
        self.d_k = d_model//h
        self.d_model = d_model
        self.h = h
        # 4 linear layers: WQ WK WV and final linear mapping WO
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x_query, x_key, x_value, mask=None):
        nbatches = x_query.size(0) # get batch size
        # 1) Do all the linear projections in batch from d_model => h x d_k
        # parttion into h sections，switch 2,3 axis for computation.
        #LHS query, key, value dimensions: nbatch*h*dseq*dk
        #x dimension nbatch*dseq*d_model
        query = self.WQ(x_query).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
        key   = self.WK(x_key).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
        value = self.WV(x_value).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
        # 2) Apply attention on all the projected vectors in batch.
        # query, key, value all have size: nbatch*h*d_seq*d_k
        # scores has size: nbatch*h*d_seq*d_seq
        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_model)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        p_attn = torch.nn.functional.softmax(scores, dim=-1)
        x = torch.matmul(p_attn, value)
        # 3) "Concat" using a view and apply a final linear.
        # x dimensions:nbtach*dseq*(h*dk)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)

        return self.linear(x) # final linear layer


class ResidualConnection(nn.Module):
  '''residual connection: x + dropout(sublayer(layernorm(x))) '''
  def __init__(self, dim, dropout):
      super().__init__()
      self.drop = nn.Dropout(dropout)
      self.norm = nn.LayerNorm(dim)

  def forward(self, x, sublayer):
      return x + self.drop(sublayer(self.norm(x)))


# Encoder-only Transformer = words embedding + position embedding -> N stack of EncoderBlock ->full connected layer
class Transformer(nn.Module):
    def __init__(self, max_len, vocab_size, h, d_model, dropout, N):
        super().__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model) #words embedding
        self.pos_embed = nn.Embedding(max_len, d_model) #position embedding
        self.encoder_layer = nn.Sequential(
            *[EncoderBlock(h, d_model, 4*d_model) for _ in range(N)]
        )
        self.linear = nn.Linear(d_model, num_classes)

    def forward(self, input, mask=None):
        x = self.embed(input) * math.sqrt(self.d_model)
        x_pos = self.pos_embed(torch.tensor(range(input.size(-1))).to(DEVICE))
        x = x + x_pos
        for layer in self.encoder_layer:
            x = layer(x, mask)
        return self.linear(torch.mean(x,-2))



class EncoderBlock(nn.Module):
    def __init__(self, h, d_model, d_ff, dropout=0.1):
        super(EncoderBlock, self).__init__()
        self.attn = MultiHeadedAttention(h, d_model, dropout)
        #self.attn = nn.MultiheadAttention(d_model, h,  batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.residual1 = ResidualConnection(d_model, dropout)
        self.residual2 = ResidualConnection(d_model, dropout)

    def forward(self, x, mask=None):
        x = self.residual1(x, lambda x: self.attn(x, x, x, mask =mask))
        # positionwise feed-forwad
        x = self.residual2(x, lambda x: self.feed_forward(x))
        return x


In [28]:
# d_model is the embedding dimension
d_model = 32
# h is the number of attention head
# N is the number of encoder blocks
# For the text classification problem in this notebook, h=1 and N=1 are already enough.
h = 1
N = 1

dropout = 0.1
model =Transformer(max_seq_len, vocab_size, h, d_model, dropout, N).to(DEVICE)
# initialize model parameters
# it seems that this initialization is very important!
for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [29]:
def train_epoch(model, dataloader):
    model.train()
    total_loss, acc, count = 0,0,0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for idx, (x, y)  in  pbar:
        optimizer.zero_grad()
        features= x.to(DEVICE)
        labels  = y.to(DEVICE)
        pred = model(features, (features==0).unsqueeze(-2).unsqueeze(1).to(DEVICE))

        loss = loss_fn(pred, labels).to(DEVICE)
        loss.backward()
        optimizer.step()

        total_loss += loss
        acc += (pred.argmax(1) == labels).sum().item()
        count += len(labels)
        # report progress
        if idx%50 == 0:
            val_acc, val_loss = evaluate(x_test, y_test)
            pbar.set_description(f"Train acc={acc/count:.3f}, Train loss={total_loss.item()/(idx+1):.3f}, test acc = {val_acc:.3f}, test loss= {val_loss:.5f}")

def train(model,dataloader, epochs):
    for ep in range(epochs):
        train_epoch(model,dataloader)

def evaluate(x_test, y_test):
    model.eval()
    with torch.no_grad():
        features= x_test.to(DEVICE)
        labels  = y_test.to(DEVICE)
        pred = model(features, (features==0).unsqueeze(-2).unsqueeze(1).to(DEVICE))
        loss = loss_fn(pred,labels).to(DEVICE)
        acc = (pred.argmax(1) == labels).sum().item()
        count = len(labels)
    return acc/count, loss.item()

In [30]:
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
hist = train(model, train_loader, epochs=5)
# strangely, the test accuracy is higher than the training accuracy during the first epoch

Train acc=0.903, Train loss=0.284, test acc = 0.921, test loss= 0.23204: 100%|██████████| 938/938 [01:13<00:00, 12.83it/s]
Train acc=0.944, Train loss=0.159, test acc = 0.920, test loss= 0.24258: 100%|██████████| 938/938 [01:11<00:00, 13.13it/s]
Train acc=0.959, Train loss=0.112, test acc = 0.915, test loss= 0.25909: 100%|██████████| 938/938 [01:13<00:00, 12.74it/s]
Train acc=0.969, Train loss=0.082, test acc = 0.917, test loss= 0.31659: 100%|██████████| 938/938 [01:13<00:00, 12.79it/s]
Train acc=0.977, Train loss=0.061, test acc = 0.912, test loss= 0.38330: 100%|██████████| 938/938 [01:11<00:00, 13.16it/s]


In [32]:

ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

# The model correctly classifies a theoretical physics news as Sci/Tec news, :-)
ex_text = """The conformal bootstrapDavid Poland1,2and David Simmons-Duﬃn2*The conformal bootstrap was
proposed in the 1970s as a strategy for calculating the properties of second-order phasetransitions.
After spectacular success elucidating two-dimensional systems, little progress was made on systems in
 higher dimensions until a recent renaissance beginning in 2008. We report on some of the main results and
  ideas from thisrenaissance, focusing on new determinations of critical exponents and correlation
  functions in the three-dimensional Ising and O(N) models.
"""

x_ex_text = tokenizer(ex_text.lower())[0:max_seq_len]
x_ex_int = torch.tensor([[word_to_id.get(word, UNK) for word in x_ex_text]]).to(DEVICE)

model.eval()
with torch.no_grad():
    pred = model(x_ex_int).argmax(1).item() + 1

print(f"This is a {ag_news_label[pred]} news")

This is a Sci/Tec news
