In [8]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling
import numpy as np
from tracr.compiler.validating import validate
from typing import Union, TypeVar
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from rasp_generator import sampling, utils, map_primitives


def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)


def compile_rasp_to_model(x: rasp.SOp,
                 vocab={1, 2, 3, 4}, 
                 max_seq_len=5, 
                 compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        x,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos,
    )
 
# Apply via:
# out = model.apply([compiler_bos] + [v for v in vocab])
# return out.decoded


TEST_INPUT = [1,2,3,4]

In [9]:
from tracr.rasp import rasp
from tracr.compiler import compiling, validating

sel = rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.EQ)
sop = rasp.Aggregate(sel, rasp.indices)
program = rasp.Aggregate(sel, sop)

In [10]:
model = compiling.compile_rasp_to_model(program, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")
compiled_output = model.apply(["BOS", 1, 2, 3, 4]).decoded
rasp_output = program([1, 2, 3, 4])


# The output of the compiled model does not match the output of the RASP program:
print(rasp_output)  # [2.0, 3.0, None, None]
print(compiled_output) # ['BOS', 2, 3, 0, 1]

# The validator doesn't catch the error:
print(validating.validate(program, [1, 2, 3, 4])) # []

[2, 3, None, None]
['BOS', 2, 3, 0, 1]
[TracrUnsupportedExpr(expr=<tracr.rasp.rasp.Aggregate object at 0x7f17dc0072b0>, reason='Categorical aggregate does not support Selectors with width > 1 that require aggregation (eg. averaging).')]


In [11]:
from typing import Set

from tracr.compiler import assemble
from tracr.compiler import basis_inference
from tracr.compiler import craft_graph_to_model
from tracr.compiler import craft_model_to_transformer
from tracr.compiler import expr_to_craft_graph
from tracr.compiler import rasp_to_graph
from tracr.compiler import validating
from tracr.craft import bases
from tracr.rasp import rasp


#program = rasp.Aggregate(
#    rasp.Select(
#        rasp.Map(lambda x: x**2, rasp.tokens),
#        rasp.indices,
#        rasp.Comparison.GEQ,
#    ),
#    rasp.tokens,
#)

vocab = {1, 2, 3, 4}
max_seq_len = 5
compiler_bos = "BOS"
mlp_exactness = 100


extracted = rasp_to_graph.extract_rasp_graph(program)
graph, sources, sink = extracted.graph, extracted.sources, extracted.sink

basis_inference.infer_bases(
    graph,
    sink,
    vocab,
    max_seq_len,
)

expr_to_craft_graph.add_craft_components_to_rasp_graph(
    graph,
    bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos),
    mlp_exactness=mlp_exactness,
)

craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources)


In [24]:
test_input_vector

VectorInBasis(basis_directions=[BasisDirection(name='indices', value=0), BasisDirection(name='indices', value=1), BasisDirection(name='indices', value=2), BasisDirection(name='indices', value=3), BasisDirection(name='indices', value=4), BasisDirection(name='rasp_to_craft_integration_test_ONE', value=None), BasisDirection(name='rasp_to_transformer_integration_test_BOS', value=None), BasisDirection(name='tokens', value=1), BasisDirection(name='tokens', value=2), BasisDirection(name='tokens', value=3), BasisDirection(name='tokens', value=4)], magnitudes=array([[0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1.]]))

In [25]:
from tracr.compiler.rasp_to_craft_integration_test import _make_input_space, nodes, _embed_input, _embed_output

categorical_output = rasp.is_categorical(program)
assert categorical_output
test_input = [1, 2, 3, 4]



input_space = _make_input_space(vocab, max_seq_len)
output_space = bases.VectorSpaceWithBasis(
    extracted.sink[nodes.OUTPUT_BASIS])

test_input_vector = _embed_input(test_input, input_space)
test_output = craft_model.apply(test_input_vector).project(output_space)

In [28]:
expected_output = _embed_output(
    output_seq=rasp_output,
    output_space=output_space,
    categorical_output=True)

In [29]:
test_output.magnitudes

array([[8.00000000e-02, 2.80000000e-01, 2.80000000e-01, 2.80000000e-01,
        0.00000000e+00],
       [5.20810637e-44, 8.92818234e-44, 1.00000000e+00, 8.92818234e-44,
        0.00000000e+00],
       [5.20810637e-44, 8.92818234e-44, 8.92818234e-44, 1.00000000e+00,
        0.00000000e+00],
       [2.00000000e-01, 2.00000000e-01, 2.00000000e-01, 2.00000000e-01,
        0.00000000e+00],
       [8.00000000e-02, 2.80000000e-01, 2.80000000e-01, 2.80000000e-01,
        0.00000000e+00]])