In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys
from collections import defaultdict 
import jax
import flax
import chex
from jaxtyping import ArrayLike
from typing import Union, TypeVar
import numpy as np
import matplotlib.pyplot as plt
import traceback
import jax.numpy as jnp

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import rasp_to_graph


from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import config
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils


rng = np.random.default_rng(0)

In [2]:
VAL_DATA_RATIO = 0.1
MAX_RASP_LENGTH = config.MAX_RASP_LENGTH
MAX_WEIGHTS_LENGTH = config.MAX_WEIGHTS_LENGTH
FULL_DATA_DIR = config.full_dataset_dir
ALL_LAYERS_MULTIPLIER = 15
split_layers = False

In [3]:
data = data_utils.load_and_process_data(
    rng=rng,
    loaddir=FULL_DATA_DIR,
    max_data=5000,
    shuffle=True,
    name="train",
    d_model=128,
    max_rasp_len=MAX_RASP_LENGTH if split_layers else MAX_RASP_LENGTH * ALL_LAYERS_MULTIPLIER,
    max_weights_len=MAX_WEIGHTS_LENGTH if split_layers else MAX_WEIGHTS_LENGTH * ALL_LAYERS_MULTIPLIER,
    split_layers=split_layers,
)

2024-03-25 15:06:12 - [INFO]: Loading data from /home/lauro/projects/meta-models/decompile-tracr/data/full.


2024-03-25 15:06:15 - [INFO]: load_batches: Loaded 5037 >= 5000 datapoints. Stopping and truncating to 5000.


2024-03-25 15:06:17.132942: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [None]:
print("keys:", data.keys())
print("data shapes:", {k: v.shape for k, v in data.items()})

keys: dict_keys(['tokens', 'weights', 'program_id', 'n_sops', 'n_layers'])
data shapes: {'tokens': (96, 480), 'weights': (96, 960, 128), 'program_id': (96,), 'n_sops': (96,), 'n_layers': (96,)}


## Tokens

In [None]:
# check for duplicates among tokens
tokens = data["tokens"]
unique_tokens = defaultdict(list)

for i, token in enumerate(tokens):
    t = tuple(token.tolist())
    unique_tokens[t].append(i)

print(f"Found {len(unique_tokens)}/{len(tokens)} unique tokens "
      f"({100 * len(unique_tokens) / len(tokens):.2f}%)")

Found 96/96 unique tokens (100.00%)


In [None]:
# number of non-padding tokens
print(f"Number of nonzero tokens: {(tokens > 0).sum() / tokens.size * 100:0.1f}%")

Number of nonzero tokens: 6.3%


In [None]:
# distribution of token types

# encodings
cat, num = tokenizer.encode(["categorical", "numerical"])
n_categorical = (tokens == cat).sum()
n_numerical = (tokens == num).sum()
total = n_categorical + n_numerical

print(f"Categorical sops: {100*n_categorical/total:0.1f}%")
print(f"Numerical sops: {100*n_numerical/total:0.1f}%")
print(f"Total sops: {total:,}")

Categorical sops: 84.3%
Numerical sops: 15.7%
Total sops: 306


In [None]:
ops = tokenizer.encode(vocab.ops)
op_counts = {vocab.vocab[op]: (tokens == op).sum() for op in ops}
total = sum(op_counts.values())

print("Operation counts:")
for op, count in op_counts.items():
    print(f"{op}: {100*count/total:.1f}%")

print(f"Total SOps: {total:,}")

Operation counts:
Map: 14.7%
SequenceMap: 30.4%
LinearSequenceMap: 1.0%
SelectAggregate: 18.3%
SelectorWidth: 35.6%
Total SOps: 306


## Weights

In [None]:
# check for duplicates among tokens
weights = data["weights"]
unique = defaultdict(list)
duplicate_weights = []

for i, w in enumerate(weights):
    w = tuple(w.flatten().tolist())
    if w in unique:
        duplicate_weights.append(i)
    
    unique[w].append(i)

print(f"Found {len(unique)}/{len(weights)} unique model params "
      f"({100 * len(unique) / len(weights):.2f}%)")

Found 96/96 unique model params (100.00%)


In [None]:
print(f"percent padding: {100 * (weights == 0.05).sum() / weights.size:0.1f}%")
print(f"percent zero: {100 * (weights == 0).sum() / weights.size:0.1f}%")
print(f"left over: {100 * np.logical_and(weights != 0, weights != 0.05).sum() / weights.size:0.1f}%")

percent padding: 93.7%
percent zero: 6.1%
left over: 0.2%


## Visualize

In [None]:
def get_percentages(idx):
    w = data["weights"][idx]

    print(f"percent padding: {100 * (w == 0.05).sum() / w.size:0.1f}%")
    print(f"percent zero: {100 * (w == 0).sum() / w.size:0.1f}%")
    print(f"left over: {100 * np.logical_and(w != 0, w != 0.05).sum() / w.size:0.1f}%")


def plot_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    w = w.flatten()

    plt.plot(w, ".")
    plt.yscale("symlog", linthresh=0.1)

    print(" ".join(tokenizer.decode(t)))


def imshow_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    _, d_model = w.shape
    w = w.flatten()

    is_padding = w == 0.05
    first_padding_idx = is_padding.tolist().index(True)
    idx = first_padding_idx + (d_model - first_padding_idx % d_model)
    reshaped_w = w[:idx].reshape(-1, d_model)
    reshaped_w[reshaped_w == 0] = np.nan
    plt.imshow(reshaped_w, aspect="auto", interpolation="nearest")


get_percentages(0)

percent padding: 93.1%
percent zero: 6.7%
left over: 0.2%


In [None]:
idx = 4
#plot_datapoint(idx)

## Investigate duplicates

In [None]:
for dupe_idx in duplicate_weights:
    w = weights[dupe_idx]
    t = tokens[dupe_idx]

    duplicates = unique[tuple(w.flatten().tolist())]

    if not all([np.all(tokens[i] == tokens[dupe_idx]) for i in duplicates]):
        print(f"dupe idx: {dupe_idx}")
        print("Found duplicates with different tokens:")
        for i in duplicates:
            print(" ".join(tokenizer.decode(tokens[i])))
        print()
        print()
        print()
        print()

## Check for close duplicates

In [None]:
from tqdm import tqdm

for w in tqdm(weights):
    close = [np.allclose(w, u) for u in unique.values()]

100%|██████████| 96/96 [00:04<00:00, 20.13it/s]
