In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from pathlib import Path
import time
import sys
from collections import defaultdict 
from typing import Union, TypeVar
import h5py
import traceback

import pandas as pd
import tqdm
import jax
import flax
import chex
from jaxtyping import ArrayLike
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import haiku as hk
import flax.linen as nn
import optax
import einops

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import craft_graph_to_model
from tracr.compiler import rasp_to_graph
from tracr.compiler import lib as tracr_lib
from tracr.compiler import assemble
from tracr.transformer import model
from tracr.transformer.model import CompiledTransformerModel
from tracr.transformer.encoder import CategoricalEncoder
from tracr.compiler.assemble import AssembledTransformerModel

from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import dataloading
from decompile_tracr.dataset import config
from decompile_tracr.dataset import compile as comp
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils
from decompile_tracr.sampling.map_primitives import FunctionWithRepr
from decompile_tracr.training.autoencoder import get_residuals_sampler
from decompile_tracr.training import autoencoder
from decompile_tracr.training import transformer
from decompile_tracr.training.transformer import Residuals

from metamodels_for_rasp.train import Updater, TrainState


def _compile(program):
    return compiling.compile_rasp_to_model(
        program,
        vocab=set(range(5)),
        max_seq_len=5,
    )


rng = np.random.default_rng(0)
key = jax.random.key(0)

PLOT = False

2024-05-13 15:30:15.208262: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
program_toks = tokenizer.tokenize(sampling.sample(
    rng, program_length=5, only_categorical=True))
assembled_model = comp.compile_tokens_to_model(program_toks)

m = assembled_model

def sample_tokens(key: jax.random.PRNGKey):
    """Utility function to sample 
    a random sequence of input tokens.
    """
    batch_size = 1
    seq_len = 6
    bos: int = assembled_model.input_encoder.bos_encoding
    inputs = jax.random.randint(key, (batch_size, seq_len-1), 0, 5)
    inputs = jnp.concatenate(
        [bos * jnp.ones((batch_size, 1), dtype=int), inputs], axis=1)
    return inputs


@hk.without_apply_rng
@hk.transform
def _embed(tokens):
    compiled_model = assembled_model.get_compiled_model()
    return compiled_model.embed(tokens)


def embed(tokens: jnp.ndarray):
    return _embed.apply(assembled_model.params, tokens)


@hk.without_apply_rng
@hk.transform
def _unembed(x):
    cm = assembled_model.get_compiled_model()
    return cm.unembed(x, use_unembed_argmax=cm.use_unembed_argmax)

def unembed(x):
    return _unembed.apply(assembled_model.params, x)

In [3]:
t = time.time()
key, subkey = jax.random.split(key)
ae_state, ae_log, ae_model = autoencoder.train_autoencoder(
    subkey, assembled_model, nsteps=10000)
print(f'training autoencoder took {time.time() - t:.2f}s')

training autoencoder took 8.81s


In [4]:
if PLOT:
    plt.plot([x['train/loss'] for x in ae_log])
    plt.yscale('log')
    plt.xscale('log')
    print('final loss:', ae_log[-1]['train/loss'])

In [5]:
# compare to original
inputs = [0, 0, 4, 1, 2]
x = np.array([assembled_model.input_encoder.bos_encoding] + inputs)
x = np.expand_dims(x, 0)
assembled_out = assembled_model.apply(['compiler_bos'] + inputs)

original = np.concatenate(assembled_out.residuals)
decoded = ae_model.apply({'params': ae_state.params}, original)
decoded = np.array(decoded, dtype=np.float32)

In [6]:
def get_range(*arrays):
    all = np.concatenate(arrays)
    return all.min(), all.max()

In [8]:
if PLOT:
    _min, _max = get_range(original, decoded)

    fig, axs = plt.subplots(len(original), 2, figsize=[10, 10])

    axs[0, 0].set_title('Original')
    for i, res in enumerate(original):
        im = axs[i, 0].imshow(res, vmin=_min, vmax=_max)
        axs[i, 0].set_ylabel(f'Layer {i}')

    axs[0, 1].set_title('Decoded')
    for i, res in enumerate(decoded):
        im = axs[i, 1].imshow(res, vmin=_min, vmax=_max)

    for x in axs.flatten():
        x.set_xticks(np.arange(original.shape[-1], step=2))
        x.set_xticklabels(np.arange(original.shape[-1], step=2))

    fig.colorbar(im, ax=axs, orientation='horizontal')

