In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys
from collections import defaultdict 
import jax
import flax
import chex
from jaxtyping import ArrayLike
from typing import Union, TypeVar
import numpy as np
import matplotlib.pyplot as plt
import traceback
import jax.numpy as jnp

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import rasp_to_graph


from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import config
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils


rng = np.random.default_rng(0)

In [10]:
vocab.vocab

('lambda x: x * 2',
 'lambda x: x - 1',
 'lambda x, y: x * y',
 '-3',
 'sop_16',
 'lambda x: int(x)',
 'SequenceMap',
 'sop_21',
 'categorical',
 'lambda x: x != 3',
 'lambda x: x == 0',
 'PAD',
 'sop_24',
 'sop_13',
 'indices',
 'sop_3',
 'sop_11',
 'sop_18',
 'lambda x: not x',
 'sop_17',
 '-1',
 'sop_6',
 'lambda x: x + 1',
 'lambda x, y: x + y % 10',
 'sop_0',
 '3',
 'sop_9',
 'lambda x: x < 2',
 'GEQ',
 'NEQ',
 '-2',
 'sop_7',
 'LinearSequenceMap',
 'tokens',
 'sop_22',
 'EQ',
 'LT',
 'sop_20',
 'lambda x: x + 0.5',
 'Map',
 'sop_5',
 'sop_2',
 'lambda x: x',
 'lambda x: x > 3.5',
 '2',
 'lambda x: x > 2',
 '1',
 'sop_19',
 'SelectorWidth',
 'EOS',
 'sop_12',
 'lambda x: x == 3',
 'lambda x: x > 3',
 'lambda x: x < 3',
 'lambda x: bool(x)',
 'lambda x, y: x and y',
 'lambda x: x == 2',
 'numerical',
 'sop_15',
 'LEQ',
 'sop_4',
 'sop_8',
 'sop_23',
 'lambda x: x < 3.5',
 'SEP',
 'BOS',
 'lambda x: x != 2',
 'SelectAggregate',
 'lambda x, y: x or y',
 'TRUE',
 'GT',
 'sop_1',
 'sop

In [3]:
vocab.eos_id

49

In [5]:
vocab.bos_id

65

In [6]:
vocab.size

75

In [7]:
len(vocab.vocab)

75

In [3]:
#m1 = compiling.compile_rasp_to_model(
#    program=rasp.Map(lambda x: x - 1, rasp.Map(lambda x: x + 1, rasp.tokens)),
#    vocab=set(range(5)),
#    max_seq_len=5,
#)
#
#m2 = compiling.compile_rasp_to_model(
#    program=rasp.Map(lambda x: x - 1, rasp.Map(lambda x: x + 1 % 5, rasp.tokens)),
#    vocab=set(range(5)),
#    max_seq_len=5,
#)

In [19]:
m1 = compiling.compile_rasp_to_model(
    program=rasp.numerical(rasp.Map(
        lambda x: x, 
        rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens),
    )),
    vocab=set(range(5)),
    max_seq_len=5,
)


m2 = compiling.compile_rasp_to_model(
    program=rasp.numerical(rasp.Map(
        lambda x: x, 
        rasp.SequenceMap(lambda x, y: x + y, rasp.indices, rasp.tokens),
    )),
    vocab=set(range(5)),
    max_seq_len=5,
)

p1, p2 = m1.params, m2.params

In [20]:
chex.assert_trees_all_close(p1, p2)  # True

In [None]:
m1 = compiling.compile_rasp_to_model(
    program=rasp.numerical(rasp.LinearSequenceMap(
        rasp.numerical(rasp.tokens + 0),
        rasp.numerical(rasp.indices + 0),
        -2, -1,
    )),
    vocab=set(range(5)),
    max_seq_len=5,
)


m2 = compiling.compile_rasp_to_model(
    program=rasp.numerical(rasp.LinearSequenceMap(
        rasp.numerical(rasp.indices + 0),
        rasp.numerical(rasp.tokens + 0),
        -2, -1,
    )),
    vocab=set(range(5)),
    max_seq_len=5,
)


p1, p2 = m1.params, m2.params

In [7]:
chex.assert_trees_all_close(p1, p2)  # True

In [None]:
def plot_params(p):
    fig, axs = plt.subplots(4, 1, figsize=(7, 15))

    axs[0].imshow(p['pos_embed']['embeddings'])
    axs[1].imshow(p['token_embed']['embeddings'])
    axs[2].imshow(p['transformer/layer_0/mlp/linear_1']['w'])
    axs[3].imshow(p['transformer/layer_0/mlp/linear_2']['w'])
    print(p['transformer/layer_0/mlp/linear_1']['b'])
    print(p['transformer/layer_0/mlp/linear_2']['b'])

In [None]:
p1

In [None]:
m1.apply(["compiler_bos", 1,2,3]).decoded

In [None]:
m2.apply(["compiler_bos", 1,2,3]).decoded

In [None]:
m1.output_encoder.encoding_map

In [None]:
m2.output_encoder.encoding_map

In [None]:
m1.__dir__()