In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling, validating
import numpy as np
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 


def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)


def compile_rasp(x: rasp.SOp,
                 vocab={0, 1, 2, 3}, 
                 max_seq_len=5, 
                 compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        x,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos,
    )

# Apply via:
# out = model.apply([compiler_bos] + [v for v in vocab])
# return out.decoded

# RASP programs that don't compile correctly

In [2]:

# sum of all inputs up to current index
def sum_of_inputs() -> rasp.SOp:
    before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    means = rasp.Aggregate(before, rasp.tokens)  # returns sequence s_i = mean_{j<=i} input_j
    sums = rasp.SequenceMap(lambda x, y: x*y, means, rasp.indices+1)
    return sums

sums = sum_of_inputs()

compiled_model = compiling.compile_rasp_to_model(sums, vocab={1,2,3}, max_seq_len=5, compiler_bos="BOS")
compiled_output = compiled_model.apply(["BOS", 3, 2, 1, 1]).decoded
rasp_output = sums([3, 2, 1, 1])

print("The output of the compiled model does not match the output of the RASP program:")
print(rasp_output)  # output: [3.0, 5.0, 6.0, 7.0]
print(compiled_output)  # output: ['BOS', 3, 4, 3, 4]

print()
print("This error gets caught by the validator:")
print(validating.validate(sums, [1, 2, 3]))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


The output of the compiled model does not match the output of the RASP program:
[3, 5.0, 6.0, 7.0]
['BOS', 3, 4, 3, 4]

This error gets caught by the validator:
[TracrUnsupportedExpr(expr=<tracr.rasp.rasp.Aggregate object at 0x7f0c60d69780>, reason='Categorical aggregate does not support Selectors with width > 1 that require aggregation (eg. averaging).')]


In [3]:
from tracr.compiler import validating

# sum of all inputs up to current index
def sum_of_inputs() -> rasp.SOp:
    before = rasp.numerical(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ))
    means = rasp.numerical(rasp.Aggregate(before, rasp.tokens, default=0))  # returns sequence s_i = mean_{j<=i} input_j
#    sums = rasp.SequenceMap(lambda x, y: x*y, means, rasp.indices+1)
    sums = means * rasp.numerical(rasp.indices+1)
    return sums

sums = sum_of_inputs()

try:
    compiled_model = compiling.compile_rasp_to_model(sums, vocab={1,2,3}, max_seq_len=5, compiler_bos="BOS")
except NotImplementedError as err:
    print(err)
    # caught by compiler (as of a more recent commit)


# compiled_output = compiled_model.apply(["BOS", 3, 2, 1, 1]).decoded
# rasp_output = sums([3, 2, 1, 1])
# 
# print("The output of the compiled model does not match the output of the RASP program:")
# print(rasp_output)  # output: [3.0, 5.0, 6.0, 7.0]
# print(compiled_output)  # output: ['BOS', 3, 4, 3, 4]
# 
# print()
# print("This error gets caught by the validator:")
# print(validating.validate(sums, [1, 2, 3]))

Unsupported RASP expressions:
sequence_map: (Non-linear) SequenceMap only supports categorical inputs/outputs.
aggregate: An aggregate's output encoding must match its input encoding. Input: Encoding.NUMERICAL   Output: Encoding.CATEGORICAL  


In [4]:
# parenthesis matching: count number of unmatched left parens
def count_unmatched_left_parens():
    is_left = rasp.tokens == "("
    is_right = rasp.tokens == ")"
    before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    count_left = rasp.Aggregate(before, is_left) * (rasp.indices + 1)
    count_right = rasp.Aggregate(before, is_right) * (rasp.indices + 1)
    left_unmatched = count_left - count_right
    return left_unmatched


count = count_unmatched_left_parens()


try:
    compiled_model = compiling.compile_rasp_to_model(sums, vocab=set("abc()"), max_seq_len=6, compiler_bos="BOS")
except NotImplementedError as err:
    # caught by compiler now
    print(err)


#compiled_output = compiled_model.apply(["BOS", *list("((abc)")]).decoded
#rasp_output = count("((abc)")
#
#
#print("The output of the compiled model does not match the output of the RASP program:")
#print(rasp_output)
#print(compiled_output)
#
#print()
#print("This error gets caught by the validator:")
#print(validating.validate(count, list("((abc)")))

Unsupported RASP expressions:
sequence_map: (Non-linear) SequenceMap only supports categorical inputs/outputs.
aggregate: An aggregate's output encoding must match its input encoding. Input: Encoding.NUMERICAL   Output: Encoding.CATEGORICAL  


In [5]:
# count number of 'x' tokens in input
is_x = rasp.tokens == "x"
before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
freqs = rasp.Aggregate(before, is_x)
counts = freqs * (rasp.indices + 1)
rasp_output = counts("axxcdx")

counts_model = compile_rasp(counts, vocab={"a", "b", "c", "d", "x"}, max_seq_len=6, compiler_bos="BOS")
compiled_output = counts_model.apply(["BOS", "a", "x", "x", "c", "d", "x"]).decoded



print("The output of the compiled model does not match the output of the RASP program:")
print(rasp_output)
print(compiled_output)

print()
print("This error gets caught by the validator:")
print(validating.validate(counts, ["a", "x"]))

The output of the compiled model does not match the output of the RASP program:
[0, 1.0, 2.0, 2.0, 2.0, 3.0]
['BOS', 0, 0, 3, 0, 0, 0]

This error gets caught by the validator:
[TracrUnsupportedExpr(expr=<tracr.rasp.rasp.Aggregate object at 0x7f0c5f413bb0>, reason='Categorical aggregate does not support Selectors with width > 1 that require aggregation (eg. averaging).')]


