In [2]:
from absl import logging
import flax
import jax.numpy as jnp
import jax
import numpy as np

logging.set_verbosity(logging.INFO)

In [3]:
from transformers import FlaxGPT2LMHeadModel

model_hf = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2")

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params)))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, 

In [6]:
jnp.shape(model_hf.params['transformer']['wpe']['embedding'])

(1024, 768)

In [7]:
model_hf.params['transformer']['h']['0']['attn'].keys()

dict_keys(['c_attn', 'c_proj'])

In [4]:
from flax import linen as nn
from functools import partial

# define a full GTP2 class in Flax and see if we can replicate the original paper along with
# Karpathy

class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class GPTMLP(nn.Module):
    config: GPTConfig

    def setup(self):
        # Simple MLP that upscales, runs through a gelu activation,
        # and then resamples back to the n_embd size (the model size)
        #self.c_fc = nn.Dense(4 * self.config.n_embd)
        # Had to use Einsum to match the matrix multiplication of GPT2 and pytorch. They accept 
        # both shapes and then multiply by the transpose of the matrix (basically means the 
        # shape of the matrix is transpose, but the operation is the same)
        self.c_fc = nn.Einsum((4 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')
        #self.c_proj = nn.Dense(self.config.n_embd)
        self.c_proj = nn.Einsum((self.config.n_embd, 4 * self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs

        x = self.c_fc(x)
        x = nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x
    
class GPTAttention(nn.Module):
    config: GPTConfig

    # we will need to roll our own attention module because the built in one has a bunch of different
    # naming and structure compared to the original GPT, which just references the projection layers
    def setup(self):
        # The first thing we do is project up to 3x the model size, because we are going to split
        # the data into q, v, k
        self.c_attn = nn.Einsum((3 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

        # At the end we have to project everything back to the regular model size of n_embd
        self.c_proj = nn.Einsum((self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs
        B, T, C = jnp.shape(x)
        
        # Project to qkv
        qkv = self.c_attn(x)

        q, k, v = jnp.split(qkv, 3, axis=2)

        query_length, key_length = q.shape[1], k.shape[1]

        # Now reshape with a new head "batch" dimension
        k = jnp.reshape(k, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        q = jnp.reshape(q, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        v = jnp.reshape(v, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)

        # Make the attention mask
        # TODO: For round 1 I'm just trusting the linen functions. They seem to be doing the correct thing here, but I may have to 
        # return to this for a closer look at the math if I'm not getting the GPT2 results
        # Just copied this from the huggingface code to be consistent. First part just broadcasts the causal masks to the batch
        # dimensions (replicating it as a lower triangular matrix of truths). The attention_bias is a stripped down version of the 
        # huggingface code, but the bias has to be floats. They bias the attention softmax, so basically set to -inf where you
        # want it ignored. This will need to be OR'd with an attention mask in certain situations, like encoder/decoder networks.
        c_mask = nn.make_causal_mask(jnp.ones((1, self.config.block_size), dtype="bool"), dtype="bool")
        c_mask = c_mask[:, :, :query_length, :key_length]
        c_mask = jnp.broadcast_to(c_mask, (B,) + c_mask.shape[1:])

        attention_bias = jax.lax.select(
                c_mask > 0,
                jnp.full(c_mask.shape, 0.0).astype(jnp.float32),
                jnp.full(c_mask.shape, jnp.finfo(jnp.float32).min).astype(jnp.float32),
            )

        # use the built in flax libraries to calculate the attention matrix - attention weights are not returned, but could be 
        # I think bias gives more control compared to mask - i.e. bias can be a float. They might result in the same output with a 
        # boolean mask, but I will have to test that.
        y = nn.dot_product_attention(q, k, v, bias=attention_bias)

        # Merge the heads back together
        y = y.reshape((B, T, C))

        # Project output with a FC layer
        y = self.c_proj(y)

        return y

class GPTBlock(nn.Module):
    config: GPTConfig

    def setup(self):
        self.ln_1 = nn.LayerNorm(epsilon=1e-05)
        # I might have to write this manually to get the proper number of parameters, as the old GPT2 code 
        # migh have subtle differences from the Flax implementation
        self.attn = GPTAttention(self.config)
        self.ln_2 = nn.LayerNorm(epsilon=1e-05)
        self.mlp = GPTMLP(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.ln_1(x)
        x = inputs + self.attn(x)
        inputs2 = x
        x = self.ln_2(x)
        x = inputs2 + self.mlp(x)
        return x

class GPTLayers(nn.Module):
    config: GPTConfig
    
    def setup(self):
        self.blocks = [ GPTBlock(self.config, name=str(i)) for i in range(self.config.n_layer) ]

    def __call__(self, inputs):
        x = inputs

        for block in self.blocks:
            x = block(x)
        return x

class GPTModel(nn.Module):
    config: GPTConfig

    def setup(self):
        # This is a little confusing. vocab size is the number of embeddings we need.
        # n_embd is the dimension of each embedding (i.e. capacity for learning properties
        # about this embedding token)
        # input size = (1 x self.block_size) - 1 int for each token
        # output size = then (self.block_size x self.n_embd)
        self.wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
        # embed is just a randomzied parameter matrix, so can be used for positional 
        # encoding as well. I think block size is the token length.
        # This has to match the size of the previous output, as we are just adding.
        self.wpe = nn.Embed(self.config.block_size, self.config.n_embd)
        # The attention layers
        self.h = GPTLayers(self.config)
        self.ln_f = nn.LayerNorm(epsilon=1e-05)

    def __call__(self, inputs):
        x = inputs
        input_shape = jnp.shape(x)       

        x = self.wte(x)

        # For the positional encodings we need an index that is simple the position
        # of each token. This will be the same shape as the input, but will simply
        # be repeating and increasing numbers from 1.
        # jnp.atleast_2d is needed so we can initialize with a batch size of 1
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(inputs).shape[-1]), input_shape)
        x_wpe = self.wpe(position_ids)
        
        x += x_wpe

        x = self.h(x)

        x = self.ln_f(x)

        return x


class GPT(nn.Module):
    config: GPTConfig

    def setup(self):
        self.transformer = GPTModel(self.config)
        # So the Flax model does not return parameters for lm_head, because lm_head is simply the 
        # inverse operation of the initial word embedding, so we can just reuse those weights. They
        # use the term 'tied' for this. The weights are tied together.
        self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)

    def __call__(self, inputs):
        x = inputs
        x = self.transformer(x)

        # This is from the huggingface code - might as well reuse it here.
        shared_kernel = self.transformer.variables['params']['wte']['embedding'].T
        lm_logits = self.lm_head.apply({'params': {'kernel': shared_kernel}}, x)
        return lm_logits

In [5]:
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.random.randint(0, 50257, (1,1024), dtype='i4')

y = GPT2.init(key2, x)

In [376]:
res = GPT2.apply({'params': model_hf.params}, np.random.randint(0, 50257, (2,1024), dtype='i4'))

In [180]:
res.shape

(2, 1024, 50257)

In [138]:
import tiktoken

num_return_sequences = 5
max_length = 20

enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("A pulsar is")

tokens = jnp.array(tokens, dtype="i4")

x = jnp.atleast_2d(tokens).repeat(num_return_sequences, 0)

In [125]:
x

Array([[   32, 22271,   283,   318,   257,  2099,   286, 18758],
       [   32, 22271,   283,   318,   257, 22271,   283,   326]],      dtype=int32)

In [139]:
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

key = jax.random.key(2)
# This needs to use the jax loop functions to be fast
# This will be crazy slow without @jit compilation
while x.shape[1] < max_length:

    #logits = GPT2.apply({'params': model_hf.params}, x)
    logits = GPT2.apply({'params': model_hf.params}, x)

    logits = logits[:, -1, :]

    #probs = nn.softmax(logits)

    topk_probs, topk_indices = jax.lax.top_k(logits, 50)

    #dist = tfd.Multinomial(1, logits=topk_probs)

    #key, samp_key = jax.random.split(key)
    #ix=dist.sample(seed=samp_key)

    ix = jnp.expand_dims(jax.random.categorical(key, topk_probs, axis=-1), axis=-1)

    #print(ix)

    xcol=jnp.atleast_2d(jnp.take_along_axis(topk_indices, ix.astype(jnp.int32), axis=-1))

    #print(xcol.shape)

    x = jnp.append(x, xcol, -1)

In [45]:
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

pad_width = max_length - x.shape[-1]

def generate_text(model_hf, max_length, key, initial_input):
    batch_size = initial_input.shape[0]
    initial_length = initial_input.shape[1]
    
    # Pad the initial input to max_length outside of JAX's traced functions
    #pad_width = max_length - initial_length
    x_padded = jnp.pad(initial_input, ((0, 0), (0, pad_width)), mode='constant', constant_values=0)
    
    def body_fn(state):
        x, key, i = state
        
        # Slice x to the current length using dynamic slicing
        x_slice = jax.lax.dynamic_slice(x, (0, 0), (batch_size, i + 1))
        
        logits = GPT2.apply({'params': params}, x_slice)
        logits = logits[:, -1, :]
        
        topk_probs, topk_indices = jax.lax.top_k(logits, 50)
        
        dist = tfd.Categorical(logits=topk_probs)
        
        key, samp_key = jax.random.split(key)
        ix = dist.sample(seed=samp_key)
        
        xcol = topk_indices[jnp.arange(ix.shape[0]), ix][:, None]
        
        # Update x with new token at position i using dynamic update
        x = jax.lax.dynamic_update_slice(x, xcol, (0, i))
        
        return x, key, i + 1
    
    def cond_fn(state):
        x, key, i = state
        return i < max_length

    # Initialize state with padded input
    state = (x_padded, key, initial_length)
    
    # Run the while loop
    x, key, _ = jax.lax.while_loop(cond_fn, body_fn, state)
    
    # Return the sequence up to max_length
    return x[:, :max_length]

# Example usage:
key = jax.random.PRNGKey(57)
initial_input = jnp.array(x, dtype=jnp.int32)  # Replace with your initial token(s)
max_length = 50  # Replace with desired sequence length

generated_sequence = jax.jit(generate_text)(model_hf.params, max_length, key, initial_input)
print(generated_sequence)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (2, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_fn at /tmp/ipykernel_2479/446094793.py:17 for while_loop. This concrete value was not available in Python because it depends on the value of the argument state[2].

In [208]:
import tiktoken
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

num_return_sequences = 20
max_length = 50

enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("The quick brown fox,")

tokens = jnp.array(tokens, dtype="i4")

initial_input = jnp.atleast_2d(tokens).repeat(num_return_sequences, 0)

init_length=jnp.shape(initial_input)[-1]

initial_input=jnp.pad(initial_input, ((0, 0), (0, max_length-init_length)), mode='constant', constant_values=0)

In [167]:
initial_input

Array([[   32, 22271,   283,   318,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [   32, 22271,   283,   318,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [   32, 22271,   283,   318,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [   32, 22271,   283,   318,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [   32, 22271,   283,   318,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0]], dtype=int32)

In [213]:
rng = jax.random.key(2)

def get_sentence(n, carry):
    x, rng = carry

    rng, new_rng = jax.random.split(rng)

    # pull out only the sequence up until the currently filled in
    # values
    logits = GPT2.apply({'params': model_hf.params}, x)

    # GPT2 predicted the new token int he last column of the output, 
    # so we only need this.
    logits = logits[:, n, :]

    # Sample from the top 50 tokens
    topk_probs, topk_indices = jax.lax.top_k(logits, 50)

    #dist = tfd.Multinomial(1, logits=topk_probs)

    #rng, samp_key = jax.random.split(rng)
    #ix=dist.sample(seed=samp_key)

    #xcol=jnp.take_along_axis(topk_indices, ix.astype(jnp.int32), axis=-1)

    ix = jnp.expand_dims(jax.random.categorical(rng, topk_probs/0.7, axis=-1), axis=-1)

    #print(ix)

    xcol=jnp.atleast_2d(jnp.take_along_axis(topk_indices, ix.astype(jnp.int32), axis=-1))

    #print(x.shape)

    #x = x.at[:, n+1].set(xcol)

    x = jax.lax.dynamic_update_slice(x, xcol, (0, n+1))

    #x = jnp.append(x, xcol, -1)

    return x, rng

x, rng = jax.lax.fori_loop(init_length-1, max_length-1, get_sentence, (initial_input, rng))


In [96]:
rng = jax.random.key(57)

generated_sequence = jax.jit(generate_text)(initial_input=x, rng=rng)

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>


TypeError: Shapes must be 1D sequences of concrete values of integer type, got (2, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function scanned_fun at /home/don/miniconda3/envs/jax-env/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:2007 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[1][2].

In [12]:
initial_input.shape

(1, 4)

In [214]:
for i in range(0, x.shape[0]):
    print(enc.decode(x[i,:].tolist()))

The quick brown fox, however, was not what I was worried about. It was a big bear, but if he'd grown to like it he would have been just as eager to hunt.

I could not quite say if there was danger
The quick brown fox, by the way, seems to have some sort of a strong instinct for violence. He is, in fact, one of the most cunning-looking foxes amongst the group's members, apparently using both paws to try and keep
The quick brown fox, at last seen in the sky behind my back

Took my hand for its big nose, but found me the long fox

The great bald eagle, at last seen on the horizon

He ran with me
The quick brown fox, for instance, needs to stay up on a day's sleep, which can mean making sure his coat doesn't become a habit, or maybe picking a litter of flies to feed to.

In many cases, these two
The quick brown fox, who had a big grin on his face, was smiling. "So who's in town? Are you looking for your father." A small fox with very bright yellow eyes and little blue teeth, who had an arm

In [170]:
logits = GPT2.apply({'params': model_hf.params}, x)


In [171]:
logits.shape

(5, 8, 50257)

In [137]:
ix.shape

(5, 50)

In [85]:
res.shape

(5, 8, 50257)

In [67]:
from jax.tree_util import tree_structure
print(tree_structure(y['params']['transformer']['h']['0']))

PyTreeDef({'attn': {'c_attn': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}, 'ln_1': {'bias': *, 'scale': *}, 'ln_2': {'bias': *, 'scale': *}, 'mlp': {'c_fc': {'bias': *, 'kernel': *}, 'c_proj': {'bias': *, 'kernel': *}}})


In [170]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params']['transformer'])))

initialized parameter shapes:
 {'h': {'0': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '1': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '2': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '3': {'attn': {'c

In [171]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params['transformer'])))

initialized parameter shapes:
 {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '11': {'attn': {

In [41]:
# Seems like pytorch Linear is not equal to flax Dense. The matrices are defined transpose. So, we can't actually use
# Dense with the original GPT2 weights

x = jax.random.uniform(jax.random.key(0), (1, 2, 5))

# This will transpose the kernel and multiply by the input,
# This should be the same as the GPT weights
lp = nn.Einsum((3*5, 5), '...ij,...kj->...ik')

y = lp.init(jax.random.key(0), x)

In [42]:
lp.apply(y, x)

Array([[[ 0.40416887, -0.15771458, -0.1418275 ,  0.21624644,
          0.01153318,  0.09767484, -0.16552436, -0.06833879,
         -0.07042018, -0.13803177, -0.00222424,  0.2503722 ,
         -0.3504156 ,  0.18865553, -0.33016038],
        [ 0.62610245, -0.35361993, -0.40769407,  0.22174977,
          0.1036396 ,  0.2302788 , -0.25040534, -0.5411675 ,
         -0.35929912,  0.05858286,  0.03919724,  0.3880062 ,
         -0.13242133,  0.4200734 , -0.47516495]]], dtype=float32)

In [43]:
lp2 = nn.Dense(3*5)

y2 = lp2.init(jax.random.key(0), x)

In [172]:
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params'])))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '1': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '2': {'attn': {'c_attn': {'bias': (1152,), 'kernel': (1152, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 384)}}, 'ln_1': {'bias': (384,), 'scale': (384,)}, 'ln_2': {'bias': (384,), 'scale': (384,)}, 'mlp': {'c_fc': {'bias': (1536,), 'kernel': (1536, 384)}, 'c_proj': {'bias': (384,), 'kernel': (384, 1536)}}}, '

In [15]:
jax.tree.map(lambda x: x.shape, y['params'])


{'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,),
      'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_pro

In [16]:
jax.tree.map(lambda x: x.shape, model_hf.params)


{'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,),
      'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}},
   '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)},
     'c_proj': {'bias': (768,), 'kernel': (768, 768)}},
    'ln_1': {'bias': (768,), 'scale': (768,)},
    'ln_2': {'bias': (768,), 'scale': (768,)},
    'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)},
     'c_pro