In [1]:
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

from torchtext import data, datasets, vocab

import numpy as np

from argparse import ArgumentParser
from torch.utils.tensorboard import SummaryWriter

import random, tqdm, sys, math, gzip
from pathlib import Path
import shutil

In [2]:
def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device
DEVICE = get_device()

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d:int, d_k=int, d_v=int, heads=int):
        super().__init__()
        #
        self.d = d
        self.d_k = d_k
        self.d_v = d_v
        self.heads = heads        
        #
        self.to_K = nn.Linear(d, d_k * heads, bias=False)
        self.to_Q = nn.Linear(d, d_k * heads, bias=False)
        self.to_V = nn.Linear(d, d_v * heads, bias=False)
        #
        self.concat_heads = nn.Linear(heads * d_v, d_v)

    def forward(self, X):
        
        b, L, d = X.size()
        assert self.d == d
        #
        Q = self.to_Q(X)
        K = self.to_K(X)
        V = self.to_V(X)
        #
        Q = Q.view((b, L, self.heads, self.d_k))
        K = K.view((b, L, self.heads, self.d_k))
        V = V.view((b, L, self.heads, self.d_v))
        #
        assert Q.shape == torch.Size((b, L, self.heads, self.d_k))
        assert K.shape == torch.Size((b, L, self.heads, self.d_k))
        assert V.shape == torch.Size((b, L, self.heads, self.d_v))
        #
        # reshape (b, L, h, d) to (b, h, L, d)
        K = K.transpose(1, 2).contiguous().view(b * self.heads, L, self.d_k)
        Q = Q.transpose(1, 2).contiguous().view(b * self.heads, L, self.d_k)
        V = V.transpose(1, 2).contiguous().view(b * self.heads, L, self.d_v)
        #
        # scale stuff
        # k^(1/4) * k^(1/4) = k^(1/2)
        Q = Q / (self.d_k ** (1/4))
        K = K / (self.d_k ** (1/4))
        # calculate attention
        Z = torch.bmm(Q, K.transpose(1, 2))
        #
        assert Z.size() == (b*self.heads, L, L)
        #
        A = F.softmax(Z, dim=2)
        
        # output
        Y = torch.bmm(A, V).view(b, self.heads, L, self.d_v)
        
        Y = Y.transpose(1, 2).contiguous().view(b, L, self.heads * self.d_v)
        Y = self.concat_heads(Y)
        return Y, A

class TransformerBlock(nn.Module):
    def __init__(self, d:int, heads:int, layer_width=8):
        super().__init__()
        
        self.attention = MultiHeadAttention(d, d, d, heads)
        self.norm1 = nn.LayerNorm(d)
        self.norm2 = nn.LayerNorm(d)
    
        self.mlp = nn.Sequential(
              nn.Linear(d, layer_width * d),
              nn.ReLU(),
              nn.Linear(d * layer_width, d))

    def forward(self, X):
        Y, _ = self.attention(X)
        X = self.norm1(Y + X)
        #
        Y = self.mlp(X)
        X = self.norm2(Y + X)
        
        return X

class Transformer(nn.Module):
    def __init__(self, d:int, heads:int, depth:int, seq_length:int, num_tokens:int, num_classes:int):
        super().__init__()

        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, d)
        self.pos_emb = nn.Embedding(seq_length, d)

        # The sequence of transformer blocks that does all the 
        # heavy lifting
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(d=d, heads=heads))
        self.tblocks = nn.Sequential(*tblocks)

        # Maps the final output sequence to class logits
        self.toprobs = nn.Linear(d, num_classes)

    def forward(self, x):
        """
        :param x: A (b, t) tensor of integer values representing 
                  words (in some predetermined vocabulary).
        :return: A (b, c) tensor of log-probabilities over the 
                 classes (where c is the nr. of classes).
        """
        # generate token embeddings
        tokens = self.token_emb(x)
        b, t, k = tokens.size()

        # generate position embeddings
        positions = torch.arange(t, device=DEVICE)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)
        
        x = tokens + positions
        x = self.tblocks(x)
        
        # Average-pool over the t dimension and project to class 
        # probabilities
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x, dim=1)

