In [1]:
import torch as t
from torch import nn
from torch.utils.data import Dataset, DataLoader
import plotly.express as px
from IPython.display import display
import pandas as pd
import numpy as np
import copy
from fancy_einsum import einsum
from dataclasses import dataclass
from tqdm.notebook import tqdm_notebook

from einops import rearrange, reduce, repeat

import utils
import cnn_modules as cm
import transformer_modules as tm

In [2]:
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (batch, seq_len, embed_size)
    K: shape (batch, seq_len, embed_size)
    V: shape (batch, seq_len, embed_size)

    Return: shape (batch, seq_len, embed_size)
    '''
    scores = Q @ t.transpose(K, -2, -1)
    scores /= Q.shape[-1] ** 0.5
    scores = t.softmax(scores, dim=-1)
    Z = einsum('B Seq Score, B Seq Emb -> B Seq Emb', scores, V)
    return Z
    

In [3]:
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of masked self-attention.

    See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.

    Q: shape (batch, seq_len, embed_size)
    K: shape (batch, seq_len, embed_size)
    V: shape (batch, seq_len, embed_size)

    Return: shape (batch, seq_len, embed_size)
    '''
    batch_size, seq_len, embed_size = Q.shape
    scores = Q @ t.transpose(K, -2, -1)
    scores /= Q.shape[-1] ** 0.5
    
    # create lower-left triangle of ones, including the diagonal
    mask = t.tril(t.ones(seq_len, seq_len), diagonal=0)
    # fill with close-to-neg-inf values
    scores = scores.masked_fill(mask==0, -1e9)
    
    scores = t.softmax(scores, dim=-1)
    Z = einsum('B Seq Score, B Seq Emb -> B Seq Emb', scores, V)
    return Z
    

