In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys
from collections import defaultdict 
import jax
import flax
import chex
from jaxtyping import ArrayLike
from typing import Union, TypeVar
import numpy as np
import matplotlib.pyplot as plt
import traceback
import jax.numpy as jnp

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import rasp_to_graph


from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import config
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils


rng = np.random.default_rng(0)

In [2]:
VAL_DATA_RATIO = 0.1
MAX_RASP_LENGTH = config.MAX_RASP_LENGTH
MAX_WEIGHTS_LENGTH = config.MAX_WEIGHTS_LENGTH
FULL_DATA_DIR = config.full_dataset_dir
ALL_LAYERS_MULTIPLIER = 15
split_layers = False

In [3]:
data = data_utils.load_dataset_for_model_input(
    rng=rng,
    loaddir=FULL_DATA_DIR,
    max_data=1000,
    shuffle=True,
    d_model=128,
    max_rasp_len=MAX_RASP_LENGTH if split_layers else MAX_RASP_LENGTH * ALL_LAYERS_MULTIPLIER,
    max_weights_len=MAX_WEIGHTS_LENGTH if split_layers else MAX_WEIGHTS_LENGTH * ALL_LAYERS_MULTIPLIER,
    split_layers=split_layers,
)

2024-04-12 12:11:19 - [INFO]: Loading data from /home/lauro/projects/meta-models/decompile-tracr/data/full.


2024-04-12 12:11:19 - [INFO]: load_batches: Loaded 1079 >= 1000 datapoints. Stopping and truncating to 1000 datapoints.


2024-04-12 12:11:19.821075: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW
2024-04-12 12:11:19.821160: E external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:256] kernel version 535.161.7 does not match DSO version 535.171.4 -- cannot find working devices in this configuration
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
print("keys:", list(data.keys()))
print("data shapes:", {k: v.shape for k, v in data.items()})

keys: ['tokens', 'weights', 'n_sops', 'program_id', 'n_layers']
data shapes: {'tokens': (1000, 1920), 'weights': (1000, 1920, 128), 'n_sops': (1000,), 'program_id': (1000,), 'n_layers': (1000,)}


## Tokens

In [5]:
# check for duplicates among tokens
tokens = data["tokens"]
unique_tokens = defaultdict(list)

for i, token in enumerate(tokens):
    t = tuple(token.tolist())
    unique_tokens[t].append(i)

print(f"Found {len(unique_tokens)}/{len(tokens)} unique tokens "
      f"({100 * len(unique_tokens) / len(tokens):.2f}%)")

Found 1000/1000 unique tokens (100.00%)


In [6]:
# number of non-padding tokens
print(f"Number of nonzero tokens: {(tokens > 0).sum() / tokens.size * 100:0.1f}%")

Number of nonzero tokens: 1.5%


In [7]:
# distribution of token types

# encodings
cat, num = (tokenizer.encode(t) for t in ["categorical", "numerical"])
n_categorical = (tokens == cat).sum()
n_numerical = (tokens == num).sum()
total = n_categorical + n_numerical

print(f"Categorical sops: {100*n_categorical/total:0.1f}%")
print(f"Numerical sops: {100*n_numerical/total:0.1f}%")
print(f"Total sops: {total:,}")

Categorical sops: 64.7%
Numerical sops: 35.3%
Total sops: 3,223


In [8]:
ops = (tokenizer.encode(t) for t in vocab.ops)
op_counts = {vocab.vocab[op]: (tokens == op).sum() for op in ops}
total = sum(op_counts.values())

print("Operation counts:")
for op, count in op_counts.items():
    print(f"{op}: {100*count/total:.1f}%")

print(f"Total SOps: {total:,}")

Operation counts:
Map: 33.3%
SequenceMap: 31.1%
LinearSequenceMap: 3.6%
SelectAggregate: 29.8%
SelectorWidth: 2.2%
Total SOps: 3,223


## Programs

In [9]:
def get_test_inputs_and_outputs(
    programs: list[rasp.SOp],
    n_samples: int = 50,
):
    """Generate test inputs and pass forward through programs to get outputs."""
    test_inputs = [rasp_utils.sample_test_input(rng, max_seq_len=5, 
                                    min_seq_len=5, vocab=set(range(10))) 
                for _ in range(n_samples)]
    outputs = [[p(x) for x in test_inputs] for p in programs]
    outputs = np.array(outputs, dtype=float)
    outputs = np.nan_to_num(outputs, nan=0.0)
    return test_inputs, outputs


