<a href="https://colab.research.google.com/github/eugeneteoh/ai-algorithms/blob/transformer/transformer/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Reference:

https://github.com/Atcold/pytorch-Deep-Learning/blob/master/15-transformer.ipynb

https://nn.labml.ai/transformers

In [None]:
%pip install torchdata torchtext pytorch-lightning

In [1]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = self.d_v = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        A = Q @ K.transpose(2, 3)
        A /= np.sqrt(self.d_k)
        A = F.softmax(A, dim=-1)
        attn = A @ V
        return attn

    def forward(self, X_q, X_k, X_v):
        batch_size, seq_length, dim = X_q.shape

        Q = self.W_q(X_q)
        K = self.W_k(X_k)
        V = self.W_v(X_v)

        # Split heads
        Q = Q.view(batch_size, self.num_heads, seq_length, self.d_k)
        K = K.view(batch_size, self.num_heads, seq_length, self.d_k)
        V = V.view(batch_size, self.num_heads, seq_length, self.d_v)

        H_cat = self.scaled_dot_product_attention(Q, K, V)
        H_cat = H_cat.view(batch_size, seq_length, dim)

        out = self.W_o(H_cat)
        return out

mha = MultiHeadAttention()

In [4]:
test_K = torch.tensor(
    [[10, 0, 0],
     [ 0,10, 0],
     [ 0, 0,10],
     [ 0, 0,10]]
).float()[None,None]

test_V = torch.tensor(
    [[   1,0,0],
     [  10,0,0],
     [ 100,5,0],
     [1000,6,0]]
).float()[None,None]

test_Q = torch.tensor(
    [[0, 10, 0]]
).float()[None, None]

test_X_k = torch.randn((1, 1, 512))
test_X_v = torch.randn((1, 1, 512))
test_X_q = torch.randn((1, 1, 512))

# Test scaled_dot_product_attention shape
output = mha.scaled_dot_product_attention(test_Q, test_K, test_V)
assert test_Q.shape == output.shape

# Test mha output shape
output = mha(test_X_q, test_X_k, test_X_v)
assert test_X_q.shape == output.shape

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, ff_hidden_dim=2048):
        super().__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, d_model)
        )

        self.layernorm1 = nn.LayerNorm(normalized_shape=d_model)
        self.layernorm2 = nn.LayerNorm(normalized_shape=d_model)

    def forward(self, x):
        out = self.layernorm1(x + self.mha(x, x, x))
        out = self.layernorm2(out + self.ff(out))
        return out

In [6]:
class Embedding(nn.Module):
    def __init__(self, d_model=512, vocab_size=10000, max_len=5000):
        super().__init__()

        self.word_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=1)
        self.register_buffer("positional_encodings", self.get_positional_encoding(d_model, max_len))
        self.layernorm = nn.LayerNorm(d_model)

    def get_positional_encoding(self, d_model, max_len):
        encodings = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        two_i = torch.arange(0, d_model, 2)
        denominator = 10000 ** (two_i / d_model)
        div = position / denominator
        encodings[:, 0::2] = torch.sin(div)
        encodings[:, 1::2] = torch.cos(div)
        encodings.requires_grad_ = False

        return encodings

    def forward(self, x):
        # seq_length = x.shape[1]
        # position_ids = torch.arange(seq_length, dtype=torch.long, device=x.device) # (max_seq_length)
        # position_ids = position_ids.unsqueeze(0).expand_as(x)                      # (bs, max_seq_length)

        word_embeddings = self.word_embeddings(x)

        embeddings = word_embeddings + self.positional_encodings
        
        return self.layernorm(embeddings)


In [7]:
class Encoder(nn.Module):
    def __init__(self, vocab_size=10000, max_seq_len=5000, num_layers=6, d_model=512, num_heads=8, ff_hidden_dim=2048):
        super().__init__()

        self.embedding_layer = Embedding(d_model=d_model, vocab_size=vocab_size, max_len=max_seq_len)
        self.enc_layers = nn.Sequential(*[EncoderLayer(d_model=d_model, num_heads=num_heads, ff_hidden_dim=ff_hidden_dim) for _ in range(num_layers)])

    def forward(self, x):
        x = self.embedding_layer(x)
        return self.enc_layers(x)


In [8]:
class TransformerClassifier(nn.Module):
    def __init__(self, num_outputs, vocab_size=10000, max_seq_len=5000, num_layers=6, d_model=512, num_heads=8, ff_hidden_dim=2048):
        super().__init__()

        self.encoder = Encoder(
            vocab_size=vocab_size, max_seq_len=max_seq_len, num_layers=num_layers, d_model=d_model, num_heads=num_heads, ff_hidden_dim=ff_hidden_dim
        )
        self.dense = nn.Linear(d_model, num_outputs)

    def forward(self, x):
        x = self.encoder(x)
        x, _ = torch.max(x, dim=1)
        x = self.dense(x)
        return x



