Orphaned notebook with values to "glue" using [myst-nb](https://myst-nb.readthedocs.io/en/v0.13.2/use/glue.html).

In [1]:
import os

import myst_nb

import penzai
from penzai.experimental.v2 import pz

import jax
import jax.numpy as jnp
import orbax.checkpoint
from jax.experimental import mesh_utils

import IPython.utils.capture

In [2]:
pz.ts.register_as_default(streaming=False)
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from penzai.experimental.v2.models.transformer.variants import gemma
from penzai.experimental.v2.models.transformer.variants import llamalike_common

In [4]:
try:
  import kagglehub
  if jax.devices()[0].platform == "tpu":
    load_gemma = True
  else:
    load_gemma = False
  load_gemma = True
except ImportError:
  kagglehub = None
  load_gemma = False

In [5]:
if load_gemma:
  weights_dir = kagglehub.model_download('google/gemma/Flax/2b')
  ckpt_path = os.path.join(weights_dir, '2b')

  checkpointer = orbax.checkpoint.PyTreeCheckpointer()
  metadata = checkpointer.metadata(ckpt_path)

  n_devices = jax.local_device_count()
  sharding_devices = mesh_utils.create_device_mesh((n_devices,))
  sharding = jax.sharding.PositionalSharding(sharding_devices)
  restore_args = jax.tree_util.tree_map(
      lambda m: orbax.checkpoint.ArrayRestoreArgs(
          restore_type=jax.Array,
          sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
      ),
      metadata,
  )
  flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

  model = gemma.gemma_from_pretrained_checkpoint(
      flat_params, upcast_activations_to_float32=False
  )

else:
  model = llamalike_common.build_llamalike_transformer(
      llamalike_common.LlamalikeTransformerConfig(
          num_kv_heads=8,
          query_head_multiplier=1,
          embedding_dim=256,
          projection_dim=32,
          mlp_hidden_dim=512,
          num_decoder_blocks=10,
          vocab_size=1000,
          mlp_variant="geglu_approx",
          rope_wavelength=10_000,
          tie_embedder_and_logits=True,
          use_layer_stack=False,
          parameter_dtype=jnp.float32,
          activation_dtype=jnp.float32,
      ),
      init_base_rng=jax.random.key(42),
  )



In [6]:
%%autovisualize

with IPython.utils.capture.capture_output() as capturer:
  pz.select(model).at(lambda root: (
      root.body.sublayers[2].sublayers[0].delta.sublayers[1].input_to_query,
      root.body.sublayers[2].sublayers[1].delta.sublayers[1],
  )).show_value()

In [7]:
myst_nb.glue(
    "penzai_teaser",
    IPython.display.HTML(
        "".join(output.data['text/html'] for output in capturer.outputs)
    ),
)