# Compiling & Visualizing Tracr Models

This notebook demonstrates how to compile a tracr model and provides some tools visualize the model's residual stream or layer outputs for a given input sequence.

In [None]:
!git clone https://github.com/jpsank/tracr
!git -C tracr pull
!pip install ./tracr
# !pip install librosa

In [None]:
#@title Imports
import jax
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile

# The default of float16 can lead to discrepancies between outputs of
# the compiled model and the RASP program.
jax.config.update('jax_default_matmul_precision', 'float32')

from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp

In [None]:
# import librosa

# # Function to preprocess the audio file by getting decibel levels every 10ms
# def preprocess_audio(file_path):
#     # Load the audio file
#     y, sr = librosa.load(file_path, sr=None)
#     # Get amplitude average for every 10ms
#     amp = np.sqrt(np.mean(y.reshape(-1, sr//100)**2, axis=1))
#     # Get decibel levels
#     out = []
#     state = 0
#     current_sum = 0
#     current_len = 0
#     for i in range(0, len(amp)):
#         if state == 1:  # sound is on
#             if amp[i] < 1e-5:
#                 state = 0
#                 out.append(current_sum/current_len)
#             else:
#                 decibel = abs(20*np.log10(amp[i]))
#                 current_sum += decibel
#                 current_len += 1
#         elif state == 0:  # silence
#             if amp[i] > 1e-5:
#                 out.append(0)
#                 state = 1
#     return out


In [None]:
#@title Plotting functions
def tidy_label(label, value_width=5):
  if ':' in label:
    label, value = label.split(':')
  else:
    value = ''
  return label + f":{value:>{value_width}}"


def add_residual_ticks(model, value_width=5, x=False, y=True):
  if y:
    plt.yticks(
            np.arange(len(model.residual_labels))+0.5, 
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels], 
            family='monospace',
            fontsize=20,
    )
  if x:
    plt.xticks(
            np.arange(len(model.residual_labels))+0.5, 
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels], 
            family='monospace',
            rotation=90,
            fontsize=20,
    )


def plot_computation_trace(model,
                           input_labels,
                           residuals_or_outputs,
                           add_input_layer=False,
                           figsize=(12, 9)):
  fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)
  value_width = max(map(len, map(str, input_labels))) + 1

  for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):
    plt.sca(ax)
    plt.pcolormesh(layer[0].T, vmin=0, vmax=1)
    if i == 0:
      add_residual_ticks(model, value_width=value_width)
    plt.xticks(
        np.arange(len(input_labels))+0.5,
        input_labels,
        rotation=90,
        fontsize=20,
    )
    if add_input_layer and i == 0:
      title = 'Input'
    else:
      layer_no = i - 1 if add_input_layer else i
      layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'
      title = f'{layer_type} {layer_no // 2 + 1}'
    plt.title(title, fontsize=20)


def plot_residuals_and_input(model, inputs, figsize=(12, 9)):
  """Applies model to inputs, and plots the residual stream at each layer."""
  model_out = model.apply(inputs)
  residuals = np.concatenate([model_out.input_embeddings[None, ...],
                              model_out.residuals], axis=0)
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=residuals,
      add_input_layer=True,
      figsize=figsize)


def plot_layer_outputs(model, inputs, figsize=(12, 9)):
  """Applies model to inputs, and plots the outputs of each layer."""
  model_out = model.apply(inputs)
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=model_out.layer_outputs,
      add_input_layer=False,
      figsize=figsize)


In [None]:
# def make_apwm() -> rasp.SOp:
#   """Make the auditory parametric working memory task.
#   Input is a raw audio waveform.
#   Step 1: Detect stimuli sequences a and b in the input, separated by silence.
#   Step 2: Calculate the decibel of the stimuli.
#   Step 3: Compare the decibel of the stimuli.
#   Returns 1 if the first tone is louder than the second tone, 0 otherwise.
#   """

#   # Get tone A (0.25 - 0.65s)
#   sample_rate = 44100  # hard-coded for now
#   tone_a_selector = rasp.Select(rasp.indices, rasp.indices, lambda x: 0.25 * sample_rate < x < 0.65 * sample_rate)
  
#   # Get tone B (-0.85 - -0.45s) starting from the end
#   # This requires us to reverse the input sequence
#   reversed_sop = lib.make_reverse(rasp.indices)
#   tone_b_selector = rasp.Select(reversed_sop, reversed_sop, lambda x: 0.45 * sample_rate < x < 0.85 * sample_rate)

#   # Compute root mean square for each tone
#   squared = rasp.numerical(rasp.tokens * rasp.tokens)
#   tone_a_sum = rasp.Aggregate(tone_a_selector, squared, default=0)
#   tone_b_sum = rasp.Aggregate(tone_b_selector, squared, default=0)

