In [1]:
import sys
sys.path.append("/home/dibya/nfs/practice/transformer_from_scratch/")

In [2]:
import jax
import jax.numpy as jnp
from gemma import test_equality, transformer, sampler
import dataclasses

In [None]:
def test_sampling():
    model = transformer.Gemma.create(jax.random.PRNGKey(0), transformer.GemmaConfig.gemma2_2b(jnp.bfloat16))
    print(dataclasses.asdict(model))
    sampler.sample(model, jnp.ones((1, 1024), dtype=jnp.int32), jnp.ones((1, 1024), dtype=jnp.int32), 1024, 1, jax.random.PRNGKey(0))

jax.eval_shape(test_sampling)

{'embedder': {'input_embedding': Traced<ShapedArray(float32[256000,2304])>with<DynamicJaxprTrace>}, 'transformer': {'layers': {'pre_attention_norm': {'scale': Traced<ShapedArray(float32[26,2304])>with<DynamicJaxprTrace>}, 'attn': {'q_einsum': {'w': Traced<ShapedArray(float32[26,8,2304,256])>with<DynamicJaxprTrace>, 'dtype': <class 'jax.numpy.bfloat16'>}, 'kv_einsum': {'w': Traced<ShapedArray(float32[26,2,4,2304,256])>with<DynamicJaxprTrace>, 'dtype': <class 'jax.numpy.bfloat16'>}, 'attn_vec_einsum': {'w': Traced<ShapedArray(float32[26,8,256,2304])>with<DynamicJaxprTrace>, 'dtype': <class 'jax.numpy.bfloat16'>}, 'num_heads': 8, 'num_kv_heads': 4, 'head_dim': 256, 'dtype': <class 'jax.numpy.bfloat16'>, 'attn_logits_softcap': 50.0, 'query_pre_attn_norm': 'rsqrt_head_dim'}, 'post_attention_norm': {'scale': Traced<ShapedArray(float32[26,2304])>with<DynamicJaxprTrace>}, 'pre_ffw_norm': {'scale': Traced<ShapedArray(float32[26,2304])>with<DynamicJaxprTrace>}, 'mlp': {'gating_einsum': Traced<Sh

In [4]:
params = test_equality._load_params()



In [5]:
tokenizer = test_equality._load_tokenizer()

In [6]:
tokens = tokenizer.encode("Continue this sequence: 1 1 2 3 5", add_bos=True)
tokens = jnp.array(tokens)[None]

2025-03-16 21:26:53.086453: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-16 21:26:53.104907: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-16 21:26:53.110471: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [7]:
for tok in tokens[0].tolist():
    print(tok, tokenizer.id_to_piece(tok))

2 <bos>
20017 Continue
736 ▁this
10629 ▁sequence
235292 :
235248 ▁
235274 1
235248 ▁
235274 1
235248 ▁
235284 2
235248 ▁
235304 3
235248 ▁
235308 5


In [8]:
logits, extra = test_equality._run_reference(params, tokens)

In [9]:
jax.tree.map(lambda x: x.shape, extra)

{'encoded': (1, 15, 2304),
 'logits': (1, 15, 256128),
 'logits_pre_norm': (1, 15, 256128),
 'pre_logits': (1, 15, 2304)}

In [10]:
pred_logits, pred_extra = test_equality._run_ours(params, tokens)

In [11]:
jax.tree.map(lambda x: x.shape, pred_extra)

{'embeddings': (1, 15, 2304),
 'encoded': (1, 15, 2304),
 'logits': (1, 15, 256128),
 'logits_pre_norm': (1, 15, 256128),
 'pre_logits': (1, 15, 2304),
 'scan': {}}

In [19]:
extra['encoded'] - pred_extra['encoded']

Array([[[ 0.01154709,  0.00688124,  0.02609921, ...,  0.00097656,
         -0.00624847,  0.00311089],
        [-0.01195002,  0.08745241,  0.05194855, ...,  0.00266457,
          0.06879759, -0.0241518 ],
        [-0.01238871,  0.03646362, -0.01284981, ...,  0.03772545,
         -0.03418159,  0.03597641],
        ...,
        [ 0.03705287, -0.0365696 ,  0.01210785, ..., -0.01609993,
          0.00453806,  0.00426555],
        [ 0.01983333, -0.00200248, -0.08643913, ...,  0.08524203,
          0.00840425,  0.08208084],
        [ 0.02864361,  0.08491135, -0.13770962, ..., -0.04216766,
          0.00235653, -0.04208452]]], dtype=float32)

In [12]:
pred_logits.dtype

dtype('float32')

In [13]:
o1 = jax.nn.softmax(logits, axis=-1)
o2 = jax.nn.softmax(pred_logits, axis=-1)

In [14]:
o1.shape, o2.shape

((1, 15, 256128), (1, 15, 256128))

In [15]:
jnp.abs(o1 - o2).sum(axis=-1)

Array([[0.00160499, 0.00971379, 0.00632929, 0.0070101 , 0.00223727,
        0.00165324, 0.00538984, 0.00191486, 0.00532418, 0.00274373,
        0.00077039, 0.00405245, 0.00143561, 0.00265977, 0.0007084 ]],      dtype=float32)

In [16]:
jax.lax.top_k(o1, 5)

[Array([[[0.13441935, 0.07789737, 0.07040887, 0.06043119, 0.05333681],
         [0.26655453, 0.20126286, 0.16760115, 0.0621862 , 0.03980314],
         [0.10803294, 0.07695038, 0.07177233, 0.05175554, 0.05008181],
         [0.26113826, 0.11728791, 0.11223355, 0.06043551, 0.05512523],
         [0.3845569 , 0.24846284, 0.11104581, 0.04259438, 0.02696753],
         [0.32959858, 0.1319176 , 0.10494106, 0.07236732, 0.07091951],
         [0.324436  , 0.1523442 , 0.08446122, 0.07489645, 0.05883117],
         [0.43126646, 0.14894664, 0.14168763, 0.08984134, 0.06583921],
         [0.4894781 , 0.07629847, 0.06626438, 0.06558751, 0.05672293],
         [0.44705915, 0.2899046 , 0.10973729, 0.05108383, 0.03355091],
         [0.94157904, 0.01166774, 0.00475539, 0.00426424, 0.00411787],
         [0.4281638 , 0.20188926, 0.13267156, 0.12836124, 0.05609043],
         [0.94731784, 0.00885544, 0.00622879, 0.0051957 , 0.00248374],
         [0.6293795 , 0.18396144, 0.05174865, 0.05107486, 0.03111089],
      

In [17]:
@jax.jit
def test_sampling(params, tokens):
    model = transformer.Gemma.create_from_pretrained(params, transformer.GemmaConfig.gemma2_2b())
    return sampler.sample(model, tokens, jnp.ones_like(tokens, dtype=jnp.int32), 32, 1, jax.random.PRNGKey(0))

sampled_tokens = test_sampling(params, tokens)

1 15


In [18]:
sampled_tokens.input_mask

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'input_mask'

In [29]:
tokenizer.decode(sampled_tokens.token_buffer[0].tolist())

' 8 13 21 34 55 87 144 233 377 61'

In [18]:
import flax
import numpy as np
import nnjax
from gemma import transformer

model = jax.eval_shape(
    lambda: transformer.Gemma.create(key=jax.random.PRNGKey(0), config=transformer.GemmaConfig.gemma2_2b())
)
new_params = jax.tree.map(lambda x: x, params) # Copy
new_params["embedder"] = new_params["transformer"].pop("embedder")
new_params = jax.tree.map(jnp.asarray, new_params)
d1 = flax.traverse_util.flatten_dict(new_params)
d2 = flax.traverse_util.flatten_dict(nnjax.asdict(model))
print("keys in update but not in model:")
print(set(d1.keys()) - set(d2.keys()))
print("keys in model but not in update:")
print(set(d2.keys()) - set(d1.keys()))

keys in update but not in model:
set()
keys in model but not in update:
set()


In [None]:
from gemma import reference
import jax
reference_model = reference.Model(
    **{**reference.get_config("gemma2_2b").to_dict(), "vocab_size": 256_128}
)
jax.eval_shape(reference_model.init, jax.random.PRNGKey(0), jnp.array(tokens)[None])['params']

ValueError: too many values to unpack (expected 3)

In [18]:
params

{'transformer': {'embedder': {'input_embedding': array([[0.0341797, -0.0319824, 0.0732422, ..., 0.0200195, 0.0493164,
           -0.0327148],
          [-0.019165, 0.0498047, -0.0380859, ..., -0.00524902, -0.0258789,
           -0.0137939],
          [0.000125885, -0.00585938, 0.022583, ..., 0.0146484, -0.00799561,
           -0.0122681],
          ...,
          [0.0213623, -0.0412598, 0.0310059, ..., 0.0410156, 0.0125732,
           -0.0283203],
          [0.0307617, -0.0410156, 0.0358887, ..., 0.0393066, 0.0205078,
           -0.0241699],
          [0.019043, -0.0446777, 0.0319824, ..., 0.0407715, 0.0187988,
           -0.027832]], dtype=bfloat16)},
  'final_norm': {'scale': array([2.32812, 2.35938, 2.28125, ..., 4.65625, 2.53125, 2.4375],
         dtype=bfloat16)},
  'layers': {'attn': {'attn_vec_einsum': {'w': array([[[[0.00872803, 0.010376, 0.0148315, ..., 0.00270081,
               -0.00558472, 0.00891113],
              [0.0115967, 0.0206299, 0.00970459, ..., 0.000938416,
     

In [17]:
from gemma import transformer, sampler
import dataclasses

# import orbax.checkpoint as ocp
# checkpointer = ocp.StandardCheckpointer()
# params = checkpointer.restore("/nfs/nfs2/users/dibya/gemma-2/gemma2-2b/")

# model = transformer.Gemma.create_from_pretrained(params, dataclasses.replace(transformer.GemmaConfig.gemma2_2b(), dtype="bfloat16"))

1 1024


TypeError: scan body function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:

The input carry x has type bfloat16[1,1024,2304] but the corresponding output carry component has type float32[1,1024,2304], so the dtypes do not match.

Revise the function so that all output types (e.g. shapes and dtypes) match the corresponding input types.

In [8]:
import flax
import jax
import jax.numpy as jnp
params = flax.traverse_util.flatten_dict(params, sep="/")


In [9]:

new_params = {k: v for k, v in params.items() if "transformer/layer_" not in k}

for k in filter(lambda x: "transformer/layer_0/" in x, params.keys()):
    new_params[k.replace("/layer_0/", "/layers/")] = jnp.stack([
        params[k.replace("layer_0", f"layer_{i}")]
        for i in range(26)
    ])
new_params = flax.traverse_util.unflatten_dict(new_params, sep="/")

In [10]:
import jax
jax.tree.map(lambda x: x.shape, new_params)

{'transformer': {'embedder': {'input_embedding': (256128, 2304)},
  'final_norm': {'scale': (2304,)},
  'layers': {'attn': {'attn_vec_einsum': {'w': (26, 8, 256, 2304)},
    'kv_einsum': {'w': (26, 2, 4, 2304, 256)},
    'q_einsum': {'w': (26, 8, 2304, 256)}},
   'mlp': {'gating_einsum': (26, 2, 2304, 9216), 'linear': (26, 9216, 2304)},
   'post_attention_norm': {'scale': (26, 2304)},
   'post_ffw_norm': {'scale': (26, 2304)},
   'pre_attention_norm': {'scale': (26, 2304)},
   'pre_ffw_norm': {'scale': (26, 2304)}}}}

In [11]:
import nnjax
flattened_target = flax.traverse_util.flatten_dict(nnjax.asdict(model))
flattened_src = flax.traverse_util.flatten_dict(new_params)

NameError: name 'model' is not defined

In [8]:
print("keys in target but not in src:")
print(set(flattened_target.keys()) - set(flattened_src.keys()))
print("keys in src but not in target:")
print(set(flattened_src.keys()) - set(flattened_target.keys()))


keys in target but not in src:
{('embedder', 'input_embedding')}
keys in src but not in target:
{('transformer', 'embedder', 'input_embedding')}
