In this notebook we want to implement the zero-layer transformer from transformer circuits (https://transformer-circuits.pub/2021/framework/index.html). We would like to do this using an observer pattern framework similar to pytorch ignite. Let's order the steps we should go through:

 - get data
 - write model
 - write training loop
 - visualization

# Write Model

We will implement the zero-layer transformer in Flax.

We have this paragraph for model details:

"The models used as examples in this paper are zero, one, and two layer decoder-only, attention-only transformers 
[1]
. For all models, d_\text{model} = n_\text{heads} * d_\text{head}d 
model
​
 =n 
heads
​
 ∗d 
head
​
 , typically with n_\text{heads}=12n 
heads
​
 =12 and d_\text{head}=64d 
head
​
 =64, but also with one explicitly noted example where n_\text{heads}=32n 
heads
​
 =32 and d_\text{head}=128d 
head
​
 =128.

Models have a context size of 2048 tokens and use dense attention. (Dense attention was preferred over sparse attention for simplicity, but made a smaller context perferrable.) We use a positional mechanism similar to Press et al. 
[14]
, adding sinusoidal embeddings immediately before multiplying by W_QW 
Q
​
  and W_KW 
K
​
  to produce queries and keys. (This excludes pointer-arithmetic based algorithms without the distorted QK matrices like rotary.)"
  

The zero-layer transformer is just comprised of an embedding matrix and an unembedding matrix. It should contain approximate bigram log-likelihood.

In [86]:
from flax import linen as nn


class ZeroLayerTransformer(nn.Module):
    vocab_size: int
    embed_dim: int
    
    @nn.compact
    def __call__(self, input_ids):
        """
        input_ids will be a batch of input ids shape (n_examples, max_seq_len)
        """
        embedded = nn.Embed(self.vocab_size, self.embed_dim, name='embedding_matrix')(input_ids)
        unembedded = nn.Dense(vocab_size, use_bias=False, name='unembedding_matrix')(embedded)
        logits = nn.softmax(unembedded)
        return logits


In [81]:
import os
from jax import numpy as jnp
from jax import random

os.environ['CUDA_VISIBLE_DEVICES'] = '5'

key = random.PRNGKey(0)

def get_new_key():
    global key
    key, subkey = random.split(key)
    return subkey


In [82]:
random.randint(get_new_key(), [1], 0, 10)

DeviceArray([6], dtype=int32)

In [87]:
vocab_size = 20
embed_dim = 10
max_seq_len = 16
n_examples = 5

sample_input_ids = random.randint(get_new_key(), (n_examples, max_seq_len), 0, vocab_size)
zltransformer = ZeroLayerTransformer(vocab_size, embed_dim)
params = zltransformer.init(get_new_key(), sample_input_ids)
sample_out = zltransformer.apply(params, sample_input_ids)

In [88]:
sample_out.shape

(5, 16, 20)

In [93]:
sample_out.sum(axis=-1)

DeviceArray([[0.99999994, 1.        , 1.0000001 , 0.9999999 , 1.0000001 ,
              1.0000001 , 1.        , 0.99999994, 1.        , 1.        ,
              1.        , 1.0000001 , 0.99999994, 1.0000001 , 1.        ,
              1.        ],
             [0.99999994, 0.99999994, 0.99999994, 0.99999976, 0.9999999 ,
              1.        , 1.        , 1.        , 1.0000001 , 1.        ,
              1.        , 1.0000001 , 0.99999994, 1.        , 0.99999994,
              1.0000001 ],
             [1.        , 1.        , 0.99999994, 1.        , 1.        ,
              1.        , 1.        , 1.        , 1.        , 1.        ,
              0.9999999 , 1.        , 1.        , 1.0000001 , 1.0000001 ,
              1.0000001 ],
             [0.9999999 , 1.        , 1.        , 0.99999994, 1.0000001 ,
              1.0000001 , 1.0000001 , 1.0000001 , 1.0000001 , 1.        ,
              1.0000001 , 1.0000001 , 1.        , 1.        , 1.0000001 ,
              1.0000001 ],
    

In [91]:
sample_out.argmax(axis=-1)

DeviceArray([[18, 17, 18, 15, 11, 12, 18,  1,  3, 17,  2, 11,  1,  4,  6,
              11],
             [14, 14, 14,  6, 15, 17, 11,  2, 18, 15,  6,  4, 18, 18,  1,
              12],
             [ 3, 15,  1,  2, 17, 18, 11, 11,  6, 18, 15, 17, 15,  4, 11,
               6],
             [15,  6,  3, 18,  4,  6, 12,  4, 18, 11, 12, 18,  2,  2,  4,
              18],
             [18, 17, 18, 14, 11, 12, 18,  3, 11, 15, 11, 18, 11,  2, 14,
              14]], dtype=int32)