In this notebook, I train a encoder-only transformer to classify arXiv titles and abstracts into their main category. This is a simple task, and the purpose of this notebook is to make myself more familiar with the Transformer structure by implemeting it from scratch. 

In [1]:
# import packages
import torch
import json
import torchtext
import collections
import numpy as np
import torch.nn as nn
import math
import copy
import torch.nn.functional as functional
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
from collections import Counter

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cpu


# Data processing

In [5]:
arxiv_url = "https://www.kaggle.com/datasets/Cornell-University/arxiv/download"

In [6]:
import urllib.request
with urllib.request.urlopen(arxiv_url) as f:
    html = f.read().decode('utf-8')

In [None]:
# The arXiv dataset is pretty large (3.3) is can be downloaded from Kaggle:
# https://www.kaggle.com/datasets/Cornell-University/arxiv

data_file = '/Users/hongbinchen/Downloads/arxiv-metadata-oai-snapshot.json'

""" Using `yield` to load the JSON file in a loop to prevent Python memory issues if JSON is loaded directly"""

def get_metadata():
    with open(data_file, 'r') as f:
        for line in f:
            yield line

In [None]:
metadata = get_metadata()
categories = []
for paper in metadata:
    cates = json.loads(paper)['categories']
    categories.append(cates.split(".")[0])

In [None]:
metadata = get_metadata()
counter = Counter()
for paper in metadata:
    cates = json.loads(paper)['categories']
    counter.update([cates.split()[0]])

In [None]:
num_classes = 5
classes = counter.most_common(num_classes)
classes = {classes[i][0]:i for i in range(len(classes))}

In [None]:
import torchtext
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [None]:
metadata = get_metadata()
max_len = 100
dataset = []
max_num_each_class = 10000
num_each_class = {cate: 0 for cate in classes}
for paper in metadata:
    paper_data = json.loads(paper)
    cate = paper_data['categories'].split()[0]
    if cate in classes and num_each_class[cate]<max_num_each_class:
        temp = paper_data['title']+" " + paper_data['abstract']
        temp.replace('\n', ' ')
        temp = tokenizer(temp)[0:max_len]
        dataset.append((temp, classes[cate]))
        num_each_class[cate] += 1

In [None]:
counter = Counter()
for x in dataset:
    counter.update(x[0])
vocab_size = 10000
vocab = counter.most_common(vocab_size-2)
UNK = 1
PAD = 0
word_to_int = {vocab[i][0]:i+2 for i in range(len(vocab))}
def word_to_int_fn(text):
    return torch.tensor([word_to_int.get(word, UNK) for word in text])
dataset = [(word_to_int_fn(text), label) for text, label in dataset]

In [None]:
# split the dataset into 80% training, and 20% test
dataset_size = len(dataset)
train_size = dataset_size*8//10
train_data, test_data = random_split(dataset,
                                     [dataset_size-train_size, train_size],generator=torch.Generator().manual_seed(42)
                                     )

In [None]:
metadata = get_metadata()
categories = []
for paper in metadata:
    authors = json.loads(paper)['authors']
    if 'Hongbin Chen' in authors:
        print(json.loads(paper))

In [None]:
# map the words in the training and test texts to integers
x_train = [x[0] for x in train_data]
y_train = torch.tensor([x[1] for x in train_data])
x_test = [x[0] for x in test_data]
y_test = torch.tensor([x[1] for x in test_data])
#x_test  = [torch.tensor([word_to_ID.get(word, UNK) for word in text])
#          for text in x_test_texts]
#x_test = torch.nn.utils.rnn.pad_sequence(x_test,
#                                batch_first=True, padding_value = PAD)

In [None]:
print(len(y_test))
print(len(y_train))

In [None]:
class ArXivDataset:
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        return self.features[item], self.labels[item]


train_dataset = ArXivDataset(x_train, y_train)
test_dataset = ArXivDataset(x_test, y_test)

In [None]:
def pad_sequence(batch):
    texts  = [text for text, label in batch]
    labels = torch.tensor([label for text, label in batch])
    texts_padded = torch.nn.utils.rnn.pad_sequence(texts,
                                batch_first=True, padding_value = PAD)
    return texts_padded, labels

# each batch returned by dataloader will be padded such that all the texts in
# that batch have the same length as the longest text in that batch
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                        collate_fn = pad_sequence)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True,
                        collate_fn = pad_sequence)

In [None]:
next(iter(test_dataloader))

# Building the encoder-only transformer model for text classification

In [None]:
# one can also replace the MultiHeadedAttention class here with 
# torch.nn.MultiheadAttention provided by pytorch
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0 # check the h number
        self.d_k = d_model//h
        self.d_model = d_model
        self.h = h
        # 4 linear layers: WQ WK WV and final linear mapping WO 
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x_query, x_key, x_value, mask=None):
        nbatches = x_query.size(0) # get batch size
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        # parttion into h sections，switch 2,3 axis for computation. 
        #LHS query, key, value dimensions: nbatch*h*dseq*dk
        #x dimension nbatch*dseq*d_model
        query = self.WQ(x_query).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
        key   = self.WK(x_key).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
        value = self.WV(x_value).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
        # 2) Apply attention on all the projected vectors in batch.
        # query, key, value all have size: nbatch*h*d_seq*d_k
        # scores has size: nbatch*h*d_seq*d_seq
        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_model)
        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)
        p_attn = functional.softmax(scores, dim=-1)
        x = torch.matmul(p_attn, value)
        # 3) "Concat" using a view and apply a final linear. 
        # x dimensions:nbtach*dseq*(h*dk)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linear(x) # final linear layer

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, h, d_model, d_ff, dropout):
        super(EncoderBlock, self).__init__()
        self.self_attn = MultiHeadedAttention(h, d_model, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.self_attn(x, x, x, mask))
        x = self.dropout(x)
        # positionwise feed-forwad 
        return self.norm2(x + self.feed_forward(x))

