In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from collections import defaultdict 
import jax
import flax
import chex
from jaxtyping import ArrayLike
from typing import Union, TypeVar
import numpy as np
import matplotlib.pyplot as plt

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError

from rasp_tokenizer import tokenizer
from rasp_tokenizer import vocab
from rasp_tokenizer.compiling import COMPILER_BOS
from rasp_generator.utils import sample_test_input
from rasp_generator import sampling, utils, map_primitives
from rasp_tokenizer.utils import RaspFlatDatapoint


rng = np.random.default_rng(0)

In [2]:
from time import time
import os
import pprint
import json
import pickle

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
import flax.linen as nn
import optax
import chex
from typing import Optional
from dataclasses import asdict
import numpy as np
import wandb
import argparse
from etils import epath
from tqdm import tqdm
import orbax.checkpoint
from etils import epath

from nn_utils import schedules
from meta_transformer import preprocessing, module_path, on_cluster, output_dir, interactive
from meta_transformer.meta_model import DecompilerModel, mup_adamw
from meta_transformer.train import Updater, Logger
from meta_transformer.logger_config import setup_logger
from meta_transformer.data import data_iterator
import meta_transformer.utils

from rasp_tokenizer import paths
from rasp_tokenizer import vocab
from rasp_tokenizer.utils import RaspFlatDatapoint
logger = setup_logger(__name__)

In [3]:
def load_data():
    datapath = paths.data_dir / "5000programs.pkl"
    logger.info(f"Loading train/val data from {datapath}.")
    with open(datapath, "rb") as f:
        return pickle.load(f)


def pad_to(x: np.ndarray, max_len: int, pad_value: int = 0):
    """Pad a 1D array to a given length. Not jittable."""
    x = np.array(x)
    assert len(x) <= max_len
    chex.assert_rank(x, 1)
    return np.pad(x, (0, max_len - len(x)), constant_values=pad_value)


def process_single_datapoint(
        x: RaspFlatDatapoint,
        d_model: int,
        max_program_len: int = 32,
        max_weights_len: int = 8192,
    ):
    """Process a single datapoint for model input.
    1) Program tokens: pad to max program length.
    2) Weights: pad to max_weights_len, then chunk.
    """
    if len(x.program) > max_program_len:
        raise ValueError(f"Program length ({len(x.program)}) exceeds "
                         f"max program length ({max_program_len}).")
    elif len(x.weights) > max_weights_len:
        raise ValueError(f"Weights length ({len(x.weights)}) exceeds "
                         f"max weights length ({max_weights_len}).")
    
    program_toks = pad_to(x.program, max_program_len, pad_value=vocab.pad_id)
    weights = pad_to(x.weights, max_weights_len)
    weights = preprocessing.pad_and_chunk(weights, d_model)  # (n_chunks, d_model)
    return {
        "program": program_toks,
        "weights": weights,
    }


def process_data(
        data: set[RaspFlatDatapoint],
        d_model: int,
        max_program_len: int = 32,
        max_weights_len: int = 8192,
    ):
    n = len(data)
    out = dict(program=[], weights=[])
    for x in data:
        x_proc = process_single_datapoint(
            x, d_model, max_program_len, max_weights_len)
        out["program"].append(x_proc["program"])
        out["weights"].append(x_proc["weights"])
    out = {k: np.stack(v) for k, v in out.items()}
    chex.assert_shape(out["program"], (n, max_program_len))
    chex.assert_shape(out["weights"], (n, max_weights_len//d_model, d_model))
    return out

In [4]:
data = load_data()
data = process_data(
    data=data, 
    d_model=128,
)
#
## normalize weights
w_mean, w_std = data["weights"].mean(), data["weights"].std()
data['weights'] = (data['weights'] - w_mean) / w_std

2024-01-26 14:40:37 [INFO]: Loading train/val data from /home/lauro/projects/meta-models/rasp-generator/scripts/data/5000programs.pkl.


In [6]:
data['weights'].shape

(10589, 64, 128)

In [7]:
np.isnan(data['weights']).sum()

0

In [8]:
np.isinf(data['weights']).sum()

0

In [9]:
np.isnan(data['program']).sum()

0

In [10]:
np.isinf(data['program']).sum()

0