In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
print(torch.__version__)

1.10.0


# Data

In [2]:
import torchtext

In [3]:
dataset_train, dataset_test = torchtext.datasets.AG_NEWS()
print(len(dataset_train))
print(len(dataset_test))

120000
7600


In [4]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')
train_iter = torchtext.datasets.AG_NEWS(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), min_freq=2,
                                  specials=["[PAD]", "[UNK]", "[CLS]", "[MASK]"])
vocab.set_default_index(vocab["[UNK]"])

In [5]:
#vocab(['here', 'is', 'an', 'example'])
vocab(["the"])
full_vocab = vocab.vocab.get_stoi().keys()
print(len(full_vocab))

53130


In [6]:
import random
import math
import copy

# copied from tutorial, added padding
def text_pipeline(x, max_len, percentage_masked=0.15):
    vocab_list = np.array(vocab(tokenizer(x)))[:max_len]
    k = len(vocab_list)
    missing_len = max_len - k
    missing_list = missing_len * vocab(["[PAD]"])
    # true labels
    labels = np.concatenate([copy.deepcopy(vocab_list), missing_list])
    # padding mask
    mask = np.random.choice([True, False], k, p=[percentage_masked, 1-percentage_masked])
    #print(mask)
    mask_idxs = np.arange(len(vocab_list))[mask]
    random.shuffle(mask_idxs)
    mask_80 = mask_idxs[:math.floor(0.8*len(mask_idxs))]
    mask_10 = mask_idxs[math.floor(0.8*len(mask_idxs)):math.floor(0.9*len(mask_idxs))]
    #print(mask_idxs)
    #print(mask_80)
    #print(mask_10)
    vocab_list[mask_80] = vocab(["[MASK]"])
    # fill remaining 10 percent with random words
    random_words = np.random.choice(list(full_vocab), len(mask_10), replace=True)
    vocab_list[mask_10] = vocab(list(random_words))
    
    return(np.concatenate([vocab_list, missing_list]), labels, np.concatenate([mask, np.zeros(missing_len, dtype=bool)]))

print(text_pipeline('He married Mabel Scott in 1890, but they soon separated. Unable to get an English divorce, in 1900, he became the first celebrity to get one in Nevada, and remarried there, but the divorce was invalid in England. In June 1901, he was arrested for bigamy, and was convicted before the House of Lords, the last time a peer was convicted by the Lords.', 50))