In [None]:
#Transformer = words embedding + position embedding -> N stack of EncoderBlock ->full connected layer 
class Transformer(nn.Module):
    def __init__(self, encoder_layer, max_len, vocab_size, d_model, dropout, N):
        super(Transformer, self).__init__()
        self.embed = nn.Embedding(vocab_size, d_model) #words embedding
        self.pos_embed = nn.Embedding(max_len, d_model) #position embedding
        self.encoder_layer = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(N)])
        self.linear = nn.Linear(d_model, num_classes)

    def forward(self, input, mask=None):
        x = self.embed(input)
        x_pos = self.pos_embed(torch.tensor(range(input.size(-1))).to(DEVICE))
        x = x + x_pos
        for layer in self.encoder_layer:
            x = layer(x, mask)
        return self.linear(torch.mean(x,-2))

In [None]:
# paramters for the model structure
# d_model is the embedding dimension
d_model = 32
# d_ff is the dimension of the fully-connected layer
d_ff = 32
# h is the number of attention head
h = 2
dropout = 0.1
max_len = max_len
# N is the number of encoder blocks, for the text classification problem in this notebook,
# N = 1 is already enough 
N = 1

model =Transformer(EncoderBlock(h, d_model, d_ff, dropout),
                       max_len, vocab_size, d_model, dropout, N).to(DEVICE)
# initialize model parameters
# it seems that this initialization is very important!
for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [None]:
(torch.tensor([[1,2,4],[2,4,5]])==0).unsqueeze(-2).unsqueeze(1)==0

In [None]:
def train_epoch(model, dataloader):
    model.train()
    total_loss, acc, count = 0,0,0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for idx, (x, y)  in  pbar:
        optimizer.zero_grad()
        features= x.to(DEVICE)
        labels  = y.to(DEVICE)
        pred = model(features, (features==0).unsqueeze(-2).unsqueeze(1).to(DEVICE))
        loss = loss_fn(pred, labels).to(DEVICE)
        loss.backward()
        optimizer.step()

        total_loss += loss
        acc += (pred.argmax(1) == labels).sum().item()
        count += len(labels)
        # report progress
        if idx%50 == 0:
            val_acc, val_loss = evaluate(test_dataloader)
            pbar.set_description(f"Train acc={acc/count:.3f}, Train loss={total_loss.item()/(idx+1):.3f}, test acc = {val_acc:.3f}, test loss= {val_loss:.5f}")

def train(model,dataloader, epochs):
    for ep in range(epochs):
        train_epoch(model,dataloader)

def evaluate(dataloder):
    model.eval()
    total_loss = 0
    total_acc = 0
    count = 0
    with torch.no_grad():
        for i, (x, y) in enumerate(dataloder):
            features= x.to(DEVICE)
            labels  = y.to(DEVICE)
            pred = model(features, (features==0).unsqueeze(-2).unsqueeze(1).to(DEVICE))
            total_loss += loss_fn(pred,labels).to(DEVICE)
            total_acc += (pred.argmax(1) == labels).sum().item()
            count += len(labels)
    return total_acc/count, total_loss/(i+1)

In [None]:
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
hist = train(model, train_dataloader, epochs=5)
# strangely, the test accuracy is higher than the training accuracy during the first epoch

In [None]:
int_to_classes = {item[1]:item[0] for item in classes.items()}

In [None]:
int_to_classes

In [None]:
ex_text = """2D Ising Field Theory in a Magnetic Field: The Yang-Lee Singularity
 We study Ising Field Theory (the scaling limit of Ising model near the Curie critical point) in pure imaginary external magnetic field. We put particular emphasis on the detailed structure of the Yang-Lee edge singularity. While the leading singular behavior is controlled by the Yang-Lee fixed point (= minimal CFT 2/5), the fine structure of the subleading singular terms is determined by the effective action which involves a tower of irrelevant operators. We use numerical data obtained through the "Truncated Free Fermion Space Approach" to estimate the couplings associated with two least irrelevant operators. One is the operator TT¯, and we use the universal properties of the TT¯ deformation to fix the contributions of higher orders in the corresponding coupling parameter α. Another irrelevant operator we deal with is the descendant L−4L¯−4ϕ of the relevant primary in 2/5. The significance of this operator is that it is the lowest dimension operator which breaks integrability of the effective theory. We also establish analytic properties of the particle mass M (= inverse correlation length) as the function of complex magnetic field.
"""

x_ex_text = tokenizer(ex_text.lower())[0:max_len]
x_ex_int = torch.tensor([[word_to_int.get(word, UNK) for word in x_ex_text]]).to(DEVICE)

model.eval()
with torch.no_grad():
    pred = model(x_ex_int).argmax(1).item()

print(f"This is a {int_to_classes[pred]} paper")

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params