Orphaned notebook with values to "glue" using [myst-nb](https://myst-nb.readthedocs.io/en/v0.13.2/use/glue.html).

In [1]:
import os

import myst_nb

import penzai
from penzai import pz

import jax
import jax.numpy as jnp
import orbax.checkpoint
from jax.experimental import mesh_utils

import IPython.utils.capture

In [2]:
pz.ts.register_as_default(streaming=False)
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from penzai.example_models import gemma

In [4]:
try:
  import kagglehub
  if jax.devices()[0].platform == "tpu":
    load_gemma = True
  else:
    load_gemma = False
  load_gemma = True
except ImportError:
  kagglehub = None
  load_gemma = False

In [5]:
if load_gemma:
  weights_dir = kagglehub.model_download('google/gemma/Flax/2b')
  ckpt_path = os.path.join(weights_dir, '2b')

  checkpointer = orbax.checkpoint.PyTreeCheckpointer()
  metadata = checkpointer.metadata(ckpt_path)

  n_devices = jax.local_device_count()
  sharding_devices = mesh_utils.create_device_mesh((n_devices,))
  sharding = jax.sharding.PositionalSharding(sharding_devices)
  restore_args = jax.tree_util.tree_map(
      lambda m: orbax.checkpoint.ArrayRestoreArgs(
          restore_type=jax.Array,
          sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
      ),
      metadata,
  )
  flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

  model = gemma.model_core.GemmaTransformer.from_pretrained(
      flat_params, upcast_activations_to_float32=False
  )

else:
  model = pz.nn.initialize_parameters(
      gemma.model_core.GemmaTransformer.from_config(
          gemma.model_core.GemmaTransformerConfig(
              num_heads=8,
              embedding_dim=256,
              projection_dim=32,
              single_kv_head=False,
              mlp_hidden_dim=512,
              num_decoder_blocks=10,
              vocab_size=1000,
              parameter_dtype=jnp.float32,
              activation_dtype=jnp.float32,
          )
      ),
      jax.random.key(1),
  )



Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download...


  0%|          | 0.00/3.66G [00:00<?, ?B/s]

  0%|          | 5.00M/3.66G [00:00<01:56, 33.7MB/s]

  0%|          | 9.00M/3.66G [00:00<01:53, 34.7MB/s]

  1%|          | 25.0M/3.66G [00:00<01:08, 57.1MB/s]

  1%|          | 41.0M/3.66G [00:00<00:49, 77.9MB/s]

  2%|▏         | 57.0M/3.66G [00:00<00:54, 71.7MB/s]

  2%|▏         | 73.0M/3.66G [00:01<00:49, 77.6MB/s]

  2%|▏         | 89.0M/3.66G [00:01<00:43, 88.8MB/s]

  3%|▎         | 105M/3.66G [00:01<00:37, 102MB/s]  

  3%|▎         | 121M/3.66G [00:01<00:34, 111MB/s]

  4%|▎         | 133M/3.66G [00:01<00:34, 110MB/s]

  4%|▍         | 150M/3.66G [00:01<00:29, 126MB/s]

  4%|▍         | 163M/3.66G [00:01<00:36, 102MB/s]

  5%|▍         | 185M/3.66G [00:02<00:32, 116MB/s]

  5%|▌         | 201M/3.66G [00:02<00:40, 91.2MB/s]

  6%|▌         | 225M/3.66G [00:02<00:41, 88.7MB/s]

  7%|▋         | 249M/3.66G [00:02<00:32, 111MB/s] 

  7%|▋         | 262M/3.66G [00:02<00:35, 104MB/s]

  7%|▋         | 274M/3.66G [00:03<00:39, 92.5MB/s]

  8%|▊         | 297M/3.66G [00:03<00:32, 111MB/s] 

  8%|▊         | 309M/3.66G [00:03<00:36, 99.2MB/s]

  9%|▉         | 329M/3.66G [00:03<00:36, 99.4MB/s]

  9%|▉         | 351M/3.66G [00:03<00:28, 124MB/s] 

 10%|▉         | 370M/3.66G [00:03<00:25, 139MB/s]

 10%|█         | 385M/3.66G [00:04<00:31, 114MB/s]

 11%|█         | 398M/3.66G [00:04<00:29, 118MB/s]

 11%|█         | 417M/3.66G [00:04<00:29, 117MB/s]

 12%|█▏        | 440M/3.66G [00:04<00:24, 144MB/s]

 12%|█▏        | 456M/3.66G [00:05<00:50, 67.9MB/s]

 13%|█▎        | 473M/3.66G [00:05<00:44, 77.7MB/s]

 13%|█▎        | 497M/3.66G [00:05<00:46, 73.2MB/s]

 14%|█▍        | 521M/3.66G [00:05<00:35, 96.6MB/s]

 14%|█▍        | 536M/3.66G [00:05<00:39, 85.8MB/s]

 15%|█▍        | 561M/3.66G [00:06<00:32, 102MB/s] 

 16%|█▌        | 587M/3.66G [00:06<00:25, 130MB/s]

 16%|█▌        | 604M/3.66G [00:06<00:25, 130MB/s]

 17%|█▋        | 620M/3.66G [00:06<00:26, 125MB/s]

 17%|█▋        | 643M/3.66G [00:06<00:21, 149MB/s]

 18%|█▊        | 660M/3.66G [00:06<00:22, 143MB/s]

 18%|█▊        | 679M/3.66G [00:06<00:20, 155MB/s]

 19%|█▊        | 697M/3.66G [00:06<00:20, 154MB/s]

 19%|█▉        | 713M/3.66G [00:07<00:20, 154MB/s]

 19%|█▉        | 729M/3.66G [00:07<00:23, 132MB/s]

 20%|██        | 752M/3.66G [00:07<00:20, 157MB/s]

 21%|██        | 769M/3.66G [00:07<00:30, 103MB/s]

 21%|██        | 791M/3.66G [00:07<00:24, 126MB/s]

 22%|██▏       | 807M/3.66G [00:08<00:30, 100MB/s]

 22%|██▏       | 828M/3.66G [00:08<00:25, 122MB/s]

 23%|██▎       | 849M/3.66G [00:08<00:21, 141MB/s]

 23%|██▎       | 866M/3.66G [00:08<00:28, 107MB/s]

 24%|██▎       | 890M/3.66G [00:08<00:22, 134MB/s]

 24%|██▍       | 912M/3.66G [00:08<00:19, 153MB/s]

 25%|██▍       | 931M/3.66G [00:08<00:18, 162MB/s]

 25%|██▌       | 949M/3.66G [00:09<00:27, 108MB/s]

 26%|██▌       | 969M/3.66G [00:09<00:23, 126MB/s]

 26%|██▋       | 991M/3.66G [00:09<00:19, 147MB/s]

 27%|██▋       | 0.99G/3.66G [00:09<00:24, 119MB/s]

 27%|██▋       | 1.00G/3.66G [00:09<00:27, 105MB/s]

 28%|██▊       | 1.02G/3.66G [00:09<00:21, 130MB/s]

 28%|██▊       | 1.04G/3.66G [00:09<00:20, 137MB/s]

 29%|██▉       | 1.05G/3.66G [00:10<00:21, 130MB/s]

 29%|██▉       | 1.07G/3.66G [00:10<00:26, 104MB/s]

 30%|██▉       | 1.09G/3.66G [00:10<00:23, 119MB/s]

 30%|███       | 1.10G/3.66G [00:10<00:22, 123MB/s]

 30%|███       | 1.12G/3.66G [00:10<00:20, 133MB/s]

 31%|███       | 1.13G/3.66G [00:10<00:22, 123MB/s]

 31%|███▏      | 1.15G/3.66G [00:10<00:19, 141MB/s]

 32%|███▏      | 1.17G/3.66G [00:11<00:21, 126MB/s]

 32%|███▏      | 1.18G/3.66G [00:11<00:18, 140MB/s]

 33%|███▎      | 1.20G/3.66G [00:11<00:18, 147MB/s]

 33%|███▎      | 1.21G/3.66G [00:11<00:22, 119MB/s]

 34%|███▎      | 1.23G/3.66G [00:11<00:21, 119MB/s]

 34%|███▍      | 1.24G/3.66G [00:11<00:20, 128MB/s]

 34%|███▍      | 1.26G/3.66G [00:11<00:22, 117MB/s]

 35%|███▍      | 1.27G/3.66G [00:12<00:27, 94.8MB/s]

 35%|███▌      | 1.29G/3.66G [00:12<00:20, 123MB/s] 

 36%|███▌      | 1.31G/3.66G [00:12<00:26, 95.1MB/s]

 36%|███▌      | 1.33G/3.66G [00:12<00:20, 121MB/s] 

 37%|███▋      | 1.34G/3.66G [00:12<00:29, 85.1MB/s]

 37%|███▋      | 1.36G/3.66G [00:13<00:22, 110MB/s] 

 38%|███▊      | 1.38G/3.66G [00:13<00:26, 92.9MB/s]

 38%|███▊      | 1.40G/3.66G [00:13<00:20, 117MB/s] 

 39%|███▉      | 1.42G/3.66G [00:13<00:17, 141MB/s]

 39%|███▉      | 1.44G/3.66G [00:13<00:20, 118MB/s]

 40%|███▉      | 1.46G/3.66G [00:13<00:17, 138MB/s]

 40%|████      | 1.48G/3.66G [00:14<00:19, 121MB/s]

 41%|████      | 1.50G/3.66G [00:14<00:16, 145MB/s]

 41%|████▏     | 1.52G/3.66G [00:14<00:15, 146MB/s]

 42%|████▏     | 1.53G/3.66G [00:14<00:15, 145MB/s]

 42%|████▏     | 1.55G/3.66G [00:14<00:17, 129MB/s]

 43%|████▎     | 1.57G/3.66G [00:14<00:15, 145MB/s]

 43%|████▎     | 1.59G/3.66G [00:14<00:13, 162MB/s]

 44%|████▍     | 1.60G/3.66G [00:15<00:21, 101MB/s]

 44%|████▍     | 1.62G/3.66G [00:15<00:19, 113MB/s]

 45%|████▍     | 1.64G/3.66G [00:15<00:16, 132MB/s]

 45%|████▌     | 1.65G/3.66G [00:15<00:16, 127MB/s]

 46%|████▌     | 1.67G/3.66G [00:15<00:15, 141MB/s]

 46%|████▋     | 1.69G/3.66G [00:15<00:12, 164MB/s]

 47%|████▋     | 1.71G/3.66G [00:15<00:16, 125MB/s]

 47%|████▋     | 1.73G/3.66G [00:15<00:14, 147MB/s]

 48%|████▊     | 1.75G/3.66G [00:16<00:14, 144MB/s]

 48%|████▊     | 1.77G/3.66G [00:16<00:19, 106MB/s]

 49%|████▊     | 1.78G/3.66G [00:16<00:20, 99.3MB/s]

 49%|████▉     | 1.80G/3.66G [00:16<00:16, 121MB/s] 

 50%|████▉     | 1.82G/3.66G [00:16<00:15, 128MB/s]

 50%|█████     | 1.84G/3.66G [00:16<00:14, 132MB/s]

 51%|█████     | 1.86G/3.66G [00:17<00:13, 149MB/s]

 51%|█████     | 1.88G/3.66G [00:17<00:11, 169MB/s]

 52%|█████▏    | 1.89G/3.66G [00:17<00:11, 163MB/s]

 52%|█████▏    | 1.92G/3.66G [00:17<00:10, 181MB/s]

 53%|█████▎    | 1.93G/3.66G [00:17<00:12, 153MB/s]

 53%|█████▎    | 1.96G/3.66G [00:17<00:10, 171MB/s]

 54%|█████▍    | 1.97G/3.66G [00:17<00:10, 178MB/s]

 54%|█████▍    | 1.99G/3.66G [00:17<00:10, 179MB/s]

 55%|█████▍    | 2.01G/3.66G [00:18<00:12, 146MB/s]

 55%|█████▌    | 2.03G/3.66G [00:18<00:12, 142MB/s]

 56%|█████▌    | 2.04G/3.66G [00:18<00:17, 102MB/s]

 56%|█████▋    | 2.06G/3.66G [00:18<00:13, 128MB/s]

 57%|█████▋    | 2.08G/3.66G [00:18<00:17, 94.6MB/s]

 57%|█████▋    | 2.10G/3.66G [00:18<00:14, 115MB/s] 

 58%|█████▊    | 2.11G/3.66G [00:19<00:13, 122MB/s]

 58%|█████▊    | 2.13G/3.66G [00:19<00:11, 139MB/s]

 59%|█████▊    | 2.15G/3.66G [00:19<00:15, 108MB/s]

 59%|█████▉    | 2.16G/3.66G [00:19<00:13, 120MB/s]

 60%|█████▉    | 2.18G/3.66G [00:19<00:11, 137MB/s]

 60%|█████▉    | 2.20G/3.66G [00:19<00:11, 134MB/s]

 60%|██████    | 2.21G/3.66G [00:19<00:10, 144MB/s]

 61%|██████    | 2.23G/3.66G [00:19<00:09, 154MB/s]

 61%|██████▏   | 2.25G/3.66G [00:20<00:09, 158MB/s]

 62%|██████▏   | 2.26G/3.66G [00:20<00:09, 162MB/s]

 62%|██████▏   | 2.28G/3.66G [00:20<00:12, 122MB/s]

 63%|██████▎   | 2.29G/3.66G [00:20<00:12, 117MB/s]

 63%|██████▎   | 2.31G/3.66G [00:20<00:14, 101MB/s]

 64%|██████▎   | 2.33G/3.66G [00:20<00:11, 127MB/s]

 64%|██████▍   | 2.34G/3.66G [00:21<00:13, 105MB/s]

 64%|██████▍   | 2.36G/3.66G [00:21<00:12, 110MB/s]

 65%|██████▌   | 2.38G/3.66G [00:21<00:10, 136MB/s]

 65%|██████▌   | 2.40G/3.66G [00:21<00:09, 140MB/s]

 66%|██████▌   | 2.41G/3.66G [00:21<00:10, 125MB/s]

 66%|██████▋   | 2.43G/3.66G [00:21<00:10, 128MB/s]

 67%|██████▋   | 2.45G/3.66G [00:21<00:08, 151MB/s]

 67%|██████▋   | 2.47G/3.66G [00:21<00:08, 150MB/s]

 68%|██████▊   | 2.48G/3.66G [00:22<00:08, 156MB/s]

 68%|██████▊   | 2.50G/3.66G [00:22<00:08, 149MB/s]

 69%|██████▊   | 2.52G/3.66G [00:22<00:11, 105MB/s]

 69%|██████▉   | 2.54G/3.66G [00:22<00:09, 128MB/s]

 70%|██████▉   | 2.55G/3.66G [00:22<00:11, 105MB/s]

 70%|███████   | 2.57G/3.66G [00:22<00:09, 126MB/s]

 71%|███████   | 2.59G/3.66G [00:23<00:09, 118MB/s]

 71%|███████   | 2.60G/3.66G [00:23<00:09, 126MB/s]

 71%|███████▏  | 2.62G/3.66G [00:23<00:09, 117MB/s]

 72%|███████▏  | 2.64G/3.66G [00:23<00:07, 139MB/s]

 72%|███████▏  | 2.65G/3.66G [00:23<00:10, 105MB/s]

 73%|███████▎  | 2.67G/3.66G [00:23<00:08, 130MB/s]

 73%|███████▎  | 2.69G/3.66G [00:23<00:08, 116MB/s]

 74%|███████▍  | 2.71G/3.66G [00:24<00:07, 138MB/s]

 74%|███████▍  | 2.73G/3.66G [00:24<00:12, 80.3MB/s]

 75%|███████▍  | 2.75G/3.66G [00:24<00:09, 102MB/s] 

 75%|███████▌  | 2.76G/3.66G [00:24<00:10, 89.4MB/s]

 76%|███████▌  | 2.78G/3.66G [00:24<00:09, 99.1MB/s]

 76%|███████▋  | 2.79G/3.66G [00:25<00:07, 118MB/s] 

 77%|███████▋  | 2.81G/3.66G [00:25<00:07, 118MB/s]

 77%|███████▋  | 2.83G/3.66G [00:25<00:06, 141MB/s]

 78%|███████▊  | 2.85G/3.66G [00:25<00:05, 162MB/s]

 78%|███████▊  | 2.87G/3.66G [00:25<00:06, 132MB/s]

 79%|███████▊  | 2.88G/3.66G [00:25<00:05, 140MB/s]

 79%|███████▉  | 2.90G/3.66G [00:25<00:05, 144MB/s]

 80%|███████▉  | 2.92G/3.66G [00:26<00:07, 113MB/s]

 80%|████████  | 2.93G/3.66G [00:26<00:06, 130MB/s]

 81%|████████  | 2.95G/3.66G [00:26<00:05, 147MB/s]

 81%|████████  | 2.97G/3.66G [00:26<00:04, 150MB/s]

 82%|████████▏ | 2.99G/3.66G [00:26<00:04, 167MB/s]

 82%|████████▏ | 3.01G/3.66G [00:26<00:04, 174MB/s]

 83%|████████▎ | 3.03G/3.66G [00:26<00:05, 128MB/s]

 83%|████████▎ | 3.05G/3.66G [00:26<00:04, 152MB/s]

 84%|████████▎ | 3.07G/3.66G [00:27<00:06, 96.5MB/s]

 84%|████████▍ | 3.09G/3.66G [00:27<00:05, 119MB/s] 

 85%|████████▍ | 3.10G/3.66G [00:27<00:05, 111MB/s]

 85%|████████▌ | 3.12G/3.66G [00:27<00:04, 123MB/s]

 86%|████████▌ | 3.14G/3.66G [00:27<00:04, 139MB/s]

 86%|████████▌ | 3.15G/3.66G [00:27<00:04, 126MB/s]

 87%|████████▋ | 3.17G/3.66G [00:28<00:04, 129MB/s]

 87%|████████▋ | 3.19G/3.66G [00:28<00:05, 95.6MB/s]

 87%|████████▋ | 3.20G/3.66G [00:28<00:04, 106MB/s] 

 88%|████████▊ | 3.22G/3.66G [00:28<00:03, 125MB/s]

 88%|████████▊ | 3.24G/3.66G [00:28<00:03, 138MB/s]

 89%|████████▉ | 3.26G/3.66G [00:28<00:03, 120MB/s]

 89%|████████▉ | 3.28G/3.66G [00:28<00:02, 141MB/s]

 90%|████████▉ | 3.29G/3.66G [00:29<00:02, 142MB/s]

 90%|█████████ | 3.31G/3.66G [00:29<00:02, 163MB/s]

 91%|█████████ | 3.33G/3.66G [00:29<00:02, 161MB/s]

 91%|█████████▏| 3.35G/3.66G [00:29<00:01, 172MB/s]

 92%|█████████▏| 3.37G/3.66G [00:29<00:02, 137MB/s]

 92%|█████████▏| 3.38G/3.66G [00:29<00:02, 111MB/s]

 93%|█████████▎| 3.40G/3.66G [00:29<00:02, 129MB/s]

 93%|█████████▎| 3.42G/3.66G [00:30<00:02, 125MB/s]

 94%|█████████▍| 3.44G/3.66G [00:30<00:01, 139MB/s]

 94%|█████████▍| 3.45G/3.66G [00:30<00:01, 148MB/s]

 95%|█████████▍| 3.47G/3.66G [00:30<00:01, 117MB/s]

 95%|█████████▌| 3.49G/3.66G [00:30<00:01, 103MB/s]

 96%|█████████▌| 3.50G/3.66G [00:30<00:01, 124MB/s]

 96%|█████████▌| 3.52G/3.66G [00:31<00:01, 102MB/s]

 97%|█████████▋| 3.54G/3.66G [00:31<00:01, 129MB/s]

 97%|█████████▋| 3.56G/3.66G [00:31<00:00, 124MB/s]

 97%|█████████▋| 3.57G/3.66G [00:31<00:01, 90.2MB/s]

 98%|█████████▊| 3.59G/3.66G [00:31<00:00, 113MB/s] 

 98%|█████████▊| 3.61G/3.66G [00:31<00:00, 123MB/s]

 99%|█████████▉| 3.62G/3.66G [00:31<00:00, 122MB/s]

 99%|█████████▉| 3.64G/3.66G [00:32<00:00, 138MB/s]

100%|█████████▉| 3.66G/3.66G [00:32<00:00, 104MB/s]

100%|██████████| 3.66G/3.66G [00:32<00:00, 121MB/s]


Extracting model files...


In [6]:
%%autovisualize

with IPython.utils.capture.capture_output() as capturer:
  pz.select(model).at(lambda root: (
      root.body.body.body.sublayers[2].sublayers[0].delta.sublayers[1].input_to_query,
      root.body.body.body.sublayers[2].sublayers[1].delta.sublayers[1],
  )).show_value()

In [7]:
myst_nb.glue(
    "penzai_teaser",
    IPython.display.HTML(
        "".join(output.data['text/html'] for output in capturer.outputs)
    ),
)