In [5]:
from collections import defaultdict
import os
import inspect
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling, rasp_to_graph
import numpy as np
from tracr.compiler.validating import validate
from rasp_generator import map_primitives, sampling, utils
import networkx as nx

from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 


def compile_rasp_to_model(sop: rasp.SOp, vocab={0,1,2,3,4}, max_seq_len=5, compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        sop,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos
    )


TEST_INPUTS = [
    [1,2,3,4],
    [0,1,2,3,4],
    [4,3,2,1,0],
    [0,3,2,4],
    [1,1,1,1],
    [0,0,0,0],
    [4,4,4,4],
    [0,0,4,0]
]

TEST_INPUTS = [
    [1,2,3,4],
    [0,1,2,3,4],
    [1,1,1,1],
    [0,0,0,0],
    [1,0,1,0],
    [2,4,2,1],
    [4,4,4,3],
]

In [2]:
sampler = sampling.ProgramSampler(validate_compilation=False, disable_categorical_aggregate=False)
errs = sampler.sample(n_sops=30)
print("Retries encountered:")
print()
for err in errs:
    print(err)
print()
print("Program outputs:")
for x in TEST_INPUTS:
    print(sampler.program(x))

Retries encountered:

EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
SamplingError("Could not sample categorical Aggregate with valid output domain (Maximum retries reached). This because the sampler couldn't find a selector with width 1, and other sampled selectors don't result in an output domain that is a subset of the input domain.") (tried to sample categorical_aggregate)
EmptyScopeError('Filter failed. No SOps of type bool in scope.') (tried to sample numerical_aggregate)
EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
SamplingError("Could not sample categorical Aggregate with valid output domain (Maximum retries reached). This because the sampler couldn't find a selector with width 1, and other sampled selectors don't result in an output domain that is a subset of the input domain.") (tried to sample categorical_aggregate)
SamplingError("Could not sample categorical A

In [3]:
# print program with and without test input
k = 1
utils.print_program(sampler.sops[-k])
print()
print()
utils.print_program(sampler.sops[-k], TEST_INPUTS[0])

select_10 = Select(tokens, tokens, predicate=Comparison.LEQ)
map_15 = Map(lambda x: x + 1, tokens)    # type: categorical
select_16 = Select(tokens, tokens, predicate=Comparison.LT)
select_17 = Select(tokens, tokens, predicate=Comparison.EQ)
map_9 = Map(lambda x: x != 0, map_15)    # type: bool
selector_width_11 = SelectorWidth(select_16)    # type: categorical
aggregate_5 = Aggregate(select_10, map_9)    # type: float
sequence_map_13 = SequenceMap(lambda x, y: x*y, selector_width_11, tokens)    # type: categorical
select_19 = Select(indices, selector_width_11, predicate=Comparison.EQ)
map_18 = Map(lambda x: int(x), aggregate_5)    # type: categorical
map_7 = Map(lambda x: x == 1, sequence_map_13)    # type: bool
aggregate_14 = Aggregate(select_19, selector_width_11)    # type: categorical
aggregate_12 = Aggregate(select_17, map_18)    # type: categorical
select_8 = Select(aggregate_14, tokens, predicate=Comparison.GT)
select_6 = Select(selector_width_11, aggregate_12, predicate=Compar

## Validate compilation

In [4]:
errs = defaultdict(list)
results = []
for _ in range(40):
    try:
        sampler = sampling.ProgramSampler()
        retries = sampler.sample(n_sops=30)
        errs['retries'] + retries
        results.append(dict(program=sampler.program))
    except Exception as err:
        errs['sampling'].append(err)
    
print("Done sampling.")

for r in results:
    try:
        model = compile_rasp_to_model(r['program'])
        r['model'] = model
    except Exception as err:
        errs['compilation'].append(err)
        r['compilation_error'] = err
    
print("Done compiling.")

for r in results:
    for x in TEST_INPUTS:
        rasp_out = r['program'](x)
        rasp_out_sanitized = [0 if x is None else x for x in rasp_out]
        model_out = r['model'].apply(["BOS"] + x).decoded[1:]
        if not np.allclose(model_out, rasp_out_sanitized, rtol=1e-3, atol=1e-3):
            err = ValueError(f"Compiled program {r['program'].label} does not match RASP output.\n"
                                f"Compiled output: {model_out}\n"
                                f"RASP output: {rasp_out}\n"
                                f"Test input: {x}\n")
            errs['validation'].append(err)
            r['validation_error'] = err
            break


Done sampling.
Done compiling.


ValueError: ('Inputs {0} not found in encoding ', dict_keys([1, 2, 3, 4, 'BOS', 'compiler_pad']))

In [None]:
errs['compilation']

[ValueError('Failed to find a node with label tokens. This is probably because your RASP program does not include rasp.tokens. A program must include rasp.tokens to be compiled.'),
 ValueError('Failed to find a node with label tokens. This is probably because your RASP program does not include rasp.tokens. A program must include rasp.tokens to be compiled.')]

In [None]:
errs['validation']

[ValueError('Failed to find a node with label tokens. This is probably because your RASP program does not include rasp.tokens. A program must include rasp.tokens to be compiled.'),
 ValueError('Compiled program selector_width_61 does not match RASP output.\nCompiled output: [1, 0, 0, 0, 0]\nRASP output: [2, 0, 0, 0, 0]\nTest input: [0, 1, 2, 3, 4]\nSOp: <tracr.rasp.rasp.SelectorWidth object at 0x7f4aa4c9cf40>'),
 ValueError('Failed to find a node with label tokens. This is probably because your RASP program does not include rasp.tokens. A program must include rasp.tokens to be compiled.'),
 ValueError('Compiled program selector_width_338 does not match RASP output.\nCompiled output: [4, 5, 5, 5]\nRASP output: [0, 0, 4, 4]\nTest input: [1, 2, 3, 4]\nSOp: <tracr.rasp.rasp.SelectorWidth object at 0x7f4aa4cea560>')]