[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)

In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import math
from itertools import combinations, combinations_with_replacement

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Using gpu: %s ' % torch.cuda.is_available())

# [Thinking like Transformers](https://arxiv.org/abs/2106.06981)

Here we code our 'toy' GPT without any training in order to compute histograms. For the input sequence `<BOS>,a,a,b,a,b,c`, the output should be `0,3,3,2,3,2,1` as the letter `a` appears 3 times, the letter `b` 2 times and the letter `c` once. Each letter is replaced by its number of occurences (except `<BOS>` replaced by a `0`). 

## Self-Attention

First start by coding your Self-Attention layer (do not worry about initialization for the moment).

In [None]:
class SelfAttentionLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_channels = config.n_channels
        self.key_channels = config.key_channels
        self.Query = nn.Linear(self.n_channels, self.key_channels, bias=False)
        self.Key = nn.Linear(self.n_channels, self.key_channels, bias = False)
        self.Value = nn.Linear(self.n_channels, self.n_channels, bias = False)
           
    def _init_id(self):
        self.Query.weight.data = 100*torch.eye(self.key_channels, self.n_channels)
        self.Key.weight.data = 100*torch.eye(self.key_channels,self.n_channels)
        self.Value.weight.data = torch.eye(self.key_channels,self.n_channels)        
        
    def forward(self, x): # x (bs, T, ic)
        Q = self.Query(x) # (bs, T, kc)
        K = self.Key(x)/math.sqrt(self.key_channels) # (bs, T, kc)
        V = self.Value(x) # (bs, T, oc)
        A = # your code here
        y = # your code here
        return y, A

Check your implementation.

In [None]:
class toy_config:
    n_channels = 3
    key_channels = 3
    
sa_toy = SelfAttentionLayer(toy_config)

In [None]:
input = torch.randn(5,10,3)
y,A = sa_toy(input)

In [None]:
y.shape

In [None]:
torch.sum(A, dim=-1)

## identity GPT

We first start with a simple example where we want to contruct the identity map. Clearly, in this case, we can just use the skip connections present in real transformer block. Instead, we will ignore these skip connections and use the self-attention layer. In this practical, we will ignore the layer norm.

To make our life simpler, we encode `<BOS>` with a `0`, letter `a` with a `1` and so on...

If we give as input the sequence `0,1,1,2,3,4,2,3,1`, we want to get the same sequence as output. This is clearly doable with a transformer block as follows:
- take one-hot encoding of each token 
- take Query and Key matrices as `100*Id`
- take Value matrix as `Id`
As a result, the output of the self-attention layer will be the same as the input.

Then take a Feed Forward Network which is simply the identity map as coded below:

