In [2]:
from absl import logging
import flax
import jax.numpy as jnp
import jax
import numpy as np

logging.set_verbosity(logging.INFO)

In [4]:
from transformers import FlaxGPT2LMHeadModel

model_hf = FlaxGPT2LMHeadModel.from_pretrained('gpt2')

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params)))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, 

In [6]:
jnp.shape(model_hf.params['transformer']['wpe']['embedding'])

(1024, 768)

In [7]:
model_hf.params['transformer']['h']['0']['attn'].keys()

dict_keys(['c_attn', 'c_proj'])

In [8]:
jnp.broadcast_to(jnp.arange(10, dtype="i4")[None, :], (32, 10))

Array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5,

In [58]:
from flax import linen as nn
from functools import partial

# define a full GTP2 class in Flax and see if we can replicate the original paper along with
# Karpathy

class GPTConfig:
    block_size: int = 256
    vocab_size: int = 65
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384

class GPTMLP(nn.Module):
    config: GPTConfig

    def setup(self):
        # Simple MLP that upscales, runs through a gelu activation,
        # and then resamples back to the n_embd size (the model size)
        #self.c_fc = nn.Dense(4 * self.config.n_embd)
        self.c_fc = nn.Einsum((4 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')
        #self.c_proj = nn.Dense(self.config.n_embd)
        self.c_proj = nn.Einsum((self.config.n_embd, 4 * self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs

        x = self.c_fc(x)
        x = nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x
    
class GPTAttention(nn.Module):
    config: GPTConfig

    # we will need to roll our own attention module because the built in one has a bunch of different
    # naming and structure compared to the original GPT, which just references the projection layers
    def setup(self):
        self.c_proj = nn.Einsum((4 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        return inputs

class GPTBlock(nn.Module):
    config: GPTConfig

    def setup(self):
        self.ln_1 = nn.LayerNorm()
        # I might have to write this manually to get the proper number of parameters, as the old GPT2 code 
        # migh have subtle differences from the Flax implementation
        self.attn = nn.MultiHeadAttention(num_heads=self.config.n_head, qkv_features=self.config.n_head)
        self.ln_2 = nn.LayerNorm()
        self.mlp = GPTMLP(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.ln_1(x)
        x = self.attn(x)
        x = self.ln_2(x)
        x = self.mlp(x)
        return x

class GPTLayers(nn.Module):
    config: GPTConfig
    
    def setup(self):
        self.blocks = [ GPTBlock(self.config, name=str(i)) for i in range(self.config.n_layer) ]

    def __call__(self, inputs):
        x = inputs

        for block in self.blocks:
            x = block(x)
        return inputs

class GPTModel(nn.Module):
    config: GPTConfig

    def setup(self):
        # This is a little confusing. vocab size is the number of embeddings we need.
        # n_embd is the dimension of each embedding (i.e. capacity for learning properties
        # about this embedding token)
        # input size = (1 x self.block_size) - 1 int for each token
        # output size = then (self.block_size x self.n_embd)
        self.wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
        # embed is just a randomzied parameter matrix, so can be used for positional 
        # encoding as well. I think block size is the token length.
        # This has to match the size of the previous output, as we are just adding.
        self.wpe = nn.Embed(self.config.block_size, self.config.n_embd)
        # The attention layers
        self.h = GPTLayers(self.config)
        self.ln_f = nn.LayerNorm()

    def __call__(self, inputs):
        x = inputs
        x = self.wte(x)
        
        #x += self.wpe(x)

        x = self.h(x)

        x = self.ln_f(x)

        return x


class GPT(nn.Module):
    config: GPTConfig

    def setup(self):
        self.transformer = GPTModel(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.transformer(x)
        return x

In [59]:
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.random.randint(0, 1000, (4,4), dtype='int32')

y = GPT2.init(key2, x)

In [60]:
from jax.tree_util import tree_structure
print(tree_structure(y['params']))

PyTreeDef({'transformer': {'h': {'0': {'attn': {'key': {'bias': *, 'kernel': *}, 'out': {'bias': *, 'kernel': *}, 'query': {'bias': *, 'kernel': *}, 'value': {'bias': *, 'kernel': *}}, 'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '1': {'attn': {'key': {'bias': *, 'kernel': *}, 'out': {'bias': *, 'kernel': *}, 'query': {'bias': *, 'kernel': *}, 'value': {'bias': *, 'kernel': *}}, 'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '2': {'attn': {'key': {'bias': *, 'kernel': *}, 'out': {'bias': *, 'kernel': *}, 'query': {'bias': *, 'kernel': *}, 'value': {'bias': *, 'kernel': *}}, 'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '3': {'attn': {'key': {'bias': *, 'kernel': *}, 'out': {'bias': *, 'kerne

In [61]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params']['transformer']['h']['0']['attn'])))

initialized parameter shapes:
 {'attn': {'key': {'bias': (6, 1), 'kernel': (384, 6, 1)}, 'out': {'bias': (384,), 'kernel': (6, 1, 384)}, 'query': {'bias': (6, 1), 'kernel': (384, 6, 1)}, 'value': {'bias': (6, 1), 'kernel': (384, 6, 1)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}


In [62]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params['transformer']['h']['0']['attn'])))

initialized parameter shapes:
 {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}


In [41]:
# Seems like pytorch Linear is not equal to flax Dense. The matrices are defined transpose. So, we can't actually use
# Dense with the original GPT2 weights

x = jax.random.uniform(jax.random.key(0), (1, 2, 5))

# This will transpose the kernel and multiply by the input,
# This should be the same as the GPT weights
lp = nn.Einsum((3*5, 5), '...ij,...kj->...ik')

y = lp.init(jax.random.key(0), x)

In [42]:
lp.apply(y, x)

Array([[[ 0.40416887, -0.15771458, -0.1418275 ,  0.21624644,
          0.01153318,  0.09767484, -0.16552436, -0.06833879,
         -0.07042018, -0.13803177, -0.00222424,  0.2503722 ,
         -0.3504156 ,  0.18865553, -0.33016038],
        [ 0.62610245, -0.35361993, -0.40769407,  0.22174977,
          0.1036396 ,  0.2302788 , -0.25040534, -0.5411675 ,
         -0.35929912,  0.05858286,  0.03919724,  0.3880062 ,
         -0.13242133,  0.4200734 , -0.47516495]]], dtype=float32)

In [43]:
lp2 = nn.Dense(3*5)

y2 = lp2.init(jax.random.key(0), x)

In [46]:
y2['params']

{'kernel': Array([[ 0.65126026, -0.20171352,  0.7524083 ,  0.21318944,  0.66664237,
         -0.5474334 ,  0.5442594 , -0.66585183, -0.21363808, -0.22357357,
         -0.49834606,  0.24824531, -0.08310969, -0.48232296, -0.21830165],
        [ 0.23788127,  0.42619142,  0.03609344,  0.46733794, -0.27462825,
         -0.70285726,  0.40865424, -0.2941622 ,  0.544966  ,  0.0685642 ,
          0.06401048, -0.16370532,  0.50371426,  0.77258414, -0.4492223 ],
        [-0.6562224 ,  0.04016771, -0.08710401, -0.15943906, -0.01229531,
          0.27645266,  0.26468387,  0.20921318, -0.64844024, -0.9589507 ,
          0.580644  ,  0.07664067, -0.12084891, -0.6119603 , -0.550388  ],
        [-0.4711891 , -0.17671804, -0.25029108,  0.4280646 ,  0.13103884,
          0.60195965, -0.41966525, -0.01266516, -0.66290927,  0.5892318 ,
         -0.0306597 ,  0.27763686,  0.15033169,  0.42397204,  0.34783563],
        [-0.63989514, -0.8911751 ,  0.31193644,  0.50128835, -0.2739739 ,
          0.32711524, -0

In [47]:
y['params']

{'kernel': Array([[ 0.3760053 , -0.11645935,  0.43440315,  0.12308498,  0.38488615],
        [-0.3160608 ,  0.31422833, -0.38442972, -0.123344  , -0.12908025],
        [-0.28772023,  0.1433245 , -0.0479834 , -0.2784693 , -0.12603652],
        [ 0.13734081,  0.24606173,  0.02083855,  0.26981768, -0.15855668],
        [-0.40579483,  0.23593663, -0.16983464,  0.31463626,  0.03958556],
        [ 0.03695647, -0.09451531,  0.2908196 ,  0.44605166, -0.2593586 ],
        [-0.37887016,  0.02319084, -0.05028952, -0.09205218, -0.0070987 ],
        [ 0.15961002,  0.1528153 ,  0.12078928, -0.37437713, -0.55365044],
        [ 0.33523497,  0.04424851, -0.06977215, -0.3533154 , -0.31776664],
        [-0.27204117, -0.10202821, -0.14450562,  0.24714321,  0.07565531],
        [ 0.34754157, -0.24229383, -0.00731223, -0.38273084,  0.34019315],
        [-0.01770138,  0.16029371,  0.08679405,  0.24478038,  0.200823  ],
        [-0.36944363, -0.51452017,  0.18009658,  0.28941897, -0.15817891],
        [ 0.188

In [45]:
lp2.apply(y2, x)

Array([[[-0.22546573, -0.13162534,  0.29677063,  0.5539263 ,
          0.00525441, -0.25402257,  0.42499188, -0.2886829 ,
         -0.05022871, -0.18250518, -0.00925286, -0.01849659,
          0.20977122, -0.02310264, -0.5931318 ],
        [-0.6551649 , -0.74319994,  0.34639242,  0.85376245,
          0.12546577,  0.35563067,  0.02477333, -0.2340745 ,
         -0.37190086,  0.38618025, -0.30876082,  0.07604566,
          0.06765305, -0.19773354, -0.28735214]]], dtype=float32)

In [387]:
y['params']

{'kernel': Array([[ 0.00585502,  0.06663669, -0.02862518, ...,  0.03995296,
         -0.02848659,  0.03771213],
        [-0.01506117, -0.04920985, -0.02150657, ..., -0.00969302,
         -0.00604414, -0.06075063],
        [-0.01117269, -0.05698648, -0.02741355, ...,  0.01138836,
          0.02523999, -0.03665182],
        ...,
        [ 0.00031354,  0.03191991,  0.01749613, ..., -0.03667144,
          0.02852946,  0.02422978],
        [ 0.01442624,  0.02207879, -0.02500978, ..., -0.00365288,
          0.00623836,  0.01097682],
        [-0.03645209, -0.06231447, -0.00414094, ...,  0.03171502,
          0.06253771, -0.02757666]], dtype=float32),
 'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)}

In [38]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params'])))

initialized parameter shapes:
 {'bias': (15,), 'kernel': (15, 5)}


In [369]:
A = np.array([[ 0,  1,  2, 3],
              [ 4,  5,  6, 7],
              [ 8,  9, 10, 11]])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

np.einsum('ij,kj->ik', A, B)

array([[ 14,  38,  62],
       [ 38, 126, 214],
       [ 62, 214, 366]])