In [93]:
import torch as t
import torch.nn as nn
from torch import Tensor
from fancy_einsum import einsum
from einops import rearrange, repeat
import math
from dataclasses import dataclass

def singlehead_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor):
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (b, s, c)
    K: shape (b, s, c)
    V: shape (b, s, c)
    b = batch
    s = seq_len
    c = dims

    Return: shape (b s s)
    '''
    d_k = math.sqrt(Q.shape[-1])
    scaled_dot_prod: Tensor = einsum('b s1 c, b s2 c -> b s1 s2', Q, K) / d_k
    return scaled_dot_prod.softmax(dim=-1) @ V

def masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, mask: t.Tensor):
    '''
    Q: shape (b, s, c)
    K: shape (b, s, c)
    V: shape (b, s, c)
    mask: shape (b, s, s)
    b = batch
    s = seq_len
    c = dims

    Return: shape (b s s)
    '''
    d_k = math.sqrt(Q.shape[-1])
    scaled_dot_prod: Tensor = einsum('b s1 c, b s2 c -> b s1 s2', Q, K) / d_k
    if mask is not None:
        scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e9)
    return scaled_dot_prod.softmax(dim=-1) @ V

def multihead_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, n_heads: int):
    '''
    Q: shape (b, s1, e)
    K: shape (b, s2, e)
    V: shape (b, s2, e)

    e = nheads * h
    b = batch
    s = seq_len
    h = hidden

    Return: shape (b s e)
    '''

    assert Q.shape[-1] % n_heads == 0
    assert K.shape[-1] % n_heads == 0
    assert V.shape[-1] % n_heads == 0
    assert K.shape[-1] == V.shape[-1]

    # mask for autoencoder
    mask = t.triu(t.ones(Q.shape[1], K.shape[1]), diagonal=1).bool()

    Q = rearrange(Q, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
    K = rearrange(K, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
    V = rearrange(V, 'b s (nheads h) -> b nheads s h', nheads=n_heads)

    scaled_dot_prod = einsum('b nheads s1 h, b nheads s2 h -> b nheads s2 s1', K, Q) / math.sqrt(Q.shape[-1])
    # if mask is not None:
    #     if mask.dim() == 2:
    #         mask = repeat(mask, 's1 s2 -> b s1 s2', b=Q.shape[0])
    #     else:
    #         mask = mask.unsqueeze(1)
    #     # print(mask.shape, scaled_dot_prod.shape)
    #     scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 1, -1e9)
    mask_filter = t.triu(t.full_like(scaled_dot_prod, -t.inf), 1)
    scaled_dot_prod += mask_filter
    attention_probs = scaled_dot_prod.softmax(dim=-1)
    attention_vals = einsum('b nheads s1 s2, b nheads s2 c -> b nheads s1 c', attention_probs, V)
    attention = rearrange(attention_vals, 'b nheads s c -> b s (nheads c)')
    return attention

class MultiheadMaskedAttention(nn.Module):

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.W_QKV = nn.Linear(hidden_size, hidden_size * 3)
        self.W_O = nn.Linear(hidden_size, hidden_size)
        self.num_heads = num_heads

    def forward(self, x: t.Tensor, mask=None) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)
        Return: shape (batch, seq, hidden_size)
        '''
        Q, K, V = self.W_QKV(x).chunk(3, dim=-1)
        att = multihead_attention(Q, K, V, self.num_heads)
        return self.W_O(att)


In [94]:
# test single_head_attention
Q = t.arange(2 * 7 * 3).reshape(2, 7, 3).type(t.float32)
K = Q * 0.5
V = Q * 0.8
print(singlehead_attention(Q,K,V))

tensor([[[14.2070, 15.0070, 15.8070],
         [14.3999, 15.1999, 15.9999],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000],
         [14.4000, 15.2000, 16.0000]],

        [[31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000],
         [31.2000, 32.0000, 32.8000]]])


In [95]:
Q = t.linspace(0, 10, 2 * 10 * 1).reshape(2, 10, 1)
K = t.linspace(5, 20, 2 * 10 * 1).reshape(2, 10, 1)
V = t.linspace(15, 2, 2 * 10 * 1).reshape(2, 10, 1)
# b = 2, s = 5, c = 4
multihead_attention(Q, K, V, n_heads=1)

tensor([[[15.0000],
         [14.5878],
         [13.9747],
         [13.2046],
         [12.4225],
         [11.6769],
         [10.9564],
         [10.2500],
         [ 9.5519],
         [ 8.8588]],

        [[ 8.1579],
         [ 7.4807],
         [ 6.7942],
         [ 6.1084],
         [ 5.4231],
         [ 4.7382],
         [ 4.0535],
         [ 3.3690],
         [ 2.6846],
         [ 2.0003]]])

