In [None]:
import torch, math
torch.set_printoptions(precision=2, sci_mode=False, linewidth=200)
n_tokens = 100
memory = torch.triu(torch.ones(n_tokens, n_tokens), diagonal=0)
memory = memory * torch.randint_like(memory, 0, n_tokens)
memory = memory + memory.T - torch.diag(memory.diagonal())
memory = memory.long()
memory

tensor([[60, 48,  2,  ..., 92, 34, 22],
        [48, 71,  8,  ..., 16, 57, 69],
        [ 2,  8, 50,  ..., 73, 23,  7],
        ...,
        [92, 16, 73,  ..., 52, 89, 27],
        [34, 57, 23,  ..., 89, 15,  2],
        [22, 69,  7,  ..., 27,  2, 33]])

In [None]:
import torch
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, dim_feedforward, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.dim_feedforward = dim_feedforward

        self.self_attn = nn.MultiheadAttention(d_model, n_head, batch_first=True, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, tgt, skip_feedforward=False, skip_self_attn=False, linear_mask=None, history=None, custom_attention=False, head_mask=None):
        hist = {}
        attn_heads = None
        if not skip_self_attn:
            # mask = torch.triu(torch.ones(tgt.shape[1], tgt.shape[1]), diagonal=1).bool().to(self.device)
            # tgt2, attn_heads = hist['self_attn_non_residual'] = self.self_attn(tgt, tgt, tgt, attn_mask=mask, average_attn_weights=False)
            tgt2 = hist['self_attn_non_residual'] = self.attn_forward(tgt, self.self_attn, custom_attention, hist, head_mask)
            tgt = hist['self_attn'] = tgt + tgt2
        tgt = hist['norm1'] = self.norm1(tgt)
        if self.dim_feedforward > 0 and not skip_feedforward:
            tgt2 = hist['linear1'] = nn.functional.relu(self.linear1(tgt))
            tgt2 = hist['linear1_dropout'] = self.dropout(tgt2)
            if linear_mask is not None:
                tgt2 = tgt2 * linear_mask
            tgt2 = hist['linear2_non_residual'] = self.linear2(tgt2)
            tgt = hist['linear2'] = tgt + tgt2
        tgt = hist['norm2'] = self.norm2(tgt)
        return tgt if history is None else hist[history], attn_heads

    def attn_forward(self, x, attn, custom_attention, history, head_mask=None):
        attn_mask = torch.tril(torch.ones(x.shape[1], x.shape[1]), diagonal=0).to(self.device)
        if not custom_attention:
            return attn(x, x, x, attn_mask=attn_mask)[0]

        batch_size = x.shape[0]
        x = x.transpose(0, 1)
        # this is just torch's attention but expanded so we can modify it
        proj = F.linear(x, attn.in_proj_weight, attn.in_proj_bias)
        proj = proj.unflatten(-1, (3, self.d_model)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
        q, k, v = proj[0], proj[1], proj[2]
        q = q.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        k = k.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        v = v.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)

        history.update({
            'q': q,
            'k': k,
            'v': v
        })

        attn_mask = attn_mask.masked_fill(attn_mask == False, float('-inf'))

        attn_output = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + attn_mask
        attn_output = F.softmax(attn_output, dim=-1)
        attn_output = torch.matmul(attn_output, v)
        attn_output = attn_output.permute(2, 0, 1, 3).contiguous()  # [seq_len, batch_size, n_head, d_model // n_head]
        # apply head mask
        if head_mask is not None:
            attn_output = attn_output * head_mask[None, None, :, None]

        history['attn_output'] = attn_output

        attn_output = attn_output.flatten(-2, -1)
        attn_output = F.linear(attn_output, attn.out_proj.weight, attn.out_proj.bias)
        return attn_output.transpose(0, 1)

class ToyTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_head, hidden_size, n_tokens, max_len, dropout=0.0):
        super().__init__()
        self.n_layers = n_layers
        self.d_model = d_model
        self.n_head = n_head
        self.hidden_size = hidden_size
        self.tokens = list(range(n_tokens))
        self.max_len = max_len

        self.embed = nn.Embedding(n_tokens, embedding_dim=d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, n_head=n_head, dim_feedforward=hidden_size)
            for _ in range(n_layers)
        ])
        self.unembed = nn.Linear(d_model, n_tokens)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(
        self,
        x,
        skip_feedforward=False,
        skip_self_attn=False,
        return_before_embedding=False,
        linear_mask=None,
        history=None,
        return_attn_weights=False,
        custom_attention=False,
        head_mask=None,
        skip_layers=None):
        if skip_layers is None:
            skip_layers = []
        if head_mask is not None:
            custom_attention = True
        tgt = self.embed(x)
        tgt = F.pad(tgt, (0, 0, 0, 1))  # [batch_size, seq_len + 1, d_model]
        for i, layer in enumerate(self.layers):
            if i in skip_layers:
                continue
            tgt, attn_heads = layer(
                tgt,
                skip_feedforward=skip_feedforward,
                skip_self_attn=skip_self_attn,
                linear_mask=linear_mask,
                history=history,
                custom_attention=custom_attention,
                head_mask=head_mask)
            if history is not None:
                return tgt
        if return_before_embedding:
            return tgt
        x = self.unembed(tgt)
        if return_attn_weights:
            return x, attn_heads
        return x

    def train(self, lr=1e-3, batch_size=128, n_epochs=1000, eval_every=1000):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0
        for i in tqdm(range(n_epochs)):
            batch = self.generate_data(batch_size)
            optimizer.zero_grad()
            input = batch[:, :-1]
            output = self(input)
            loss = criterion(output.reshape(-1, len(self.tokens)), batch.reshape(-1))
            loss.backward()
            total_loss += loss.item()
            optimizer.step()

            if (i + 1) % eval_every == 0:
                print('loss: ', total_loss / eval_every)
                total_loss = 0
                samples = 10000
                data = self.generate_data(samples)
                output = (self(data[:,:-1])[:,-1,:].argmax(dim=-1))
                print('Accuracy: ', output.eq(data[:,-1]).sum().item() / samples)

    def generate_data(self, batch_size):
        random_indices = torch.randint(0, n_tokens, (batch_size, 2))  # [batch_size, 2]
        next_tokens = memory[random_indices[:, 0], random_indices[:, 1]].unsqueeze(1)  # [batch_size, 1]
        tensor = torch.cat([random_indices, next_tokens], dim=1)
        return tensor.to(self.device)



In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
for num_layers in range(1, 13):
    print(f'==================== {num_layers} layers ====================')
    model = ToyTransformer(n_layers=num_layers, d_model=16, n_head=4, hidden_size=64, n_tokens=n_tokens, max_len=3, dropout=0.0).to(device)
    model.train(lr=1e-3, n_epochs=20000, batch_size=300, eval_every=4000)
    samples = 1000
    data = model.generate_data(samples)
    output = (model(data[:,:-1])[:,-1,:].argmax(dim=-1))
    print('Accuracy: ', output.eq(data[:,-1]).sum().item() / samples)




 20%|██        | 4046/20000 [00:15<00:56, 281.02it/s]

loss:  1.631412913262844
Accuracy:  0.0651


 40%|████      | 8036/20000 [00:29<00:40, 293.74it/s]

loss:  1.3738710118830204
Accuracy:  0.0954


 60%|██████    | 12049/20000 [00:44<00:27, 293.25it/s]

loss:  1.3161240410506725
Accuracy:  0.1189


 80%|████████  | 16034/20000 [00:59<00:14, 272.01it/s]

loss:  1.280550512045622
Accuracy:  0.1433


100%|██████████| 20000/20000 [01:13<00:00, 270.54it/s]


loss:  1.2555710887610911
Accuracy:  0.1532
Accuracy:  0.163


 20%|██        | 4030/20000 [00:24<01:27, 182.38it/s]

loss:  1.5372524098157883
Accuracy:  0.1433


 40%|████      | 8025/20000 [00:48<01:04, 185.80it/s]

