In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import math
from itertools import combinations, combinations_with_replacement

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Using gpu: %s ' % torch.cuda.is_available())

# [Thinking like Transformers](https://arxiv.org/abs/2106.06981)

Here we code our 'toy' GPT without any training in order to compute histograms.

First start by coding your Self-Attention layer (do not forget to choose properly your initialization).

In [None]:
class SelfAttentionLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.in_channels = config.in_channels
        self.out_channels = config.out_channels
        self.key_channels = config.key_channels
        self.Query = nn.Linear(self.in_channels, self.key_channels)
        self.Key = nn.Linear(self.in_channels, self.key_channels)
        self.Value = nn.Linear(self.in_channels, self.out_channels)
    
    def _init_hist(self):
        #
        # your code here
        #
          
        
    def forward(self, x): # x (bs, T, ic)
        Q = self.Query(x) # (bs, T, kc)
        K = self.Key(x)/math.sqrt(self.key_channels) # (bs, T, kc)
        V = self.Value(x) # (bs, T, oc)
        A = #
        # your code here
        #
        return y, A

Check your implementation.

In [None]:
class toy_config:
    in_channels = 3
    out_channels = 3
    key_channels = 3
    
sa_toy = SelfAttentionLayer(toy_config)

In [None]:
input = torch.randn(5,10,3)
y,A = sa_toy(input)

In [None]:
y.shape

In [None]:
A.shape

Now code your 'toy' transformer block and your 'toy' GPT to compute histograms:

In [None]:
class Block_hist(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = SelfAttentionLayer(config)
        self.final_function = # your code here
        self.attn._init_hist()

    def forward(self, x):
        x,_ = self.attn(x)
        x = self.final_function(x)
        return x

In [None]:
class GPT_hist(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.in_channels = config.in_channels
        self.tok_emb = nn.Embedding(self.in_channels,self.in_channels)
        self.block = Block_hist(config)
        self._init_weights()
        
    def _init_weights(self):
        #
        # your code here
        #
        
    def forward(self, idx):
        x = self.tok_emb(idx)
        x = self.block(x)
        return x

Check your implementation by first choosing properly your configuration:

In [None]:
nb_digits = 4
class config:
    in_channels=nb_digits+1
    out_channels=1
    key_channels=nb_digits+1
    #max_hist = 20

In [None]:
gh = GPT_hist(config)

In [None]:
gh(torch.tensor([0,1,1,2,3,2,1]).unsqueeze(0))

# Generating your dataset

Now, we will use a 'micro' GPT to learn the task of histograms. We will use your 'toy' GPT to generate the dataset. Since GPT is equivariant, we can indeed compute all possible different outputs and this number is not too high.

In [None]:
seq_train = 30
nb_digits = 4
comb = combinations_with_replacement(range(0,seq_train+1), nb_digits-1)

def make_seq(c, seq_train):
    c_l = [0] + list(c) + [seq_train]
    len_seq = len(c_l)-1
    return [c_l[i+1]-c_l[i] for i in range(len_seq)]

l_comb =  [make_seq(c,seq_train) for c in comb]

len(l_comb)

In [None]:
math.comb(seq_train+nb_digits-1, nb_digits-1)

In [None]:
def make_inputs(l_comb, nb_digits=nb_digits):
    inputs = []
    for t in l_comb:
        curr = [0]
        for (i,j) in enumerate(t):
            curr += [i+1 for _ in range(j)]
        inputs.append(torch.tensor(np.array(curr)))
    return inputs

def duplicate(l, n):
    return [e for e in l for _ in range(n)]

def make_loader(len_seq,nb_digits,size):
    comb = combinations_with_replacement(range(0,len_seq+1), nb_digits-1)
    l_comb =  [make_seq(c,len_seq) for c in comb]
    inputs = make_inputs(l_comb)
    labels = [(gh(d.unsqueeze(0)).squeeze(0).squeeze(1)).type(torch.LongTensor) for d in inputs]
    dataset = list(zip(inputs,labels))
    n_dup = size // math.comb(len_seq+2, nb_digits-1)
    dataset_l = duplicate(dataset,n_dup)
    len_in = len(dataset_l)
    loader = torch.utils.data.DataLoader(dataset_l, batch_size=128, shuffle=True)
    return loader, len_in, inputs

In [None]:
train_loader, size_train, inputs_train = make_loader(seq_train,nb_digits,10000)

In [None]:
size_train

In [None]:
batch_in = next(iter(train_loader))

# Coding 'micro' GPT

Now we need to code the 'micro' GPT used for learning. The game here is to reuse your `SelfAttentionLayer` above without any modification. The only part that is modified is the hard-coded `final_function` which now replaced by a MLP.

In [None]:
class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.attn = SelfAttentionLayer(config)
        self.mlp = # your code here

    def forward(self, x): # x (bs, T,ic)
        #
        # your code here
        #

In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.in_channels = config.in_channels
        self.nb_digits = config.nb_digits
        #
        # your code here
        #
        
    def forward(self, idx, targets=None, verbose=False):
        # shape of idx: (bs, len) 0=bos and 1...nb_digits
        # shape of targets: (bs, len)
        #
        # your code here
        #
        
        loss = None
        if targets is not None:
            loss = # your code here
        if verbose:
            return logits, loss, A
        else:
            return logits, loss

In [None]:
class config_gpt:
    nb_digits = nb_digits
    in_channels = 12 
    out_channels = in_channels 
    key_channels = 128 
    max_hist = seq_train+1
    
gptmini = GPT(config_gpt)

In [None]:
logits, _ = gptmini(batch_in[0])

In [None]:
logits.shape

In [None]:
_,preds = torch.max(logits,-1)

In [None]:
preds.shape

In [None]:
batch_in[0].shape

In [None]:
torch.sum(preds == batch_in[1])

In [None]:
def train_model(model, dataloader, size, epochs=1, optimizer=None):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        running_corrects = 0
        n_batch = 0
        for inputs,targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            logits, loss = model(inputs,targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #
            # complete the code below:
            _,preds = torch.max(logits,-1)
           
            running_corrects += 
            running_loss += 
            n_batch += 1
        epoch_loss = running_loss /n_batch
        epoch_acc = running_corrects.data.item() /n_batch
        print('Loss: {:.4f} Acc: {:.4f}'.format(
                     epoch_loss, epoch_acc))

In [None]:
gptmini = GPT(config_gpt)
gptmini = gptmini.to(device)
lr = 0.01
optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)

In [None]:
len_train = (seq_train+1)*size_train
train_model(gptmini,train_loader,size_train,25,optimizer)

In [None]:
lr = 0.001
optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)
train_model(gptmini,train_loader,len_train,15,optimizer)

# Generalization

Adapt your code to be able to deal with smaller sequences (with possibly fewer digits). Hint: use padding....