In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling
import numpy as np


def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)


def compile_rasp(x: rasp.SOp,
                 vocab={0, 1, 2, 3}, 
                 max_seq_len=5, 
                 compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        x,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos,
    )

# Apply via:
# out = model.apply([compiler_bos] + [v for v in vocab])
# return out.decoded

# RASP programs that don't compile correctly

In [2]:
from tracr.compiler import validating

# sum of all inputs up to current index
def sum_of_inputs() -> rasp.SOp:
    before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    means = rasp.Aggregate(before, rasp.tokens)  # returns sequence s_i = mean_{j<=i} input_j
    sums = rasp.SequenceMap(lambda x, y: x*y, means, rasp.indices+1)
    return sums

sums = sum_of_inputs()

compiled_model = compiling.compile_rasp_to_model(sums, vocab={1,2,3}, max_seq_len=5, compiler_bos="BOS")
compiled_output = compiled_model.apply(["BOS", 3, 2, 1, 1]).decoded
rasp_output = sums([3, 2, 1, 1])

print("The output of the compiled model does not match the output of the RASP program:")
print(rasp_output)  # output: [3.0, 5.0, 6.0, 7.0]
print(compiled_output)  # output: ['BOS', 3, 4, 3, 4]

print()
print("This error gets caught by the validator:")
print(validating.validate(sums, [1, 2, 3]))



The output of the compiled model does not match the output of the RASP program:
[3.0, 5.0, 6.0, 7.0]
['BOS', 3, 4, 3, 4]

This error gets caught by the validator:
[TracrUnsupportedExpr(expr=<tracr.rasp.rasp.Aggregate object at 0x7f5f52962530>, reason='Categorical aggregate does not support Selectors with width > 1 that require aggregation (eg. averaging).')]


In [8]:
from tracr.compiler import validating

# sum of all inputs up to current index
def sum_of_inputs() -> rasp.SOp:
    before = rasp.numerical(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ))
    means = rasp.numerical(rasp.Aggregate(before, rasp.tokens, default=0))  # returns sequence s_i = mean_{j<=i} input_j
#    sums = rasp.SequenceMap(lambda x, y: x*y, means, rasp.indices+1)
    sums = means * rasp.numerical(rasp.indices+1)
    return sums

sums = sum_of_inputs()

compiled_model = compiling.compile_rasp_to_model(sums, vocab={1,2,3}, max_seq_len=5, compiler_bos="BOS")
compiled_output = compiled_model.apply(["BOS", 3, 2, 1, 1]).decoded
rasp_output = sums([3, 2, 1, 1])

print("The output of the compiled model does not match the output of the RASP program:")
print(rasp_output)  # output: [3.0, 5.0, 6.0, 7.0]
print(compiled_output)  # output: ['BOS', 3, 4, 3, 4]

print()
print("This error gets caught by the validator:")
print(validating.validate(sums, [1, 2, 3]))



NotImplementedError: Unsupported RASP expressions:
sequence_map: (Non-linear) SequenceMap only supports categorical inputs/outputs.
aggregate: An aggregate's output encoding must match its input encoding. Input: Encoding.NUMERICAL   Output: Encoding.CATEGORICAL  

In [4]:
# parenthesis matching: count number of unmatched left parens
def count_unmatched_left_parens():
    is_left = rasp.tokens == "("
    is_right = rasp.tokens == ")"
    before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    count_left = rasp.Aggregate(before, is_left) * (rasp.indices + 1)
    count_right = rasp.Aggregate(before, is_right) * (rasp.indices + 1)
    left_unmatched = count_left - count_right
    return left_unmatched


count = count_unmatched_left_parens()


compiled_model = compiling.compile_rasp_to_model(sums, vocab=set("abc()"), max_seq_len=6, compiler_bos="BOS")
compiled_output = compiled_model.apply(["BOS", *list("((abc)")]).decoded
rasp_output = count("((abc)")


print("The output of the compiled model does not match the output of the RASP program:")
print(rasp_output)
print(compiled_output)

print()
print("This error gets caught by the validator:")
print(validating.validate(count, list("((abc)")))



The output of the compiled model does not match the output of the RASP program:
[1.0, 2.0, 2.0, 2.0, 2.0, 1.0]
['BOS', '(', '((', '(((', '((((', '(((((', '((((((']

This error gets caught by the validator:
[TracrUnsupportedExpr(expr=<tracr.rasp.rasp.Aggregate object at 0x7febe1e4c220>, reason='Categorical aggregate does not support Selectors with width > 1 that require aggregation (eg. averaging).'), TracrUnsupportedExpr(expr=<tracr.rasp.rasp.Aggregate object at 0x7fec28288b50>, reason='Categorical aggregate does not support Selectors with width > 1 that require aggregation (eg. averaging).')]


In [5]:
# count number of 'x' tokens in input
is_x = rasp.tokens == "x"
before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
freqs = rasp.Aggregate(before, is_x)
counts = freqs * (rasp.indices + 1)
rasp_output = counts("axxcdx")

counts_model = compile_rasp(counts, vocab={"a", "b", "c", "d", "x"}, max_seq_len=6, compiler_bos="BOS")
compiled_output = counts_model.apply(["BOS", "a", "x", "x", "c", "d", "x"]).decoded




print("The output of the compiled model does not match the output of the RASP program:")
print(rasp_output)
print(compiled_output)

print()
print("This error gets caught by the validator:")
print(validating.validate(counts, ["a", "x"]))



The output of the compiled model does not match the output of the RASP program:
[0.0, 1.0, 2.0, 2.0, 2.0, 3.0]
['BOS', 0, 0, 3, 0, 0, 0]

This error gets caught by the validator:
[TracrUnsupportedExpr(expr=<tracr.rasp.rasp.Aggregate object at 0x7febd83eb100>, reason='Categorical aggregate does not support Selectors with width > 1 that require aggregation (eg. averaging).')]
