In [5]:
import numpy as np
import haiku as hk
import jax

from haiku._src.typing import Mapping
from jax._src.basearray import Array

from helpers import (
    evaluate_transformer_decoder,
    make_chomsky_generator,
    make_model,
    utm_data_generator,
)

In [6]:
seed = 1

ORDERED_TASKS = [
    # Regular.
    "even_pairs",
    "modular_arithmetic",
    "parity_check",
    "cycle_navigation",
    # Context free.
    "stack_manipulation",
    "reverse_string",
    "modular_arithmetic_brackets",
    "solve_equation",
    # Context sensitive.
    "duplicate_string",
    "missing_duplicate_string",
    "odds_first",
    "binary_addition",
    "binary_multiplication",
    "compute_sqrt",
    "bucket_sort",
]

In [None]:
def load_model_params(
    data_generator, params_path: str, batch_size: int = 128
) -> hk.Params:
    """Loads saved model parameters and returns the initialized model and params.

    Args:
        params_path: Path to the saved .npz file containing model parameters
        vocab_size: Size of the vocabulary used by the model

    Returns:
        A tuple containing (model, params) where model is the initialized Haiku model
        and params are the loaded parameters
    """
    # Create the same model configuration as used in training
    model = make_model(data_generator)

    # Load the saved parameters
    loaded = np.load(params_path, allow_pickle=True)
    tree_def = loaded["tree_def"].item()  # Get PyTreeDef
    flat_params = [loaded[f"arr_{i}"] for i in range(len(loaded.files) - 1)]
    loaded_params = jax.tree_util.tree_unflatten(tree_def, flat_params)

    # Initialize the model with a dummy batch to get the parameter structure
    dummy_batch, _ = data_generator.sample_dummy(batch_size)  # Minimal dummy input
    dummy_batch = np.argmax(dummy_batch, axis=-1)

    rng = jax.random.PRNGKey(0)
    model.init(rng, dummy_batch)

    return loaded_params


model_paths = [
    "artifacts/params_markov_transformer_large.npz", 
    # "artifacts/params_original_transformer_large.npz",
    # "artifacts/params_random_initialized_transformer_large.npz",
    # "artifacts/params_markov_transformer_medium.npz", 
    # "artifacts/params_original_transformer_medium.npz",
    # "artifacts/params_random_initialized_transformer_medium.npz",
    # "artifacts/params_markov_transformer_small.npz",
    # "artifacts/params_original_transformer_small.npz",
    # "artifacts/params_random_initialized_transformer_small.npz",
]

rng = np.random.default_rng(seed=seed)
utm_generator = utm_data_generator(rng)

# Load the model and parameters
model_params: Mapping[str, Mapping[str, Mapping[str, Array]]] = {}
for model_path in model_paths:
    params: Mapping[str, Mapping[str, Array]] = load_model_params(utm_generator, model_path)
    model_params[model_path] = params

In [9]:
get_size_from_model_path = lambda path: "large" if "large" in path else "medium" if "medium" in path else "small" if "small" in path else "unknown"
for model_str, params in model_params.items():
    for task in ORDERED_TASKS:
        print("Model: ", model_str)
        try:
            # max_input_length = 256 is arbitrary, I dont think this is explicitly defined in the paper.
            chomsky_generator = make_chomsky_generator(
                rng, task_str=task
            )
            regret, total_accuracy, total_final_accuracy = evaluate_transformer_decoder(
                chomsky_generator, params, utm_generator, num_batches=10, size=get_size_from_model_path(model_str)
            )
            print(task, total_accuracy, total_final_accuracy)
        except Exception as e:
            print("Failed to evaluate task", task, e)


Model:  artifacts/params_markov_transformer_large.npz
even_pairs 0.48148188 0.4815034
Model:  artifacts/params_markov_transformer_large.npz
modular_arithmetic 0.058838494 0.058833234
Model:  artifacts/params_markov_transformer_large.npz
parity_check 0.4728376 0.47282377
Model:  artifacts/params_markov_transformer_large.npz
cycle_navigation 0.3117885 0.3117702
Model:  artifacts/params_markov_transformer_large.npz
stack_manipulation 0.37461048 0.37046084
Model:  artifacts/params_markov_transformer_large.npz
reverse_string 0.48154837 0.48626065
Model:  artifacts/params_markov_transformer_large.npz
modular_arithmetic_brackets 0.07369607 0.073682815
Model:  artifacts/params_markov_transformer_large.npz
solve_equation 0.09740067 0.097401455
Model:  artifacts/params_markov_transformer_large.npz
duplicate_string 0.48207384 0.48070508
Model:  artifacts/params_markov_transformer_large.npz
missing_duplicate_string 0.47244197 0.46932322
Model:  artifacts/params_markov_transformer_large.npz
odds_fi