In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import pandas as pd
import sys
from collections import defaultdict 
import jax
import flax
import chex
from jaxtyping import ArrayLike
from typing import Union, TypeVar
import numpy as np
import matplotlib.pyplot as plt
import traceback
import jax.numpy as jnp

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import rasp_to_graph


from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import config
from decompile_tracr.dataset import compile as comp
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils
from decompile_tracr.sampling.map_primitives import FunctionWithRepr
from decompile_tracr.tokenizing.str_to_rasp import split_list


rng = np.random.default_rng(0)

In [13]:
program = sampling.sample(rng, 15)
#program = rasp.Map(lambda x: x, rasp.tokens)

In [14]:
from tracr.compiler import assemble
from tracr.compiler import basis_inference
from tracr.compiler import craft_graph_to_model
from tracr.compiler import craft_model_to_transformer
from tracr.compiler import expr_to_craft_graph
from tracr.compiler import rasp_to_graph
from tracr.compiler import validating
from tracr.craft import bases
from tracr.rasp import rasp

extracted = rasp_to_graph.extract_rasp_graph(program)
graph, sources, sink = extracted.graph, extracted.sources, extracted.sink

basis_inference.infer_bases(
    graph,
    sink,
    vocab={0,1,2,3,4},
    max_seq_len=5,
)

expr_to_craft_graph.add_craft_components_to_rasp_graph(
    graph,
    bos_dir=bases.BasisDirection(rasp.tokens.label, "bos"),
    mlp_exactness=100,
)

craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources)

In [15]:
layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
    graph,
    sources,
)

In [16]:
layer_allocation

defaultdict(<function tracr.compiler.craft_graph_to_model._allocate_modules_to_layers.<locals>.<lambda>()>,
            {'map_45': 1,
             'map_42': 1,
             'sequence_map_48': 1,
             'selector_width_57': 2,
             'sequence_map_50': 3,
             'map_59': 5,
             'aggregate_51': 4,
             'linear_sequence_map_64': 7,
             'aggregate_60': 6,
             'linear_sequence_map_71': 9,
             'linear_sequence_map_72': 11,
             'linear_sequence_map_82': 11,
             'linear_sequence_map_108': 13})

In [17]:
tokenizer.rasp_to_str.get_nodes_by_layer(layer_allocation)

{'layer_0/attn': [],
 'layer_0/mlp': ['map_42', 'map_45', 'sequence_map_48'],
 'layer_1/attn': ['selector_width_57'],
 'layer_1/mlp': ['sequence_map_50'],
 'layer_2/attn': ['aggregate_51'],
 'layer_2/mlp': ['map_59'],
 'layer_3/attn': ['aggregate_60'],
 'layer_3/mlp': ['linear_sequence_map_64'],
 'layer_4/attn': [],
 'layer_4/mlp': ['linear_sequence_map_71'],
 'layer_5/attn': [],
 'layer_5/mlp': ['linear_sequence_map_72', 'linear_sequence_map_82'],
 'layer_6/attn': [],
 'layer_6/mlp': ['linear_sequence_map_108']}

In [17]:
graph.nodes

NodeView(('map_34', 'tokens'))