In [6]:
import numpy as np
import haiku as hk
import jax
import json

from collections import defaultdict
from haiku._src.typing import Mapping
from jax._src.basearray import Array

from helpers import (
    evaluate_transformer_decoder,
    make_chomsky_generator,
    make_model,
    utm_data_generator,
)

In [7]:
seed = 1

ORDERED_TASKS = [
    # Regular.
    "even_pairs",
    "modular_arithmetic",
    "parity_check",
    "cycle_navigation",
    # Context free.
    "stack_manipulation",
    "reverse_string",
    "modular_arithmetic_brackets",
    "solve_equation",
    # Context sensitive.
    "duplicate_string",
    "missing_duplicate_string",
    "odds_first",
    "binary_addition",
    "binary_multiplication",
    "compute_sqrt",
    "bucket_sort",
]

In [8]:
def load_model_params(
    data_generator, params_path: str, batch_size: int = 128
) -> hk.Params:
    """Loads saved model parameters and returns the initialized model and params.

    Args:
        params_path: Path to the saved .npz file containing model parameters
        vocab_size: Size of the vocabulary used by the model

    Returns:
        A tuple containing (model, params) where model is the initialized Haiku model
        and params are the loaded parameters
    """
    # Create the same model configuration as used in training
    model = make_model(data_generator)

    # Load the saved parameters
    loaded = np.load(params_path, allow_pickle=True)
    tree_def = loaded["tree_def"].item()  # Get PyTreeDef
    flat_params = [loaded[f"arr_{i}"] for i in range(len(loaded.files) - 1)]
    loaded_params = jax.tree_util.tree_unflatten(tree_def, flat_params)

    # Initialize the model with a dummy batch to get the parameter structure
    dummy_batch, _ = data_generator.sample_dummy(batch_size)  # Minimal dummy input
    dummy_batch = np.argmax(dummy_batch, axis=-1)

    rng = jax.random.PRNGKey(0)
    model.init(rng, dummy_batch)

    return loaded_params


model_paths = [
    "artifacts/params_markov_transformer_large.npz", 
    "artifacts/params_original_transformer_large.npz",
    # "artifacts/params_random_initialized_transformer_large.npz",
    "artifacts/params_markov_transformer_medium.npz", 
    "artifacts/params_original_transformer_medium.npz",
    # "artifacts/params_random_initialized_transformer_medium.npz",
    "artifacts/params_markov_transformer_small.npz",
    "artifacts/params_original_transformer_small.npz",
    # "artifacts/params_random_initialized_transformer_small.npz",
]

rng = np.random.default_rng(seed=seed)
utm_generator = utm_data_generator(rng)

# Load the model and parameters
model_params: Mapping[str, Mapping[str, Mapping[str, Array]]] = {}
for model_path in model_paths:
    params: Mapping[str, Mapping[str, Array]] = load_model_params(utm_generator, model_path)
    model_params[model_path] = params

In [9]:
get_size_from_model_path_fn = lambda path: "large" if "large" in path else "medium" if "medium" in path else "small" if "small" in path else "unknown"
get_markov_from_model_path_fn = lambda path: "markov" if "markov" in path else "original" if "original" in path else "unknown"

model_task_results = defaultdict(dict)# {model_str: {task: (regret, total_accuracy, total_final_accuracy)}}
for model_str, params in model_params.items():
    for task in ORDERED_TASKS:
        print("Model: ", model_str)
        try:
            # max_input_length = 256 is arbitrary, I dont think this is explicitly defined in the paper.
            chomsky_generator = make_chomsky_generator(
                rng, task_str=task
            )
            regret, total_accuracy, total_final_accuracy = evaluate_transformer_decoder(
                chomsky_generator, params, utm_generator, num_batches=10, size=get_size_from_model_path_fn(model_str)
            )
            model_task_results[model_str][task] = {
                "regret": regret, "total_accuracy": total_accuracy, "total_final_accuracy": total_final_accuracy
            }
            print(task, total_accuracy, total_final_accuracy)
        except Exception as e:
            print("Failed to evaluate task", task, e)


Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  even_pairs
even_pairs 0.48771706 0.48606926
Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  modular_arithmetic
modular_arithmetic 0.10822785 0.108205654
Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  parity_check
parity_check 0.49603423 0.5011854
Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  cycle_navigation
cycle_navigation 0.31283465 0.31278235
Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  stack_manipulation
stack_manipulation 0.4173534 0.44632658
Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  reverse_string
reverse_string 0.48629132 0.4766636
Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  modular_arithmetic_brackets
modular_arithmetic_brackets 0.106273636 0.10625094
Model:  artifacts/params_markov_transformer_large.npz
Chomsky Task:  solve_equation
solve_equation 0.14586955 0.14582935

In [10]:
results_dict = {
    model: {
        task: {
            name: float(result) for name ,result in results.items()
        } for task, results in tasks.items()
    } 
    for model, tasks in model_task_results.items()
}

output_file = f"artifacts/chomsky_results.json"

with open(output_file, 'w') as f:
    json.dump(results_dict, f, indent=4)

print(f"Results saved to {output_file}")

Results saved to artifacts/chomsky_results.json
