In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from pathlib import Path
import time
import sys
from collections import defaultdict 
from typing import Union, TypeVar
import h5py
import traceback

import pandas as pd
import tqdm
import jax
import flax
import chex
from jaxtyping import ArrayLike
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import haiku as hk
import flax.linen as nn
import optax

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import craft_graph_to_model
from tracr.compiler import rasp_to_graph
from tracr.compiler import lib as tracr_lib
from tracr.compiler import assemble
from tracr.transformer import model

from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import dataloading
from decompile_tracr.dataset import config
from decompile_tracr.dataset import compile as comp
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils
from decompile_tracr.sampling.map_primitives import FunctionWithRepr
from decompile_tracr.tokenizing.str_to_rasp import split_list
from decompile_tracr.dataset.compile import get_weights
from decompile_tracr.training.autoencoder import Autoencoder, get_loss_fn

from metamodels_for_rasp.model import TransformerConfig, AddPositionEmbs, Encoder1DBlock, MLPBlock
from metamodels_for_rasp.train import Updater


def _compile(program):
    return compiling.compile_rasp_to_model(
        program,
        vocab=set(range(5)),
        max_seq_len=5,
    )


rng = np.random.default_rng(0)
key = jax.random.key(0)

2024-05-06 14:55:19.669892: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Assume m is shorthand for an AssembledModel, eg
```
m = _compile(program)
```
Then the following all commute:
* m.apply: inputs --> out
* m.forward: emb --> out
* m.input_encoder.encode: inputs --> tokens
* compiled_model.embed: tokens --> emb
* transformer: emb --> out

So overall we have
* inputs --> tokens --> emb --> out
* inputs --> out (via m.apply)
* tokens --> out (via m.foward)
* emb --> out (via transformer)

ETA: that's only approximately true. The methods apply, forward, and transformer have different output types:
* m.apply returns AssembledTransformerOutput
* m.forward returns CompiledTransformerOutput
* transformer returns TransformerOutput (included as attribute in the other two)

In [2]:
p = tokenizer.detokenize(tokenizer.tokenize(sampling.sample(rng, 5)))
m = _compile(p)
d_model = m.params['token_embed']['embeddings'].shape[-1]
print("d_model:", d_model)

d_model: 28


In [3]:
m.input_encoder.encode(['compiler_bos', 4])

[5, 4]

In [4]:
@hk.without_apply_rng
@hk.transform
def embed(tokens):
    compiled_model = m.get_compiled_model()
    return compiled_model.embed(tokens)


e = embed.apply(m.params, np.array([1, 2, 3, 0]))
e.shape

(4, 28)

In [9]:
out = m.apply(["compiler_bos", 1, 2, 3])
type(out)

tracr.compiler.assemble.AssembledTransformerModelOutput

In [18]:
[type(o) for o in out.values()]

[tracr.transformer.model.TransformerOutput, jaxlib.xla_extension.ArrayImpl]

In [40]:
@hk.without_apply_rng
@hk.transform
def forward(tokens: ArrayLike):
    """tokens must be integer arrays"""
    compiled_model = m.get_compiled_model()
    return compiled_model(tokens, use_dropout=False)


out = forward.apply(m.params, np.ones((5, 5), dtype=int))
print(out.keys())
print(out.transformer_output.keys())
print()
print(type(out))
print([type(o) for o in out.values()])

dict_keys(['transformer_output', 'unembedded_output'])
dict_keys(['layer_outputs', 'residuals', 'attn_logits', 'output', 'input_embeddings'])

<class 'tracr.transformer.model.CompiledTransformerModelOutput'>
[<class 'tracr.transformer.model.TransformerOutput'>, <class 'jaxlib.xla_extension.ArrayImpl'>]


In [21]:
@hk.without_apply_rng
@hk.transform
def transformer(embeddings: ArrayLike):
    """embeddings must be float arrays of shape (batch_size, seq_len, d_model)
    """
    compiled_model = m.get_compiled_model()
    return compiled_model.transformer(
        embeddings, jnp.ones(embeddings.shape[:-1]), use_dropout=False)


seq = 4
out = transformer.apply(m.params, np.ones((1, seq, d_model), dtype=float))
out.output.shape
type(out)

tracr.transformer.model.TransformerOutput

In [31]:
out.output.shape

(1, 4, 28)

In [34]:
np.array(out.residuals).shape

(2, 1, 4, 28)