In [6]:
sel = rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.EQ)
float_sop = rasp.Aggregate(sel, rasp.indices)
output = rasp.Aggregate(sel, float_sop)


model = compile_rasp(output, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")
compiled_output = model.apply(["BOS", 1, 2, 3, 4]).decoded
rasp_output = output([1, 2, 3, 4])


print("The output of the compiled model does not match the output of the RASP program:")
print(rasp_output)  # [2.0, 3.0, None, None]
print(compiled_output) # ['BOS', 2, 3, 0, 1]

print()
print("This error is not caught by the validator:")
print(validating.validate(counts, [1, 2, 3, 4]))

The output of the compiled model does not match the output of the RASP program:
[2, 3, None, None]
['BOS', 2, 3, 0, 1]

This error is not caught by the validator:
[]


In [7]:
# numerical SOps can't be negative

program = rasp.numerical(rasp.Map(lambda x: x - 3, rasp.tokens))
#program = rasp.Map(lambda x: x - 3, rasp.tokens)
model = compile_rasp(program, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")

rasp_output = program([1, 2, 3, 4])
compiled_output = model.apply(["BOS", 1, 2, 3, 4]).decoded

print(rasp_output)
print(compiled_output)


print()
print("This error is not caught by the validator:")
print(validating.validate(program, [1, 2, 3, 4]))

[-2, -1, 0, 1]
['BOS', 0.0, 0.0, 0.0, 1.0]

This error is not caught by the validator:
[]


In [8]:
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                     lambda k, q: q == k + offset)
  out = rasp.Aggregate(select_off_by_offset, sop, default=None)
  return out.named(f"shift_by({offset})")


shift_by_one = shift_by(1, rasp.tokens)
program = shift_by_one
model = compiling.compile_rasp_to_model(program, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")

rasp_output = program([1, 2, 3, 4])
compiled_output = model.apply(["BOS", 1, 2, 3, 4]).decoded

print(rasp_output)
print(compiled_output)


print()
print("This error is not caught by the validator:")
print(validating.validate(program, [1, 2, 3, 4]))

[None, 1, 2, 3]
['BOS', 1, 1, 2, 3]

This error is not caught by the validator:
[]


In [9]:
all_false = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.FALSE)
program = rasp.Aggregate(all_false, rasp.tokens)
model = compiling.compile_rasp_to_model(program, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")


print(program([1, 2, 3, 4]))
print(model.apply(["BOS", 1, 2, 3, 4]).decoded)
print(validating.validate(program, [1, 2, 3, 4]))  # []

[None, None, None, None]
['BOS', 1, 1, 1, 1]
[]


In [10]:
sel = rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)
program = rasp.Aggregate(sel, rasp.tokens)
model = compiling.compile_rasp_to_model(program, vocab={0,1,2,3,4,23}, max_seq_len=5, compiler_bos="BOS")


print(program([23, 3, 4]))
print(model.apply(["BOS", 23, 3, 4]).decoded)
print(validating.validate(program, [1, 2, 3, 4]))  # []

[None, None, None]
['BOS', 3, 3, 3]
[]


In [11]:
# list of other failure cases caught by compiler:

select_22 = Select(indices, tokens, predicate=Comparison.FALSE)
aggregate_21 = Aggregate(select_22, tokens)    # type: categorical


select_64 = Select(tokens, tokens, predicate=Comparison.FALSE)
aggregate_63 = Aggregate(select_64, tokens)    # type: categorical

select_77 = Select(tokens, indices, predicate=Comparison.EQ)
aggregate_76 = Aggregate(select_77, tokens)    # type: categorical

select_17 = Select(tokens, indices, predicate=Comparison.FALSE)
aggregate_16 = Aggregate(select_17, tokens)    # type: categorical

select_2 = Select(tokens, indices, predicate=Comparison.EQ)
aggregate_1 = Aggregate(select_2, tokens)    # type: categorical

select_2 = Select(tokens, indices, predicate=Comparison.EQ)
aggregate_1 = Aggregate(select_2, tokens)    # type: categorical

In [12]:
# numerical error?

map_14 = Map(lambda x: x + 3, indices)    # type: categorical
sequence_map_12 = SequenceMap(lambda x, y: x*y, indices, tokens)    # type: categorical
sequence_map_13 = SequenceMap(lambda x, y: x*y, indices, map_14)    # type: categorical
select_10 = Select(indices, sequence_map_12, predicate=Comparison.GT)
map_11 = rasp.numerical(Map(lambda x: x > 2, sequence_map_13))    # type: bool
aggregate_9 = rasp.numerical(Aggregate(select_10, map_11, default=0))    # type: float

program = aggregate_9
model = compiling.compile_rasp_to_model(program, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")


print(program([1, 2, 3, 4]))
print(model.apply(["BOS", 1, 2, 3, 4]).decoded)
print(validating.validate(program, [1, 2, 3, 4]))  # []

[1.0, True, 0, 0]
['BOS', 0.9999998807907104, 0.9999994039535522, 1.6166625300684245e-06, 1.6166625300684245e-06]
[]


In [38]:
m = SequenceMap(lambda x, y: x+y, tokens, tokens)    # type: categorical
sel = Select(tokens, tokens, predicate=Comparison.EQ)
program = Aggregate(sel, m)    # type: categorical

model = compiling.compile_rasp_to_model(program, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")


print(program([1, 2, 3, 4]))
print(model.apply(["BOS", 1, 2, 3, 4]).decoded)
print(validating.validate(program, [1, 2, 3, 4]))  # []




ValueError: could not broadcast input array from shape (26,7) into shape (26,6)