In [None]:
class Block_id(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = SelfAttentionLayer(config)
        self.fake_mlp = (lambda x : x)
        self.attn._init_id()

    def forward(self, x):
        x, A = self.attn(x)
        x = self.fake_mlp(x)
        return x, A

In [None]:
nb_digits = 4
class config:
    n_channels=nb_digits+1
    key_channels=nb_digits+1

In [None]:
bid = Block_id(config)
one_sample = torch.tensor([[0.,0.,1.,0.,0.],[0.,1.,0.,0.,0.]]).unsqueeze(0)
bid(one_sample)

Now to have really the identity map, we need to project back the one-hot encoding and this can be done with a linear layer (with good weights initialization).

In [None]:
class GPT_id(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_channels = config.n_channels
        self.tok_emb = nn.Embedding(self.n_channels,self.n_channels)
        self.block = Block_id(config)
        self.head = nn.Linear(self.n_channels, 1, bias = False)
        self._init_weights()
        
    def _init_weights(self):
        #
        # your code here
        #
        
    def forward(self, idx):
        x = self.tok_emb(idx)
        x, A = self.block(x)
        return self.head(x), A

In [None]:
gid = GPT_id(config)

In [None]:
one_sample = torch.tensor([0,1,1,2,3,4,2,3,1]).unsqueeze(0)
y, A = gid(one_sample)

In [None]:
y == one_sample

In [None]:
plt.imshow(A[0,:,:].cpu().data, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.show()

## histogram GPT

Now we need to adapt previous case to code our 'toy' transformer block and your 'toy' GPT to compute histograms:
- you will need to find a good initialization for the Quey, Key and Value matrices
- for the feed forward network, you can fake the mlp with any function you'd like.

In [None]:
class Block_hist(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = SelfAttentionLayer(config)
        self.fake_mlp = # your code here
        self.attn._init_hist() # this need to be coded in your self attention layer

    def forward(self, x):
        x, A = self.attn(x)
        x = self.fake_mlp(x)
        return x, A

In [None]:
class GPT_hist(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_channels = config.n_channels
        self.tok_emb = nn.Embedding(self.n_channels,self.n_channels)
        self.block = Block_hist(config)
        self._init_weights()
        
    def _init_weights(self):
        #
        # your code here
        #
        
        
    def forward(self, idx):
        x = self.tok_emb(idx)
        x, A = self.block(x)
        return x, A

Check your implementation by first choosing properly your configuration:

In [None]:
gh = GPT_hist(config)

In [None]:
one_sample = torch.tensor([0,1,1,2,3,4,2,3,1]).unsqueeze(0)
y, A = gh(one_sample)
y

In [None]:
y.shape

In [None]:
plt.imshow(A[0,:,:].cpu().data, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.show()

# Generating your dataset

Now, we will use a 'micro' GPT to learn the task of histograms. Before that, we will use our 'toy' GPT to generate the dataset. Since GPT is equivariant (a permutation of the input will permute the output), we can always take as input a sequence ordered. We can indeed compute all possible different inputs and this number is not too high. For a sequence of lenght `seq_train=s` with at most `nb_digits=n`, there are ${s+n-1 \choose n-1}$ possibilities. Now for each such sequence, we pass it through our toy GPT to get the label.

In [None]:
seq_train = 30
nb_digits = 4
comb = combinations_with_replacement(range(0,seq_train+1), nb_digits-1)

def make_seq(c, seq_train):
    c_l = [0] + list(c) + [seq_train]
    len_seq = len(c_l)-1
    return [c_l[i+1]-c_l[i] for i in range(len_seq)]

l_comb =  [make_seq(c,seq_train) for c in comb]

len(l_comb)

In [None]:
math.comb(seq_train+nb_digits-1, nb_digits-1)

In [None]:
def make_inputs(l_comb, nb_digits=nb_digits):
    inputs = []
    for t in l_comb:
        curr = [0]
        for (i,j) in enumerate(t):
            curr += [i+1 for _ in range(j)]
        inputs.append(torch.tensor(np.array(curr)))
    return inputs

def make_loader(len_seq,nb_digits):
    comb = combinations_with_replacement(range(0,len_seq+1), nb_digits-1)
    l_comb =  [make_seq(c,len_seq) for c in comb]
    inputs = make_inputs(l_comb)
    labels = [(gh(d.unsqueeze(0))[0].squeeze(0).squeeze(1)).type(torch.LongTensor) for d in inputs]
    dataset = list(zip(inputs,labels))
    len_in = len(dataset)
    loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
    return loader, len_in, inputs

In [None]:
train_loader, size_train, inputs_train = make_loader(seq_train,nb_digits)

In [None]:
size_train

In [None]:
batch_in = next(iter(train_loader))

In [None]:
batch_in[0].shape

In [None]:
batch_in[1].shape

In [None]:
batch_in[0][0]

In [None]:
batch_in[1][0]

# Coding 'micro' GPT

Now we need to code the 'micro' GPT used for learning. The game here is to reuse our `SelfAttentionLayer` above without any modification. The only part that is modified is the hard-coded `fake_mlp` which is now replaced by a real MLP.

In [None]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = SelfAttentionLayer(config)
        self.mlp = # your code here

    def forward(self, x, verbose=False): # x (bs, T,ic)
        #
        # your code here
        #
        if verbose:
            return x, A
        else:
            return x

In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_channels = config.n_channels
        self.nb_digits = config.nb_digits
        self.tok_emb = # your code here 
        self.block = Block(config)
        self.head = # your code here 
        
    def forward(self, idx, targets=None, verbose=False):
        # shape of idx: (bs, len) 0=bos and 1...nb_digits
        # shape of targets: (bs, len)
        #
        # your code here
        #
        
        loss = None
        if targets is not None:
            loss = # your code here
        if verbose:
            return logits, loss, A
        else:
            return logits, loss

In [None]:
class config_gpt:
    nb_digits = nb_digits
    n_channels = 32 
    key_channels = 64 
    max_hist = seq_train+1

In [None]:
gptmini = GPT(config_gpt)

In [None]:
logits, _ = gptmini(batch_in[0])

In [None]:
logits.shape

In [None]:
_,preds = torch.max(logits,-1)

In [None]:
preds.shape

In [None]:
batch_in[0].shape

In [None]:
torch.sum(preds == batch_in[1])

In [None]:
def train_model(model, dataloader, size, epochs=1, optimizer=None):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        running_corrects = 0
        n_batch = 0
        for inputs,targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            logits, loss = model(inputs,targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #
            # complete the code below:
            _,preds = torch.max(logits,-1)
           
            running_corrects += torch.true_divide(torch.sum(preds == targets.data),targets.shape[0]*targets.shape[1])
            running_loss +=  loss.data.item()
            n_batch += 1
        epoch_loss = running_loss /n_batch
        epoch_acc = running_corrects.data.item() /n_batch
        print('Loss: {:.4f} Acc: {:.4f}'.format(
                     epoch_loss, epoch_acc))

In [None]:
gptmini = GPT(config_gpt)
gptmini = gptmini.to(device)
lr = 0.01
optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)

In [None]:
len_train = (seq_train+1)*size_train
train_model(gptmini,train_loader,size_train,15,optimizer)

In [None]:
lr = 0.005
optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)
train_model(gptmini,train_loader,len_train,15,optimizer)

In [None]:
lr = 0.001
optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)
train_model(gptmini,train_loader,len_train,15,optimizer)

In [None]:
lr = 0.0001
optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)
train_model(gptmini,train_loader,len_train,15,optimizer)

In [None]:
one_batch = batch_in[0].to(device)
logits, loss, A = gptmini(one_batch,verbose=True)
A.shape

In [None]:
k = 45
plt.imshow(A[k,:,:].cpu().data, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.show()

[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)