In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys
from collections import defaultdict 
import jax
import flax
import chex
from jaxtyping import ArrayLike
from typing import Union, TypeVar
import numpy as np
import matplotlib.pyplot as plt
import traceback
import jax.numpy as jnp

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import rasp_to_graph


from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import config
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils


rng = np.random.default_rng(0)

In [3]:
with jax.default_device(jax.devices("cpu")[0]):
    data = data_utils.load_dataset_for_model_input(
        loadfile=config.data_dir / "full.h5",
        ndata=50,
    )

2024-04-26 14:34:36 - [INFO]: Loading data from /home/lauro/projects/meta-models/decompile-tracr/data/full.h5.
2024-04-26 14:34:36 - [INFO]: load_dataset_for_model_input: All data loading and processing took 0.06s.


In [4]:
print("keys:", list(data.keys()))
print("data shapes:", {k: v.shape for k, v in data.items()})

keys: ['layer_idx', 'n_layers', 'n_sops', 'tokens', 'weights']
data shapes: {'layer_idx': (50, 25), 'n_layers': (50,), 'n_sops': (50,), 'tokens': (50, 256), 'weights': (50, 65536)}


## Tokens

In [5]:
# check for duplicates among tokens
tokens = data["tokens"]
unique_tokens = defaultdict(list)

for i, token in enumerate(tokens):
    t = tuple(token.tolist())
    unique_tokens[t].append(i)

print(f"Found {len(unique_tokens)}/{len(tokens)} unique tokenized programs "
      f"({100 * len(unique_tokens) / len(tokens):.2f}%)")

Found 50/50 unique tokenized programs (100.00%)


In [6]:
def assert_sops_sorted(toks):
    decoded = tokenizer.decode(toks)
    sops = [x for x in decoded if x.startswith("sop_") 
            and not ("tokens" in x or "indices" in x)]
    first_occurences = []
    for sop in sops:
        if sop not in first_occurences:
            first_occurences.append(sop)
    assert sorted(first_occurences) == first_occurences, first_occurences


for i, toks in enumerate(tokens):
    assert_sops_sorted(toks)

In [7]:
tokenizer.decode(tokens[0])

