In [195]:
import torch as t
import torch.nn as nn
from torch import Tensor
from fancy_einsum import einsum
from einops import rearrange, repeat
import math
from dataclasses import dataclass

def singlehead_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor):
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (b, s, c)
    K: shape (b, s, c)
    V: shape (b, s, c)
    b = batch
    s = seq_len
    c = dims

    Return: shape (b s s)
    '''
    d_k = math.sqrt(Q.shape[-1])
    scaled_dot_prod: Tensor = einsum('b s1 c, b s2 c -> b s1 s2', Q, K) / d_k
    return scaled_dot_prod.softmax(dim=-1) @ V

def masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, mask: t.Tensor):
    '''
    Q: shape (b, s, c)
    K: shape (b, s, c)
    V: shape (b, s, c)
    mask: shape (b, s, s)
    b = batch
    s = seq_len
    c = dims

    Return: shape (b s s)
    '''
    d_k = math.sqrt(Q.shape[-1])
    scaled_dot_prod: Tensor = einsum('b s1 c, b s2 c -> b s1 s2', Q, K) / d_k
    if mask is not None:
        scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e9)
    return scaled_dot_prod.softmax(dim=-1) @ V

def multihead_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, n_heads: int):
    '''
    Q: shape (b, s1, e)
    K: shape (b, s2, e)
    V: shape (b, s2, e)

    e = nheads * h
    b = batch
    s = seq_len
    c = dims

    Return: shape (b s e)
    '''

    # print(Q.shape[-1], n_heads, Q.shape[-1] // n_heads)
    assert Q.shape[-1] % n_heads == 0
    assert K.shape[-1] % n_heads == 0
    assert V.shape[-1] % n_heads == 0
    assert K.shape[-1] == V.shape[-1]

    # mask for autoencoder
    mask = t.triu(t.ones(Q.shape[1], K.shape[1]), diagonal=1).bool()
    # print(f'mask: {mask.shape}')

    Q = rearrange(Q, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
    K = rearrange(K, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
    V = rearrange(V, 'b s (nheads h) -> b nheads s h', nheads=n_heads)

    scaled_dot_prod = einsum('b nheads s1 h, b nheads s2 h -> b nheads s2 s1', K, Q) / math.sqrt(Q.shape[-1])
    if mask is not None:
        if mask.dim() == 2:
            mask = repeat(mask, 's1 s2 -> b s1 s2', b=Q.shape[0])
        else:
            mask = mask.unsqueeze(1)
        # print(mask.shape, scaled_dot_prod.shape)
        scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 1, -1e9)
    attention_probs = scaled_dot_prod.softmax(dim=-1)
    attention_vals = einsum('b nheads s1 s2, b nheads s2 c -> b nheads s1 c', attention_probs, V)
    
    return rearrange(attention_vals, 'b nheads s c -> b s (nheads c)')

class MultiheadMaskedAttention(nn.Module):

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.W_QKV = nn.Linear(hidden_size, hidden_size * 3)
        self.W_O = nn.Linear(hidden_size, hidden_size)
        self.num_heads = num_heads

    def forward(self, x: t.Tensor, mask=None) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)
        Return: shape (batch, seq, hidden_size)
        '''
        Q, K, V = self.W_QKV(x).chunk(3, dim=-1)
        return self.W_O(multihead_attention(Q, K, V, self.num_heads))


In [196]:
# test single_head_attention
Q = t.arange(2 * 7 * 3).reshape(2, 7, 3).type(t.float32)
K = Q * 0.5
V = Q * 0.8
print(singlehead_attention(Q,K,V))

tensor([[[14.2070, 15.0070, 15.8070],
         [14.3999, 15.1999, 15.9999],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000]],

        [[31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000]]])


In [197]:
Q = t.linspace(0, 10, 2 * 5 * 4).reshape(2, 5, 4)
K = t.linspace(5, 20, 2 * 5 * 4).reshape(2, 5, 4)
V = t.linspace(15, 2, 2 * 5 * 4).reshape(2, 5, 4)
# b = 2, s = 5, c = 4
multihead_attention(Q, K, V, n_heads=2)

tensor([[[15.0000, 14.6667, 14.3333, 14.0000],
         [13.7668, 13.4335, 13.0346, 12.7012],
         [12.3451, 12.0117, 11.6705, 11.3372],
         [11.0013, 10.6679, 10.3337, 10.0004],
         [ 9.6668,  9.3335,  9.0000,  8.6667]],

        [[ 8.3333,  8.0000,  7.6667,  7.3333],
         [ 7.0000,  6.6667,  6.3333,  6.0000],
         [ 5.6667,  5.3333,  5.0000,  4.6667],
         [ 4.3333,  4.0000,  3.6667,  3.3333],
         [ 3.0000,  2.6667,  2.3333,  2.0000]]])

