Orphaned notebook with values to "glue" using [myst-nb](https://myst-nb.readthedocs.io/en/v0.13.2/use/glue.html).

In [None]:
import os
import functools

import myst_nb

import treescope

import jax
import jax.numpy as jnp

import IPython.utils.capture

from penzai import pz

In [None]:
from penzai.models.transformer.variants import llamalike_common

model = llamalike_common.build_llamalike_transformer(
    llamalike_common.LlamalikeTransformerConfig(
        num_kv_heads=8,
        query_head_multiplier=1,
        embedding_dim=256,
        projection_dim=32,
        mlp_hidden_dim=512,
        num_decoder_blocks=10,
        vocab_size=1000,
        mlp_variant="geglu_approx",
        rope_wavelength=10_000,
        tie_embedder_and_logits=True,
        use_layer_stack=False,
        parameter_dtype=jnp.float32,
        activation_dtype=jnp.float32,
    ),
    init_base_rng=jax.random.key(42),
)

_, params = pz.unbind_params(model, freeze=True)
params = pz.select(params).at(lambda root: root[0].value).apply(
    lambda x: x.order_as("embedding", "vocabulary")
)

In [None]:
nested_params = {}
for param in params:
  label_parts = param.label.split("/")
  current = nested_params
  for part in label_parts[:-1]:
    if part not in current:
      current[part] = {}
    current = current[part]
  current[label_parts[-1]] = param.value.data_array

In [None]:
myst_nb.glue(
    "treescope_before",
    nested_params,
)

In [None]:
with IPython.utils.capture.capture_output() as capturer:
  with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
    with treescope.using_expansion_strategy(max_height=30):
      treescope.display(nested_params)

In [None]:
myst_nb.glue(
    "treescope_after",
    IPython.display.HTML(
        "".join(output.data['text/html'] for output in capturer.outputs)
    ),
)

In [None]:
with IPython.utils.capture.capture_output() as capturer:
  with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
    with treescope.using_expansion_strategy(max_height=30):
      treescope.display(model)

In [None]:
myst_nb.glue(
    "treescope_penzai",
    IPython.display.HTML(
        "".join(output.data['text/html'] for output in capturer.outputs)
    ),
)