In [12]:
import numpy as np
import haiku as hk
import jax
import jax.numpy as jnp
import functools

from haiku._src.typing import Mapping
from jax._src.basearray import Array

from helpers import (
    evaluate_transformer_decoder,
    make_chomsky_generator,
    make_model,
    utm_data_generator,
)

In [14]:
seed = 1

ORDERED_TASKS = [
    # Regular.
    "even_pairs",
    "modular_arithmetic",
    "parity_check",
    "cycle_navigation",
    # Context free.
    "stack_manipulation",
    "reverse_string",
    "modular_arithmetic_brackets",
    "solve_equation",
    # Context sensitive.
    "duplicate_string",
    "missing_duplicate_string",
    "odds_first",
    "binary_addition",
    "binary_multiplication",
    "compute_sqrt",
    "bucket_sort",
]

In [19]:
def load_model_params(
    data_generator, params_path: str, batch_size: int = 128
) -> hk.Params:
    """Loads saved model parameters and returns the initialized model and params.

    Args:
        params_path: Path to the saved .npz file containing model parameters
        vocab_size: Size of the vocabulary used by the model

    Returns:
        A tuple containing (model, params) where model is the initialized Haiku model
        and params are the loaded parameters
    """
    # Create the same model configuration as used in training
    model = make_model(data_generator)

    # Load the saved parameters
    loaded = np.load(params_path, allow_pickle=True)
    tree_def = loaded["tree_def"].item()  # Get PyTreeDef
    flat_params = [loaded[f"arr_{i}"] for i in range(len(loaded.files) - 1)]
    loaded_params = jax.tree_util.tree_unflatten(tree_def, flat_params)

    # Initialize the model with a dummy batch to get the parameter structure
    dummy_batch, _ = data_generator.sample_dummy(batch_size)  # Minimal dummy input
    dummy_batch = np.argmax(dummy_batch, axis=-1)

    rng = jax.random.PRNGKey(0)
    model.init(rng, dummy_batch)

    return loaded_params


model_paths = [
    "artifacts/params_markov_old.npz",
    "artifacts/params_original_old.npz",
    # "params_markov_transformer_large.npz", 
    # "params_original_transformer_large.npz",
    # "params_random_initialized_transformer_large.npz",
    # "params_markov_transformer_medium.npz", 
    # "params_original_transformer_medium.npz",
    # "params_random_initialized_transformer_medium.npz",
    # "params_markov_transformer_small.npz",
    # "params_original_transformer_small.npz",
    # "params_random_initialized_transformer_small.npz",
]

rng = np.random.default_rng(seed=seed)
utm_generator = utm_data_generator(rng)

# Load the model and parameters
model_params: Mapping[str, Mapping[str, Mapping[str, Array]]] = {}
for model_path in model_paths:
    params: Mapping[str, Mapping[str, Array]] = load_model_params(utm_generator, model_path)
    model_params[model_path] = params

In [20]:
for task in ORDERED_TASKS:
    print("Chomsky Task:", task)
    try:
        # max_input_length = 256 is arbitrary, I dont think this is explicitly defined in the paper.
        chomsky_generator = make_chomsky_generator(
            rng, task_str=task
        )
        regret, total_accuracy, total_final_accuracy = evaluate_transformer_decoder(
            chomsky_generator, params, utm_generator
        )
        print(task, total_accuracy, total_final_accuracy)
    except Exception as e:
        print("Failed to evaluate task", task, e)


Chomsky Task: even_pairs
even_pairs 0.42429695 0.42429194
Chomsky Task: modular_arithmetic
Failed to evaluate task modular_arithmetic all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 7 and the array at index 1 has size 11
Chomsky Task: parity_check
parity_check 0.42754906 0.42755023
Chomsky Task: cycle_navigation
Failed to evaluate task cycle_navigation all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 5 and the array at index 1 has size 7
Chomsky Task: stack_manipulation
Failed to evaluate task stack_manipulation all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 5 and the array at index 1 has size 7
Chomsky Task: reverse_string
reverse_string 0.43992084 0.44297442
Chomsky Task: modular_arithmetic_brackets
Failed to evaluate 