In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from pathlib import Path
import time
import sys
from collections import defaultdict 
from typing import Union, TypeVar
import h5py
import traceback

import pandas as pd
import tqdm
import jax
import flax
import chex
from jaxtyping import ArrayLike
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import haiku as hk
import flax.linen as nn
import optax
import einops

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import craft_graph_to_model
from tracr.compiler import rasp_to_graph
from tracr.compiler import lib as tracr_lib
from tracr.compiler import assemble
from tracr.transformer import model
from tracr.transformer.model import CompiledTransformerModel
from tracr.transformer.encoder import CategoricalEncoder
from tracr.compiler.assemble import AssembledTransformerModel

from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import dataloading
from decompile_tracr.dataset import config
from decompile_tracr.dataset import compile as comp
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils
from decompile_tracr.sampling.map_primitives import FunctionWithRepr
from decompile_tracr.training import autoencoder
from decompile_tracr.training import transformer
from decompile_tracr.training.transformer import Residuals
from decompile_tracr.training.metrics import Accuracy, Embed, Unembed, Decode

from metamodels_for_rasp.train import Updater, TrainState


def _compile(program):
    return compiling.compile_rasp_to_model(
        program,
        vocab=set(range(5)),
        max_seq_len=5,
    )


rng = np.random.default_rng(0)
key = jax.random.key(0)

PLOT = True

In [None]:
# for _ in range(1000):
#     try:
#         program_toks = tokenizer.tokenize(sampling.sample(
#             rng, program_length=5, only_categorical=True))
#         assembled_model = comp.compile_tokens_to_model(program_toks)
#         d_model = assembled_model.params['token_embed']['embeddings'].shape[-1]
#         ds.append(d_model)
#     except:
#         continue
# ds = np.array(ds)
# 
# plt.hist(ds)
# 
# d_model = 25
# print(np.mean(ds), np.std(ds))
# print("Frac too small: ", np.mean(ds < d_model))
# print("Frac too large: ", np.mean(ds > 1.5*d_model))

In [None]:

program_toks = tokenizer.tokenize(sampling.sample(
    rng, program_length=5, only_categorical=True))
assembled_model = comp.compile_tokens_to_model(program_toks)
d_model = assembled_model.params['token_embed']['embeddings'].shape[-1]
print(d_model)

residuals_sampler = autoencoder.ResidualsSampler(
    model=assembled_model,
    seq_len=6,
    batch_size=2**12,
    flatten_leading_axes=False,
)

embed = Embed(assembled_model=assembled_model)
unembed = Unembed(assembled_model=assembled_model)
accuracy = Accuracy(assembled_model=assembled_model)
decode = Decode(assembled_model=assembled_model)

In [None]:
t = time.time()
key, subkey = jax.random.split(key)
ae_state, ae_log, ae_model = autoencoder.train_autoencoder(
    subkey, assembled_model, nsteps=50_000, lr=1e-3, hidden_size=50)
print(f'training autoencoder took {time.time() - t:.2f}s')

In [None]:
if PLOT:
    plt.plot([x['train/loss'] for x in ae_log])
    plt.yscale('log')
    plt.xscale('log')
    print('final loss:', ae_log[-1]['train/loss'])

In [None]:
# compare to original
key, subkey = jax.random.split(key)
test_data = residuals_sampler.sample_residuals(subkey)

#x = np.array([assembled_model.input_encoder.bos_encoding] + inputs)
#x = np.expand_dims(x, 0)
#assembled_out = assembled_model.apply(['compiler_bos'] + inputs)

original = np.squeeze(np.array(test_data.residuals))
decoded = ae_model.apply({'params': ae_state.params}, original)
decoded = np.round(decoded, 0) # round to nearest integer
decoded = np.array(decoded, dtype=int)

In [None]:
ae_acc = accuracy(original[:, -1], decoded[:, -1])
print(ae_acc)

In [None]:
print(ae_model.hidden_size)
print(ae_model.output_size)

In [None]:
def get_range(*arrays):
    all = np.concatenate(arrays)
    return all.min(), all.max()

