In [1]:
from absl import logging
import flax
import jax.numpy as jnp
import jax
import numpy as np

logging.set_verbosity(logging.INFO)

In [5]:
from transformers import FlaxGPT2LMHeadModel

model_hf = FlaxGPT2LMHeadModel.from_pretrained('gpt2')

AttributeError: 'FlaxGPT2LMHeadModel' object has no attribute 'state_dict'

In [57]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params)))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, 

In [191]:
jnp.shape(model_hf.params['transformer']['wpe']['embedding'])

(1024, 768)

In [267]:
model_hf.params['transformer']['h']['0']['attn'].keys()

dict_keys(['c_attn', 'c_proj'])

In [202]:
jnp.broadcast_to(jnp.arange(10, dtype="i4")[None, :], (32, 10))

Array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5,

In [280]:
from flax import linen as nn
from functools import partial

# define a full GTP2 class in Flax and see if we can replicate the original paper along with
# Karpathy

class GPTConfig:
    block_size: int = 256
    vocab_size: int = 65
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384

class GPTMLP(nn.Module):
    config: GPTConfig

    def setup(self):
        # Simple MLP that upscales, runs through a gelu activation,
        # and then resamples back to the n_embd size (the model size)
        self.c_fc = nn.Dense(4 * self.config.n_embd)
        self.c_proj = nn.Dense(self.config.n_embd)

    def __call__(self, inputs):
        x = inputs

        x = self.c_fc(x)
        x = nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x

class GPTBlock(nn.Module):
    config: GPTConfig

    def setup(self):
        self.ln_1 = nn.LayerNorm()
        # I might have to write this manually to get the proper number of parameters, as the old GPT2 code 
        # migh have subtle differences from the Flax implementation
        #self.attn = nn.MultiHeadAttention(num_heads=self.config.n_head, qkv_features=self.config.n_head)
        self.ln_2 = nn.LayerNorm()
        self.mlp = GPTMLP(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.ln_1(x)
        #x = self.attn(x)
        x = self.ln_2(x)
        x = self.mlp(x)
        return x

class GPTLayers(nn.Module):
    config: GPTConfig
    
    def setup(self):
        self.blocks = [ GPTBlock(self.config, name=str(i)) for i in range(self.config.n_layer) ]

    def __call__(self, inputs):
        x = inputs

        for block in self.blocks:
            x = block(x)
        return inputs

class GPTModel(nn.Module):
    config: GPTConfig

    def setup(self):
        # This is a little confusing. vocab size is the number of embeddings we need.
        # n_embd is the dimension of each embedding (i.e. capacity for learning properties
        # about this embedding token)
        # input size = (1 x self.block_size) - 1 int for each token
        # output size = then (self.block_size x self.n_embd)
        self.wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
        # embed is just a randomzied parameter matrix, so can be used for positional 
        # encoding as well. I think block size is the token length.
        # This has to match the size of the previous output, as we are just adding.
        self.wpe = nn.Embed(self.config.block_size, self.config.n_embd)
        # The attention layers
        self.h = GPTLayers(self.config)
        self.ln_f = nn.LayerNorm()

    def __call__(self, inputs):
        x = inputs
        x = self.wte(x)
        
        #x += self.wpe(x)

        x = self.h(x)

        x = self.ln_f(x)

        return x


class GPT(nn.Module):
    config: GPTConfig

    def setup(self):
        self.transformer = GPTModel(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.transformer(x)
        return x

In [281]:
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.random.randint(0, 1000, (4,4), dtype='int32')

y = GPT2.init(key2, x)

In [282]:
from jax.tree_util import tree_structure
print(tree_structure(y['params']))

PyTreeDef({'transformer': {'h': {'0': {'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '1': {'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '2': {'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '3': {'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '4': {'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}, '5': {'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}}}, 'ln_f': {'bias': *, 'scale': *}, 'wte': {'embedding': 

In [293]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params']['transformer']['h']['0']['mlp'])))

initialized parameter shapes:
 {'c_fc': {'bias': (1536,), 'kernel': (384, 1536)}, 'c_proj': {'bias': (384,), 'kernel': (1536, 384)}}


In [296]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params['transformer']['h']['0']['mlp'])))

initialized parameter shapes:
 {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}


In [386]:
# Seems like pytorch Linear is not equal to flax Dense. The matrices are defined transpose. So, we can't actually use
# Dense with the original GPT2 weights

x = jax.random.uniform(jax.random.key(0), (32, 64, 384))

# This will transpose the kernel and multiply by the input,
# This should be the same as the GPT weights
lp = nn.Einsum((3*384, 384), '...ij,...kj->ik')

y = lp.init(jax.random.key(0), x)

In [387]:
y['params']

{'kernel': Array([[ 0.00585502,  0.06663669, -0.02862518, ...,  0.03995296,
         -0.02848659,  0.03771213],
        [-0.01506117, -0.04920985, -0.02150657, ..., -0.00969302,
         -0.00604414, -0.06075063],
        [-0.01117269, -0.05698648, -0.02741355, ...,  0.01138836,
          0.02523999, -0.03665182],
        ...,
        [ 0.00031354,  0.03191991,  0.01749613, ..., -0.03667144,
          0.02852946,  0.02422978],
        [ 0.01442624,  0.02207879, -0.02500978, ..., -0.00365288,
          0.00623836,  0.01097682],
        [-0.03645209, -0.06231447, -0.00414094, ...,  0.03171502,
          0.06253771, -0.02757666]], dtype=float32),
 'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)}

In [388]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params'])))

initialized parameter shapes:
 {'bias': (1152,), 'kernel': (1152, 384)}


In [369]:
A = np.array([[ 0,  1,  2, 3],
              [ 4,  5,  6, 7],
              [ 8,  9, 10, 11]])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

np.einsum('ij,kj->ik', A, B)

array([[ 14,  38,  62],
       [ 38, 126, 214],
       [ 62, 214, 366]])