(array([5.2000e+01, 6.6170e+03, 4.7691e+04, 3.0000e+00, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 4.8000e+01, 7.0000e+01, 7.4700e+02,
       3.0000e+00, 4.0000e+00, 2.6377e+04, 7.0000e+00, 2.2500e+02,
       3.3000e+01, 1.8870e+03, 1.3152e+04, 3.0000e+00, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 5.2000e+01, 1.3610e+03, 5.0000e+00,
       5.0000e+01, 7.9570e+03, 7.0000e+00, 2.2500e+02, 6.4000e+01,
       1.0000e+01, 6.4220e+03, 3.0000e+00, 1.1000e+01, 1.0000e+00,
       2.3200e+02, 3.0000e+00, 4.8000e+01, 5.0000e+00, 1.3152e+04,
       3.8000e+01, 1.5491e+04, 1.0000e+01, 3.1800e+02, 4.0000e+00,
       1.0000e+01, 3.0000e+00, 3.0441e+04, 6.0000e+00, 5.2000e+01]), array([5.2000e+01, 6.6170e+03, 4.7691e+04, 2.6430e+03, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 4.8000e+01, 7.0000e+01, 7.4700e+02,
       1.0618e+04, 4.0000e+00, 4.3080e+03, 7.0000e+00, 2.2500e+02,
       3.3000e+01, 1.8870e+03, 1.3152e+04, 6.0000e+00, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 5.2000e+01, 1.3610e+03, 5.00

In [7]:
# copied from tutorial, removed offsets
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

def collate_batch(batch):
    label_list, text_list, mask_list = [], [], []
    for _, text in batch:
        input_, label_, mask_ = text_pipeline(text, max_len=100, percentage_masked=0.15)
        text_list.append(torch.tensor(input_, dtype=torch.int64))
        label_list.append(label_)
        mask_list.append(torch.tensor(mask_))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list).view(len(label_list), -1)
    mask_list = torch.cat(mask_list).view(len(label_list), -1)
    return text_list.to(device), label_list.to(device), mask_list

train_iter = dataset_train
BATCH_SIZE = 32
dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_batch)

cpu


# Architecture

In [8]:
### build classifier transformer
import torch.nn.functional as F

class MyBERT(nn.Module):
    
    def __init__(self, embedding_dim, heads, seq_length, vocab_size, depth=5, num_classes=2):
        super().__init__()

        self.vocab_size = vocab_size
        self.token_emb = nn.Embedding(vocab_size, embedding_dim)
        self.pos_emb = nn.Embedding(seq_length, embedding_dim)
        self.num_heads = heads

        # sequence of transformers
        tblocks = []
        for i in range(depth):
            tblocks.append(nn.TransformerEncoderLayer(d_model=embedding_dim,
                                                            nhead=self.num_heads, 
                                                            batch_first=True, dropout=0.1))
        self.tblocks = nn.Sequential(*tblocks)
        
        # final linear layer
        self.last_linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x, mask_idxs):
        # generate token embeddings
        tokens = self.token_emb(x)
        batch_size, token_size, embed_size = tokens.size()

        # generate position embeddings
        positions = torch.arange(token_size)
        positions = self.pos_emb(positions).expand(batch_size, token_size, embed_size)

        x = tokens + positions
        x = self.tblocks(x)

        # only predict on the masked tokens
        masked_idxs = mask_idxs[:, :, None].expand(-1, -1, x.size(-1))
        #print("x size: ", x.size())
        #print("mask size: ", mask_idxs.size())
        masked_tokens = x[mask_idxs] #torch.gather(x, dim=1, index=masked_pos)
        out = self.last_linear(masked_tokens)
        return out

# Training

In [9]:
my_bert = MyBERT(embedding_dim=30, heads=5, 
                        seq_length=100, vocab_size=len(full_vocab),
                        depth=1)

optimizer = torch.optim.Adam(my_bert.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [10]:
# training
from tqdm.notebook import tqdm

num_epochs = 20

for epoch in range(num_epochs):
    print("epoch: ", epoch)
    training_loss = 0
    for i, (x, y, mask_idxs) in tqdm(enumerate(dataloader), total=len(dataset_train)//BATCH_SIZE):

        #print(x)
        if i > 50: break
        #mask_idxs = mask_idxs
        #print(mask_idxs.size())
        optimizer.zero_grad()
        #print(y.size())
        out = my_bert(x, mask_idxs)
        
        masked_idxs = mask_idxs[:, :, None].expand(-1, -1, y.size(-1))
#        print("x size: ", x.size())
#        print("mask size: ", mask_idxs.size())
        y_masked = y[mask_idxs]
        #y_masked = torch.gather(y, dim=1, index=mask_idxs)
#        loss = criterion(out.transpose(1,2), y_masked)
        loss = criterion(out, y_masked)
        training_loss += loss
        loss.backward()
        optimizer.step()
    
    print("training_loss: ", training_loss)

epoch:  0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3750.0), HTML(value='')))

  label_list = torch.tensor(label_list, dtype=torch.int64)
  label_list = torch.tensor(label_list, dtype=torch.int64)





KeyboardInterrupt: 

In [None]:
# save model
BERT_PATH = "parameters/my_bert_small.pth"
print("saving model at: {}".format(BERT_PATH))
torch.save(my_bert.state_dict(), BERT_PATH)

# Testing

In [None]:
# load model
BERT_PATH = "parameters/my_bert_small.pth"
my_bert = MyBERT(embedding_dim=30, heads=5, 
                        seq_length=100, vocab_size=len(full_vocab),
                        depth=1)

print("loading model from: {}".format(BERT_PATH))
CIFAR10_model.load_state_dict(torch.load(BERT_PATH))