In [8]:
import os
import inspect
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling, rasp_to_graph
import numpy as np
from tracr.compiler.validating import validate
from typing import Union, TypeVar
from rasp_generator import map_primitives, sampling, utils
import networkx as nx

from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 


def compile_rasp_to_model(sop: rasp.SOp, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        sop,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos
    )


TEST_INPUT = [1,2,3,4]

# Program Generation

In [9]:
# Remaining problems
# - sometimes categorical Aggregate is hard to sample (reaches max retries)
# - sometimes a sampled program doesn't depend on rasp.tokens
# - sometimes programs are trivial (e.g. output is all Nones)
# - I suspect that sometimes the output is constant. TODO: check multiple inputs


# TODOS
# - remove SOps that are all (or mostly) None
# - maybe remove / downweight constant SOps?

In [10]:
sampler = sampling.ProgramSampler(validate_compilation=True)
sampler.sample()
print(sampler.output([1,2,3,4]))

[0, 0, 4, 4]


In [11]:
utils.print_program(sampler.sops[2])

select_6 = Select(tokens, tokens, predicate=Comparison.LEQ)
selector_width_5 = SelectorWidth(select_6)    # type: categorical


In [12]:
# print program with and without test input
k = 1
utils.print_program(sampler.sops[-k])
print()
print()
utils.print_program(sampler.sops[-k], TEST_INPUT)

select_10 = Select(tokens, tokens, predicate=Comparison.LEQ)
select_6 = Select(tokens, tokens, predicate=Comparison.LEQ)
selector_width_5 = SelectorWidth(select_6)    # type: categorical
select_12 = Select(tokens, selector_width_5, predicate=Comparison.EQ)
selector_width_11 = SelectorWidth(select_12)    # type: categorical
aggregate_9 = Aggregate(select_10, selector_width_11)    # type: categorical
select_8 = Select(aggregate_9, indices, predicate=Comparison.LT)
selector_width_7 = SelectorWidth(select_8)    # type: categorical


select_10 = Select(tokens, tokens, predicate=Comparison.LEQ)
select_6 = Select(tokens, tokens, predicate=Comparison.LEQ)
selector_width_5 = SelectorWidth(select_6)    # output: [1, 2, 3, 4]
select_12 = Select(tokens, selector_width_5, predicate=Comparison.EQ)
selector_width_11 = SelectorWidth(select_12)    # output: [1, 1, 1, 1]
aggregate_9 = Aggregate(select_10, selector_width_11)    # output: [1, 1, 1, 1]
select_8 = Select(aggregate_9, indices, predicate=Comp