In [1]:
import numpy as np
import haiku as hk
import jax
import jax.numpy as jnp
from helpers import (
    evaluate_transformer_decoder,
    make_chomsky_generator,
    make_model,
    utm_data_generator,
)
from models import transformer
from data import utm_data_generator as utm_dg_lib
from data import utms as utms_lib
import functools

In [2]:
ORDERED_TASKS = [
    # Regular.
    "even_pairs",
    "modular_arithmetic",
    "parity_check",
    "cycle_navigation",
    # Context free.
    "stack_manipulation",
    "reverse_string",
    "modular_arithmetic_brackets",
    "solve_equation",
    # Context sensitive.
    "duplicate_string",
    "missing_duplicate_string",
    "odds_first",
    "binary_addition",
    "binary_multiplication",
    "compute_sqrt",
    "bucket_sort",
]


In [3]:
from haiku._src.typing import Mapping
from jax._src.basearray import Array


def load_model_params(
    data_generator, params_path: str, vocab_size: int, batch_size: int = 128
) -> hk.Params:
    """Loads saved model parameters and returns the initialized model and params.

    Args:
        params_path: Path to the saved .npz file containing model parameters
        vocab_size: Size of the vocabulary used by the model

    Returns:
        A tuple containing (model, params) where model is the initialized Haiku model
        and params are the loaded parameters
    """
    # Create the same model configuration as used in training
    model = make_model(data_generator)

    # Load the saved parameters
    loaded = np.load("params.npz", allow_pickle=True)
    tree_def = loaded["tree_def"].item()  # Get PyTreeDef
    flat_params = [loaded[f"arr_{i}"] for i in range(len(loaded.files) - 1)]
    loaded_params = jax.tree_util.tree_unflatten(tree_def, flat_params)

    # Initialize the model with a dummy batch to get the parameter structure
    dummy_batch, _ = data_generator.sample_dummy(128)  # Minimal dummy input
    dummy_batch = np.argmax(dummy_batch, axis=-1)

    rng = jax.random.PRNGKey(0)
    model.init(rng, dummy_batch)

    return loaded_params


    # Example usage
rng = np.random.default_rng(seed=1)
data_generator = utm_data_generator(rng)
# Load the model and parameters
params: Mapping[str, Mapping[str, Array]] = load_model_params(
    data_generator, "params_markov.npz", data_generator.feature_size
)




In [11]:
for task in ORDERED_TASKS:
    print("Task:", task)
    try:
        chomsky_generator = make_chomsky_generator(
            rng, use_delimiters=False, max_input_length=20, task_str=task
        )
        regret, total_accuracy, total_final_accuracy = evaluate_transformer_decoder(
            chomsky_generator, params, data_generator
        )
        print(task, total_accuracy, total_final_accuracy)
    except Exception as e:
        print("Failed to evaluate task", task, e)


Task: even_pairs
even_pairs 0.42573318 0.42573148
Task: modular_arithmetic
Failed to evaluate task modular_arithmetic all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 9 and the array at index 1 has size 5
Task: parity_check
parity_check 0.41762608 0.41762835
Task: cycle_navigation
Failed to evaluate task cycle_navigation all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 3 and the array at index 1 has size 5
Task: stack_manipulation
Failed to evaluate task stack_manipulation all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 5 and the array at index 1 has size 3
Task: reverse_string
reverse_string 0.43242103 0.4310565
Task: modular_arithmetic_brackets
Failed to evaluate task modular_arithmetic_brackets all the input array dimen