In [9]:
if PLOT:
    plt.imshow(original - decoded)
    plt.colorbar(orientation='horizontal')

In [10]:
def decode_layer(x: ArrayLike):
    unembedded = np.squeeze(unembed(x))
    unembedded = unembedded.tolist()
    print("unembedded:", unembedded)
    tokens = assembled_model.output_encoder.decode(unembedded)
    return ['compiler_bos'] + tokens[1:]

In [11]:
assembled_out.decoded

['compiler_bos', 3, 3, 2, 2, 2]

In [13]:
original.shape

(8, 6, 37)

In [14]:
# print(decode_layer(original))
# print(decode_layer(decoded))

In [15]:
# layer_idx = 0
# decoded_residuals = ae_model.apply(
#     {'params': ae_state.params}, assembled_out.residuals[layer_idx])
# 
# print(decode_layer(assembled_out.residuals[layer_idx]))
# print(decode_layer(decoded_residuals))

In [16]:
# x = np.squeeze(decoded_residuals).astype(np.float32)
# y = np.squeeze(assembled_out.residuals[layer_idx])

# fig, axs = plt.subplots(2, 1)
# axs[0].imshow(x)
# axs[1].imshow(y)

# Per layer training

In [17]:
def encode(x):
    return ae_model.apply({'params': ae_state.params}, x, method=ae_model.encode)

In [18]:
get_batch = transformer.DataGenerator(
    assembled_model=assembled_model,
    encode=encode,
    batch_size=32,
    seq_len=6,
)

model, state, log = transformer.train_transformer(
    subkey, 
    get_batch=get_batch, 
    args=transformer.TransformerTrainingArgs(
        nsteps=1000,
        learning_rate=1e-3,
    ),
)

In [19]:
if PLOT:
    plt.plot([x['train/loss'] for x in log])
    plt.yscale('log')
    plt.xscale('log')
    print('final loss:', log[-1]['train/loss'])

## Metrics

In [20]:
m = assembled_model

def sample_tokens(key: jax.random.PRNGKey):
    """Utility function to sample 
    a random sequence of input tokens.
    """
    batch_size = 1
    seq_len = 6
    bos: int = assembled_model.input_encoder.bos_encoding
    inputs = jax.random.randint(key, (batch_size, seq_len-1), 0, 5)
    inputs = jnp.concatenate(
        [bos * jnp.ones((batch_size, 1), dtype=int), inputs], axis=1)
    return inputs


@hk.without_apply_rng
@hk.transform
def _embed(tokens):
    compiled_model = assembled_model.get_compiled_model()
    return compiled_model.embed(tokens)


def embed(tokens: jnp.ndarray):
    return _embed.apply(assembled_model.params, tokens)


@hk.without_apply_rng
@hk.transform
def _unembed(x):
    cm = assembled_model.get_compiled_model()
    return cm.unembed(x, use_unembed_argmax=cm.use_unembed_argmax)

def unembed(x):
    return _unembed.apply(assembled_model.params, x)

In [21]:
def accuracy(x, y):
    x, y = unembed(x), unembed(y)
    return jnp.mean(x == y)

In [22]:
get_residuals = get_residuals_sampler(
    assembled_model,
    seq_len=6,
    batch_size=1,
    flatten_leading_axes=False,
)

In [23]:
key, subkey = jax.random.split(key)
res = get_residuals(subkey)
decoded = ae_model.apply({'params': ae_state.params}, res.residuals)

In [24]:
assert isinstance(m.output_encoder, CategoricalEncoder)

In [25]:
unembed(decoded)

Array([[[2, 1, 2, 5, 2, 4],
        [2, 1, 2, 5, 2, 4],
        [2, 3, 3, 3, 3, 3],
        [2, 3, 3, 3, 3, 3],
        [2, 1, 2, 1, 2, 1],
        [2, 1, 2, 1, 2, 1],
        [2, 1, 2, 2, 2, 1],
        [2, 1, 2, 2, 2, 1],
        [2, 1, 2, 2, 2, 2]]], dtype=int32)

In [26]:
unembed(res.residuals)

Array([[[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 1, 2, 2, 2, 2]]], dtype=int32)

In [28]:
accuracy(res.residuals, decoded)

Array(0.0925926, dtype=float32)