In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys
from collections import defaultdict 
import jax
import flax
import chex
from jaxtyping import ArrayLike
from typing import Union, TypeVar
import numpy as np
import matplotlib.pyplot as plt
import traceback
import jax.numpy as jnp

from tracr.compiler.validating import validate
from tracr.rasp.rasp import Map, SequenceMap, LinearSequenceMap, Select, Aggregate, Comparison, SelectorWidth, indices, tokens 
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler.assemble import AssembledTransformerModel
from tracr.compiler.craft_model_to_transformer import NoTokensError
from tracr.compiler.basis_inference import InvalidValueSetError
from tracr.compiler import rasp_to_graph


from decompile_tracr.dataset import lib
from decompile_tracr.dataset import data_utils
from decompile_tracr.dataset import config
from decompile_tracr.tokenizing import tokenizer
from decompile_tracr.tokenizing import vocab
from decompile_tracr.sampling import sampling
from decompile_tracr.sampling import rasp_utils


rng = np.random.default_rng(0)

In [2]:
VAL_DATA_RATIO = 0.1
MAX_RASP_LENGTH = config.MAX_RASP_LENGTH
MAX_WEIGHTS_LENGTH = config.MAX_WEIGHTS_LENGTH
FULL_DATA_DIR = config.full_dataset_dir
ALL_LAYERS_MULTIPLIER = 15
split_layers = False

In [3]:
data = data_utils.load_dataset_for_model_input(
    rng=rng,
    loaddir=FULL_DATA_DIR,
    max_data=1000,
    shuffle=True,
    d_model=128,
    max_rasp_len=MAX_RASP_LENGTH if split_layers else MAX_RASP_LENGTH * ALL_LAYERS_MULTIPLIER,
    max_weights_len=MAX_WEIGHTS_LENGTH if split_layers else MAX_WEIGHTS_LENGTH * ALL_LAYERS_MULTIPLIER,
    split_layers=split_layers,
)

2024-04-12 19:34:53 - [INFO]: Loading data from /home/lauro/projects/meta-models/decompile-tracr/data/full.




2024-04-12 19:34:53.623084: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW
2024-04-12 19:34:53.623171: E external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:256] kernel version 535.161.7 does not match DSO version 535.171.4 -- cannot find working devices in this configuration
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
print("keys:", list(data.keys()))
print("data shapes:", {k: v.shape for k, v in data.items()})

keys: ['tokens', 'weights', 'n_sops', 'program_id', 'n_layers']
data shapes: {'tokens': (443, 1920), 'weights': (443, 1920, 128), 'n_sops': (443,), 'program_id': (443,), 'n_layers': (443,)}


## Tokens

In [5]:
# check for duplicates among tokens
tokens = data["tokens"]
unique_tokens = defaultdict(list)

for i, token in enumerate(tokens):
    t = tuple(token.tolist())
    unique_tokens[t].append(i)

print(f"Found {len(unique_tokens)}/{len(tokens)} unique tokens "
      f"({100 * len(unique_tokens) / len(tokens):.2f}%)")

Found 443/443 unique tokens (100.00%)


In [6]:
def assert_sops_sorted(toks):
    decoded = tokenizer.decode(toks)
    sops = [x for x in decoded if x.startswith("sop_")]
    first_occurences = []
    for sop in sops:
        if sop not in first_occurences:
            first_occurences.append(sop)
    assert sorted(first_occurences) == first_occurences, first_occurences


for i, toks in enumerate(tokens):
    print(i)
    assert_sops_sorted(toks)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [22]:
tokenizer.decode(tokens[0])

