In [2]:
import torch, math
torch.set_printoptions(precision=2, sci_mode=False, linewidth=200)
def generate_memory(n_tokens):
    # memory = torch.randint(0, n_tokens, (n_tokens, n_tokens))
    memory = torch.triu(torch.ones(n_tokens, n_tokens), diagonal=0)
    memory = memory * torch.randint_like(memory, 0, n_tokens)
    memory = memory + memory.T - torch.diag(memory.diagonal())
    memory = memory.long()
    return memory

generate_memory(10)

tensor([[1, 0, 5, 2, 7, 0, 5, 8, 2, 5],
        [0, 0, 1, 4, 0, 0, 0, 5, 6, 2],
        [5, 1, 8, 9, 4, 9, 1, 0, 2, 2],
        [2, 4, 9, 8, 4, 3, 8, 4, 9, 6],
        [7, 0, 4, 4, 8, 1, 6, 4, 0, 1],
        [0, 0, 9, 3, 1, 0, 8, 6, 2, 3],
        [5, 0, 1, 8, 6, 8, 6, 2, 4, 8],
        [8, 5, 0, 4, 4, 6, 2, 1, 0, 0],
        [2, 6, 2, 9, 0, 2, 4, 0, 3, 6],
        [5, 2, 2, 6, 1, 3, 8, 0, 6, 9]])

In [8]:
import torch
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, dim_feedforward, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.dim_feedforward = dim_feedforward

        self.self_attn = nn.MultiheadAttention(d_model, n_head, batch_first=True, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, tgt, skip_feedforward=False, skip_self_attn=False, linear_mask=None, history=None, custom_attention=False, head_mask=None):
        hist = {}
        attn_heads = None
        if not skip_self_attn:
            # mask = torch.triu(torch.ones(tgt.shape[1], tgt.shape[1]), diagonal=1).bool().to(self.device)
            # tgt2, attn_heads = hist['self_attn_non_residual'] = self.self_attn(tgt, tgt, tgt, attn_mask=mask, average_attn_weights=False)
            tgt2 = hist['self_attn_non_residual'] = self.attn_forward(tgt, self.self_attn, custom_attention, hist, head_mask)
            tgt = hist['self_attn'] = tgt + tgt2
        tgt = hist['norm1'] = self.norm1(tgt)
        if self.dim_feedforward > 0 and not skip_feedforward:
            tgt2 = hist['linear1'] = nn.functional.relu(self.linear1(tgt))
            tgt2 = hist['linear1_dropout'] = self.dropout(tgt2)
            if linear_mask is not None:
                tgt2 = tgt2 * linear_mask
            tgt2 = hist['linear2_non_residual'] = self.linear2(tgt2)
            tgt = hist['linear2'] = tgt + tgt2
        tgt = hist['norm2'] = self.norm2(tgt)
        return tgt if history is None else hist[history], attn_heads
    
    def attn_forward(self, x, attn, custom_attention, history, head_mask=None):
        attn_mask = torch.tril(torch.ones(x.shape[1], x.shape[1]), diagonal=0).to(self.device)
        if not custom_attention:
            return attn(x, x, x, attn_mask=attn_mask)[0]

        batch_size = x.shape[0]
        x = x.transpose(0, 1)
        # this is just torch's attention but expanded so we can modify it
        proj = F.linear(x, attn.in_proj_weight, attn.in_proj_bias)
        proj = proj.unflatten(-1, (3, self.d_model)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
        q, k, v = proj[0], proj[1], proj[2]
        q = q.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        k = k.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        v = v.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        
        history.update({
            'q': q,
            'k': k,
            'v': v
        })

        attn_mask = attn_mask.masked_fill(attn_mask == False, float('-inf'))
        
        attn_output = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + attn_mask
        attn_output = F.softmax(attn_output, dim=-1)
        attn_output = torch.matmul(attn_output, v)
        attn_output = attn_output.permute(2, 0, 1, 3).contiguous()  # [seq_len, batch_size, n_head, d_model // n_head]
        # apply head mask
        if head_mask is not None:
            attn_output = attn_output * head_mask[None, None, :, None]
        
        history['attn_output'] = attn_output

        attn_output = attn_output.flatten(-2, -1)
        attn_output = F.linear(attn_output, attn.out_proj.weight, attn.out_proj.bias)
        return attn_output.transpose(0, 1)

class ToyTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_head, hidden_size, n_tokens, max_len, dropout=0.0):
        super().__init__()
        self.n_layers = n_layers
        self.d_model = d_model
        self.n_head = n_head
        self.hidden_size = hidden_size
        self.tokens = list(range(n_tokens))
        self.max_len = max_len

        self.embed = nn.Embedding(n_tokens, embedding_dim=d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, n_head=n_head, dim_feedforward=hidden_size)
            for _ in range(n_layers)
        ])
        self.unembed = nn.Linear(d_model, n_tokens)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(
        self,
        x,
        skip_feedforward=False,
        skip_self_attn=False,
        return_before_embedding=False,
        linear_mask=None,
        history=None,
        return_attn_weights=False,
        custom_attention=False,
        head_mask=None):
        if head_mask is not None:
            custom_attention = True
        tgt = self.embed(x)
        tgt = F.pad(tgt, (0, 0, 0, 1))  # [batch_size, seq_len + 1, d_model]
        for layer in self.layers:
            tgt, attn_heads = layer(
                tgt,
                skip_feedforward=skip_feedforward,
                skip_self_attn=skip_self_attn,
                linear_mask=linear_mask,
                history=history,
                custom_attention=custom_attention,
                head_mask=head_mask)
            if history is not None:
                return tgt
        if return_before_embedding:
            return tgt
        x = self.unembed(tgt)
        if return_attn_weights:
            return x, attn_heads
        return x

    def train(self, memory, episode, episode_length, lr=1e-3, batch_size=128, n_epochs=1000, verbose=True):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for _ in tqdm(range(n_epochs)) if verbose else range(n_epochs):
            batch = self.generate_data(batch_size, memory, episode, episode_length)
            optimizer.zero_grad()
            input = batch[:, :-1]
            output = self(input)
            loss = criterion(output.reshape(-1, len(self.tokens)), batch.reshape(-1))

            loss.backward()
            optimizer.step()

        if verbose:
            print('loss: ', loss.item())

    def generate_data(self, batch_size, memory, episode, episode_length):
        random_indices1 = torch.randint(episode * episode_length, (episode + 1) * episode_length, (batch_size, 1))  # [batch_size, 1]
        random_indices2 = torch.randint(0, memory.shape[1], (batch_size, 1))  # [batch_size, 1]
        random_indices = torch.cat([random_indices1, random_indices2], dim=1)
        
        next_tokens = memory[random_indices[:, 0], random_indices[:, 1]].unsqueeze(1)  # [batch_size, 1]
        tensor = torch.cat([random_indices, next_tokens], dim=1)
        return tensor.to(self.device)



