Orphaned notebook with values to "glue" using [myst-nb](https://myst-nb.readthedocs.io/en/v0.13.2/use/glue.html).

In [1]:
import os

import myst_nb

import penzai
from penzai import pz

import jax
import jax.numpy as jnp
import orbax.checkpoint
from jax.experimental import mesh_utils

import IPython.utils.capture

In [2]:
pz.ts.register_as_default(streaming=False)
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from penzai.example_models import gemma

In [4]:
try:
  import kagglehub
  if jax.devices()[0].platform == "tpu":
    load_gemma = True
  else:
    load_gemma = False
  load_gemma = True
except ImportError:
  kagglehub = None
  load_gemma = False

In [5]:
if load_gemma:
  weights_dir = kagglehub.model_download('google/gemma/Flax/2b')
  ckpt_path = os.path.join(weights_dir, '2b')

  checkpointer = orbax.checkpoint.PyTreeCheckpointer()
  metadata = checkpointer.metadata(ckpt_path)

  n_devices = jax.local_device_count()
  sharding_devices = mesh_utils.create_device_mesh((n_devices,))
  sharding = jax.sharding.PositionalSharding(sharding_devices)
  restore_args = jax.tree_util.tree_map(
      lambda m: orbax.checkpoint.ArrayRestoreArgs(
          restore_type=jax.Array,
          sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
      ),
      metadata,
  )
  flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

  model = gemma.model_core.GemmaTransformer.from_pretrained(
      flat_params, upcast_activations_to_float32=False
  )

else:
  model = pz.nn.initialize_parameters(
      gemma.model_core.GemmaTransformer.from_config(
          gemma.model_core.GemmaTransformerConfig(
              num_heads=8,
              embedding_dim=256,
              projection_dim=32,
              single_kv_head=False,
              mlp_hidden_dim=512,
              num_decoder_blocks=10,
              vocab_size=1000,
              parameter_dtype=jnp.float32,
              activation_dtype=jnp.float32,
          )
      ),
      jax.random.key(1),
  )

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download...


  0%|          | 0.00/3.66G [00:00<?, ?B/s]

  0%|          | 5.00M/3.66G [00:00<01:22, 47.8MB/s]

  0%|          | 10.0M/3.66G [00:00<02:04, 31.6MB/s]

  0%|          | 17.0M/3.66G [00:00<01:40, 39.1MB/s]

  1%|          | 39.0M/3.66G [00:00<00:41, 94.7MB/s]

  1%|▏         | 50.0M/3.66G [00:00<00:40, 95.4MB/s]

  2%|▏         | 61.0M/3.66G [00:00<00:49, 78.2MB/s]

  2%|▏         | 81.0M/3.66G [00:01<00:39, 96.5MB/s]

  3%|▎         | 102M/3.66G [00:01<00:30, 124MB/s]  

  3%|▎         | 119M/3.66G [00:01<00:27, 137MB/s]

  4%|▎         | 134M/3.66G [00:01<00:37, 100MB/s]

  4%|▍         | 153M/3.66G [00:01<00:33, 113MB/s]

  5%|▍         | 169M/3.66G [00:01<00:30, 122MB/s]

  5%|▍         | 183M/3.66G [00:02<00:42, 88.0MB/s]

  5%|▌         | 201M/3.66G [00:02<00:52, 70.8MB/s]

  6%|▌         | 223M/3.66G [00:02<00:39, 94.6MB/s]

  6%|▋         | 236M/3.66G [00:02<00:36, 101MB/s] 

  7%|▋         | 257M/3.66G [00:02<00:29, 124MB/s]

  7%|▋         | 273M/3.66G [00:02<00:31, 114MB/s]

  8%|▊         | 294M/3.66G [00:03<00:26, 136MB/s]

  8%|▊         | 310M/3.66G [00:03<00:30, 118MB/s]

  9%|▉         | 332M/3.66G [00:03<00:25, 141MB/s]

  9%|▉         | 353M/3.66G [00:03<00:32, 109MB/s]

 10%|█         | 376M/3.66G [00:03<00:26, 133MB/s]

 10%|█         | 392M/3.66G [00:03<00:29, 121MB/s]

 11%|█         | 413M/3.66G [00:03<00:24, 140MB/s]

 11%|█▏        | 429M/3.66G [00:04<00:34, 102MB/s]

 12%|█▏        | 449M/3.66G [00:04<00:37, 92.6MB/s]

 13%|█▎        | 471M/3.66G [00:04<00:29, 115MB/s] 

 13%|█▎        | 486M/3.66G [00:04<00:34, 98.3MB/s]

 14%|█▎        | 508M/3.66G [00:04<00:27, 122MB/s] 

 14%|█▍        | 523M/3.66G [00:05<00:31, 106MB/s]

 14%|█▍        | 542M/3.66G [00:05<00:27, 123MB/s]

 15%|█▍        | 559M/3.66G [00:05<00:24, 135MB/s]

 15%|█▌        | 575M/3.66G [00:05<00:24, 139MB/s]

 16%|█▌        | 590M/3.66G [00:05<00:24, 137MB/s]

 16%|█▌        | 607M/3.66G [00:05<00:22, 147MB/s]

 17%|█▋        | 622M/3.66G [00:05<00:30, 106MB/s]

 17%|█▋        | 643M/3.66G [00:06<00:25, 129MB/s]

 18%|█▊        | 658M/3.66G [00:06<00:24, 132MB/s]

 18%|█▊        | 673M/3.66G [00:06<00:26, 122MB/s]

 18%|█▊        | 692M/3.66G [00:06<00:22, 140MB/s]

 19%|█▉        | 707M/3.66G [00:06<00:22, 142MB/s]

 19%|█▉        | 722M/3.66G [00:06<00:32, 98.1MB/s]

 20%|█▉        | 742M/3.66G [00:06<00:26, 120MB/s] 

 20%|██        | 759M/3.66G [00:07<00:23, 133MB/s]

 21%|██        | 781M/3.66G [00:07<00:19, 156MB/s]

 21%|██▏       | 799M/3.66G [00:07<00:18, 163MB/s]

 22%|██▏       | 819M/3.66G [00:07<00:17, 174MB/s]

 22%|██▏       | 837M/3.66G [00:07<00:17, 173MB/s]

 23%|██▎       | 855M/3.66G [00:07<00:19, 153MB/s]

 23%|██▎       | 871M/3.66G [00:07<00:22, 132MB/s]

 24%|██▎       | 885M/3.66G [00:07<00:24, 121MB/s]

 24%|██▍       | 907M/3.66G [00:08<00:20, 146MB/s]

 25%|██▍       | 927M/3.66G [00:08<00:18, 161MB/s]

 25%|██▌       | 944M/3.66G [00:08<00:23, 125MB/s]

 26%|██▌       | 965M/3.66G [00:08<00:20, 145MB/s]

 26%|██▌       | 981M/3.66G [00:08<00:26, 110MB/s]

 27%|██▋       | 995M/3.66G [00:08<00:26, 109MB/s]

 27%|██▋       | 0.98G/3.66G [00:08<00:25, 111MB/s]

 27%|██▋       | 1.00G/3.66G [00:09<00:22, 124MB/s]

 28%|██▊       | 1.02G/3.66G [00:09<00:23, 123MB/s]

 28%|██▊       | 1.03G/3.66G [00:09<00:20, 136MB/s]

 29%|██▊       | 1.05G/3.66G [00:09<00:20, 136MB/s]

 29%|██▉       | 1.06G/3.66G [00:09<00:19, 140MB/s]

 29%|██▉       | 1.08G/3.66G [00:09<00:23, 117MB/s]

 30%|██▉       | 1.10G/3.66G [00:09<00:20, 135MB/s]

 30%|███       | 1.11G/3.66G [00:10<00:22, 120MB/s]

 31%|███       | 1.12G/3.66G [00:10<00:26, 102MB/s]

 31%|███▏      | 1.15G/3.66G [00:10<00:19, 136MB/s]

 32%|███▏      | 1.16G/3.66G [00:10<00:22, 122MB/s]

 32%|███▏      | 1.18G/3.66G [00:10<00:22, 117MB/s]

 33%|███▎      | 1.20G/3.66G [00:10<00:19, 135MB/s]

 33%|███▎      | 1.21G/3.66G [00:10<00:22, 117MB/s]

 33%|███▎      | 1.22G/3.66G [00:11<00:23, 111MB/s]

 34%|███▎      | 1.24G/3.66G [00:11<00:22, 115MB/s]

 34%|███▍      | 1.26G/3.66G [00:11<00:18, 140MB/s]

 35%|███▍      | 1.27G/3.66G [00:11<00:18, 136MB/s]

 35%|███▌      | 1.29G/3.66G [00:11<00:17, 144MB/s]

 36%|███▌      | 1.31G/3.66G [00:11<00:15, 161MB/s]

 36%|███▌      | 1.33G/3.66G [00:11<00:15, 159MB/s]

 37%|███▋      | 1.34G/3.66G [00:11<00:16, 154MB/s]

 37%|███▋      | 1.36G/3.66G [00:12<00:29, 83.1MB/s]

 38%|███▊      | 1.38G/3.66G [00:12<00:22, 108MB/s] 

 38%|███▊      | 1.39G/3.66G [00:12<00:23, 104MB/s]

 39%|███▊      | 1.42G/3.66G [00:12<00:19, 121MB/s]

 39%|███▉      | 1.43G/3.66G [00:13<00:27, 86.2MB/s]

 40%|███▉      | 1.45G/3.66G [00:13<00:21, 112MB/s] 

 40%|████      | 1.47G/3.66G [00:13<00:19, 119MB/s]

 40%|████      | 1.48G/3.66G [00:13<00:18, 129MB/s]

 41%|████      | 1.50G/3.66G [00:13<00:17, 136MB/s]

 41%|████      | 1.51G/3.66G [00:13<00:17, 129MB/s]

 42%|████▏     | 1.52G/3.66G [00:13<00:18, 123MB/s]

 42%|████▏     | 1.55G/3.66G [00:13<00:15, 149MB/s]

 43%|████▎     | 1.56G/3.66G [00:13<00:16, 135MB/s]

 43%|████▎     | 1.58G/3.66G [00:14<00:15, 146MB/s]

 44%|████▎     | 1.60G/3.66G [00:14<00:13, 167MB/s]

 44%|████▍     | 1.62G/3.66G [00:14<00:16, 133MB/s]

 45%|████▍     | 1.64G/3.66G [00:14<00:14, 155MB/s]

 45%|████▌     | 1.65G/3.66G [00:14<00:17, 124MB/s]

 46%|████▌     | 1.67G/3.66G [00:14<00:14, 145MB/s]

 46%|████▌     | 1.69G/3.66G [00:15<00:29, 71.5MB/s]

 47%|████▋     | 1.71G/3.66G [00:15<00:23, 88.8MB/s]

 47%|████▋     | 1.72G/3.66G [00:15<00:21, 97.2MB/s]

 47%|████▋     | 1.74G/3.66G [00:15<00:18, 112MB/s] 

 48%|████▊     | 1.75G/3.66G [00:16<00:26, 77.6MB/s]

 48%|████▊     | 1.77G/3.66G [00:16<00:21, 93.7MB/s]

 49%|████▉     | 1.79G/3.66G [00:16<00:20, 97.6MB/s]

 49%|████▉     | 1.81G/3.66G [00:16<00:20, 99.5MB/s]

 50%|████▉     | 1.82G/3.66G [00:16<00:17, 114MB/s] 

 50%|█████     | 1.83G/3.66G [00:16<00:20, 96.1MB/s]

 51%|█████     | 1.85G/3.66G [00:16<00:17, 110MB/s] 

 51%|█████     | 1.86G/3.66G [00:16<00:16, 117MB/s]

 51%|█████     | 1.88G/3.66G [00:17<00:19, 99.4MB/s]

 52%|█████▏    | 1.90G/3.66G [00:17<00:14, 129MB/s] 

 52%|█████▏    | 1.92G/3.66G [00:17<00:13, 142MB/s]

 53%|█████▎    | 1.93G/3.66G [00:17<00:12, 147MB/s]

 53%|█████▎    | 1.95G/3.66G [00:17<00:12, 153MB/s]

 54%|█████▎    | 1.96G/3.66G [00:17<00:15, 118MB/s]

 54%|█████▍    | 1.98G/3.66G [00:17<00:14, 123MB/s]

 54%|█████▍    | 2.00G/3.66G [00:18<00:13, 136MB/s]

 55%|█████▍    | 2.01G/3.66G [00:18<00:18, 96.9MB/s]

 55%|█████▌    | 2.03G/3.66G [00:18<00:14, 123MB/s] 

 56%|█████▌    | 2.05G/3.66G [00:18<00:17, 98.9MB/s]

 56%|█████▌    | 2.06G/3.66G [00:18<00:21, 78.9MB/s]

 57%|█████▋    | 2.08G/3.66G [00:19<00:20, 83.4MB/s]

 57%|█████▋    | 2.09G/3.66G [00:19<00:19, 86.9MB/s]

 58%|█████▊    | 2.12G/3.66G [00:19<00:15, 109MB/s] 

 58%|█████▊    | 2.13G/3.66G [00:19<00:19, 84.1MB/s]

 59%|█████▊    | 2.15G/3.66G [00:19<00:14, 109MB/s] 

 59%|█████▉    | 2.17G/3.66G [00:19<00:12, 129MB/s]

 60%|█████▉    | 2.19G/3.66G [00:20<00:13, 120MB/s]

 60%|██████    | 2.20G/3.66G [00:20<00:12, 126MB/s]

 60%|██████    | 2.21G/3.66G [00:20<00:15, 98.4MB/s]

 61%|██████    | 2.23G/3.66G [00:20<00:12, 125MB/s] 

 62%|██████▏   | 2.26G/3.66G [00:20<00:09, 153MB/s]

 62%|██████▏   | 2.28G/3.66G [00:20<00:08, 168MB/s]

 63%|██████▎   | 2.30G/3.66G [00:20<00:10, 134MB/s]

 63%|██████▎   | 2.32G/3.66G [00:21<00:09, 154MB/s]

 64%|██████▍   | 2.34G/3.66G [00:21<00:10, 137MB/s]

 64%|██████▍   | 2.36G/3.66G [00:21<00:08, 157MB/s]

 65%|██████▍   | 2.38G/3.66G [00:21<00:07, 175MB/s]

 66%|██████▌   | 2.40G/3.66G [00:21<00:13, 103MB/s]

 66%|██████▌   | 2.42G/3.66G [00:21<00:10, 126MB/s]

 67%|██████▋   | 2.44G/3.66G [00:22<00:11, 117MB/s]

 67%|██████▋   | 2.46G/3.66G [00:22<00:09, 138MB/s]

 68%|██████▊   | 2.48G/3.66G [00:22<00:08, 145MB/s]

 68%|██████▊   | 2.49G/3.66G [00:22<00:08, 151MB/s]

 69%|██████▊   | 2.51G/3.66G [00:22<00:08, 147MB/s]

 69%|██████▉   | 2.53G/3.66G [00:22<00:07, 164MB/s]

 70%|██████▉   | 2.55G/3.66G [00:22<00:08, 144MB/s]

 70%|██████▉   | 2.56G/3.66G [00:23<00:11, 107MB/s]

 70%|███████   | 2.58G/3.66G [00:23<00:10, 112MB/s]

 71%|███████   | 2.59G/3.66G [00:23<00:08, 128MB/s]

 71%|███████   | 2.61G/3.66G [00:23<00:09, 124MB/s]

 72%|███████▏  | 2.62G/3.66G [00:23<00:08, 128MB/s]

 72%|███████▏  | 2.65G/3.66G [00:23<00:06, 160MB/s]

 73%|███████▎  | 2.66G/3.66G [00:23<00:07, 138MB/s]

 73%|███████▎  | 2.68G/3.66G [00:23<00:06, 159MB/s]

 74%|███████▎  | 2.70G/3.66G [00:24<00:07, 133MB/s]

 74%|███████▍  | 2.72G/3.66G [00:24<00:06, 152MB/s]

 75%|███████▍  | 2.74G/3.66G [00:24<00:10, 95.6MB/s]

 75%|███████▌  | 2.76G/3.66G [00:24<00:08, 117MB/s] 

 76%|███████▌  | 2.77G/3.66G [00:24<00:09, 98.4MB/s]

 76%|███████▋  | 2.80G/3.66G [00:25<00:07, 123MB/s] 

 77%|███████▋  | 2.81G/3.66G [00:25<00:06, 132MB/s]

 77%|███████▋  | 2.83G/3.66G [00:25<00:06, 144MB/s]

 78%|███████▊  | 2.85G/3.66G [00:25<00:09, 94.8MB/s]

 78%|███████▊  | 2.86G/3.66G [00:25<00:08, 96.2MB/s]

 78%|███████▊  | 2.87G/3.66G [00:25<00:08, 105MB/s] 

 79%|███████▉  | 2.89G/3.66G [00:25<00:07, 117MB/s]

 79%|███████▉  | 2.90G/3.66G [00:26<00:06, 123MB/s]

 80%|███████▉  | 2.92G/3.66G [00:26<00:05, 143MB/s]

 80%|████████  | 2.94G/3.66G [00:26<00:06, 122MB/s]

 81%|████████  | 2.96G/3.66G [00:26<00:05, 146MB/s]

 81%|████████  | 2.97G/3.66G [00:26<00:06, 109MB/s]

 82%|████████▏ | 2.99G/3.66G [00:26<00:05, 121MB/s]

 82%|████████▏ | 3.01G/3.66G [00:27<00:06, 103MB/s]

 83%|████████▎ | 3.03G/3.66G [00:27<00:05, 123MB/s]

 83%|████████▎ | 3.04G/3.66G [00:27<00:05, 120MB/s]

 83%|████████▎ | 3.06G/3.66G [00:27<00:05, 129MB/s]

 84%|████████▍ | 3.07G/3.66G [00:27<00:05, 114MB/s]

 84%|████████▍ | 3.09G/3.66G [00:27<00:05, 115MB/s]

 85%|████████▍ | 3.11G/3.66G [00:27<00:04, 138MB/s]

 85%|████████▌ | 3.12G/3.66G [00:27<00:04, 131MB/s]

 86%|████████▌ | 3.14G/3.66G [00:28<00:04, 132MB/s]

 86%|████████▌ | 3.15G/3.66G [00:28<00:04, 118MB/s]

 86%|████████▋ | 3.17G/3.66G [00:28<00:04, 109MB/s]

 87%|████████▋ | 3.18G/3.66G [00:28<00:04, 117MB/s]

 87%|████████▋ | 3.20G/3.66G [00:28<00:03, 128MB/s]

 88%|████████▊ | 3.22G/3.66G [00:28<00:03, 136MB/s]

 88%|████████▊ | 3.24G/3.66G [00:28<00:02, 156MB/s]

 89%|████████▉ | 3.26G/3.66G [00:29<00:03, 134MB/s]

 89%|████████▉ | 3.27G/3.66G [00:29<00:03, 121MB/s]

 90%|████████▉ | 3.28G/3.66G [00:29<00:04, 98.5MB/s]

 90%|████████▉ | 3.29G/3.66G [00:29<00:03, 103MB/s] 

 90%|█████████ | 3.30G/3.66G [00:29<00:03, 103MB/s]

 91%|█████████ | 3.32G/3.66G [00:29<00:02, 124MB/s]

 91%|█████████ | 3.33G/3.66G [00:29<00:03, 109MB/s]

 92%|█████████▏| 3.35G/3.66G [00:30<00:03, 87.5MB/s]

 92%|█████████▏| 3.37G/3.66G [00:30<00:02, 114MB/s] 

 92%|█████████▏| 3.39G/3.66G [00:30<00:02, 112MB/s]

 93%|█████████▎| 3.41G/3.66G [00:30<00:02, 131MB/s]

 93%|█████████▎| 3.42G/3.66G [00:30<00:02, 127MB/s]

 94%|█████████▍| 3.43G/3.66G [00:30<00:02, 103MB/s]

 94%|█████████▍| 3.45G/3.66G [00:31<00:02, 85.3MB/s]

 94%|█████████▍| 3.46G/3.66G [00:31<00:02, 94.1MB/s]

 95%|█████████▌| 3.48G/3.66G [00:31<00:01, 123MB/s] 

 95%|█████████▌| 3.49G/3.66G [00:31<00:02, 71.8MB/s]

 96%|█████████▌| 3.52G/3.66G [00:31<00:01, 91.4MB/s]

 96%|█████████▋| 3.53G/3.66G [00:32<00:01, 100MB/s] 

 97%|█████████▋| 3.54G/3.66G [00:32<00:01, 91.9MB/s]

 97%|█████████▋| 3.56G/3.66G [00:32<00:00, 112MB/s] 

 98%|█████████▊| 3.58G/3.66G [00:32<00:00, 122MB/s]

 98%|█████████▊| 3.59G/3.66G [00:32<00:00, 119MB/s]

 98%|█████████▊| 3.61G/3.66G [00:32<00:00, 123MB/s]

 99%|█████████▉| 3.62G/3.66G [00:32<00:00, 105MB/s]

 99%|█████████▉| 3.63G/3.66G [00:33<00:00, 93.9MB/s]

 99%|█████████▉| 3.64G/3.66G [00:33<00:00, 73.3MB/s]

100%|█████████▉| 3.66G/3.66G [00:33<00:00, 100MB/s] 

100%|██████████| 3.66G/3.66G [00:33<00:00, 118MB/s]


Extracting model files...


In [6]:
%%autovisualize

with IPython.utils.capture.capture_output() as capturer:
  pz.select(model).at(lambda root: (
      root.body.body.body.sublayers[2].sublayers[0].delta.sublayers[1].input_to_query,
      root.body.body.body.sublayers[2].sublayers[1].delta.sublayers[1],
  )).show_value()

In [7]:
myst_nb.glue(
    "penzai_teaser",
    IPython.display.HTML(
        "".join(output.data['text/html'] for output in capturer.outputs)
    ),
)