Orphaned notebook with values to "glue" using [myst-nb](https://myst-nb.readthedocs.io/en/v0.13.2/use/glue.html).

In [1]:
import os

import myst_nb

import penzai
from penzai import pz

import jax
import jax.numpy as jnp
import orbax.checkpoint
from jax.experimental import mesh_utils

import IPython.utils.capture



In [2]:
pz.ts.register_as_default(streaming=False)
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from penzai.example_models import gemma

In [4]:
try:
  import kagglehub
  if jax.devices()[0].platform == "tpu":
    load_gemma = True
  else:
    load_gemma = False
  load_gemma = True
except ImportError:
  kagglehub = None
  load_gemma = False

In [5]:
if load_gemma:
  weights_dir = kagglehub.model_download('google/gemma/Flax/2b')
  ckpt_path = os.path.join(weights_dir, '2b')

  checkpointer = orbax.checkpoint.PyTreeCheckpointer()
  metadata = checkpointer.metadata(ckpt_path)

  n_devices = jax.local_device_count()
  sharding_devices = mesh_utils.create_device_mesh((n_devices,))
  sharding = jax.sharding.PositionalSharding(sharding_devices)
  restore_args = jax.tree_util.tree_map(
      lambda m: orbax.checkpoint.ArrayRestoreArgs(
          restore_type=jax.Array,
          sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
      ),
      metadata,
  )
  flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

  model = gemma.model_core.GemmaTransformer.from_pretrained(
      flat_params, upcast_activations_to_float32=False
  )

else:
  model = pz.nn.initialize_parameters(
      gemma.model_core.GemmaTransformer.from_config(
          gemma.model_core.GemmaTransformerConfig(
              num_heads=8,
              embedding_dim=256,
              projection_dim=32,
              single_kv_head=False,
              mlp_hidden_dim=512,
              num_decoder_blocks=10,
              vocab_size=1000,
              parameter_dtype=jnp.float32,
              activation_dtype=jnp.float32,
          )
      ),
      jax.random.key(1),
  )

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download...


  0%|          | 0.00/3.66G [00:00<?, ?B/s]

  0%|          | 9.00M/3.66G [00:00<00:47, 82.7MB/s]

  1%|          | 25.0M/3.66G [00:00<00:31, 124MB/s] 

  1%|          | 37.0M/3.66G [00:00<00:43, 89.3MB/s]

  1%|▏         | 49.0M/3.66G [00:00<00:38, 100MB/s] 

  2%|▏         | 65.0M/3.66G [00:00<00:35, 109MB/s]

  2%|▏         | 86.0M/3.66G [00:00<00:27, 140MB/s]

  3%|▎         | 109M/3.66G [00:00<00:22, 168MB/s] 

  3%|▎         | 127M/3.66G [00:00<00:21, 174MB/s]

  4%|▍         | 145M/3.66G [00:01<00:30, 125MB/s]

  5%|▍         | 169M/3.66G [00:01<00:24, 154MB/s]

  5%|▍         | 187M/3.66G [00:01<00:23, 159MB/s]

  6%|▌         | 209M/3.66G [00:01<00:31, 118MB/s]

  6%|▌         | 233M/3.66G [00:01<00:26, 138MB/s]

  7%|▋         | 256M/3.66G [00:01<00:23, 159MB/s]

  7%|▋         | 274M/3.66G [00:02<00:30, 120MB/s]

  8%|▊         | 289M/3.66G [00:02<00:35, 102MB/s]

  8%|▊         | 302M/3.66G [00:02<00:33, 108MB/s]

  9%|▊         | 325M/3.66G [00:02<00:26, 134MB/s]

  9%|▉         | 347M/3.66G [00:02<00:23, 155MB/s]

 10%|▉         | 370M/3.66G [00:02<00:20, 175MB/s]

 10%|█         | 392M/3.66G [00:02<00:18, 189MB/s]

 11%|█         | 412M/3.66G [00:03<00:20, 175MB/s]

 12%|█▏        | 437M/3.66G [00:03<00:17, 196MB/s]

 12%|█▏        | 459M/3.66G [00:03<00:16, 203MB/s]

 13%|█▎        | 480M/3.66G [00:03<00:18, 184MB/s]

 13%|█▎        | 500M/3.66G [00:03<00:17, 191MB/s]

 14%|█▍        | 523M/3.66G [00:03<00:16, 203MB/s]

 15%|█▍        | 544M/3.66G [00:03<00:18, 177MB/s]

 15%|█▍        | 562M/3.66G [00:03<00:18, 176MB/s]

 16%|█▌        | 582M/3.66G [00:04<00:18, 183MB/s]

 16%|█▌        | 601M/3.66G [00:04<00:17, 187MB/s]

 17%|█▋        | 620M/3.66G [00:04<00:17, 190MB/s]

 17%|█▋        | 639M/3.66G [00:04<00:17, 190MB/s]

 18%|█▊        | 658M/3.66G [00:04<00:20, 155MB/s]

 18%|█▊        | 677M/3.66G [00:04<00:19, 163MB/s]

 19%|█▊        | 698M/3.66G [00:04<00:18, 176MB/s]

 19%|█▉        | 721M/3.66G [00:05<00:25, 125MB/s]

 20%|█▉        | 742M/3.66G [00:05<00:21, 144MB/s]

 20%|██        | 763M/3.66G [00:05<00:19, 159MB/s]

 21%|██        | 781M/3.66G [00:05<00:21, 146MB/s]

 21%|██▏       | 804M/3.66G [00:05<00:18, 168MB/s]

 22%|██▏       | 822M/3.66G [00:05<00:22, 138MB/s]

 23%|██▎       | 845M/3.66G [00:05<00:19, 160MB/s]

 23%|██▎       | 865M/3.66G [00:05<00:17, 170MB/s]

 24%|██▎       | 883M/3.66G [00:06<00:20, 149MB/s]

 24%|██▍       | 899M/3.66G [00:06<00:21, 140MB/s]

 25%|██▍       | 921M/3.66G [00:06<00:18, 161MB/s]

 25%|██▌       | 945M/3.66G [00:06<00:16, 183MB/s]

 26%|██▌       | 964M/3.66G [00:06<00:18, 156MB/s]

 26%|██▌       | 981M/3.66G [00:06<00:19, 151MB/s]

 27%|██▋       | 0.98G/3.66G [00:06<00:16, 173MB/s]

 27%|██▋       | 1.00G/3.66G [00:06<00:14, 192MB/s]

 28%|██▊       | 1.03G/3.66G [00:07<00:13, 207MB/s]

 29%|██▊       | 1.05G/3.66G [00:07<00:17, 161MB/s]

 29%|██▉       | 1.07G/3.66G [00:07<00:19, 144MB/s]

 30%|██▉       | 1.09G/3.66G [00:07<00:16, 165MB/s]

 30%|███       | 1.11G/3.66G [00:07<00:17, 157MB/s]

 31%|███       | 1.12G/3.66G [00:07<00:18, 150MB/s]

 31%|███       | 1.14G/3.66G [00:07<00:16, 164MB/s]

 32%|███▏      | 1.16G/3.66G [00:07<00:14, 182MB/s]

 32%|███▏      | 1.18G/3.66G [00:08<00:14, 183MB/s]

 33%|███▎      | 1.20G/3.66G [00:08<00:13, 194MB/s]

 33%|███▎      | 1.22G/3.66G [00:08<00:14, 176MB/s]

 34%|███▍      | 1.24G/3.66G [00:08<00:14, 181MB/s]

 34%|███▍      | 1.26G/3.66G [00:08<00:22, 116MB/s]

 35%|███▍      | 1.28G/3.66G [00:08<00:20, 127MB/s]

 35%|███▌      | 1.29G/3.66G [00:08<00:17, 142MB/s]

 36%|███▌      | 1.31G/3.66G [00:09<00:20, 125MB/s]

 36%|███▌      | 1.32G/3.66G [00:09<00:24, 104MB/s]

 36%|███▋      | 1.34G/3.66G [00:09<00:23, 106MB/s]

 37%|███▋      | 1.35G/3.66G [00:09<00:21, 118MB/s]

 37%|███▋      | 1.37G/3.66G [00:09<00:17, 141MB/s]

 38%|███▊      | 1.40G/3.66G [00:09<00:14, 172MB/s]

 39%|███▊      | 1.42G/3.66G [00:09<00:14, 169MB/s]

 39%|███▉      | 1.44G/3.66G [00:09<00:12, 195MB/s]

 40%|███▉      | 1.46G/3.66G [00:10<00:11, 201MB/s]

 40%|████      | 1.48G/3.66G [00:10<00:12, 194MB/s]

 41%|████      | 1.50G/3.66G [00:10<00:11, 194MB/s]

 42%|████▏     | 1.52G/3.66G [00:10<00:11, 206MB/s]

 42%|████▏     | 1.54G/3.66G [00:10<00:12, 185MB/s]

 43%|████▎     | 1.57G/3.66G [00:10<00:10, 205MB/s]

 43%|████▎     | 1.59G/3.66G [00:10<00:10, 215MB/s]

 44%|████▍     | 1.62G/3.66G [00:10<00:09, 229MB/s]

 45%|████▍     | 1.64G/3.66G [00:10<00:10, 217MB/s]

 45%|████▌     | 1.66G/3.66G [00:11<00:09, 221MB/s]

 46%|████▌     | 1.68G/3.66G [00:11<00:09, 214MB/s]

 47%|████▋     | 1.71G/3.66G [00:11<00:09, 227MB/s]

 47%|████▋     | 1.73G/3.66G [00:11<00:14, 145MB/s]

 48%|████▊     | 1.75G/3.66G [00:11<00:14, 142MB/s]

 48%|████▊     | 1.77G/3.66G [00:11<00:13, 156MB/s]

 49%|████▉     | 1.79G/3.66G [00:11<00:11, 173MB/s]

 49%|████▉     | 1.81G/3.66G [00:12<00:11, 174MB/s]

 50%|█████     | 1.84G/3.66G [00:12<00:10, 194MB/s]

 51%|█████     | 1.86G/3.66G [00:12<00:10, 191MB/s]

 51%|█████     | 1.88G/3.66G [00:12<00:11, 169MB/s]

 52%|█████▏    | 1.89G/3.66G [00:12<00:11, 166MB/s]

 52%|█████▏    | 1.91G/3.66G [00:12<00:10, 174MB/s]

 53%|█████▎    | 1.93G/3.66G [00:12<00:13, 142MB/s]

 53%|█████▎    | 1.95G/3.66G [00:13<00:11, 159MB/s]

 54%|█████▍    | 1.98G/3.66G [00:13<00:09, 183MB/s]

 54%|█████▍    | 1.99G/3.66G [00:13<00:10, 168MB/s]

 55%|█████▍    | 2.01G/3.66G [00:13<00:09, 178MB/s]

 56%|█████▌    | 2.04G/3.66G [00:13<00:09, 194MB/s]

 56%|█████▌    | 2.06G/3.66G [00:13<00:08, 203MB/s]

 57%|█████▋    | 2.08G/3.66G [00:13<00:09, 175MB/s]

 57%|█████▋    | 2.10G/3.66G [00:13<00:09, 176MB/s]

 58%|█████▊    | 2.12G/3.66G [00:13<00:08, 194MB/s]

 58%|█████▊    | 2.14G/3.66G [00:14<00:07, 206MB/s]

 59%|█████▉    | 2.16G/3.66G [00:14<00:07, 202MB/s]

 60%|█████▉    | 2.19G/3.66G [00:14<00:07, 212MB/s]

 60%|██████    | 2.21G/3.66G [00:14<00:08, 193MB/s]

 61%|██████    | 2.22G/3.66G [00:14<00:08, 193MB/s]

 61%|██████▏   | 2.25G/3.66G [00:14<00:07, 211MB/s]

 62%|██████▏   | 2.27G/3.66G [00:14<00:06, 215MB/s]

 63%|██████▎   | 2.30G/3.66G [00:14<00:06, 229MB/s]

 63%|██████▎   | 2.32G/3.66G [00:14<00:06, 236MB/s]

 64%|██████▍   | 2.34G/3.66G [00:15<00:06, 234MB/s]

 65%|██████▍   | 2.37G/3.66G [00:15<00:06, 210MB/s]

 65%|██████▌   | 2.39G/3.66G [00:15<00:06, 202MB/s]

 66%|██████▌   | 2.41G/3.66G [00:15<00:06, 213MB/s]

 66%|██████▋   | 2.43G/3.66G [00:15<00:07, 172MB/s]

 67%|██████▋   | 2.45G/3.66G [00:15<00:07, 185MB/s]

 67%|██████▋   | 2.47G/3.66G [00:15<00:07, 175MB/s]

 68%|██████▊   | 2.50G/3.66G [00:15<00:06, 195MB/s]

 69%|██████▊   | 2.52G/3.66G [00:16<00:09, 134MB/s]

 69%|██████▉   | 2.54G/3.66G [00:16<00:07, 154MB/s]

 70%|██████▉   | 2.56G/3.66G [00:16<00:06, 179MB/s]

 71%|███████   | 2.58G/3.66G [00:16<00:06, 188MB/s]

 71%|███████   | 2.61G/3.66G [00:16<00:05, 203MB/s]

 72%|███████▏  | 2.63G/3.66G [00:16<00:05, 213MB/s]

 72%|███████▏  | 2.65G/3.66G [00:16<00:05, 200MB/s]

 73%|███████▎  | 2.67G/3.66G [00:16<00:05, 193MB/s]

 74%|███████▎  | 2.69G/3.66G [00:17<00:05, 201MB/s]

 74%|███████▍  | 2.72G/3.66G [00:17<00:04, 219MB/s]

 75%|███████▍  | 2.74G/3.66G [00:17<00:04, 207MB/s]

 75%|███████▌  | 2.76G/3.66G [00:17<00:05, 191MB/s]

 76%|███████▌  | 2.78G/3.66G [00:17<00:05, 181MB/s]

 77%|███████▋  | 2.80G/3.66G [00:17<00:04, 195MB/s]

 77%|███████▋  | 2.82G/3.66G [00:17<00:05, 172MB/s]

 78%|███████▊  | 2.85G/3.66G [00:17<00:04, 193MB/s]

 78%|███████▊  | 2.87G/3.66G [00:18<00:04, 188MB/s]

 79%|███████▉  | 2.89G/3.66G [00:18<00:04, 207MB/s]

 80%|███████▉  | 2.91G/3.66G [00:18<00:03, 206MB/s]

 80%|████████  | 2.94G/3.66G [00:18<00:03, 213MB/s]

 81%|████████  | 2.96G/3.66G [00:18<00:03, 195MB/s]

 81%|████████▏ | 2.98G/3.66G [00:18<00:03, 192MB/s]

 82%|████████▏ | 3.00G/3.66G [00:18<00:04, 168MB/s]

 82%|████████▏ | 3.02G/3.66G [00:18<00:03, 181MB/s]

 83%|████████▎ | 3.03G/3.66G [00:19<00:03, 178MB/s]

 83%|████████▎ | 3.05G/3.66G [00:19<00:03, 176MB/s]

 84%|████████▍ | 3.08G/3.66G [00:19<00:03, 198MB/s]

 85%|████████▍ | 3.10G/3.66G [00:19<00:02, 217MB/s]

 85%|████████▌ | 3.12G/3.66G [00:19<00:03, 175MB/s]

 86%|████████▌ | 3.14G/3.66G [00:19<00:02, 188MB/s]

 86%|████████▋ | 3.16G/3.66G [00:19<00:03, 171MB/s]

 87%|████████▋ | 3.18G/3.66G [00:20<00:03, 132MB/s]

 87%|████████▋ | 3.20G/3.66G [00:20<00:04, 107MB/s]

 88%|████████▊ | 3.21G/3.66G [00:20<00:03, 123MB/s]

 88%|████████▊ | 3.23G/3.66G [00:20<00:04, 115MB/s]

 88%|████████▊ | 3.24G/3.66G [00:20<00:03, 116MB/s]

 89%|████████▉ | 3.25G/3.66G [00:20<00:03, 112MB/s]

 89%|████████▉ | 3.27G/3.66G [00:20<00:03, 107MB/s]

 89%|████████▉ | 3.28G/3.66G [00:21<00:04, 85.4MB/s]

 90%|████████▉ | 3.29G/3.66G [00:21<00:04, 95.4MB/s]

 90%|█████████ | 3.30G/3.66G [00:21<00:05, 67.8MB/s]

 91%|█████████ | 3.33G/3.66G [00:21<00:03, 104MB/s] 

 91%|█████████▏| 3.34G/3.66G [00:21<00:03, 103MB/s]

 92%|█████████▏| 3.37G/3.66G [00:21<00:02, 133MB/s]

 93%|█████████▎| 3.39G/3.66G [00:22<00:01, 162MB/s]

 93%|█████████▎| 3.42G/3.66G [00:22<00:01, 186MB/s]

 94%|█████████▍| 3.44G/3.66G [00:22<00:01, 161MB/s]

 95%|█████████▍| 3.46G/3.66G [00:22<00:01, 185MB/s]

 95%|█████████▌| 3.49G/3.66G [00:22<00:00, 199MB/s]

 96%|█████████▌| 3.51G/3.66G [00:22<00:00, 212MB/s]

 97%|█████████▋| 3.54G/3.66G [00:22<00:00, 227MB/s]

 97%|█████████▋| 3.56G/3.66G [00:22<00:00, 239MB/s]

 98%|█████████▊| 3.59G/3.66G [00:22<00:00, 246MB/s]

 99%|█████████▊| 3.61G/3.66G [00:23<00:00, 253MB/s]

 99%|█████████▉| 3.64G/3.66G [00:23<00:00, 256MB/s]

100%|██████████| 3.66G/3.66G [00:23<00:00, 169MB/s]


Extracting model files...


In [6]:
%%autovisualize

with IPython.utils.capture.capture_output() as capturer:
  pz.select(model).at(lambda root: (
      root.body.body.body.sublayers[2].sublayers[0].delta.sublayers[1].input_to_query,
      root.body.body.body.sublayers[2].sublayers[1].delta.sublayers[1],
  )).show_value()

In [7]:
myst_nb.glue(
    "penzai_teaser",
    IPython.display.HTML(
        "".join(output.data['text/html'] for output in capturer.outputs)
    ),
)