In [96]:
t.manual_seed(420)
m = MultiheadMaskedAttention(6, 2)
x = t.linspace(0, 42, 2 * 3 * 6).reshape(2, 3, 6)
m(x)

tensor([[[ -0.7193,   0.4614,   0.4117,  -0.5813,   0.2754,  -0.5745],
         [ -0.7746,   0.6206,   0.5520,  -0.7370,   0.1787,  -0.7289],
         [ -1.1632,   1.7392,   1.5775,  -1.7907,  -0.5079,  -1.8103]],

        [[  0.0549,  -1.9665, -10.8756,  -7.1792,   3.4559,   0.9521],
         [ -0.3971,  -0.6652,  -9.6883,  -8.4108,   2.6582,  -0.3063],
         [ -0.8686,   0.6920,  -8.4500,  -9.6953,   1.8262,  -1.6189]]],
       grad_fn=<ViewBackward0>)

In [97]:
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    num_heads: int
    vocab_size: int
    hidden_size: int # also embedding dim or d_model
    max_seq_len: int = 5000 
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    device = 'cpu'

config = TransformerConfig(
    num_layers = 6,
    num_heads = 4,
    vocab_size = 10,
    hidden_size = 96
)

In [98]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()

        d = d_model
        L = max_len
        D = d / 2

        angles = t.outer(t.arange(L), 1 / 10000 ** (2 * t.arange(D) / D))

        array_2d = t.zeros((L, d))
        array_2d[:, ::2] = t.sin(angles)
        array_2d[:, 1::2] = t.cos(angles)
        self.encoding = array_2d

    def forward(self, x: Tensor) -> Tensor:
        '''
        x: Tensor, shape [batch, seq_len, embedding_dim]
        ''' 
        batch_size, seq_len, embedding_dim = x.size()
        return self.encoding[:seq_len, :]

In [99]:
# from collections import OrderedDict

# class MultiLayerPerceptron(nn.Module):  

#     def __init__(self, d_in: int, d_out: int):
#         super().__init__()
#         d_h = d_in * 4
#         self.model = nn.Sequential(OrderedDict([
#             ('linear1', nn.Linear(d_in, d_h)),
#             ('GELU', nn.GELU()),
#             ('linear2', nn.Linear(d_h, d_in)),   
#             ('dropout', nn.Dropout(p=0.1))
#         ]))

#     def forward(self, x: t.Tensor):
#         return self.model(x)
        
# class DecoderBlock(nn.Module):

#     def __init__(self, config: TransformerConfig):
#         super().__init__()
#         self.attention = MultiheadMaskedAttention(
#             hidden_size=config.hidden_size,
#             num_heads=config.num_heads
#         )
#         self.layernorm1 = nn.LayerNorm(config.hidden_size)
#         self.layernorm2 = nn.LayerNorm(config.hidden_size)
#         self.mlp = MultiLayerPerceptron(config.hidden_size, config.hidden_size)
    
#     def forward(self, x: t.Tensor):
#         att = self.attention(x) + x
#         h1 = self.layernorm1(att)
#         h2 = self.layernorm2(self.mlp(h1) + h1)
#         return h2

# class DecoderTransformer(nn.Module):

#     def __init__(self, config: TransformerConfig):
#         super().__init__()
#         decoders = [DecoderBlock(config) for i in range(config.num_layers)]
#         names = ['decoder' + str(i) for i in range(config.num_layers)]
#         self.decoderlayer = nn.Sequential(OrderedDict(zip(names, decoders)))
#         self.dropout = nn.Dropout(p=config.dropout)
#         self.layernorm = nn.LayerNorm(config.hidden_size) # why? come back to this later
#         self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
#         self.positional_embedding = PositionalEncoding(config.hidden_size)
#         self.last_linear = nn.Linear(config.hidden_size, config.vocab_size)

#     def forward(self, tokens):
#         embedding = self.embed(tokens) # (seq_len) -> (seq_len, embedding)
#         pos_embedding = self.positional_embedding(embedding)
#         final_embedding = embedding + pos_embedding
#         a = self.dropout(final_embedding)
#         b = self.decoderlayer(a)
#         c = self.layernorm(b) @ self.embed.weight.T
#         return c


In [100]:
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F

class TestDataSet(Dataset):
    """A toy dataset to train a model to predict
     a random sequence of tokens."""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.seq_len = 25
        self.total_size = 1000
        self.text = t.randint(0,config.vocab_size, (self.total_size, self.seq_len)).to(config.device)
        
    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        label = self.text[idx,1:]
        text = self.text[idx,:-1]
        return (text, label)

class ReversedNumbers(Dataset):
    def __init__(self, vocab_size: int, seq_len: int, datasize: int):
        self.seqs = t.randint(0, vocab_size, (datasize, seq_len))

    def __len__(self):
            return len(self.seqs)

    def __getitem__(self, idx):
            input = self.seqs[idx]
            target = t.flip(input, dims=(0,))
            return (input, target)

