In [4]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling
import numpy as np
from tracr.compiler.validating import validate
from typing import Union, TypeVar


def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)


def compile_rasp(x: rasp.SOp,
                 vocab={0, 1, 2, 3}, 
                 max_seq_len=5, 
                 compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        x,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos,
    )
 
# Apply via:
# out = model.apply([compiler_bos] + [v for v in vocab])
# return out.decoded

In [7]:
from lauro_inverse_tracr.map_primitives import CAT_TO_CAT

In [8]:
fn = CAT_TO_CAT[0]

In [11]:
fn(5)

5

In [3]:
rasp.SequenceMap?

[0;31mInit signature:[0m
[0mrasp[0m[0;34m.[0m[0mSequenceMap[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mf[0m[0;34m:[0m [0mCallable[0m[0;34m[[0m[0;34m[[0m[0mUnion[0m[0;34m[[0m[0mNoneType[0m[0;34m,[0m [0mint[0m[0;34m,[0m [0mfloat[0m[0;34m,[0m [0mstr[0m[0;34m,[0m [0mbool[0m[0;34m][0m[0;34m,[0m [0mUnion[0m[0;34m[[0m[0mNoneType[0m[0;34m,[0m [0mint[0m[0;34m,[0m [0mfloat[0m[0;34m,[0m [0mstr[0m[0;34m,[0m [0mbool[0m[0;34m][0m[0;34m][0m[0;34m,[0m [0mUnion[0m[0;34m[[0m[0mNoneType[0m[0;34m,[0m [0mint[0m[0;34m,[0m [0mfloat[0m[0;34m,[0m [0mstr[0m[0;34m,[0m [0mbool[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfst[0m[0;34m:[0m [0mtracr[0m[0;34m.[0m[0mrasp[0m[0;34m.[0m[0mrasp[0m[0;34m.[0m[0mSOp[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msnd[0m[0;34m:[0m [0mtracr[0m[0;34m.[0m[0mrasp[0m[0;34m.[0m[0mrasp[0m[0;34m.[0m[0mSOp[0m[0;34m,[0m[0;34m[0m
[0;3

# Writing RASP programs

In [108]:
def make_reverse(sop: rasp.SOp) -> rasp.SOp:  # categorical -> categorical
    length = make_length()
    reversed_indices = (length - rasp.indices - 1).named("reversed_indices")
    sel = rasp.Select(rasp.indices, reversed_indices, rasp.Comparison.EQ)
    return rasp.Aggregate(sel, sop)

rev = make_reverse(rasp.tokens)
compiled = compile_rasp(rev, vocab={"a", "b", "c"})

print(rev("abc"))
print(compiled.apply(["BOS", "a", "b", "c"]).decoded)



['c', 'b', 'a']
['BOS', 'c', 'b', 'a']


In [109]:
# count fraction of all tokens that are equal to x

#is_x = rasp.tokens == "x"  # bool
def count_x():
    all_x = rasp.Map(lambda x: "x", rasp.indices)
    is_x = rasp.Select(rasp.tokens, all_x, rasp.Comparison.EQ)
    return rasp.SelectorWidth(is_x)


# alternative
def count_x_2():
    is_x = rasp.Select(rasp.tokens, rasp.tokens, lambda tok, _: tok == "x")
    return rasp.SelectorWidth(is_x)


In [124]:
# Histogram
# for every input element, return number of times it appears in the sequence
def histogram(x):
    eq = rasp.Select(x, x, rasp.Comparison.EQ)
    return rasp.SelectorWidth(eq)



# double histogram (from rasp paper)
# for every token, return number of elements in the sequence 
# that occur the same number of times
def double_histogram(x):
    hist = histogram(x)
    return histogram(hist)


double_histogram(rasp.tokens)("abcdd")

[3, 3, 3, 2, 2]

In [195]:
# Count the fraction of *previous* tokens that are equal to x
def count_prev_x():
    previous = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    return rasp.numerical(
        rasp.Aggregate(previous, rasp.numerical(rasp.tokens == "x"), default=0))

## Running List of Tracr Constraints

In [185]:
# attn patterns (ie aggregate) can only average binary variables
# eg the following does NOT work:

# compute the sum of previous tokens
def sum():
    tok = rasp.numerical(rasp.tokens+1)
    previous = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    return rasp.numerical(rasp.Aggregate(previous, tok, default=0))

print(sum()([1,2,3]))  # runs but wouldn't compile


# Compare this which does work bc it only averages 0s and 1s:
def frac_prev_x():
    previous = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    return rasp.numerical(rasp.Aggregate(previous, rasp.numerical(rasp.tokens == "x"), default=0))

print(frac_prev_x()("abxcx"))  # compiles fine

[2.0, 2.5, 3.0]
[0.0, 0.0, 0.3333333333333333, 0.25, 0.4]


# Other Stuff

In [76]:
tok = rasp.tokens
floats = rasp.numerical(tok + 0.1)
floats([1,2,3])

[1.1, 2.1, 3.1]

In [79]:
notfloats = ~floats

In [82]:
rasp.is_numerical(notfloats)

False

In [73]:
bools = floats == 1.1000000000001
bools([1,2,3])

[False, False, False]

In [74]:
1.1 == 1.100000000000001

False

In [75]:
compiled = compile_rasp(bools, vocab={0, 1, 2, 3})
compiled.apply(["BOS", 1,2,3]).decoded

['BOS', False, False, False]