In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling
import numpy as np
from tracr.compiler.validating import validate
from typing import Union, TypeVar
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from rasp_generator import sampling, utils, map_primitives


def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)


def compile_rasp_to_model(x: rasp.SOp,
                 vocab={1, 2, 3, 4}, 
                 max_seq_len=5, 
                 compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        x,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos,
    )
 
# Apply via:
# out = model.apply([compiler_bos] + [v for v in vocab])
# return out.decoded


TEST_INPUT = [1,2,3,4]

In [2]:
from typing import Set

from tracr.compiler import assemble
from tracr.compiler import basis_inference
from tracr.compiler import craft_graph_to_model
from tracr.compiler import craft_model_to_transformer
from tracr.compiler import expr_to_craft_graph
from tracr.compiler import rasp_to_graph
from tracr.compiler import validating
from tracr.craft import bases
from tracr.rasp import rasp


program = rasp.Aggregate(
    rasp.Select(
        rasp.Map(lambda x: x**2, rasp.tokens),
        rasp.indices,
        rasp.Comparison.GEQ,
    ),
    rasp.tokens,
)

vocab = {1, 2, 3, 4}
max_seq_len = 5
compiler_bos = "BOS"
mlp_exactness = 100


extracted = rasp_to_graph.extract_rasp_graph(program)
graph, sources, sink = extracted.graph, extracted.sources, extracted.sink

basis_inference.infer_bases(
    graph,
    sink,
    vocab,
    max_seq_len,
)

expr_to_craft_graph.add_craft_components_to_rasp_graph(
    graph,
    bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos),
    mlp_exactness=mlp_exactness,
)