In [None]:
if PLOT:
    x, y = original[0], decoded[0]  # take first example
    _min, _max = get_range(x, y)

    fig, axs = plt.subplots(len(x), 2, figsize=[10, 10])

    axs[0, 0].set_title('Original')
    for i, res in enumerate(x):
        im = axs[i, 0].imshow(res, vmin=_min, vmax=_max)
        axs[i, 0].set_ylabel(f'Layer {i}')

    axs[0, 1].set_title('Decoded')
    for i, res in enumerate(y):
        im = axs[i, 1].imshow(res, vmin=_min, vmax=_max)

    for ax in axs.flatten():
        ax.set_xticks(np.arange(x.shape[-1], step=2))
        ax.set_xticklabels(np.arange(x.shape[-1], step=2))

    fig.colorbar(im, ax=axs, orientation='horizontal')

In [None]:
if PLOT:
    plt.imshow(x[-1] - y[-1])
    plt.colorbar(orientation='horizontal')

In [None]:
ae_acc = accuracy(original[0, -1], decoded[0, -1])
print(ae_acc)

In [None]:
test_data.inputs[0]

In [None]:
original[0, -1]

In [None]:
out = assembled_model.apply(
    ["compiler_bos"] + rng.integers(0, 5, size=(5,)).tolist())
print(out.decoded)


In [None]:
np.all(unembed(original[:, -1]) == 4)

In [None]:
original[0, -1] - decoded[0, -1]

In [None]:
x = original[0, -1]
x.shape

In [None]:
unembed(np.round(x, 0))

In [None]:
unembed(x + rng.normal(size=x.shape) * 1e-5)

In [None]:
unembed(decoded[0][-1])

In [None]:
def encode(x):
    return ae_model.apply({'params': ae_state.params}, x, method=ae_model.encode)


def decode(x):
    return ae_model.apply({'params': ae_state.params}, x, method=ae_model.decode)

In [None]:
enc = encode(original[0, -1]).astype(np.float32)

In [None]:
plt.imshow(enc)
plt.colorbar(orientation='horizontal')

# Per layer training (Ignore)

In [None]:
raise

In [None]:
2**15

In [None]:
get_batch = transformer.DataGenerator(
    assembled_model=assembled_model,
    encode=encode,
    batch_size=2**13,
    seq_len=6,
)

model, state, log = transformer.train_transformer(
    subkey, 
    get_batch=get_batch, 
    args=transformer.TransformerTrainingArgs(
        nsteps=50_000,
        learning_rate=1e-3,
    ),
)

In [None]:
if PLOT:
    for k, v in log[0].items():
        label = k[11:]
        plt.plot([x[k] for x in log], label=label)
        print(f'Final loss at {label}:', log[-1][k])
    plt.yscale('log')
    plt.xscale('log')
    plt.legend()

In [None]:
_, acts = model.apply({'params': state.params}, test_data.inputs)
acts = dict(acts)
acts.keys()

In [None]:
acts = dict(acts)
tres = []
for k in transformer.layer_names():
    if k not in acts:
        break
    tres.append(acts[k])
tres = einops.rearrange(tres, 'l b h w -> b l h w')
tres.shape

In [None]:
#original = np.squeeze(x.residuals).astype(np.float32)
decoded = decode(tres).astype(np.float32)

In [None]:
original.shape

In [None]:
if PLOT:
    x, y = original[0], decoded[0]  # take first example
    _min, _max = get_range(x, y)

    fig, axs = plt.subplots(len(x), 2, figsize=[10, 10])

    axs[0, 0].set_title('Original')
    for i, res in enumerate(x):
        im = axs[i, 0].imshow(res, vmin=_min, vmax=_max)
        axs[i, 0].set_ylabel(f'Layer {i}')

    axs[0, 1].set_title('Decoded')
    for i, res in enumerate(y):
        im = axs[i, 1].imshow(res, vmin=_min, vmax=_max)

    for ax in axs.flatten():
        ax.set_xticks(np.arange(x.shape[-1], step=2))
        ax.set_xticklabels(np.arange(x.shape[-1], step=2))

    fig.colorbar(im, ax=axs, orientation='horizontal')

In [None]:
if PLOT:
    plt.imshow(x[-1] - y[-1])
    plt.colorbar(orientation='horizontal')

In [None]:
print('Autoencoder acc: ', ae_acc)
print('Transformer acc: ', accuracy(original[:, -1], decoded[:, -1]))