['BOS',
 'sop_00',
 'categorical',
 'SelectAggregate',
 'sop_10_tokens',
 'sop_10_tokens',
 'EQ',
 'sop_10_tokens',
 'EOO',
 'EOL',
 'sop_01',
 'numerical',
 'Map',
 'lambda x: x + 0.5',
 'sop_10_tokens',
 'EOO',
 'EOL',
 'sop_02',
 'categorical',
 'SelectAggregate',
 'sop_10_tokens',
 'sop_00',
 'EQ',
 'sop_00',
 'EOO',
 'EOL',
 'sop_03',
 'categorical',
 'SequenceMap',
 'lambda x, y: x * y',
 'sop_02',
 'sop_00',
 'EOO',
 'sop_04',
 'categorical',
 'SequenceMap',
 'lambda x, y: x + y % 10',
 'sop_02',
 'sop_10_tokens',
 'EOO',
 'EOL',
 'EOL',
 'sop_05',
 'numerical',
 'Map',
 'lambda x: x > 3',
 'sop_04',
 'EOO',
 'EOL',
 'sop_06',
 'numerical',
 'SelectAggregate',
 'sop_11_indices',
 'sop_03',
 'NEQ',
 'sop_05',
 'EOO',
 'EOL',
 'sop_07',
 'numerical',
 'LinearSequenceMap',
 'sop_06',
 'sop_01',
 '3',
 '1',
 'EOO',
 'EOL',
 'EOS',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 '

In [8]:
# number of non-padding tokens
print(f"Number of nonzero tokens: {(tokens > 0).sum() / tokens.size * 100:0.1f}%")

Number of nonzero tokens: 26.6%


In [9]:
# distribution of token types

# encodings
cat, num = (tokenizer.encode_token(t) for t in ["categorical", "numerical"])
n_categorical = (tokens == cat).sum()
n_numerical = (tokens == num).sum()
total = n_categorical + n_numerical

print(f"Categorical sops: {100*n_categorical/total:0.1f}%")
print(f"Numerical sops: {100*n_numerical/total:0.1f}%")
print(f"Total sops: {total:,}")

Categorical sops: 44.7%
Numerical sops: 55.3%
Total sops: 405


In [10]:
print(" ".join(tokenizer.decode(data['tokens'][2])).replace("EOL ", "EOL\n"))

BOS EOL
sop_00 categorical Map lambda x: x - 1 sop_11_indices EOO sop_01 numerical Map lambda x: x != 3 sop_11_indices EOO EOL
sop_02 categorical SelectAggregate sop_00 sop_00 EQ sop_10_tokens EOO EOL
sop_03 categorical Map lambda x: not x sop_01 EOO sop_04 numerical Map lambda x: x sop_01 EOO EOL
sop_05 numerical SelectAggregate sop_10_tokens sop_02 GT sop_01 EOO sop_06 numerical SelectAggregate sop_10_tokens sop_03 NEQ sop_04 EOO EOL
sop_07 numerical LinearSequenceMap sop_06 sop_05 -3 3 EOO EOL
EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PA

In [11]:
ops = tokenizer.encode(vocab.ops)
op_counts = {vocab.vocab[op]: (tokens == op).sum() for op in ops}
total = sum(op_counts.values())

print("Operation counts:")
for op, count in op_counts.items():
    print(f"{op}: {100*count/total:.1f}%")

print(f"Total SOps: {total:,}")

Operation counts:
Map: 34.6%
SequenceMap: 18.5%
LinearSequenceMap: 16.0%
SelectAggregate: 27.7%
SelectorWidth: 3.2%
Total SOps: 405


## Programs

In [12]:
def get_test_inputs_and_outputs(
    programs: list[rasp.SOp],
    n_samples: int = 50,
):
    """Generate test inputs and pass forward through programs to get outputs."""
    test_inputs = [rasp_utils.sample_test_input(rng, max_seq_len=5, 
                                    min_seq_len=5, vocab=set(range(10))) 
                for _ in range(n_samples)]
    outputs = [[p(x) for x in test_inputs] for p in programs]
    outputs = np.array(outputs, dtype=float)
    outputs = np.nan_to_num(outputs, nan=0.0)
    return test_inputs, outputs


def test_low_var(outputs: list):
    """Test that sampled programs have a reasonable amount of variance wrt input"""
    stds = np.std(outputs, axis=1).sum(axis=-1)  # std across test inputs; sum across output sequence
    are_low_var = stds < 0.01
    frac_low_var = sum(are_low_var) / len(stds)
    print(f"{frac_low_var*100}% of programs have low variance in output.")

In [13]:
programs = [tokenizer.detokenize(t) for t in tokens]
inputs, outputs = get_test_inputs_and_outputs(programs)
test_low_var(outputs)


outputs_buffer = outputs.copy()

program_data = []
for i, p in enumerate(programs):
    program_data.append({
        "program": p,
        "outputs": outputs[i],
        "std": np.std(outputs[i], axis=0).sum(),
    })


# sort by std
by_std = sorted(program_data, key=lambda x: np.std(x['outputs']))
len(by_std)
by_std = iter(by_std)

14.000000000000002% of programs have low variance in output.


In [14]:
p = next(by_std)
print("std:", np.std(p['outputs']))
print()
print('input: ', inputs[0])
rasp_utils.print_program(p['program'], test_input=inputs[0], full=True)
print()
print('sample outputs:', p['outputs'][:10])

std: 0.0

input:  [8, 6, 5, 2, 3]
select_24 = Select(tokens, tokens, predicate=Comparison.TRUE)
select_28 = Select(tokens, tokens, predicate=Comparison.EQ)
sop_01_23 = rasp.numerical(Map(lambda x: x == 0, indices, simplify=False))    # output: [True, False, False, False, False]
sop_00_27 = rasp.categorical(Aggregate(select_28, tokens))    # output: [8, 6, 5, 2, 3]
sop_02_21 = rasp.numerical(Aggregate(select_24, sop_01_23))    # output: [0.2, 0.2, 0.2, 0.2, 0.2]
sop_03_26 = rasp.categorical(SequenceMap(lambda x, y: x + y % 10, tokens, sop_00_27))    # output: [16, 12, 10, 4, 6]
sop_04_25 = rasp.categorical(SequenceMap(lambda x, y: x - y, indices, sop_03_26))    # output: [-16, -11, -8, -1, -2]
select_22 = Select(tokens, sop_04_25, predicate=Comparison.LEQ)
sop_05_20 = rasp.numerical(Aggregate(select_22, sop_01_23))    # output: [0, 0, 0, 0, 0]
sop_06_19 = rasp.numerical(LinearSequenceMap(sop_05_20, sop_02_21, -1, 2))    # output: [0.4, 0.4, 0.4, 0.4, 0.4]
sop_07_18 = rasp.numerical(Map(

## Weights

In [15]:
# check for duplicates among tokens
weights = data["weights"]
unique = defaultdict(list)
duplicate_weights = []

for i, w in enumerate(weights):
    w = tuple(w.flatten().tolist())
    if w in unique:
        duplicate_weights.append(i)
    
    unique[w].append(i)

print(f"Found {len(unique)}/{len(weights)} unique model params "
      f"({100 * len(unique) / len(weights):.2f}%)")

Found 50/50 unique model params (100.00%)


In [20]:
print(f"percent padding: {100 * (weights == 0.05).sum() / weights.size:0.1f}%")
print(f"percent zero: {100 * (weights == 0).sum() / weights.size:0.1f}%")
print(f"left over: {100 * np.logical_and(weights != 0, weights != 0.05).sum() / weights.size:0.1f}%")

percent padding: 64.7%
percent zero: 34.8%
left over: 0.5%


## Visualize

In [21]:
def get_percentages(idx):
    w = data["weights"][idx]

    print(f"percent padding: {100 * (w == 0.05).sum() / w.size:0.1f}%")
    print(f"percent zero: {100 * (w == 0).sum() / w.size:0.1f}%")
    print(f"left over: {100 * np.logical_and(w != 0, w != 0.05).sum() / w.size:0.1f}%")


def plot_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    w = w.flatten()

    plt.plot(w, ".")
    plt.yscale("symlog", linthresh=0.1)

    print(" ".join(tokenizer.decode(t)))


def imshow_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    _, d_model = w.shape
    w = w.flatten()

    is_padding = w == 0.05
    first_padding_idx = is_padding.tolist().index(True)
    idx = first_padding_idx + (d_model - first_padding_idx % d_model)
    reshaped_w = w[:idx].reshape(-1, d_model)
    reshaped_w[reshaped_w == 0] = np.nan
    plt.imshow(reshaped_w, aspect="auto", interpolation="nearest")


get_percentages(0)

percent padding: 56.4%
percent zero: 43.0%
left over: 0.5%


In [22]:
idx = 4
#plot_datapoint(idx)

In [29]:
w = data["weights"][idx]
w = data_utils.symlog(w, linear_thresh = 2.0)
w = w.flatten()
#plt.plot(w, ".")

## Investigate duplicates

In [30]:
dupe_idx = duplicate_weights[0]
dupe_w = tuple(weights[dupe_idx].flatten().tolist())
unique[dupe_w]

IndexError: list index out of range

In [31]:
a, b = unique[dupe_w][0], unique[dupe_w][1]
print(" ".join(tokenizer.decode(data['tokens'][a])).replace("EOL ", "EOL\n").replace(" PAD", ""))
print()
print(" ".join(tokenizer.decode(data['tokens'][b])).replace("EOL ", "EOL\n").replace(" PAD", ""))

NameError: name 'dupe_w' is not defined

In [32]:
for dupe_idx in duplicate_weights:
    w = weights[dupe_idx]
    t = tokens[dupe_idx]

    duplicates = unique[tuple(w.flatten().tolist())]

    if not all([np.all(tokens[i] == tokens[dupe_idx]) for i in duplicates]):
        print(f"dupe idx: {dupe_idx}")
        print("Found duplicates with different tokens:")
        for i in duplicates:
            print(" ".join(tokenizer.decode(tokens[i])))
        print()
        print()
        print()
        print()

## Check for close duplicates

In [None]:
# from tqdm import tqdm
# 
# for w in tqdm(weights):
#     close = [np.allclose(w, u) for u in unique.values()]