['BOS',
 'EOL',
 'sop_00',
 'numerical',
 'Map',
 'lambda x: x + 0.5',
 'tokens',
 'EOO',
 'sop_01',
 'numerical',
 'Map',
 'lambda x: x + 0.5',
 'tokens',
 'EOO',
 'EOL',
 'EOL',
 'sop_02',
 'numerical',
 'LinearSequenceMap',
 'sop_00',
 'sop_01',
 '-3',
 '-1',
 'EOO',
 'EOL',
 'EOL',
 'sop_03',
 'numerical',
 'LinearSequenceMap',
 'sop_02',
 'sop_00',
 '-1',
 '-3',
 'EOO',
 'EOL',
 'EOS',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD',
 'PAD'

In [8]:
# number of non-padding tokens
print(f"Number of nonzero tokens: {(tokens > 0).sum() / tokens.size * 100:0.1f}%")

Number of nonzero tokens: 1.5%


In [9]:
# distribution of token types

# encodings
cat, num = (tokenizer.encode_token(t) for t in ["categorical", "numerical"])
n_categorical = (tokens == cat).sum()
n_numerical = (tokens == num).sum()
total = n_categorical + n_numerical

print(f"Categorical sops: {100*n_categorical/total:0.1f}%")
print(f"Numerical sops: {100*n_numerical/total:0.1f}%")
print(f"Total sops: {total:,}")

Categorical sops: 65.2%
Numerical sops: 34.8%
Total sops: 1,427


In [10]:
ops = tokenizer.encode(vocab.ops)
op_counts = {vocab.vocab[op]: (tokens == op).sum() for op in ops}
total = sum(op_counts.values())

print("Operation counts:")
for op, count in op_counts.items():
    print(f"{op}: {100*count/total:.1f}%")

print(f"Total SOps: {total:,}")

Operation counts:
Map: 32.2%
SequenceMap: 31.3%
LinearSequenceMap: 4.2%
SelectAggregate: 30.1%
SelectorWidth: 2.2%
Total SOps: 1,427


## Programs

In [11]:
def get_test_inputs_and_outputs(
    programs: list[rasp.SOp],
    n_samples: int = 50,
):
    """Generate test inputs and pass forward through programs to get outputs."""
    test_inputs = [rasp_utils.sample_test_input(rng, max_seq_len=5, 
                                    min_seq_len=5, vocab=set(range(10))) 
                for _ in range(n_samples)]
    outputs = [[p(x) for x in test_inputs] for p in programs]
    outputs = np.array(outputs, dtype=float)
    outputs = np.nan_to_num(outputs, nan=0.0)
    return test_inputs, outputs


def test_low_var(outputs: list):
    """Test that sampled programs have a reasonable amount of variance wrt input"""
    stds = np.std(outputs, axis=1).sum(axis=-1)  # std across test inputs; sum across output sequence
    are_low_var = stds < 0.01
    frac_low_var = sum(are_low_var) / len(stds)
    print(f"{frac_low_var*100}% of programs have low variance in output.")

In [12]:
programs = [tokenizer.detokenize(t) for t in tokens]
inputs, outputs = get_test_inputs_and_outputs(programs)
test_low_var(outputs)


outputs_buffer = outputs.copy()

program_data = []
for i, p in enumerate(programs):
    program_data.append({
        "program": p,
        "outputs": outputs[i],
        "std": np.std(outputs[i], axis=0).sum(),
    })


# sort by std
by_std = sorted(program_data, key=lambda x: np.std(x['outputs']))
len(by_std)
by_std = iter(by_std)

10.15801354401806% of programs have low variance in output.


In [13]:
p = next(by_std)
print("std:", np.std(p['outputs']))
print()
print('input: ', inputs[0])
rasp_utils.print_program(p['program'], test_input=inputs[0], full=True)
print()
print('sample outputs:', p['outputs'][:10])

std: 0.0

input:  [9, 8, 1, 5, 3]
select_19 = Select(tokens, tokens, predicate=Comparison.TRUE)
sop_00_22 = rasp.categorical(Map(lambda x: x - 1, tokens))    # output: [8, 7, 0, 4, 2]
sop_01_21 = rasp.categorical(SequenceMap(lambda x, y: x - y, sop_00_22, tokens))    # output: [-1, -1, -1, -1, -1]
sop_02_20 = rasp.numerical(Map(lambda x: x == 3, sop_01_21))    # output: [False, False, False, False, False]
sop_03_18 = rasp.numerical(Aggregate(select_19, sop_02_20))    # output: [0.0, 0.0, 0.0, 0.0, 0.0]

sample outputs: [[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


## Weights

In [14]:
# check for duplicates among tokens
weights = data["weights"]
unique = defaultdict(list)
duplicate_weights = []

for i, w in enumerate(weights[:300]):
    w = tuple(w.flatten().tolist())
    if w in unique:
        duplicate_weights.append(i)
    
    unique[w].append(i)

print(f"Found {len(unique)}/{len(weights)} unique model params "
      f"({100 * len(unique) / len(weights):.2f}%)")

Found 300/443 unique model params (67.72%)


In [15]:
print(f"percent padding: {100 * (weights == 0.05).sum() / weights.size:0.1f}%")
print(f"percent zero: {100 * (weights == 0).sum() / weights.size:0.1f}%")
print(f"left over: {100 * np.logical_and(weights != 0, weights != 0.05).sum() / weights.size:0.1f}%")

percent padding: 97.6%
percent zero: 2.4%
left over: 0.1%


## Visualize

In [16]:
def get_percentages(idx):
    w = data["weights"][idx]

    print(f"percent padding: {100 * (w == 0.05).sum() / w.size:0.1f}%")
    print(f"percent zero: {100 * (w == 0).sum() / w.size:0.1f}%")
    print(f"left over: {100 * np.logical_and(w != 0, w != 0.05).sum() / w.size:0.1f}%")


def plot_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    w = w.flatten()

    plt.plot(w, ".")
    plt.yscale("symlog", linthresh=0.1)

    print(" ".join(tokenizer.decode(t)))


def imshow_datapoint(idx):
    t = tokens[idx]
    w = data["weights"][idx]
    _, d_model = w.shape
    w = w.flatten()

    is_padding = w == 0.05
    first_padding_idx = is_padding.tolist().index(True)
    idx = first_padding_idx + (d_model - first_padding_idx % d_model)
    reshaped_w = w[:idx].reshape(-1, d_model)
    reshaped_w[reshaped_w == 0] = np.nan
    plt.imshow(reshaped_w, aspect="auto", interpolation="nearest")


get_percentages(0)

percent padding: 99.6%
percent zero: 0.4%
left over: 0.0%


In [17]:
idx = 4
#plot_datapoint(idx)

## Investigate duplicates

In [18]:
for dupe_idx in duplicate_weights:
    w = weights[dupe_idx]
    t = tokens[dupe_idx]

    duplicates = unique[tuple(w.flatten().tolist())]

    if not all([np.all(tokens[i] == tokens[dupe_idx]) for i in duplicates]):
        print(f"dupe idx: {dupe_idx}")
        print("Found duplicates with different tokens:")
        for i in duplicates:
            print(" ".join(tokenizer.decode(tokens[i])))
        print()
        print()
        print()
        print()

## Check for close duplicates

In [19]:
# from tqdm import tqdm
# 
# for w in tqdm(weights):
#     close = [np.allclose(w, u) for u in unique.values()]