loss:  1.147566702529788
Accuracy:  0.2398


 60%|██████    | 12027/20000 [01:11<00:59, 133.78it/s]

loss:  1.0312764386683702
Accuracy:  0.2882


 80%|████████  | 16031/20000 [01:34<00:22, 179.27it/s]

loss:  0.9688733783811331
Accuracy:  0.3074


100%|██████████| 20000/20000 [01:57<00:00, 169.77it/s]


loss:  0.9274852876663208
Accuracy:  0.3333
Accuracy:  0.3


 20%|██        | 4025/20000 [00:31<01:59, 133.71it/s]

loss:  1.4659952946603299
Accuracy:  0.2091


 40%|████      | 8023/20000 [01:04<01:31, 130.26it/s]

loss:  0.9536597099006175
Accuracy:  0.3684


 60%|██████    | 12022/20000 [01:36<00:59, 133.80it/s]

loss:  0.7983411356806756
Accuracy:  0.4355


 80%|████████  | 16010/20000 [02:08<00:42, 94.43it/s]

loss:  0.7202694138735533
Accuracy:  0.4673


100%|██████████| 20000/20000 [02:40<00:00, 124.33it/s]


loss:  0.672420081987977
Accuracy:  0.4703
Accuracy:  0.448


 20%|██        | 4010/20000 [00:41<02:37, 101.54it/s]

loss:  1.396420836225152
Accuracy:  0.2972


 40%|████      | 8013/20000 [01:22<01:55, 103.93it/s]

loss:  0.8341145663261413
Accuracy:  0.4439


 60%|██████    | 12014/20000 [02:03<01:18, 102.16it/s]

loss:  0.6473211811929941
Accuracy:  0.5288


 80%|████████  | 16009/20000 [02:44<00:40, 99.58it/s] 

loss:  0.5527846050560474
Accuracy:  0.5788


100%|██████████| 20000/20000 [03:24<00:00, 97.80it/s] 


loss:  0.4952553524374962
Accuracy:  0.6159
Accuracy:  0.627


 20%|██        | 4005/20000 [00:48<04:20, 61.39it/s]

loss:  1.316667012169957
Accuracy:  0.3759


 40%|████      | 8007/20000 [01:37<02:21, 84.54it/s]

loss:  0.6766492197215557
Accuracy:  0.558


 60%|██████    | 12012/20000 [02:27<01:32, 85.92it/s]

loss:  0.4851742291003466
Accuracy:  0.6307


 80%|████████  | 16015/20000 [03:15<00:46, 84.81it/s]

loss:  0.3895584967508912
Accuracy:  0.6951


100%|██████████| 20000/20000 [04:05<00:00, 81.60it/s]


loss:  0.3339533253759146
Accuracy:  0.7039
Accuracy:  0.693


 20%|██        | 4013/20000 [00:57<03:36, 73.98it/s]

loss:  1.2563081759661436
Accuracy:  0.4466


 40%|████      | 8006/20000 [01:55<03:51, 51.80it/s]

loss:  0.5224560948610306
Accuracy:  0.6158


 60%|██████    | 12007/20000 [02:53<01:52, 70.76it/s]

loss:  0.32827521914616226
Accuracy:  0.7672


 80%|████████  | 16011/20000 [03:50<00:54, 73.36it/s]

loss:  0.2486867581754923
Accuracy:  0.785


100%|██████████| 20000/20000 [04:48<00:00, 69.41it/s]


loss:  0.20571891766786576
Accuracy:  0.7953
Accuracy:  0.816


 20%|██        | 4011/20000 [01:06<04:12, 63.30it/s]

loss:  1.1848834337890148
Accuracy:  0.5291


 40%|████      | 8008/20000 [02:13<03:09, 63.25it/s]

loss:  0.41961691026389597
Accuracy:  0.773


 60%|██████    | 12007/20000 [03:19<02:06, 62.96it/s]

loss:  0.23697573203220965
Accuracy:  0.8332


 80%|████████  | 16006/20000 [04:26<01:04, 62.30it/s]

