In [1]:
from absl import logging
import flax
import jax.numpy as jnp
import jax
import numpy as np

logging.set_verbosity(logging.INFO)

In [52]:
from transformers import FlaxGPT2LMHeadModel

model_hf = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2")

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

flax_model.msgpack:   0%|          | 0.00/498M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [5]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params)))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, 

In [6]:
jnp.shape(model_hf.params['transformer']['wpe']['embedding'])

(1024, 768)

In [7]:
model_hf.params['transformer']['h']['0']['attn'].keys()

dict_keys(['c_attn', 'c_proj'])

In [370]:
from flax import linen as nn
from functools import partial

# define a full GTP2 class in Flax and see if we can replicate the original paper along with
# Karpathy

class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class GPTMLP(nn.Module):
    config: GPTConfig

    def setup(self):
        # Simple MLP that upscales, runs through a gelu activation,
        # and then resamples back to the n_embd size (the model size)
        #self.c_fc = nn.Dense(4 * self.config.n_embd)
        # Had to use Einsum to match the matrix multiplication of GPT2 and pytorch. They accept 
        # both shapes and then multiply by the transpose of the matrix (basically means the 
        # shape of the matrix is transpose, but the operation is the same)
        self.c_fc = nn.Einsum((4 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')
        #self.c_proj = nn.Dense(self.config.n_embd)
        self.c_proj = nn.Einsum((self.config.n_embd, 4 * self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs

        x = self.c_fc(x)
        x = nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x
    
class GPTAttention(nn.Module):
    config: GPTConfig

    # we will need to roll our own attention module because the built in one has a bunch of different
    # naming and structure compared to the original GPT, which just references the projection layers
    def setup(self):
        # The first thing we do is project up to 3x the model size, because we are going to split
        # the data into q, v, k
        self.c_attn = nn.Einsum((3 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

        # At the end we have to project everything back to the regular model size of n_embd
        self.c_proj = nn.Einsum((self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs
        B, T, C = jnp.shape(x)
        
        # Project to qkv
        qkv = self.c_attn(x)

        q, k, v = jnp.split(qkv, 3, axis=2)

        query_length, key_length = q.shape[1], k.shape[1]

        # Now reshape with a new head "batch" dimension
        k = jnp.reshape(k, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        q = jnp.reshape(q, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        v = jnp.reshape(v, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)

        # Make the attention mask
        # TODO: For round 1 I'm just trusting the linen functions. They seem to be doing the correct thing here, but I may have to 
        # return to this for a closer look at the math if I'm not getting the GPT2 results
        c_mask = nn.make_causal_mask(jnp.ones((1, self.config.block_size), dtype="bool"), dtype="bool")
        c_mask = c_mask[:, :, :query_length, :key_length]
        c_mask = jnp.broadcast_to(c_mask, (B,) + c_mask.shape[1:])

        attention_bias = jax.lax.select(
                c_mask > 0,
                jnp.full(c_mask.shape, 0.0).astype(jnp.float32),
                jnp.full(c_mask.shape, jnp.finfo(jnp.float32).min).astype(jnp.float32),
            )

        # use the built in flax libraries to calculate the attention matrix - attention weights are not returned, but could be 
        # I think bias gives more control compared to mask - i.e. bias can be a float. They might result in the same output with a 
        # boolean mask, but I will have to test that.
        y = nn.dot_product_attention(q, k, v, bias=attention_bias)

        # Merge the heads back together
        y = y.reshape((B, T, C))

        # Project output with a FC layer
        y = self.c_proj(y)

        return y

class GPTBlock(nn.Module):
    config: GPTConfig

    def setup(self):
        self.ln_1 = nn.LayerNorm(epsilon=1e-05)
        # I might have to write this manually to get the proper number of parameters, as the old GPT2 code 
        # migh have subtle differences from the Flax implementation
        self.attn = GPTAttention(self.config)
        self.ln_2 = nn.LayerNorm(epsilon=1e-05)
        self.mlp = GPTMLP(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.ln_1(x)
        x = inputs + self.attn(x)
        inputs2 = x
        x = self.ln_2(x)
        x = inputs2 + self.mlp(x)
        return x

class GPTLayers(nn.Module):
    config: GPTConfig
    
    def setup(self):
        self.blocks = [ GPTBlock(self.config, name=str(i)) for i in range(self.config.n_layer) ]

    def __call__(self, inputs):
        x = inputs

        for block in self.blocks:
            x = block(x)
        return x

class GPTModel(nn.Module):
    config: GPTConfig

    def setup(self):
        # This is a little confusing. vocab size is the number of embeddings we need.
        # n_embd is the dimension of each embedding (i.e. capacity for learning properties
        # about this embedding token)
        # input size = (1 x self.block_size) - 1 int for each token
        # output size = then (self.block_size x self.n_embd)
        self.wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
        # embed is just a randomzied parameter matrix, so can be used for positional 
        # encoding as well. I think block size is the token length.
        # This has to match the size of the previous output, as we are just adding.
        self.wpe = nn.Embed(self.config.block_size, self.config.n_embd)
        # The attention layers
        self.h = GPTLayers(self.config)
        self.ln_f = nn.LayerNorm(epsilon=1e-05)

    def __call__(self, inputs):
        x = inputs
        input_shape = jnp.shape(x)       

        x = self.wte(x)

        # For the positional encodings we need an index that is simple the position
        # of each token. This will be the same shape as the input, but will simply
        # be repeating and increasing numbers from 1.
        # jnp.atleast_2d is needed so we can initialize with a batch size of 1
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(inputs).shape[-1]), input_shape)
        x_wpe = self.wpe(position_ids)
        
        x += x_wpe

        x = self.h(x)

        x = self.ln_f(x)

        return x


class GPT(nn.Module):
    config: GPTConfig

    def setup(self):
        self.transformer = GPTModel(self.config)
        # So the Flax model does not return parameters for lm_head, because lm_head is simply the 
        # inverse operation of the initial word embedding, so we can just reuse those weights. They
        # use the term 'tied' for this. The weights are tied together.
        self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)

    def __call__(self, inputs):
        x = inputs
        x = self.transformer(x)

        # This is from the huggingface code - might as well reuse it here.
        shared_kernel = self.transformer.variables['params']['wte']['embedding'].T
        lm_logits = self.lm_head.apply({'params': {'kernel': shared_kernel}}, x)
        return lm_logits

In [371]:
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.random.randint(0, 50257, (1,1024), dtype='i4')

y = GPT2.init(key2, x)

In [364]:
class testConfig:
    n_embd = 10
    n_head = 2
    block_size = 5

class testGPTAttention(nn.Module):
    config: GPTConfig

    # we will need to roll our own attention module because the built in one has a bunch of different
    # naming and structure compared to the original GPT, which just references the projection layers
    def setup(self):
        # The first thing we do is project up to 3x the model size, because we are going to split
        # the data into q, v, k
        self.c_attn = nn.Einsum((3 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

        # At the end we have to project everything back to the regular model size of n_embd
        self.c_proj = nn.Einsum((self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs
        B, T, C = jnp.shape(x)
        
        # Project to qkv
        qkv = self.c_attn(x)

        q, k, v = jnp.split(qkv, 3, axis=2)

        query_length, key_length = q.shape[1], k.shape[1]

        # Now reshape with a new head "batch" dimension
        k = jnp.reshape(k, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        q = jnp.reshape(q, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        v = jnp.reshape(v, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)

        # Make the attention mask
        # TODO: For round 1 I'm just trusting the linen functions. They seem to be doing the correct thing here, but I may have to 
        # return to this for a closer look at the math if I'm not getting the GPT2 results
        c_mask = nn.make_causal_mask(jnp.ones((1, self.config.block_size), dtype="bool"), dtype="bool")
        c_mask = c_mask[:, :, :query_length, :key_length]
        c_mask = jnp.broadcast_to(c_mask, (B,) + c_mask.shape[1:])

        attention_bias = jax.lax.select(
                c_mask > 0,
                jnp.full(c_mask.shape, 0.0).astype(jnp.float32),
                jnp.full(c_mask.shape, jnp.finfo(jnp.float32).min).astype(jnp.float32),
            )

        # use the built in flax libraries to calculate the attention matrix - attention weights are not returned, but could be 
        # I think bias gives more control compared to mask - i.e. bias can be a float. They might result in the same output with a 
        # boolean mask, but I will have to test that.
        y = nn.dot_product_attention(q, k, v, bias=attention_bias)

        # Merge the heads back together
        y = y.reshape((B, T, C))

        # Project output with a FC layer
        y = self.c_proj(y)

        return y

In [365]:
from typing import Any

class FlaxConv1D(nn.Module):
    features: int
    use_bias: bool = True
    dtype: Any = jnp.float32
    precision: Any = None

    @nn.compact
    def __call__(self, inputs):
        inputs = jnp.asarray(inputs, self.dtype)
        kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
        kernel = jnp.asarray(kernel.transpose(), self.dtype)
        y = jax.lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
        if self.use_bias:
            bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias
        return y

class FlaxGPT2Attention(nn.Module):
    config: testConfig
    dtype: jnp.dtype = jnp.float32
    causal: bool = True
    is_cross_attention: bool = False

    def setup(self):
        config = self.config
        self.embed_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_heads

        if self.is_cross_attention:
            self.c_attn = FlaxConv1D(2 * self.embed_dim, dtype=self.dtype)
            self.q_attn = FlaxConv1D(self.embed_dim, dtype=self.dtype)
        else:
            self.c_attn = FlaxConv1D(3 * self.embed_dim, dtype=self.dtype)
        self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)

        if self.causal:
            self.causal_mask =nn.make_causal_mask(
                jnp.ones((1, config.block_size), dtype="bool"), dtype="bool"
            )

    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))


    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        output_attentions: bool = False,
    ):
        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        batch_size = hidden_states.shape[0]

        qkv_out = self.c_attn(hidden_states)
        query, key, value = jnp.split(qkv_out, 3, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.causal:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])

        attention_mask = causal_mask

        # transform boolean mask into float mask
        if attention_mask is not None:
            attention_bias = jax.lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else:
            attention_bias = None

        # usual dot product attention
        attn_weights = nn.dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)

        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs

In [366]:
ln1 = testGPTAttention(testConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.ones((1, 5, 10))

y = ln1.init(key2, x)

In [367]:
ln2 = FlaxGPT2Attention(testConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.ones((1, 5, 10))

y2 = ln2.init(key2, x)

In [360]:
x1 = np.array([[[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], [6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]])

x1.shape

(1, 2, 10)

In [368]:
ln1.apply({'params': y2['params']}, x1)

Array([[[-0.02889265, -0.00587743,  0.013701  , -0.05318468,
         -0.01398449,  0.00399514,  0.00705284, -0.01679477,
         -0.00136347,  0.00796388],
        [-0.04165645, -0.01023988,  0.02881903, -0.07491028,
         -0.01374639,  0.00323178,  0.02439473, -0.03530604,
         -0.00574531,  0.00888677]]], dtype=float32)

In [369]:
ln2.apply({'params': y2['params']}, x1)

(Array([[[-0.02889265, -0.00587743,  0.013701  , -0.05318468,
          -0.01398449,  0.00399514,  0.00705284, -0.01679477,
          -0.00136347,  0.00796388],
         [-0.04165644, -0.01023988,  0.02881903, -0.07491027,
          -0.01374639,  0.00323177,  0.02439473, -0.03530603,
          -0.00574531,  0.00888677]]], dtype=float32),)

In [230]:
y['params']['kernel']

Array([[ 1.38365766e-02, -1.00438399e-02, -2.59855273e-03,
        -2.84782257e-02,  2.75602620e-02],
       [-5.10251932e-02,  2.19166707e-02, -1.04930252e-02,
         7.55112758e-03, -1.42556112e-02],
       [ 1.18566537e-02, -1.46680195e-02,  1.62017029e-02,
         2.88543049e-02, -1.16017004e-02],
       [ 1.61728784e-02, -7.08742533e-03, -1.32274013e-02,
        -2.91815065e-02, -5.62517811e-03],
       [-3.61472508e-03, -1.01082027e-02,  1.73087647e-05,
        -7.02232495e-02,  1.17731094e-02],
       [-3.53554785e-02,  1.93876233e-02,  6.20354293e-03,
         7.52076088e-03,  6.82975259e-03],
       [ 1.06677208e-02,  6.50583627e-03, -9.52924788e-03,
         1.89422257e-03, -9.77260433e-03],
       [ 2.60021482e-02,  1.21949790e-02,  8.63404572e-03,
        -4.96751908e-03, -7.38750445e-03],
       [ 1.95287429e-02,  6.33169524e-03, -2.40490027e-02,
         9.76687111e-03,  3.65712978e-02],
       [-2.29867529e-02, -2.23886129e-02, -3.12390504e-03,
        -1.47810709e-02

In [376]:
res = GPT2.apply({'params': model_hf.params}, np.random.randint(0, 50257, (2,1024), dtype='i4'))

In [180]:
res.shape

(2, 1024, 50257)

In [377]:
import tiktoken

num_return_sequences = 5
max_length = 30

enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("Neuroscience is ")

tokens = jnp.array(tokens, dtype="i4")

x = jnp.atleast_2d(tokens).repeat(num_return_sequences, 0)

In [379]:
x

Array([[15496,    11,   314,  1101,   257,  3303,  2746,    11,   290,
          340,   338,   257,  5802,   530,    13,  1892,   691,   326,
          475,   340,  1595,   470,   651,  1088,   262,  1109,   326,
          661,   508,   389],
       [15496,    11,   314,  1101,   257,  3303,  2746,    11,   257,
         1048,   526, 50256,   464,   649,   512,  1923,   329,   262,
         2097,  6796,  5531,  2523,   262,  2097, 10834,  6011,   683,
          720,  8054,    11],
       [15496,    11,   314,  1101,   257,  3303,  2746,    11,   616,
         2460,   290,   314,   389,    13,  3363,    11,  3303, 21128,
          318,   257,  5032,    13,   887,  3360,   661,   787, 10135,
          351,   340,    13],
       [15496,    11,   314,  1101,   257,  3303,  2746,    11,   290,
          314,  1254,   588,   340,   318,   257,   845,  1593,   636,
          286,   616,  1693,   553,   531,  1770,    13,  5741,    13,
          366,   464,   517],
       [15496,    11,   314,

In [183]:
key = jax.random.key(0)

In [378]:
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

key = jax.random.key(57)
# This needs to use the jax loop functions to be fast
while x.shape[1] < max_length:

    logits = GPT2.apply({'params': model_hf.params}, x)

    logits = logits[:, -1, :]

    #probs = nn.softmax(logits)

    topk_probs, topk_indices = jax.lax.top_k(logits, 50)

    dist = tfd.Multinomial(1, logits=topk_probs)

    key, samp_key = jax.random.split(key)
    ix=dist.sample(seed=samp_key)

    xcol=jnp.atleast_2d(topk_indices[ix.astype('bool')]).T

    x = jnp.append(x, xcol, -1)

In [303]:
dist = tfd.Multinomial(1, logits=topk_probs)
dist.sample(seed=samp_key)

Array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [386]:
print(enc.decode(x[0,:].tolist()))

Hello, I'm a language model, and it's a tough one. Not only that but it doesn't get around the fact that people who are


In [170]:
logits = GPT2.apply({'params': model_hf.params}, x)


In [171]:
logits.shape

(5, 8, 50257)

In [137]:
ix.shape

(5, 50)

In [85]:
res.shape

(5, 8, 50257)

In [67]:
from jax.tree_util import tree_structure
print(tree_structure(y['params']['transformer']['h']['0']))

PyTreeDef({'attn': {'c_attn': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}, 'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}})


In [170]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params']['transformer'])))

initialized parameter shapes:
 {'h': {'0': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '1': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '2': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '3': {'attn': {'c

In [171]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params['transformer'])))

initialized parameter shapes:
 {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '11': {'attn': {

In [41]:
# Seems like pytorch Linear is not equal to flax Dense. The matrices are defined transpose. So, we can't actually use
# Dense with the original GPT2 weights

x = jax.random.uniform(jax.random.key(0), (1, 2, 5))

# This will transpose the kernel and multiply by the input,
# This should be the same as the GPT weights
lp = nn.Einsum((3*5, 5), '...ij,...kj->...ik')

y = lp.init(jax.random.key(0), x)

In [42]:
lp.apply(y, x)

Array([[[ 0.40416887, -0.15771458, -0.1418275 ,  0.21624644,
          0.01153318,  0.09767484, -0.16552436, -0.06833879,
         -0.07042018, -0.13803177, -0.00222424,  0.2503722 ,
         -0.3504156 ,  0.18865553, -0.33016038],
        [ 0.62610245, -0.35361993, -0.40769407,  0.22174977,
          0.1036396 ,  0.2302788 , -0.25040534, -0.5411675 ,
         -0.35929912,  0.05858286,  0.03919724,  0.3880062 ,
         -0.13242133,  0.4200734 , -0.47516495]]], dtype=float32)

In [43]:
lp2 = nn.Dense(3*5)

y2 = lp2.init(jax.random.key(0), x)

In [172]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params'])))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '1': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '2': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '

In [15]:
jax.tree.map(lambda x: x.shape, y['params'])


{'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,),
      'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_pro

In [16]:
jax.tree.map(lambda x: x.shape, model_hf.params)


{'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,),
      'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_pro