#   # Compare the decibel of the stimuli
#   return rasp.numerical(tone_a_sum > tone_b_sum).named("apwm")

def make_clicks() -> rasp.SOp:
    """Make the pulse clicks task.
    Input is a sequence of zeros, ones, and twos, where zeros represent silence,
    ones represent a left click, and twos represent a right click.
    Step 1: Detect the clicks in the input.
    Step 2: Count the number of left and right clicks.
    Step 3: Compare the number of left and right clicks.
    Returns 1 if there are more left clicks than right clicks, 0 otherwise.
    """
    # Detect the clicks in the input
    left_clicks = rasp.Select(rasp.indices, rasp.tokens, lambda x: x == 1)
    right_clicks = rasp.Select(rasp.indices, rasp.tokens, lambda x: x == 2)

    # Count the number of left and right clicks
    left_count = rasp.Aggregate(left_clicks, rasp.numerical(1), default=0)
    right_count = rasp.Aggregate(right_clicks, rasp.numerical(1), default=0)

    # Compare the number of left and right clicks
    return rasp.numerical(left_count > right_count).named("clicks")

In [None]:
#@title Define RASP programs
def get_program(program_name, max_seq_len):
  """Returns RASP program and corresponding token vocabulary."""
  if program_name == "length":
    vocab = {"a", "b", "c", "d"}
    program = lib.make_length()
  elif program_name == "frac_prevs":
    vocab = {"a", "b", "c", "x"}
    program = lib.make_frac_prevs((rasp.tokens == "x").named("is_x"))
  elif program_name == "dyck-2":
    vocab = {"(", ")", "{", "}"}
    program = lib.make_shuffle_dyck(pairs=["()", "{}"])
  elif program_name == "dyck-3":
    vocab = {"(", ")", "{", "}", "[", "]"}
    program = lib.make_shuffle_dyck(pairs=["()", "{}", "[]"])
  elif program_name == "sort":
    vocab = {1, 2, 3, 4, 5}
    program = lib.make_sort(
        rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)
  elif program_name == "sort_unique":
    vocab = {1, 2, 3, 4, 5}
    program = lib.make_sort_unique(rasp.tokens, rasp.tokens)
  elif program_name == "hist":
    vocab = {"a", "b", "c", "d"}
    program = lib.make_hist()
  elif program_name == "sort_freq":
    vocab = {"a", "b", "c", "d"}
    program = lib.make_sort_freq(max_seq_len=max_seq_len)
  elif program_name == "pair_balance":
    vocab = {"(", ")"}
    program = lib.make_pair_balance(
        sop=rasp.tokens, open_token="(", close_token=")")
  elif program_name == "clicks":
    vocab = {0, 1, 2}
    program = make_clicks()
  else:
    raise NotImplementedError(f"Program {program_name} not implemented.")
  return program, vocab

In [None]:
#@title: Assemble model
program_name = "clicks"  #@param ["length", "frac_prevs", "dyck-2", "dyck-3", "sort", "sort_unique", "hist", "sort_freq", "pair_balance", "apwm"]
max_seq_len = 100  #@param {label: "Test", type: "integer"}

program, vocab = get_program(program_name=program_name,
                             max_seq_len=max_seq_len)

print(f"Compiling...")
print(f"   Program: {program_name}")
print(f"   Input vocabulary: {vocab}")
print(f"   Context size: {max_seq_len}")

assembled_model = compiling.compile_rasp_to_model(
      program=program,
      vocab=vocab,
      max_seq_len=max_seq_len,
      causal=False,
      compiler_bos="bos",
      compiler_pad="pad",
      mlp_exactness=100)

print("Done.")

In [None]:
# LOAD SAMPLE
# sample_rate, waveform = wavfile.read("sample.wav")
def make_sample():
    # Make sequence of silence, left clicks, and right clicks (0, 1, 2)
    sample = np.random.choice([0, 1, 2], size=np.random.randint(10, 100))
    return sample

sample = make_sample()
# Print whether there are more left clicks than right clicks
print(f"Sample: {sample}")
print(f"Left clicks: {np.sum(sample == 1)}")
print(f"Right clicks: {np.sum(sample == 2)}")

input_tokens = ["bos"] + list(sample)

In [None]:
#@title Forward pass
assembled_model.apply(input_tokens).decoded

In [None]:
#@title Plot residual stream
plot_residuals_and_input(
  model=assembled_model,
  inputs=input_tokens,
  figsize=(10, 9)
)

In [None]:
#@title Plot layer outputs
plot_layer_outputs(
  model=assembled_model,
  inputs = input_tokens,
  figsize=(8, 9)
)