In [4]:
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)

    returns: shape (batch, seq, nheads*headsize)
    '''
    Q = rearrange(Q, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
    K = rearrange(K, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
    V = rearrange(V, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)

    batch_size, seq_len, nheads, headsize = Q.shape
    scores = einsum('B Qseq nheads headsize, B Kseq nheads headsize -> B nheads Qseq Kseq', Q, K)
    scores /= Q.shape[-1] ** 0.5

    # create lower-left triangle of ones, including the diagonal
    mask = t.tril(t.ones(seq_len, seq_len), diagonal=0)
    # fill with close-to-neg-inf values where mask==0
    scores = scores.masked_fill(mask==0, -1e9)

    scores = t.softmax(scores, dim=-1)
    Z = einsum('B nheads Qseq Kseq, B Qseq nheads headsize -> B Qseq nheads headsize', scores, V)
    Z = rearrange(Z, 'B Qseq nheads headsize -> B Qseq (nheads headsize)')
    return Z

In [5]:
T = t.randn(2, 10, 4)

In [6]:
multihead_masked_attention(T, T, T, 2)

tensor([[[-1.6461, -1.4714,  0.6447, -0.5935],
         [-0.0633,  1.3094, -0.3534,  0.8460],
         [ 1.1578, -1.1582, -2.2181,  1.2401],
         [-0.2999,  0.8019, -1.3173,  1.1405],
         [ 2.1100,  0.3969,  0.2178, -1.1141],
         [-0.7605, -0.4672, -0.9641,  1.6105],
         [-2.1233,  0.4106, -1.4063,  0.3354],
         [ 0.6793, -1.6125, -0.2777,  1.0745],
         [-0.7756, -0.4430,  0.1071, -1.1027],
         [-0.3306,  0.7261,  0.3428, -0.2202]],

        [[-1.8488,  1.2835,  1.5932,  0.7873],
         [-1.1313, -0.8555,  0.8753,  0.1680],
         [ 0.7252, -0.1598, -0.5458,  1.5460],
         [ 0.2635,  1.0143, -0.4506,  0.6162],
         [ 1.1178,  0.8769,  0.3000, -1.9224],
         [-0.7717,  1.0193,  0.7313,  0.1690],
         [ 0.5073, -1.6599,  1.4223, -0.4154],
         [ 0.0070,  1.1553,  1.8530, -0.3216],
         [ 0.6567, -1.0001,  0.2039,  1.1097],
         [ 0.9693, -0.5122,  0.9608,  0.2677]]])

In [7]:
class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.query_size = int(hidden_size / num_heads)
        self.qkv = cm.Linear(hidden_size, 3*hidden_size)
        self.ff = cm.Linear(hidden_size, hidden_size)

    def multihead_masked_attention(self, Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
        '''
        Implements multihead masked attention on the matrices Q, K and V.

        Q: shape (batch, seq, nheads*headsize)
        K: shape (batch, seq, nheads*headsize)
        V: shape (batch, seq, nheads*headsize)

        returns: shape (batch, seq, nheads*headsize)
        '''
        Q = rearrange(Q, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
        K = rearrange(K, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
        V = rearrange(V, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)

        batch_size, seq_len, nheads, headsize = Q.shape
        scores = einsum('B Qseq nheads headsize, B Kseq nheads headsize -> B nheads Qseq Kseq', Q, K)
        scores /= Q.shape[-1] ** 0.5

        # create lower-left triangle of ones, including the diagonal
        mask = t.tril(t.ones(seq_len, seq_len), diagonal=0)
        # fill with close-to-neg-inf values where mask==0
        scores = scores.masked_fill(mask==0, -1e9)

        scores = t.softmax(scores, dim=-1)
        Z = einsum('B nheads Qseq Kseq, B Qseq nheads headsize -> B Qseq nheads headsize', scores, V)
        Z = rearrange(Z, 'B Qseq nheads headsize -> B Qseq (nheads headsize)')
        return Z

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        out = self.qkv(x)
        Q, K, V = t.tensor_split(out, 3, dim=-1)

        Z = self.multihead_masked_attention(Q, K, V, self.num_heads)

        return self.ff(Z)

In [8]:
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    num_heads: int
    vocab_size: int
    hidden_size: int
    max_seq_len: int
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05

In [9]:
config = TransformerConfig(
    num_layers=4, 
    num_heads=2, 
    vocab_size=500, 
    hidden_size=64,
    max_seq_len=100,
    dropout=0.1)

In [10]:
class MLP(nn.Module):

    def __init__(self, hidden_size, dropout):
        super().__init__()
        self.linear1 = cm.Linear(hidden_size, 4 * hidden_size)
        self.gelu = tm.GELU()
        self.linear2 = cm.Linear(4 * hidden_size, hidden_size)
        self.dropout = tm.Dropout(p=dropout)

    def forward(self, x):
        out = self.gelu(self.linear1(x))
        out = self.dropout(self.linear2(out))
        return out

In [11]:
class DecoderBlock(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.attn = tm.MultiheadMaskedAttention(config.hidden_size, config.num_heads)
        self.lnorm1 = tm.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = tm.MLP(config.hidden_size, config.dropout)
        self.lnorm2 = tm.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(self, x: t.Tensor) -> t.Tensor:
        normed_attn = self.lnorm1(self.attn(x))
        out = normed_attn + x
        normed_mlp = self.lnorm2(self.mlp(out))
        out = normed_mlp + out
        return out


In [12]:
class DecoderOnlyTransformer(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.emb = tm.Embedding(config.vocab_size, config.hidden_size)
        self.pos_enc = tm.PositionalEncoding(config.max_seq_len, config.hidden_size)
        self.dropout = tm.Dropout(p=config.dropout)

        decoders = [DecoderBlock(config) for l in range(config.num_layers)]
        self.decoders = nn.Sequential(*decoders)
        
        self.post_norm = tm.LayerNorm(config.hidden_size)

    def forward(self, x: t.Tensor) -> t.Tensor:
        embedding = self.emb(x.long())
        embedding = self.pos_enc(embedding)
        embedding = self.dropout(embedding)
        embedding = embedding.to(t.float32)

        out = self.decoders(embedding)
        out = self.post_norm(out)

        out = einsum("B S E, V E -> B S V", out, self.emb.weight)

        return out

## Testing Transformer

In [13]:
from torch.utils.data import Dataset

class ReverseNumberDataset(Dataset):
    def __init__(self, seq_len, total_size):
        self.seq_len = seq_len
        self.total_size = total_size

    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        text = t.randint(0, 9, (self.seq_len, ))
        label = text.flip(dims=[0])
        sample = (text, label)
        return sample

In [14]:
num_ds = ReverseNumberDataset(6, 1_000_000)
num_ds[5]

(tensor([7, 3, 4, 3, 7, 7]), tensor([7, 7, 3, 4, 3, 7]))

In [15]:
trainloader = DataLoader(num_ds, batch_size=256, shuffle=True)

In [16]:
config = TransformerConfig(
    num_layers=2, 
    num_heads=4, 
    vocab_size=10, 
    hidden_size=128,
    max_seq_len=6,
    dropout=0.1)

In [17]:
from typing import Callable


epochs = 1
loss_fn = nn.CrossEntropyLoss()
batch_size = 256

MODEL_FILENAME = "./w1d2_transformer_digits.pt"
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

def train_transformer(trainloader: DataLoader, epochs: int, loss_fn: Callable) -> list:
    '''
    Defines a Transformer from our custom modules, and trains it on the reversed digit dataset.
    '''
    model = tm.DecoderOnlyTransformer(config).to(device).train()
    optimizer = t.optim.Adam(model.parameters())
    loss_list = []
    accuracy_list = []

    for epoch in range(epochs):

        progress_bar = tqdm_notebook(trainloader)
        for (x, y) in progress_bar:

            x = x.to(t.float32)
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            print(logits.shape)
            logits = rearrange(logits, 'B S V -> (B S) V')
            print(logits)
            print(y.shape)
            y = rearrange(y, 'B S -> (B S)')


            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_list.append(loss.item())

            with t.inference_mode():
                model.eval()
                preds = model(t.tensor([[1,2,3,4,5,6]]).to(device))
                preds = preds.argmax(dim=-1)

                random_case = t.randint(0, 9, (6, )).unsqueeze(dim=0)
                random_preds = model(random_case.to(device)).argmax(dim=-1)
                random_corrects = random_case.flip(dims=[0])
                accuracy = (random_preds == random_corrects).sum() / len(random_preds)
                model.train()

            progress_bar.set_description(f"Epoch = {epoch}, Preds = {preds.squeeze()}, Loss = {loss.item():.4f}, Accuracy = {accuracy}")

    print(f"Saving model to: {MODEL_FILENAME}")
    t.save(model, MODEL_FILENAME)
    return loss_list, accuracy_list

loss_list, accuracy_list = train_transformer(trainloader, epochs, loss_fn)

fig = px.line(y=loss_list, template="simple_white")
fig.update_layout(title="Cross entropy loss on number sequences", yaxis_range=[0, max(loss_list)])
fig.show()

  0%|          | 0/3907 [00:00<?, ?it/s]

torch.Size([256, 6, 10])
tensor([[-5.9481,  3.5969,  4.7205,  ...,  3.4312, 28.4936, -2.3196],
        [-1.9964,  4.4892,  0.1760,  ...,  6.7848, 31.2959,  1.6064],
        [ 6.3925,  6.2618, -1.2825,  ...,  9.7793,  7.8082,  2.2598],
        ...,
        [-5.2464, 14.2319,  3.8885,  ...,  0.6896,  5.7464, -5.1677],
        [-2.7793,  8.3704, 20.2672,  ...,  2.0676,  7.9293, -4.6318],
        [-4.0373,  8.3074,  0.7646,  ...,  3.7032,  6.5656, -1.0396]],
       grad_fn=<ReshapeAliasBackward0>)
torch.Size([256, 6])
torch.Size([256, 6, 10])
tensor([[-0.7548,  8.6877, 10.6746,  ...,  6.3010,  2.0463,  0.3177],
        [-0.3946,  6.2471, -0.8689,  ...,  4.2615, 17.1879, -3.3923],
        [ 1.1488,  7.1769,  0.8945,  ...,  8.4211, 16.1297, -1.7383],
        ...,
        [15.8172, 11.6096,  2.1616,  ...,  4.1508,  6.9062,  0.2759],
        [ 5.5574,  9.1630,  0.6309,  ...,  3.1758, 17.9395,  0.3557],
        [ 3.9945, 10.7707,  3.8309,  ...,  6.2362,  4.7034, -0.5024]],
       grad_fn=<Resha

KeyboardInterrupt: 