## Train a model

In [4]:
# CONFIG
clear_data = False
vocab_size = 50_000
batch_size = 16
epochs = 100
lr = 0.0001
lr_warmup = 10_000
d = 128
max_length = 512 # max sequence length
heads = 8
depth = 8
classes=2
p_base = Path("/home/matthias/projects/Transformer")
p_checkpoints = p_base / "data" / "checkpoints"
p_tb = p_base / "data" / "tensorboard"
p_checkpoint = p_checkpoints / "model_cp_{}.cp"

if clear_data:
    if p_checkpoints.exists():
        shutil.rmtree(p_checkpoints)
    if p_tb.exists():
        shutil.rmtree(p_tb)

p_checkpoints.mkdir(exist_ok=True)
p_tb.mkdir(exist_ok=True)

In [None]:
# NEW CONFIGS, MAYBE BETTER
clear_data = False
vocab_size = 50_000
batch_size = 4
epochs = 80
lr = 0.0001
lr_warmup = 10_000
d = 128
max_length = 512 # max sequence length
heads = 8
depth = 6
classes=2
p_base = Path("/home/matthias/projects/Transformer")
p_checkpoints = p_base / "data" / "checkpoints"
p_tb = p_base / "data" / "tensorboard"
p_checkpoint = p_checkpoints / "model_cp_{}.cp"

if clear_data:
    if p_checkpoints.exists():
        shutil.rmtree(p_checkpoints)
    if p_tb.exists():
        shutil.rmtree(p_tb)

p_checkpoints.mkdir(exist_ok=True)
p_tb.mkdir(exist_ok=True)

In [5]:
# prepare data
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = data.Field(sequential=False)

tdata, _ = datasets.IMDB.splits(TEXT, LABEL)
train, test = tdata.split(split_ratio=0.8)

TEXT.build_vocab(train, max_size=vocab_size - 2) # - 2 to make space for <unk> and <pad>
LABEL.build_vocab(train)

train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=batch_size, device=DEVICE)



In [6]:
tbw = SummaryWriter(log_dir=str(p_tb))

In [7]:
model = Transformer(d=d, heads=heads, depth=depth, seq_length=max_length, num_tokens=vocab_size, num_classes=classes).to(DEVICE)

In [8]:
opt = torch.optim.Adam(lr=lr, params=model.parameters())
sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / (lr_warmup / batch_size), 1.0))

In [None]:
seen = 0
for e in range(epochs):
    print(f'\n epoch {e}')
    model.train(True)

    for batch in tqdm.tqdm(train_iter):

        opt.zero_grad()

        input = batch.text[0]
        label = batch.label - 1

        if input.size(1) > max_length:
            input = input[:, :max_length]
        out = model(input)
        loss = F.nll_loss(out, label)

        loss.backward()
        #
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        #
        opt.step()
        sch.step()
        #
        seen += input.size(0)
        tbw.add_scalar('classification/train-loss', float(loss.item()), seen)
        

    with torch.no_grad():
        model.train(False)
        tot, cor= 0.0, 0.0

        for batch in test_iter:
            input = batch.text[0]
            label = batch.label - 1

            if input.size(1) > max_length:
                input = input[:, :max_length]
            out = model(input).argmax(dim=1)

            tot += float(input.size(0))
            cor += float((label == out).sum().item())
        acc = cor / tot
        tbw.add_scalar('classification/acc', float(acc), e)
        print(f'-- "test" accuracy {acc:.3}')
    # save parameters
    torch.save(model.state_dict(), str(p_checkpoint).format(e))

  0%|          | 0/1250 [00:00<?, ?it/s]


 epoch 0


100%|██████████| 1250/1250 [04:22<00:00,  4.76it/s]
  0%|          | 0/1250 [00:00<?, ?it/s]

-- "test" accuracy 0.683

 epoch 1


 10%|█         | 125/1250 [00:26<04:08,  4.53it/s]

In [None]:
#model.load_state_dict(torch.load(path.format(e)))
#model.eval()