In [6]:
from collections import defaultdict
import os
import inspect
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling, rasp_to_graph
import numpy as np
from tracr.compiler.validating import validate
from rasp_generator import map_primitives, sampling, utils
import networkx as nx
from pympler import asizeof
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 

rng = np.random.default_rng()

def compile_rasp_to_model(sop: rasp.SOp, vocab={0,1,2,3,4}, max_seq_len=5, compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        sop,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos
    )


TEST_INPUTS = [
    [1,2,3,4],
    [0,1,2,3,4],
    [1,1,1,1],
    [0,0,0,0],
    [1,0,1,0],
    [2,4,2,1],
    [4,4,4,3],
]

In [35]:
sampler = sampling.ProgramSampler(validate_compilation=False, disable_categorical_aggregate=False)
errs = sampler.sample(n_sops=30)
print("Retries encountered:")
print()
for err in errs:
    print(err)
print()
print("Program outputs:")
for x in TEST_INPUTS:
    print(sampler.program(x))

Retries encountered:

SamplingError("Could not sample categorical Aggregate with valid output domain (Maximum retries reached). This because the sampler couldn't find a selector with width 1, and other sampled selectors don't result in an output domain that is a subset of the input domain.") (tried to sample categorical_aggregate)
EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
SamplingError("Could not sample categorical Aggregate with valid output domain (Maximum retries reached). This because the sampler couldn't find a selector with width 1, and other sampled selectors don't result in an output domain that is a subset of the input domain.") (tried to sample categorical_aggregate)

Program outputs:
[0, False, 0.0, 0.0]
[False, 0.0, 0.0, 0.0, 0.0]
[0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0]
[0, False, 0.0, 0.0]

In [36]:
sampler = sampling.ProgramSampler(validate_compilation=False, disable_categorical_aggregate=False)
errs = sampler.sample(n_sops=30)

In [37]:
model = compile_rasp_to_model(sampler.program)

In [20]:
# print program with and without test input
k = 1
utils.print_program(sampler.sops[-k])
print()
print()
utils.print_program(sampler.sops[-k], TEST_INPUTS[0])

map_21 = Map(lambda x: x + 3, tokens)    # type: categorical
sequence_map_22 = SequenceMap(lambda x, y: x*y, tokens, indices)    # type: categorical
select_20 = Select(map_21, sequence_map_22, predicate=Comparison.LEQ)
selector_width_19 = SelectorWidth(select_20)    # type: categorical
select_18 = Select(selector_width_19, tokens, predicate=Comparison.LT)
selector_width_17 = SelectorWidth(select_18)    # type: categorical
map_15 = Map(lambda x: x + 4, selector_width_17)    # type: categorical
select_16 = Select(selector_width_17, tokens, predicate=Comparison.GEQ)
selector_width_14 = SelectorWidth(select_16)    # type: categorical
sequence_map_13 = SequenceMap(lambda x, y: x*y, selector_width_14, map_15)    # type: categorical
select_12 = Select(tokens, sequence_map_13, predicate=Comparison.EQ)
selector_width_11 = SelectorWidth(select_12)    # type: categorical


map_21 = Map(lambda x: x + 3, tokens)    # output: [4, 5, 6, 7]
sequence_map_22 = SequenceMap(lambda x, y: x*y, tokens, indic