In [13]:
hidden_size = 64
n_tokens = 320
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
model = ToyTransformer(n_layers=3, d_model=16, n_head=4, hidden_size=hidden_size, n_tokens=n_tokens, max_len=3, dropout=0.1).to(device)
memory = generate_memory(n_tokens)
history = []
episode_length = 320 // 64
for episode in range(64):
    for epoch in range(999):
        model.train(memory, episode, episode_length, lr=1e-3, n_epochs=100, batch_size=144 * 2, verbose=False)
        samples = 1000
        data = model.generate_data(samples, memory, episode, episode_length)
        output = (model(data[:,:-1])[:,-1,:].argmax(dim=-1))
        accuracy = output.eq(data[:,-1]).sum().item() / samples
        print('accuracy ', accuracy)
        if accuracy == 1.0:
            print(f'Finished training in {epoch} epochs')
            history.append(memory)
            memory = generate_memory(n_tokens)

            for ep, mem in enumerate(history):
                data = model.generate_data(samples, mem, ep, episode_length)
                output = (model(data[:,:-1])[:,-1,:].argmax(dim=-1))
                accuracy = output.eq(data[:,-1]).sum().item() / samples
                print(f'model accuracy on episode {ep}: ', accuracy)
            break


accuracy  0.01
accuracy  0.021
accuracy  0.028
accuracy  0.048
accuracy  0.074
accuracy  0.118
accuracy  0.16
accuracy  0.24
accuracy  0.27
accuracy  0.352
accuracy  0.469
accuracy  0.479
accuracy  0.554
accuracy  0.673
accuracy  0.72
accuracy  0.788
accuracy  0.874
accuracy  0.894
accuracy  0.919
accuracy  0.96
accuracy  0.981
accuracy  0.992
accuracy  0.988
accuracy  0.998
accuracy  1.0
Finished training in 24 epochs
model accuracy on episode 0:  1.0
accuracy  0.018
accuracy  0.049
accuracy  0.081
accuracy  0.099
accuracy  0.181
accuracy  0.229
accuracy  0.303
accuracy  0.388
accuracy  0.458
accuracy  0.528
accuracy  0.596
accuracy  0.648
accuracy  0.693
accuracy  0.767
accuracy  0.792
accuracy  0.835
accuracy  0.865
accuracy  0.926
accuracy  0.95
accuracy  0.958
accuracy  0.981
accuracy  0.995
accuracy  0.994
accuracy  0.995
accuracy  0.999
accuracy  0.998
accuracy  1.0
Finished training in 26 epochs
model accuracy on episode 0:  0.008
model accuracy on episode 1:  1.0
accuracy  0.0

KeyboardInterrupt: 