In [198]:
t.manual_seed(420)
m = MultiheadMaskedAttention(6, 2)
x = t.linspace(0, 42, 2 * 3 * 6).reshape(2, 3, 6)
m(x)

tensor([[[ -0.7193,   0.4614,   0.4117,  -0.5813,   0.2754,  -0.5745],
         [ -0.7746,   0.6206,   0.5520,  -0.7370,   0.1787,  -0.7289],
         [ -1.1632,   1.7392,   1.5775,  -1.7907,  -0.5079,  -1.8103]],

        [[  0.0549,  -1.9665, -10.8756,  -7.1792,   3.4559,   0.9521],
         [ -0.3971,  -0.6652,  -9.6883,  -8.4108,   2.6582,  -0.3063],
         [ -0.8686,   0.6920,  -8.4500,  -9.6953,   1.8262,  -1.6189]]],
       grad_fn=<ViewBackward0>)

In [199]:
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    num_heads: int
    vocab_size: int
    hidden_size: int # also embedding dim or d_model
    max_seq_len: int = 5000 
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    device = 'cpu'

config = TransformerConfig(
    num_layers = 6,
    num_heads = 2,
    vocab_size = 10,
    hidden_size = 10
)

In [200]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()

        d = d_model
        L = max_len
        D = d / 2

        angles = t.outer(t.arange(L), 1 / 10000 ** (2 * t.arange(D) / D))

        array_2d = t.zeros((L, d))
        array_2d[:, ::2] = t.sin(angles)
        array_2d[:, 1::2] = t.cos(angles)
        self.encoding = array_2d

    def forward(self, x: Tensor) -> Tensor:
        '''
        x: Tensor, shape [batch, seq_len, embedding_dim]
        ''' 
        # print(x.shape)
        batch_size, seq_len, embedding_dim = x.size()
        return self.encoding[:seq_len, :]

In [201]:
from collections import OrderedDict

class MultiLayerPerceptron(nn.Module):  

    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        d_h = d_in * 4
        self.model = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(d_in, d_h)),
            ('GELU', nn.GELU()),
            ('linear2', nn.Linear(d_h, d_in)),   
            ('dropout', nn.Dropout(p=0.1))
        ]))

    def forward(self, x: t.Tensor):
        return self.model(x)
        
class DecoderBlock(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.attention = MultiheadMaskedAttention(
            hidden_size=config.hidden_size,
            num_heads=config.num_heads
        )
        self.layernorm = nn.LayerNorm(config.hidden_size)
        self.mlp = MultiLayerPerceptron(config.hidden_size, config.hidden_size)
    
    def forward(self, x: t.Tensor):
        h1 = self.layernorm(self.attention(x) + x)
        h2 = self.layernorm(self.mlp(h1) + h1)
        return h2

class DecoderTransformer(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        decoders = [DecoderBlock(config) for i in range(config.num_layers)]
        names = ['decoder' + str(i) for i in range(config.num_layers)]
        self.decoderlayer = nn.Sequential(OrderedDict(zip(names, decoders)))
        self.dropout = nn.Dropout(p=config.dropout)
        self.layernorm = nn.LayerNorm(config.hidden_size) # why? come back to this later
        self.embed = lambda x: x# lambda tokens: tokens.unsqueeze(-1) # tokenizer does nothing at the moment
        self.positional_embedding = PositionalEncoding(config.hidden_size)
        self.unembed = lambda x: x #lambda x: x.squeeze # unembed does nothing at the moment

    def forward(self, tokens):
        embedding = self.embed(tokens) # (seq_len) -> (seq_len, embedding)
        pos_embedding = self.positional_embedding(tokens)
        final_embedding = embedding + pos_embedding
        a = self.dropout(final_embedding)
        b = self.decoderlayer(a)
        c = self.layernorm(b)
        d = self.unembed(c)
        return d


In [202]:
from torch.utils.data import Dataset, DataLoader

class TestDataSet(Dataset):
    """A toy dataset to train a model to reverse
     a random sequence of tokens."""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.seq_len = 25
        self.total_size = 1000
        self.text = t.rand((self.seq_len,
                                config.hidden_size)).to(config.device).repeat(self.total_size,1,1)
        # self.labels = t.rand((self.seq_len,
        #                         config.hidden_size)).to(config.device).repeat(self.total_size,1,1)
        
    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        label = self.text[idx,1:]
        text = self.text[idx,:-1]
        sample = {'text': text, 'label': label}
        return sample

# torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)
ds = TestDataSet(config)
dl = DataLoader(ds, batch_size=2, shuffle=True)

In [203]:
from torch import optim

criterion = nn.CrossEntropyLoss()
model = DecoderTransformer(config)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data.values()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

Finished Training