loss:  0.17684553435631095
Accuracy:  0.8519


100%|██████████| 20000/20000 [05:32<00:00, 60.07it/s]


loss:  0.14684386380389333
Accuracy:  0.9335
Accuracy:  0.935


 20%|██        | 4003/20000 [01:15<06:52, 38.78it/s]

loss:  1.1455138104557991
Accuracy:  0.5752


 40%|████      | 8006/20000 [02:32<03:40, 54.28it/s]

loss:  0.32527873638272287
Accuracy:  0.8202


 60%|██████    | 12004/20000 [03:48<02:31, 52.84it/s]

loss:  0.17135150873474778
Accuracy:  0.873


 80%|████████  | 16007/20000 [05:03<01:13, 54.38it/s]

loss:  0.13314629878662526
Accuracy:  0.9462


100%|██████████| 20000/20000 [06:19<00:00, 52.63it/s]


loss:  0.11465606991783715
Accuracy:  0.8864
Accuracy:  0.88


 20%|██        | 4003/20000 [01:25<05:37, 47.36it/s]

loss:  1.0793773514777423
Accuracy:  0.6543


 40%|████      | 8006/20000 [02:51<05:52, 34.02it/s]

loss:  0.2517672010529786
Accuracy:  0.889


 60%|██████    | 12005/20000 [04:16<02:42, 49.20it/s]

loss:  0.12148481867276133
Accuracy:  1.0


 80%|████████  | 16005/20000 [05:41<01:22, 48.38it/s]

loss:  0.10057276537118014
Accuracy:  0.8864


100%|██████████| 20000/20000 [07:05<00:00, 46.96it/s]


loss:  0.05848035830832669
Accuracy:  1.0
Accuracy:  1.0


 20%|██        | 4006/20000 [01:34<06:02, 44.10it/s]

loss:  1.0363577803596855
Accuracy:  0.7186


 40%|████      | 8006/20000 [03:08<04:32, 44.06it/s]

loss:  0.2088630115222186
Accuracy:  0.901


 60%|██████    | 12003/20000 [04:42<03:20, 39.91it/s]

loss:  0.09540297432860825
Accuracy:  0.8412


 80%|████████  | 16005/20000 [06:17<01:32, 43.14it/s]

loss:  0.08047008900373476
Accuracy:  1.0


100%|██████████| 20000/20000 [07:54<00:00, 42.18it/s]


loss:  0.10299747950305754
Accuracy:  0.9246
Accuracy:  0.94


 20%|██        | 4006/20000 [01:47<06:57, 38.28it/s]

loss:  0.9853978796005249
Accuracy:  0.7573


 40%|████      | 8004/20000 [03:35<05:21, 37.29it/s]

loss:  0.16713561463262885
Accuracy:  0.9657


 60%|██████    | 12004/20000 [05:23<03:27, 38.56it/s]

loss:  0.09698200636162073
Accuracy:  0.9225


 80%|████████  | 16005/20000 [07:10<01:43, 38.50it/s]

loss:  0.017996264994348168
Accuracy:  1.0


100%|██████████| 20000/20000 [08:55<00:00, 37.33it/s]


loss:  0.15527878299837175
Accuracy:  1.0
Accuracy:  1.0


 20%|██        | 4005/20000 [01:50<07:34, 35.17it/s]

loss:  0.9647033274360001
Accuracy:  0.7139


 40%|████      | 8002/20000 [03:47<08:46, 22.77it/s]

loss:  0.15997169777564704
Accuracy:  0.9551


 60%|██████    | 12005/20000 [05:42<03:54, 34.12it/s]

loss:  0.01854490770769189
Accuracy:  1.0


 80%|████████  | 16003/20000 [07:39<01:59, 33.44it/s]

loss:  0.17118029225134523
Accuracy:  0.8158


100%|██████████| 20000/20000 [09:34<00:00, 34.81it/s]

loss:  0.006669308490279946
Accuracy:  1.0
Accuracy:  1.0



