In [1]:
from collections import defaultdict
import os
import inspect
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling, rasp_to_graph
import numpy as np
from tracr.compiler.validating import validate
from rasp_generator import map_primitives, sampling, utils
import networkx as nx

from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 


def compile_rasp_to_model(sop: rasp.SOp, vocab={0,1,2,3,4}, max_seq_len=5, compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        sop,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos
    )


TEST_INPUTS = [
    [1,2,3,4],
    [0,1,2,3,4],
    [4,3,2,1,0],
    [0,3,2,4],
    [1,1,1,1],
    [0,0,0,0],
    [4,4,4,4],
    [0,0,4,0]
]


TEST_INPUTS = [
    [1,2,3,4],
    [0,1,2,3,4],
    [1,1,1,1],
    [0,0,0,0],
    [1,0,1,0],
    [2,4,2,1],
    [4,4,4,3],
]

In [2]:
sampler = sampling.ProgramSampler(validate_compilation=False, disable_categorical_aggregate=False)
errs = sampler.sample(n_sops=30)
print("Retries encountered:")
print()
for err in errs:
    print(err)
print()
print("Program outputs:")
for x in TEST_INPUTS:
    print(sampler.program(x))

Retries encountered:

EmptyScopeError('Filter failed. No SOps of type bool in scope.') (tried to sample numerical_aggregate)
SamplingError("Could not sample categorical Aggregate with valid output domain (Maximum retries reached). This because the sampler couldn't find a selector with width 1, and other sampled selectors don't result in an output domain that is a subset of the input domain.") (tried to sample categorical_aggregate)
EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
EmptyScopeError('Filter failed. No SOps of type float in scope.') (tried to sample linear_sequence_map)
SamplingError("Could not sample categorical Aggregate with valid output domain (Maximum retries reached). This because the sampler couldn't find a s

In [3]:
# print program with and without test input
k = 1
utils.print_program(sampler.sops[-k])
print()
print()
utils.print_program(sampler.sops[-k], TEST_INPUTS[0])

map_6 = Map(lambda x: x != 4, tokens)    # type: bool
map_8 = Map(lambda x: x < 1, tokens)    # type: bool
select_11 = Select(tokens, tokens, predicate=Comparison.EQ)
select_17 = Select(tokens, tokens, predicate=Comparison.GEQ)
select_15 = Select(tokens, indices, predicate=Comparison.GEQ)
selector_width_16 = SelectorWidth(select_17)    # type: categorical
selector_width_13 = SelectorWidth(select_15)    # type: categorical
sequence_map_14 = SequenceMap(lambda x, y: x*y, tokens, selector_width_16)    # type: categorical
select_12 = Select(selector_width_13, sequence_map_14, predicate=Comparison.LT)
selector_width_10 = SelectorWidth(select_12)    # type: categorical
select_7 = Select(selector_width_10, selector_width_10, predicate=Comparison.LEQ)
aggregate_9 = Aggregate(select_11, selector_width_10)    # type: categorical
aggregate_4 = Aggregate(select_7, map_8)    # type: float
select_5 = Select(tokens, aggregate_9, predicate=Comparison.LT)
aggregate_3 = Aggregate(select_5, map_6)    # t

## Validate compilation

In [10]:
d = {}
a = 5
d['test'] = a

a = 17
d

{'test': 5}

In [11]:
errs = defaultdict(list)
results = []
for _ in range(4):
    try:
        sampler = sampling.ProgramSampler()
        retries = sampler.sample(n_sops=30)
        errs['retries'] + retries
        results.append(dict(program=sampler.program))
    except Exception as err:
        errs['sampling'].append(err)
    
print("Done sampling.")

for r in results:
    try:
        r['model'] = compile_rasp_to_model(r['program'])
        r['label'] = r['program'].label
    except Exception as err:
        errs['compilation'].append(err)
        r['compilation_error'] = err
    
print("Done compiling.")

for r in results:
    for x in TEST_INPUTS:
        rasp_out = r['program'](x)
        rasp_out_sanitized = [0 if x is None else x for x in rasp_out]
        model_out = r['model'].apply(["BOS"] + x).decoded[1:]
        if not np.allclose(model_out, rasp_out_sanitized, rtol=1e-3, atol=1e-3):
            err = ValueError(f"Compiled program {r['program'].label} does not match RASP output.\n"
                                f"Compiled output: {model_out}\n"
                                f"RASP output: {rasp_out}\n"
                                f"Test input: {x}\n")
            errs['validation'].append(err)
            r['validation_error'] = err
            break


Done sampling.
Done compiling.


In [13]:
for r in results:
    print(r['model'].apply(["BOS", 4,2,1,3]).decoded)

['BOS', 1.1180505680385977e-05, 0.9999963045120239, 0.9999963045120239, 0.7500052452087402]
['BOS', 4, 2, 4, 4]
['BOS', 0, 4, 8, 243]
['BOS', 0, 4, 6, 15]


In [6]:
errs['compilation']

[]

In [7]:
errs['validation']

[]