In [6]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling
import numpy as np
from tracr.compiler.validating import validate
from typing import Union, TypeVar


# useful primitives

# Morally, RASP has 4 types of SOps:
# int (categorical)
# str (categorical)
# float (numerical)
# bool (numerical)

def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)


MAP_FNS = [
    lambda x: x,
    lambda x: x+1,
    lambda x: x-1,
    lambda x: x*2,
    lambda x: x/2,
    lambda x: x**2,
    lambda x: x**0.5,
]


BOOL_MAP_FNS = [
    lambda x: x,
    lambda x: not x,
]


NONLINEAR_SEQMAP_FNS = [
    lambda x, y: x,
    lambda x, y: y,
    lambda x, y: x+y,
    lambda x, y: x-y,
    lambda x, y: x*y,
    lambda x, y: x/y,
    lambda x, y: x**y,
    lambda x, y: x**1/y,
]

PREDICATES = [
    rasp.Comparison.EQ, 
    rasp.Comparison.FALSE,
    rasp.Comparison.TRUE,
    rasp.Comparison.GEQ,
    rasp.Comparison.GT, 
    rasp.Comparison.LEQ,
    rasp.Comparison.LT, 
    rasp.Comparison.NEQ,
]


In [None]:
# Moral types in RASP
# int (categorical)
# str (categorical)
# bool (categorical)
# float (numerical)
# bool (numerical)

## Proposal for program generation

In [None]:
# RASP instructions:
# - Map
# - SequenceMap
# - LinearSequenceMap (compiled exactly into MLP weights)
# - Select
# - Aggregate
# - SelectorWidth

# One way to do things is to sample realistic transformers:
# ie sample MLP, attn, MLP, attn, etc
# - MLP can be one of Map, SequenceMap, LinearSequenceMap
# - attn can be one of (Select, Aggregate) or (Select, SelectorWidth)

# The other way to do it is to just sample any reasonable RASP program.
# That way we generalize better to example programs that don't follor the 
# transformer pattern.
# But what does 'reasonable RASP program' mean?

# - it follows the Tracr constraints so it compiles correctly
# - it has a reasonable number of instructions
# - it doesn't do anything trivial (eg return a constant)
# - it uses comparisons when appropriate (eg don't < compare two strings)

In [49]:
# 1) sample MLP layer
# 2) sample attn layer




# within a program, keep track of scope and distinguish between
# categorical and numerical variables.
# maybe: also distinguish 'boolean' variables (numerical and only 0,1)

In [27]:
tok = rasp.annotate(rasp.tokens, type="bool")

In [31]:
tok.annotations["type"]

'bool'

In [45]:
out = rasp.Map(lambda x: x/"b", rasp.tokens)
out("abc")

compiled = compiling.compile_rasp_to_model(out, vocab={"a", "b", "c"}, max_seq_len=5, compiler_bos="BOS")
compiled.apply(["BOS", "a", "b", "c"]).decoded

TypeError: unsupported operand type(s) for /: 'str' and 'str'

In [None]:
# 1, 2, 3 -> 2, 3, 4
# a, aa, aaa -> aa, aaa, aaaa

In [29]:
def boolean(v: rasp.SOp):
    return rasp.annotate(v, type="bool")


def categorical_scope(variable_scope: list):
    """Return the subset of variables that are categorical."""
    return [v for v in variable_scope if rasp.is_categorical(v)]


def numerical_scope(variable_scope: list):
    """Return the subset of variables that are numerical."""
    return [v for v in variable_scope if rasp.is_numerical(v)]


def sample_map(variable_scope: list):
    """Sample a map. A map applies a function elementwise to a SOp.
    The input SOps can be categorical or numerical."""
    fn = np.random.choice(MAP_FNS)
    arg = np.random.choice(variable_scope)
    return np.random.choice([
        rasp.Map(fn, arg),
        rasp.numerical(rasp.Map(fn, arg)),
    ])


def sample_sequence_map(variable_scope: list):
    """Sample a sequence map. A SM applies a function elementwise to
    two categorical SOps."""
    fn = np.random.choice(NONLINEAR_SEQMAP_FNS)
    args = np.random.choice(categorical_scope(variable_scope), size=2, replace=False)
    return rasp.SequenceMap(fn, *args)


def sample_linear_sequence_map(variable_scope: list):
    """Sample a linear sequence map. A LNS linearly combines two
    numerical SOps."""
    args = np.random.choice(numerical_scope(variable_scope), size=2, replace=False)
    weights = np.random.normal(size=2)
    return rasp.numerical(rasp.LinearSequenceMap(*args, *weights))

def sample_select(variable_scope: list):
    """Add an attention layer to the program."""
    args = np.random.choice(categorical_scope(variable_scope))
    comparison = np.random.choice(PREDICATES)
    return rasp.Select(*args, comparison)


def sample_aggregate(variable_scope: list, selector):
    sop_arg = np.random.choice(variable_scope)
    aggr = rasp.Aggregate(selector, sop_arg)
    # how to choose sop_arg?
    # 1) If selector has width >1, then sop_arg must be 'boolean',
    # that is numerical and only takes values 0 or 1.
    # 2) If selector has width 1, then sop_arg can be anything.



def is_boolean(sop):
    """Return True if sop is a boolean variable."""
    raise NotImplementedError()




class GeneratedProgram:
    def __init__(self):
        self.tokens = rasp.tokens
        self.indices = rasp.indices
        self.variable_scope = []

    def sample_mlp(self):
        """Add an MLP layer to the program."""
        mlp_out = np.random.choice([
            sample_map(self.variable_scope),
            sample_sequence_map(self.variable_scope),
            sample_linear_sequence_map(self.variable_scope),
        ])
        self.variable_scope.append(mlp_out)

    def sample_attention(self):
        """Add an attention layer to the program."""
        selector = sample_select(self.variable_scope)
        aggregate = np.random.choice([
            sample_aggregate(self.variable_scope, selector),
            rasp.SelectorWidth(selector),
        ])


    
    def sample(self, n_layers=3):
        """Sample a program."""
        for l in range(n_layers):
            self.sample_mlp()
            self.sample_attention()




In [17]:
sample_mlp()

tracr.rasp.rasp.Map