# class ShakespeareDataset(Dataset):
#     def __init__(self, config):
#         self.data = open('shakespeare.txt', 'r').read()
#         self.config = config
#         chars = sorted(set(self.data))
#         self.vocab_size = len(chars)
#         self.char_to_idx = {ch: i for i, ch in enumerate(chars)}
#         self.idx_to_char = {i: ch for i, ch in enumerate(chars)}
#         print('data has %d characters, %d unique.' % (len(self.data), self.vocab_size))

#     def __getitem__(self, index):
#         x = self.char_to_idx[self.data[index]]
#         x = t.tensor([x])
#         x = F.one_hot(x, num_classes=self.vocab_size)
#         x = x.type(t.FloatTensor)
#         t = self.char_to_idx[self.data[index + (index < (self.__len__() - 1))]]
#         t = t.tensor([t])
#         return (x.to(self.config.device), t.to(self.config.device))

#     def __len__(self):
#         return len(self.data)

#     def params(self):
#         return self.vocab_size, self.char_to_idx, self.idx_to_char

# torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)
# dummy_ds = TestDataSet(config)
# dummy_dl = DataLoader(dummy_ds, batch_size=64, shuffle=True)
nums_ds = ReversedNumbers(vocab_size=10, seq_len=6, datasize=10000)
train_ds, val_ds = random_split(nums_ds, [8000, 2000])
nums_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
nums_tl = DataLoader(val_ds, batch_size=64)

In [101]:
from torch import optim
from impl.transformer_modules import DecoderTransformer

criterion = nn.CrossEntropyLoss()
model = DecoderTransformer(config)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

accuracy_list = []

for epoch in range(50):  # loop over the dataset multiple times
    accuracy = 0
    total = 0

    running_loss = 0.0
    for i, data in enumerate(nums_dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(
            rearrange(outputs, 'batch seq vocab -> batch vocab seq'),
            labels
        )
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 19:    # print every 20 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.5f}')
            running_loss = 0.0

        
        
    for (x, y) in nums_tl:

        x = x.to(config.device)
        y = y.to(config.device)

        y_hat = model(x)
        y_predictions = y_hat.argmax(2)
        accuracy += (y_predictions == y).sum().item()
        total += y.size(0) * 6

        accuracy_list.append(accuracy/total)
    print(f'accuracy: {accuracy/total:.3f}')

print('Finished Training')

[1,    20] loss: 0.25672
[1,    40] loss: 0.04750
[1,    60] loss: 0.03223
[1,    80] loss: 0.03015
[1,   100] loss: 0.02905
[1,   120] loss: 0.02844
accuracy: 0.098
[2,    20] loss: 0.02738
[2,    40] loss: 0.02679
[2,    60] loss: 0.02639
[2,    80] loss: 0.02627
[2,   100] loss: 0.02572
[2,   120] loss: 0.02523
accuracy: 0.100
[3,    20] loss: 0.02509
[3,    40] loss: 0.02483
[3,    60] loss: 0.02459
[3,    80] loss: 0.02451
[3,   100] loss: 0.02433
[3,   120] loss: 0.02413
accuracy: 0.113
[4,    20] loss: 0.02397
[4,    40] loss: 0.02402
[4,    60] loss: 0.02398
[4,    80] loss: 0.02380
[4,   100] loss: 0.02377
[4,   120] loss: 0.02358
accuracy: 0.123
[5,    20] loss: 0.02349
[5,    40] loss: 0.02344
[5,    60] loss: 0.02333
[5,    80] loss: 0.02330
[5,   100] loss: 0.02333
[5,   120] loss: 0.02321
accuracy: 0.134
[6,    20] loss: 0.02311
[6,    40] loss: 0.02304
[6,    60] loss: 0.02301
[6,    80] loss: 0.02286
[6,   100] loss: 0.02285
[6,   120] loss: 0.02288
accuracy: 0.174
[7, 

In [102]:
print(inputs.shape)
print(labels.shape)
print(outputs.shape)

torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 6, 10])


In [103]:
print(outputs.shape)
print(outputs.transpose(1,2).shape)

torch.Size([64, 6, 10])
torch.Size([64, 10, 6])


In [104]:
arr = t.randint(1,10, (1, 6))
print(arr)
model(arr).argmax(dim=-1)

tensor([[1, 3, 2, 5, 7, 9]])


tensor([[8, 4, 9, 2, 3, 1]])

In [105]:
for x, y in nums_dl:
    print(x[0])
    print(y[0])
    break

tensor([7, 9, 8, 5, 5, 6])
tensor([6, 5, 5, 8, 9, 7])
