In [4]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling
import numpy as np
from tracr.compiler.validating import validate
from typing import Union, TypeVar
import jax


from tracr.compiler import assemble
from tracr.compiler import basis_inference
from tracr.compiler import craft_graph_to_model
from tracr.compiler import craft_model_to_transformer
from tracr.compiler import expr_to_craft_graph
from tracr.compiler import rasp_to_graph
from tracr.compiler import validating
from tracr.craft import bases
from tracr.rasp import rasp


def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)


def compile_rasp_to_model(x: rasp.SOp,
                 vocab={0, 1, 2, 3}, 
                 max_seq_len=5, 
                 compiler_bos="BOS"):
    return compiling.compile_rasp_to_model(
        x,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=compiler_bos,
    )
 

def rasp_to_layerwise_representation(program: rasp.SOp):
    """Convert a RASP program to a representation that maps every layer
    to corresponding RASP operations performed by that layer."""
    vocab = {1,2,3,4}
    max_seq_len = 5
    # first do the same steps as in `compile_rasp_to_model`
    extracted = rasp_to_graph.extract_rasp_graph(program)
    graph, sources, sink = extracted.graph, extracted.sources, extracted.sink
    basis_inference.infer_bases(
        graph,
        sink,
        vocab,
        max_seq_len,
    )

    expr_to_craft_graph.add_craft_components_to_rasp_graph(
        graph,
        bos_dir=bases.BasisDirection(rasp.tokens.label, "BOS"),
        mlp_exactness=100,
    )


    nodes_to_layers = craft_graph_to_model._allocate_modules_to_layers(graph, sources)

    # we want a dictionary the other way around, i.e. mapping from layer to RASP operations
    n_layers = max(nodes_to_layers.values()) + 1
    if n_layers % 2 != 0:
        n_layers += 1  # n_layers is always even (tracr will add dummy MLP block at the end)
    layers_to_nodes = {layer: [] for layer in range(n_layers)}
    for node_id, layer in nodes_to_layers.items():
        layers_to_nodes[layer].append(node_id)

    return graph, layers_to_nodes


# sample rasp program
def count_x(x=1):
    all_x = rasp.Map(lambda _: x, rasp.indices)
    is_x = rasp.Select(rasp.tokens, all_x, rasp.Comparison.EQ)
    return rasp.SelectorWidth(is_x)


count = count_x()

In [5]:
def print_expr(expr: rasp.RASPExpr):
    args = ", ".join([arg.label for arg in expr.children])

    if isinstance(expr, rasp.Select):
        args += f", predicate={expr.predicate}"
    elif isinstance(expr, rasp.Map):
        args += f", f={expr.f}"

    print(f"{expr.label} = {expr.name}({args})")
    return None

In [6]:
graph, layers_to_nodes = rasp_to_layerwise_representation(count)

In [7]:
for layer, node_ids in layers_to_nodes.items():
    print(f"Layer {layer}")
    for node_id in node_ids:
        expr = graph.nodes[node_id]["EXPR"]
        print_expr(expr)
    print()

Layer 0

Layer 1
map_3 = map(indices, f=<function count_x.<locals>.<lambda> at 0x7f33597d8af0>)

Layer 2
selector_width_1 = selector_width(select_2)

Layer 3

