In [None]:
import zarr
from zarr import Array, Group
from __future__ import annotations
import numpy as np


store = zarr.storage.ZipStore(
    "./activations_slimpajama_2025-08-12-021021/activations_part_0000.zarr.zip",
    read_only=True,
)
z = zarr.open(store, mode="r")

In [44]:
z.tree()

A single zarr archive is stored stored as a `Group` with three sub-elements -- `activations`, `attention_mask` and `input_ids`. 
- The group `activations` consist of `n_layers` 3D sub-arrays of shape `(batch_size, sequence_length, hidden_dimension)`.
- The data inside the `attention_mask` is particularly important because it determines the actual length of each of the sequences. Since some of them may be shorter than max_sequence_length, the useless pad tokens need to be removed before processing.
- `input_ids` contains the actual input ids used.


In [51]:
def get_seq_length(z: Group, sample_id: int) -> int:
    attention_mask = z["attention_mask"][sample_id]
    return attention_mask.sum().item()

In [None]:
def get_one_latent(
    z: Group,
    sample_id: int,
    layer_id: int,
    position: int,
) -> np.ndarray:
    """Retrieve the latent activation for a specific batch, sample, layer, and position."""  # noqa: E501
    assert position < get_seq_length(z, sample_id), (  # noqa: S101
        f"Position out of bounds. Max position for sample {sample_id} is {get_seq_length(z, sample_id) - 1}."  # noqa: E501
    )

    return z["activations"][f"layer_{layer_id}"][sample_id, position]


In [53]:
get_one_latent(z, 1, 0, 0)

array([ 0.006042, -0.2812  ,  1.086   , ...,  0.703   , -0.3828  ,
       -0.59    ], dtype=float16)

In [55]:
z["activations"]["layer_0"][
    :,
    0,
]  # Accessing the first layer's activations for all samples at position 0

array([[ 0.006042, -0.2812  ,  1.086   , ...,  0.703   , -0.3828  ,
        -0.59    ],
       [ 0.006042, -0.2812  ,  1.086   , ...,  0.703   , -0.3828  ,
        -0.59    ],
       [ 0.006042, -0.2812  ,  1.086   , ...,  0.703   , -0.3828  ,
        -0.59    ],
       ...,
       [ 0.006042, -0.2812  ,  1.086   , ...,  0.703   , -0.3828  ,
        -0.59    ],
       [ 0.006042, -0.2812  ,  1.086   , ...,  0.703   , -0.3828  ,
        -0.59    ],
       [ 0.006042, -0.2812  ,  1.086   , ...,  0.703   , -0.3828  ,
        -0.59    ]], dtype=float16)