def test_low_var(outputs: list):
    """Test that sampled programs have a reasonable amount of variance wrt input"""
    stds = np.std(outputs, axis=1).sum(axis=-1)  # std across test inputs; sum across output sequence
    are_low_var = stds < 0.01
    frac_low_var = sum(are_low_var) / len(stds)
    print(f"{frac_low_var*100}% of programs have low variance in output.")

In [10]:
programs = [tokenizer.detokenize(t) for t in tokens]
inputs, outputs = get_test_inputs_and_outputs(programs)
test_low_var(outputs)


outputs_buffer = outputs.copy()

program_data = []
for i, p in enumerate(programs):
    program_data.append({
        "program": p,
        "outputs": outputs[i],
        "std": np.std(outputs[i], axis=0).sum(),
    })


# sort by std
by_std = sorted(program_data, key=lambda x: np.std(x['outputs']))
len(by_std)
by_std = iter(by_std)

6.4% of programs have low variance in output.


In [11]:
p = next(by_std)
print("std:", np.std(p['outputs']))
print()
print('input: ', inputs[0])
rasp_utils.print_program(p['program'], test_input=inputs[0], full=True)
print()
print('sample outputs:', p['outputs'][:10])

std: 0.0

input:  [9, 3, 2, 4, 1]
sop_0_20 = rasp.numerical(Map(lambda x: x == 1, indices))    # output: [False, True, False, False, False]
sop_1_21 = rasp.categorical(SequenceMap(lambda x, y: x - y, tokens, indices))    # output: [9, 2, 0, 1, -3]
select_19 = Select(sop_1_21, tokens, predicate=Comparison.TRUE)
sop_2_18 = rasp.numerical(Aggregate(select_19, sop_0_20))    # output: [0.2, 0.2, 0.2, 0.2, 0.2]

sample outputs: [[0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2]]


## Weights

In [12]:
# check for duplicates among tokens
weights = data["weights"]
unique = defaultdict(list)
duplicate_weights = []

for i, w in enumerate(weights[:300]):
    w = tuple(w.flatten().tolist())
    if w in unique:
        duplicate_weights.append(i)
    
    unique[w].append(i)

print(f"Found {len(unique)}/{len(weights)} unique model params "
      f"({100 * len(unique) / len(weights):.2f}%)")

Found 300/1000 unique model params (30.00%)


In [13]:
print(f"percent padding: {100 * (weights == 0.05).sum() / weights.size:0.1f}%")
print(f"percent zero: {100 * (weights == 0).sum() / weights.size:0.1f}%")
print(f"left over: {100 * np.logical_and(weights != 0, weights != 0.05).sum() / weights.size:0.1f}%")

percent padding: 97.6%
percent zero: 2.3%
left over: 0.1%


## Visualize

In [14]:
def get_percentages(idx):
    w = data["weights"][idx]

    print(f"percent padding: {100 * (w == 0.05).sum() / w.size:0.1f}%")
    print(f"percent zero: {100 * (w == 0).sum() / w.size:0.1f}%")
    print(f"left over: {100 * np.logical_and(w != 0, w != 0.05).sum() / w.size:0.1f}%")


def plot_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    w = w.flatten()

    plt.plot(w, ".")
    plt.yscale("symlog", linthresh=0.1)

    print(" ".join(tokenizer.decode(t)))


def imshow_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    _, d_model = w.shape
    w = w.flatten()

    is_padding = w == 0.05
    first_padding_idx = is_padding.tolist().index(True)
    idx = first_padding_idx + (d_model - first_padding_idx % d_model)
    reshaped_w = w[:idx].reshape(-1, d_model)
    reshaped_w[reshaped_w == 0] = np.nan
    plt.imshow(reshaped_w, aspect="auto", interpolation="nearest")


get_percentages(0)

percent padding: 98.0%
percent zero: 1.9%
left over: 0.1%


In [15]:
idx = 4
#plot_datapoint(idx)

## Investigate duplicates

In [16]:
for dupe_idx in duplicate_weights:
    w = weights[dupe_idx]
    t = tokens[dupe_idx]

    duplicates = unique[tuple(w.flatten().tolist())]

    if not all([np.all(tokens[i] == tokens[dupe_idx]) for i in duplicates]):
        print(f"dupe idx: {dupe_idx}")
        print("Found duplicates with different tokens:")
        for i in duplicates:
            print(" ".join((tokenizer.decode(t) for t in tokens[i])))
        print()
        print()
        print()
        print()

## Check for close duplicates

In [17]:
# from tqdm import tqdm
# 
# for w in tqdm(weights):
#     close = [np.allclose(w, u) for u in unique.values()]