In [1]:
import os
import inspect
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling, rasp_to_graph
import numpy as np
from tracr.compiler.validating import validate
from typing import Union, TypeVar
from rasp_generator import map_primitives, sampling, utils
import networkx as nx

from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 


def compile_rasp_to_model(sop: rasp.SOp, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        sop,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos
    )


TEST_INPUT = [1,2,3,4]

# Program Generation

In [2]:
# Remaining problems
# - sometimes categorical Aggregate is hard to sample (reaches max retries)
# - sometimes a sampled program doesn't depend on rasp.tokens
# - sometimes programs are trivial (e.g. output is all Nones)
# - I suspect that sometimes the output is constant. TODO: check multiple inputs


# TODOS
# - remove SOps that are all (or mostly) None
# - maybe remove / downweight constant SOps?

In [14]:
sampler = sampling.ProgramSampler(validate_compilation=True)
errs = sampler.sample()
print(sampler.program([1,2,3,4]))
print("Errors encountered:")
print(errs)

[False, True, False, False]
Errors encountered:
[EmptyScopeError('No SOps of type float in scope.'), EmptyScopeError('No SOps of type float in scope.')]


In [4]:
utils.print_program(sampler.sops[2])

sequence_map_1 = SequenceMap(lambda x, y: x*y, indices, tokens)    # type: categorical


In [5]:
# print program with and without test input
k = 1
utils.print_program(sampler.sops[-k])
print()
print()
utils.print_program(sampler.sops[-k], TEST_INPUT)

sequence_map_1 = SequenceMap(lambda x, y: x*y, indices, tokens)    # type: categorical
select_5 = Select(sequence_map_1, indices, predicate=Comparison.EQ)
aggregate_4 = Aggregate(select_5, sequence_map_1)    # type: categorical
sequence_map_3 = SequenceMap(lambda x, y: x*y, aggregate_4, sequence_map_1)    # type: categorical
sequence_map_2 = SequenceMap(lambda x, y: x*y, tokens, sequence_map_3)    # type: categorical


sequence_map_1 = SequenceMap(lambda x, y: x*y, indices, tokens)    # output: [0, 2, 6, 12]
select_5 = Select(sequence_map_1, indices, predicate=Comparison.EQ)
aggregate_4 = Aggregate(select_5, sequence_map_1)    # output: [0, None, 2, None]
sequence_map_3 = SequenceMap(lambda x, y: x*y, aggregate_4, sequence_map_1)    # output: [0, None, 12, None]
sequence_map_2 = SequenceMap(lambda x, y: x*y, tokens, sequence_map_3)    # output: [0, None, 36, None]