In [9]:
from torchtext.datasets import AG_NEWS, IMDB
from collections import Counter
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab
from torchtext.functional import truncate, to_tensor

from torch.utils.data import DataLoader
import torchtext.transforms as T


In [19]:
tokenizer = get_tokenizer('basic_english')  
train_iter = IMDB(split='train')
test_iter = IMDB(split='test')
counter = Counter()
for (label, line) in train_iter:
    counter.update(tokenizer(line))
for (label, line) in test_iter:
    counter.update(tokenizer(line))
data_vocab = vocab(counter, min_freq = 1, specials=('\<unk\>', '\<BOS\>', '\<EOS\>', '\<PAD\>'))

batch_size = 16
max_seq_len = 256

text_transform = T.Sequential(
    T.VocabTransform(data_vocab),
    T.Truncate(max_seq_len),
    T.ToTensor(),
    T.PadTransform(max_seq_len, 1),
)
text_pipeline = lambda x: text_transform(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
apply_transform = lambda x: (label_pipeline(x[0]), text_pipeline(x[1]))

train_iter = train_iter.map(apply_transform)
train_iter = train_iter.batch(batch_size)
train_iter = train_iter.rows2columnar(["target", "token_ids"])
train_loader = DataLoader(train_iter, batch_size=None)

test_iter = test_iter.map(apply_transform)
test_iter = test_iter.batch(batch_size)
test_iter = test_iter.rows2columnar(["target", "token_ids"])
test_loader = DataLoader(test_iter, batch_size=None)



In [11]:
classifier = TransformerClassifier(num_outputs=2, vocab_size=len(data_vocab), max_seq_len=max_seq_len).to(device)

for i, batch in enumerate(train_loader):
    targets = torch.as_tensor(batch["target"], device=device)
    token_ids = torch.stack(batch["token_ids"]).to(device)

    out = classifier(token_ids)
    preds = torch.argmax(out, dim=1)
    # print(preds)
    break
    
    

In [12]:
import pytorch_lightning as pl
from torchmetrics.functional.classification import binary_accuracy

In [13]:
class TransformerClassifierLT(pl.LightningModule):
    def __init__(self, num_outputs, vocab_size=10000, max_seq_len=5000, num_layers=6, d_model=512, num_heads=8, ff_hidden_dim=2048):
        super().__init__()
        self.classifier = TransformerClassifier(
            num_outputs=num_outputs,
            vocab_size=vocab_size,
            max_seq_len=max_seq_len,
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            ff_hidden_dim=ff_hidden_dim
        )

    def training_step(self, batch, batch_idx):
        targets = torch.as_tensor(batch["target"], device=self.device)
        token_ids = torch.stack(batch["token_ids"]).to(self.device)
        
        
        out = self.classifier(token_ids)

        loss = F.cross_entropy(out, targets)

        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        targets = torch.as_tensor(batch["target"], device=self.device)
        token_ids = torch.stack(batch["token_ids"]).to(self.device)

        out = self.classifier(token_ids)
        preds = torch.argmax(out, dim=1)

        test_loss = F.cross_entropy(out, targets)
        test_acc = binary_accuracy(preds, targets)
        self.log("test_loss", test_loss)
        self.log("accuracy", test_acc)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

classifier = TransformerClassifierLT(num_outputs=2, vocab_size=len(data_vocab), max_seq_len=max_seq_len, num_layers=1, d_model=16, num_heads=2, ff_hidden_dim=512)

In [20]:
trainer = pl.Trainer(
    # limit_train_batches=100,
    max_epochs=1,
    # accelerator="gpu",
    # devices=1
)
num_workers = 2
trainer.fit(model=classifier, train_dataloaders=train_loader)
trainer.test(model=classifier, dataloaders=test_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name       | Type                  | Params
-----------------------------------------------------
0 | classifier | TransformerClassifier | 2.4 M 
-----------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.491     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0: : 1563it [03:06,  8.40it/s, loss=0.000214, v_num=1]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: : 1563it [03:06,  8.39it/s, loss=0.000214, v_num=1]


  rank_zero_warn(


Testing DataLoader 0: : 1563it [00:55, 27.98it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        accuracy            0.5001599192619324
        test_loss            4.225028991699219
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 4.225028991699219, 'accuracy': 0.5001599192619324}]