In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import functools
import numpy as np
import optax
import tqdm
import tree
import pandas as pd
import matplotlib.pyplot as plt

from absl import logging
from typing import Any

from data import utm_data_generator as utm_dg_lib
from data import chomsky_data_generator as chomsky_sampler_lib
from helpers import make_chomsky_generator, utm_data_generator, make_model, init_params, save_params, evaluate_transformer_decoder, CHOMSKY_ALPHABET_SIZE

In [2]:
seed = 1

# Follows the paper's parameters
TRAINING_STEPS = 100000
EXECUTION_STEPS = 1000
USE_DELIMITERS = True
MEMORY_SIZE = 200
BATCH_SIZE = 32

# Use Markov = [True, False]
USE_MARKOV = True
# Model Size = ["small", "medium", "large"]
MODEL_SIZE = "large"

In [3]:
def _make_loss_fn(model: hk.Transformed) -> Any:
  """Returns the loss function for update_parameters."""

  def loss_fn(
      params: hk.Params,
      sequences: jax.Array,
      mask: jax.Array,
  ) -> jnp.float32:
    """Returns the loss for the model and the last state.

    Args:
      params: The parameters of the model, usually a neural network.
      sequences: The input of sequences to evaluate. See neural_predictors.py.
      mask: A binary array, True (1's) denote where to skip computing the loss.
    """
    # This code computes the loss for a transformer decoder model:
    # 1. Apply the model to get log probabilities (conditionals) for each token
    conditionals = model.apply(
        params=params,
        targets=sequences,
        rng=None,
    )
    # 2. Extract the log probabilities of the actual tokens that appeared in the sequence
    # by using take_along_axis to select the probability corresponding to each token
    true_conditionals = jnp.take_along_axis(
        conditionals, sequences[..., None], axis=-1
    )[..., 0]
    # 3. Apply the mask to zero out log probabilities where we should skip computing loss (e.g., for padding tokens)
    true_conditionals = jnp.where(mask, 0.0, true_conditionals)
    # 4. Sum the log probabilities across the sequence dimension to get log likelihood per batch
    marginals = jnp.sum(true_conditionals, axis=1)  # Shape (B,).
    # 5. Return the negative mean log likelihood as the loss (for minimization)
    return -jnp.mean(marginals)

  return loss_fn


@functools.partial(
    jax.jit, static_argnames=('optimizer', 'grad_fn', 'normalize_gradients')
)
def _update_parameters(
    params: hk.Params,
    opt_state: optax.OptState,
    sequences: jax.Array,
    mask: jax.Array,
    grad_fn: Any,
    optimizer: optax.GradientTransformation,
    normalize_gradients: bool = True,
) -> tuple[hk.Params, optax.OptState, dict[str, Any]]:
  """Returns updated params and extra logs (like loss, last state etc).

  Backpropagation is done on the whole sequence. The whole function is jitted.

  Args:
    params: The current parameters of the network.
    opt_state: The optimizer state.
    sequences: The input of sequences to evaluate. See base_predictor.py.
    mask: A binary array, True (1's) denote where to skip computing the loss.
    grad_fn: A gradient function, which takes some parameters, a random seed,
      the data to compute the gradient on, and an initial state for the
      predictor. It returns the gradient of the parameters for this batch of
      data, and extra values.
    optimizer: An optax optimizer.
    normalize_gradients: Whether to divide the gradients by the length of the
      sequences, or keep them as is. Using this option guarantees to have the
      same scale across various sequence lengths, and therefore tasks.
  """
  loss, grad = grad_fn(params, sequences, mask)
  if normalize_gradients:
    length_sequence = float(sequences.shape[1])
    grad = tree.map_structure(lambda x: x / length_sequence, grad)
  updates, new_opt_state = optimizer.update(grad, opt_state)
  new_params = optax.apply_updates(params, updates)

  log_dict = {
      'loss': loss,
      'grad_norm_unclipped': optax.global_norm(grad),
  }

  return new_params, new_opt_state, log_dict

In [4]:
def train_transformer_decoder(
    data_generator: utm_dg_lib.UTMDataGenerator,
    training_steps: int,
    log_every: int,
    batch_size: int,
    use_tqdm: bool = True,
    with_markov: bool = False,
    size: str = "large",
    eval_data_generator: chomsky_sampler_lib.ChomskyDataGenerator = None,
) -> tuple[hk.Params, float, list[float], list[float], list[float]]:
    """Trains a neural network on some synthetic data.

    We train a decoder-only transformer on batches, minimizing the log-loss
    objective. The exact architecture can be modified using the TransformerConfig
    object (defined in models/transformer.py)

    Args:
      data_generator: Used to generate batches of data to train on.
      training_steps: Number of batches to train on.
      log_every: How often to log the loss. If negative or 0, no log at all.
      batch_size: The number of sequences in a batch.
      use_tqdm: Whether to use a progress bar or not.

    Returns:
      The final loss, and final parameters.
    """
    print("Vocab Size:", data_generator.feature_size)
    print("Model Size:", size)
    print("Batch Size:", batch_size)
    print("With Markov:", with_markov)
    model = make_model(data_generator, size)

    params = init_params(model, data_generator, batch_size)

    # Make gradient function.
    loss_fn = _make_loss_fn(model)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=False)

    # Make optimizer, to apply the gradients.
    optimizer = optax.adam(learning_rate=1e-4)
    opt_state = optimizer.init(params)

    logging.info("Initialization done, starting training...")
    
    last_loss = 0.0
    default_mask = lambda x: np.ones(x.shape[:2], dtype=bool)
    eval_losses = []
    eval_accs = []
    eval_final_accs = []

    for step in tqdm.trange(training_steps, disable=not use_tqdm):
        batch, log_dict = data_generator.sample(with_markov=with_markov)
        # Transform one-hots to integer tokens.
        batch = np.argmax(batch, axis=-1)
        if "loss_mask" in log_dict:
            loss_mask = log_dict["loss_mask"]
        else:
            loss_mask = default_mask(batch)

        params, opt_state, logs = _update_parameters(
            params=params,
            opt_state=opt_state,
            sequences=batch,
            grad_fn=grad_fn,
            optimizer=optimizer,
            mask=loss_mask,
        )
        
        if log_every > 0 and step % log_every == 0:
            logging.info(
                "Step %d, Loss (avg cumulative nats) %f, Grad norm %f",
                step,
                logs["loss"],
                logs["grad_norm_unclipped"],
            )
        
        if step % (log_every * 10) == 0 and eval_data_generator is not None:
            last_loss = logs["loss"]
            eval_loss, eval_acc, eval_final_acc = evaluate_transformer_decoder(
                eval_data_generator, params, data_generator, size=size
            )
            eval_losses.append(eval_loss)
            eval_accs.append(eval_acc)
            eval_final_accs.append(eval_final_acc)
            print(
                f"Step {step}, Eval acc: {eval_acc}, Eval final acc: {eval_final_acc}"
            )

    return params, last_loss, eval_losses, eval_accs, eval_final_accs

In [None]:
rng = np.random.default_rng(seed=1)

utm_generator = utm_data_generator(
    rng,
    maximum_steps=EXECUTION_STEPS,
    maximum_program_length=100,
    memory_size=MEMORY_SIZE,
    alphabet_size=CHOMSKY_ALPHABET_SIZE,
    batch_size=BATCH_SIZE,
)

chomsky_generator = make_chomsky_generator(rng, use_delimiters=USE_DELIMITERS, task_str="modular_arithmetic")

def training_loop():
    model_sizes = [
        # "small", 
        # "medium", 
        "large"
    ]
    use_markovs = [
        True, 
        # False
    ]
    for use_markov in use_markovs:
        for model_size in model_sizes:
            params, loss, eval_losses, eval_accs, eval_final_accs = train_transformer_decoder(
                data_generator=utm_generator,
                training_steps=TRAINING_STEPS,
                batch_size=BATCH_SIZE,
                log_every=10,
                with_markov=use_markov,
                size=model_size,
                eval_data_generator=chomsky_generator,
            )
            logging.info("Final loss: %f", loss)

            SUFFIX = "markov" if use_markov else "original"

            file_name = f"params_{SUFFIX}_transformer_{model_size}.npz"
            save_params(params, file_name)

            logging.info(f"Parameters saved in file {file_name}")

            # Create a pandas DataFrame from the evaluation metrics
            eval_data = {
                "eval_losses": eval_losses,
                "eval_accs": eval_accs,
                "eval_final_accs": eval_final_accs,
            }
            eval_df = pd.DataFrame(eval_data)

            # Save the DataFrame to a CSV file
            metrics_name = f"metrics_{SUFFIX}_transformer_{model_size}.csv"
            eval_df.to_csv(metrics_name, index=False)

            logging.info(f"Evaluation metrics saved to {metrics_name}")

training_loop()

Vocab Size: 128
Model Size: large
Batch Size: 32
With Markov: True


  0%|          | 0/100000 [00:00<?, ?it/s]

Chomsky Task:  modular_arithmetic


  0%|          | 3/100000 [00:20<151:06:12,  5.44s/it]

Step 0, Eval acc: 0.014642283320426941, Eval final acc: 0.01464389730244875


  0%|          | 100/100000 [01:59<27:36:13,  1.01it/s]

Chomsky Task:  modular_arithmetic


  0%|          | 104/100000 [02:17<76:49:35,  2.77s/it] 

Step 100, Eval acc: 0.06494100391864777, Eval final acc: 0.06494282186031342


  0%|          | 200/100000 [03:57<25:26:29,  1.09it/s]

Chomsky Task:  modular_arithmetic


  0%|          | 204/100000 [04:14<70:47:53,  2.55s/it] 

Step 200, Eval acc: 0.050692368298769, Eval final acc: 0.05068203806877136


  0%|          | 300/100000 [05:45<29:38:35,  1.07s/it]

Chomsky Task:  modular_arithmetic


  0%|          | 304/100000 [06:01<68:27:12,  2.47s/it] 

Step 300, Eval acc: 0.04888910800218582, Eval final acc: 0.048901207745075226


  0%|          | 400/100000 [07:31<25:42:47,  1.08it/s]

Chomsky Task:  modular_arithmetic


  0%|          | 404/100000 [07:47<66:31:22,  2.40s/it] 

Step 400, Eval acc: 0.04482599347829819, Eval final acc: 0.04482286423444748


  0%|          | 500/100000 [09:25<28:17:19,  1.02s/it]

Chomsky Task:  modular_arithmetic


  1%|          | 504/100000 [09:42<69:06:01,  2.50s/it] 

Step 500, Eval acc: 0.050147462636232376, Eval final acc: 0.05016108229756355


  1%|          | 600/100000 [11:17<29:50:53,  1.08s/it]

Chomsky Task:  modular_arithmetic


  1%|          | 604/100000 [11:31<62:46:24,  2.27s/it] 

Step 600, Eval acc: 0.055037520825862885, Eval final acc: 0.055041294544935226


  1%|          | 700/100000 [13:03<25:45:11,  1.07it/s]

Chomsky Task:  modular_arithmetic


  1%|          | 704/100000 [13:20<70:24:37,  2.55s/it] 

Step 700, Eval acc: 0.05121892690658569, Eval final acc: 0.051229726523160934


  1%|          | 800/100000 [14:51<25:41:20,  1.07it/s]

Chomsky Task:  modular_arithmetic


  1%|          | 804/100000 [15:06<60:48:01,  2.21s/it] 

Step 800, Eval acc: 0.0574614480137825, Eval final acc: 0.057457633316516876


  1%|          | 900/100000 [16:37<28:13:52,  1.03s/it]

Chomsky Task:  modular_arithmetic


  1%|          | 904/100000 [16:53<64:18:03,  2.34s/it] 

Step 900, Eval acc: 0.06419259309768677, Eval final acc: 0.06416437774896622


  1%|          | 1000/100000 [18:24<26:43:31,  1.03it/s]

Chomsky Task:  modular_arithmetic


  1%|          | 1004/100000 [18:38<60:03:16,  2.18s/it] 

Step 1000, Eval acc: 0.07412350177764893, Eval final acc: 0.07410893589258194


  1%|          | 1100/100000 [20:07<25:20:58,  1.08it/s]

Chomsky Task:  modular_arithmetic


  1%|          | 1104/100000 [20:22<62:40:25,  2.28s/it] 

Step 1100, Eval acc: 0.0866222158074379, Eval final acc: 0.08662828058004379


  1%|          | 1200/100000 [21:49<30:34:34,  1.11s/it]

Chomsky Task:  modular_arithmetic


  1%|          | 1204/100000 [22:03<59:47:52,  2.18s/it] 

Step 1200, Eval acc: 0.09155896306037903, Eval final acc: 0.0915369838476181


  1%|▏         | 1300/100000 [23:29<25:15:52,  1.09it/s]

Chomsky Task:  modular_arithmetic


  1%|▏         | 1304/100000 [23:42<52:29:37,  1.91s/it] 

Step 1300, Eval acc: 0.08717000484466553, Eval final acc: 0.08718611299991608


  1%|▏         | 1400/100000 [25:09<25:47:45,  1.06it/s]

Chomsky Task:  modular_arithmetic


  1%|▏         | 1404/100000 [25:23<55:33:49,  2.03s/it] 

Step 1400, Eval acc: 0.08923451602458954, Eval final acc: 0.08921435475349426


  2%|▏         | 1500/100000 [26:51<25:40:44,  1.07it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 1504/100000 [27:05<56:21:10,  2.06s/it] 

Step 1500, Eval acc: 0.09456247836351395, Eval final acc: 0.09455831348896027


  2%|▏         | 1600/100000 [28:35<25:08:21,  1.09it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 1604/100000 [28:51<66:18:07,  2.43s/it] 

Step 1600, Eval acc: 0.09457181394100189, Eval final acc: 0.09455403685569763


  2%|▏         | 1700/100000 [30:22<25:00:12,  1.09it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 1704/100000 [30:35<56:18:08,  2.06s/it] 

Step 1700, Eval acc: 0.11750316619873047, Eval final acc: 0.11752823740243912


  2%|▏         | 1800/100000 [32:03<25:15:04,  1.08it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 1804/100000 [32:15<52:03:19,  1.91s/it] 

Step 1800, Eval acc: 0.11755666881799698, Eval final acc: 0.11753208935260773


  2%|▏         | 1900/100000 [33:52<28:15:32,  1.04s/it]

Chomsky Task:  modular_arithmetic


  2%|▏         | 1904/100000 [34:06<57:56:10,  2.13s/it] 

Step 1900, Eval acc: 0.11264556646347046, Eval final acc: 0.11265194416046143


  2%|▏         | 2000/100000 [35:37<25:49:13,  1.05it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 2004/100000 [35:50<56:41:20,  2.08s/it] 

Step 2000, Eval acc: 0.12557968497276306, Eval final acc: 0.1256023496389389


  2%|▏         | 2100/100000 [37:20<24:54:44,  1.09it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 2104/100000 [37:34<56:24:23,  2.07s/it] 

Step 2100, Eval acc: 0.12620136141777039, Eval final acc: 0.1261848509311676


  2%|▏         | 2200/100000 [39:04<26:05:13,  1.04it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 2204/100000 [39:17<52:18:51,  1.93s/it] 

Step 2200, Eval acc: 0.1309569776058197, Eval final acc: 0.13093748688697815


  2%|▏         | 2300/100000 [40:48<26:25:20,  1.03it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 2304/100000 [41:01<54:33:32,  2.01s/it] 

Step 2300, Eval acc: 0.12328226864337921, Eval final acc: 0.12324991077184677


  2%|▏         | 2400/100000 [42:39<26:57:43,  1.01it/s]

Chomsky Task:  modular_arithmetic


  2%|▏         | 2404/100000 [42:52<57:10:08,  2.11s/it] 

Step 2400, Eval acc: 0.12150539457798004, Eval final acc: 0.12148203700780869


  2%|▎         | 2500/100000 [44:24<24:59:22,  1.08it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 2504/100000 [44:37<55:37:43,  2.05s/it] 

Step 2500, Eval acc: 0.1427747905254364, Eval final acc: 0.14273867011070251


  3%|▎         | 2600/100000 [46:06<24:58:58,  1.08it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 2604/100000 [46:19<53:27:44,  1.98s/it] 

Step 2600, Eval acc: 0.13391704857349396, Eval final acc: 0.13391318917274475


  3%|▎         | 2700/100000 [47:54<26:22:34,  1.02it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 2704/100000 [48:06<51:57:53,  1.92s/it] 

Step 2700, Eval acc: 0.15557195246219635, Eval final acc: 0.15558139979839325


  3%|▎         | 2800/100000 [49:35<27:06:37,  1.00s/it]

Chomsky Task:  modular_arithmetic


  3%|▎         | 2804/100000 [49:48<52:19:01,  1.94s/it] 

Step 2800, Eval acc: 0.1456228643655777, Eval final acc: 0.14567571878433228


  3%|▎         | 2900/100000 [51:16<25:15:16,  1.07it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 2904/100000 [51:29<51:14:34,  1.90s/it] 

Step 2900, Eval acc: 0.14225606620311737, Eval final acc: 0.14226698875427246


  3%|▎         | 3000/100000 [52:58<24:45:52,  1.09it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 3004/100000 [53:09<49:39:23,  1.84s/it] 

Step 3000, Eval acc: 0.14108720421791077, Eval final acc: 0.1411161571741104


  3%|▎         | 3100/100000 [54:39<27:06:12,  1.01s/it]

Chomsky Task:  modular_arithmetic


  3%|▎         | 3104/100000 [54:51<52:34:11,  1.95s/it] 

Step 3100, Eval acc: 0.14719443023204803, Eval final acc: 0.14715120196342468


  3%|▎         | 3200/100000 [56:22<25:02:44,  1.07it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 3204/100000 [56:35<50:56:59,  1.89s/it] 

Step 3200, Eval acc: 0.14627835154533386, Eval final acc: 0.14629584550857544


  3%|▎         | 3300/100000 [58:03<25:01:42,  1.07it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 3304/100000 [58:15<50:07:12,  1.87s/it] 

Step 3300, Eval acc: 0.13882625102996826, Eval final acc: 0.13877323269844055


  3%|▎         | 3400/100000 [59:43<26:04:22,  1.03it/s]

Chomsky Task:  modular_arithmetic


  3%|▎         | 3404/100000 [59:56<52:30:17,  1.96s/it] 

Step 3400, Eval acc: 0.15327170491218567, Eval final acc: 0.15326711535453796


  4%|▎         | 3500/100000 [1:01:21<23:55:25,  1.12it/s]

Chomsky Task:  modular_arithmetic


  4%|▎         | 3504/100000 [1:01:33<48:19:28,  1.80s/it] 

Step 3500, Eval acc: 0.15604159235954285, Eval final acc: 0.15604352951049805


  4%|▎         | 3600/100000 [1:02:57<24:18:05,  1.10it/s]

Chomsky Task:  modular_arithmetic


  4%|▎         | 3604/100000 [1:03:09<48:56:59,  1.83s/it] 

Step 3600, Eval acc: 0.14521488547325134, Eval final acc: 0.14518818259239197


  4%|▎         | 3700/100000 [1:04:38<25:15:20,  1.06it/s]

Chomsky Task:  modular_arithmetic


  4%|▎         | 3704/100000 [1:04:51<50:47:50,  1.90s/it] 

Step 3700, Eval acc: 0.1455685794353485, Eval final acc: 0.14557942748069763


  4%|▍         | 3800/100000 [1:06:15<24:14:32,  1.10it/s]

Chomsky Task:  modular_arithmetic


  4%|▍         | 3804/100000 [1:06:27<48:52:42,  1.83s/it] 

Step 3800, Eval acc: 0.16766729950904846, Eval final acc: 0.16764938831329346


  4%|▍         | 3900/100000 [1:07:52<24:07:32,  1.11it/s]

Chomsky Task:  modular_arithmetic


  4%|▍         | 3904/100000 [1:08:03<47:42:27,  1.79s/it] 

Step 3900, Eval acc: 0.1528073400259018, Eval final acc: 0.15286216139793396


  4%|▍         | 4000/100000 [1:09:28<24:15:57,  1.10it/s]

Chomsky Task:  modular_arithmetic


  4%|▍         | 4004/100000 [1:09:39<47:53:31,  1.80s/it] 

Step 4000, Eval acc: 0.1480754315853119, Eval final acc: 0.1480807363986969


  4%|▍         | 4100/100000 [1:11:05<24:06:18,  1.11it/s]

Chomsky Task:  modular_arithmetic


  4%|▍         | 4104/100000 [1:11:17<47:21:22,  1.78s/it] 

Step 4100, Eval acc: 0.15229371190071106, Eval final acc: 0.1523023396730423


  4%|▍         | 4200/100000 [1:21:42<40:13:56,  1.51s/it]   

Chomsky Task:  modular_arithmetic


  4%|▍         | 4204/100000 [1:23:12<328:14:35, 12.34s/it]

Step 4200, Eval acc: 0.145132914185524, Eval final acc: 0.14515350759029388


  4%|▍         | 4300/100000 [1:28:26<27:20:24,  1.03s/it] 

Chomsky Task:  modular_arithmetic


  4%|▍         | 4304/100000 [1:28:39<51:54:04,  1.95s/it] 

Step 4300, Eval acc: 0.15235105156898499, Eval final acc: 0.15237991511821747


  4%|▍         | 4400/100000 [1:30:08<24:30:32,  1.08it/s]

Chomsky Task:  modular_arithmetic


  4%|▍         | 4403/100000 [1:30:20<61:01:29,  2.30s/it] 

Step 4400, Eval acc: 0.1323196142911911, Eval final acc: 0.13231059908866882


  4%|▍         | 4500/100000 [1:31:47<25:38:02,  1.03it/s]

Chomsky Task:  modular_arithmetic


  5%|▍         | 4504/100000 [1:32:01<54:59:01,  2.07s/it] 

Step 4500, Eval acc: 0.1602381467819214, Eval final acc: 0.16027073562145233


  5%|▍         | 4600/100000 [1:33:31<25:01:19,  1.06it/s]

Chomsky Task:  modular_arithmetic


  5%|▍         | 4604/100000 [1:33:44<53:25:57,  2.02s/it] 

Step 4600, Eval acc: 0.13914185762405396, Eval final acc: 0.13916060328483582


  5%|▍         | 4700/100000 [1:35:21<25:02:59,  1.06it/s]

Chomsky Task:  modular_arithmetic


  5%|▍         | 4704/100000 [1:35:33<52:14:58,  1.97s/it] 

Step 4700, Eval acc: 0.15699675679206848, Eval final acc: 0.157059907913208


  5%|▍         | 4800/100000 [1:37:08<26:55:30,  1.02s/it]

Chomsky Task:  modular_arithmetic


  5%|▍         | 4804/100000 [1:37:20<52:05:58,  1.97s/it] 

Step 4800, Eval acc: 0.16192933917045593, Eval final acc: 0.16193966567516327


  5%|▍         | 4900/100000 [1:38:53<30:51:59,  1.17s/it]

Chomsky Task:  modular_arithmetic


  5%|▍         | 4904/100000 [1:39:05<52:55:21,  2.00s/it] 

Step 4900, Eval acc: 0.15336664021015167, Eval final acc: 0.15336163341999054


  5%|▌         | 5000/100000 [1:40:34<26:20:08,  1.00it/s]

Chomsky Task:  modular_arithmetic


  5%|▌         | 5004/100000 [1:40:47<52:35:52,  1.99s/it] 

Step 5000, Eval acc: 0.15415874123573303, Eval final acc: 0.15411527454853058


  5%|▌         | 5100/100000 [1:42:20<27:05:00,  1.03s/it]

Chomsky Task:  modular_arithmetic


  5%|▌         | 5104/100000 [1:42:32<51:13:07,  1.94s/it] 

Step 5100, Eval acc: 0.18136171996593475, Eval final acc: 0.18144065141677856


  5%|▌         | 5200/100000 [1:44:01<24:43:35,  1.06it/s]

Chomsky Task:  modular_arithmetic


  5%|▌         | 5204/100000 [1:44:14<51:15:19,  1.95s/it] 

Step 5200, Eval acc: 0.16520646214485168, Eval final acc: 0.16520735621452332


  5%|▌         | 5300/100000 [1:45:43<25:53:50,  1.02it/s]

Chomsky Task:  modular_arithmetic


  5%|▌         | 5304/100000 [1:45:55<50:36:37,  1.92s/it] 

Step 5300, Eval acc: 0.16089779138565063, Eval final acc: 0.16088327765464783


  5%|▌         | 5400/100000 [1:47:25<25:02:26,  1.05it/s]

Chomsky Task:  modular_arithmetic


  5%|▌         | 5404/100000 [1:47:38<52:00:48,  1.98s/it] 

Step 5400, Eval acc: 0.17379114031791687, Eval final acc: 0.17382481694221497


  6%|▌         | 5500/100000 [1:49:08<25:27:17,  1.03it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 5504/100000 [1:49:20<50:07:51,  1.91s/it] 

Step 5500, Eval acc: 0.17197206616401672, Eval final acc: 0.17203101515769958


  6%|▌         | 5600/100000 [1:50:49<25:20:29,  1.03it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 5604/100000 [1:51:01<49:09:04,  1.87s/it] 

Step 5600, Eval acc: 0.1454876959323883, Eval final acc: 0.14551936089992523


  6%|▌         | 5700/100000 [1:52:30<25:07:02,  1.04it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 5704/100000 [1:52:42<48:45:29,  1.86s/it] 

Step 5700, Eval acc: 0.168991819024086, Eval final acc: 0.16901499032974243


  6%|▌         | 5800/100000 [1:54:10<24:04:31,  1.09it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 5804/100000 [1:54:22<49:58:42,  1.91s/it] 

Step 5800, Eval acc: 0.15017439424991608, Eval final acc: 0.1501655876636505


  6%|▌         | 5900/100000 [1:55:53<25:13:44,  1.04it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 5904/100000 [1:56:11<67:25:44,  2.58s/it] 

Step 5900, Eval acc: 0.14435601234436035, Eval final acc: 0.14431986212730408


  6%|▌         | 6000/100000 [1:57:40<24:15:07,  1.08it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 6004/100000 [1:57:52<48:44:28,  1.87s/it] 

Step 6000, Eval acc: 0.14535701274871826, Eval final acc: 0.1453121453523636


  6%|▌         | 6100/100000 [1:59:19<24:18:30,  1.07it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 6104/100000 [1:59:31<48:35:16,  1.86s/it] 

Step 6100, Eval acc: 0.15310350060462952, Eval final acc: 0.15309938788414001


  6%|▌         | 6200/100000 [2:01:08<25:04:10,  1.04it/s]

Chomsky Task:  modular_arithmetic


  6%|▌         | 6204/100000 [2:01:21<50:51:39,  1.95s/it] 

Step 6200, Eval acc: 0.16251665353775024, Eval final acc: 0.1625455915927887


  6%|▋         | 6300/100000 [2:03:00<24:36:06,  1.06it/s]

Chomsky Task:  modular_arithmetic


  6%|▋         | 6304/100000 [2:03:12<49:45:21,  1.91s/it] 

Step 6300, Eval acc: 0.1726336032152176, Eval final acc: 0.1726003736257553


  6%|▋         | 6400/100000 [2:04:45<24:56:43,  1.04it/s]

Chomsky Task:  modular_arithmetic


  6%|▋         | 6404/100000 [2:04:57<50:17:20,  1.93s/it] 

Step 6400, Eval acc: 0.15953753888607025, Eval final acc: 0.1595081388950348


  6%|▋         | 6500/100000 [2:06:26<24:27:05,  1.06it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 6504/100000 [2:06:38<50:17:51,  1.94s/it] 

Step 6500, Eval acc: 0.1454695463180542, Eval final acc: 0.14547036588191986


  7%|▋         | 6600/100000 [2:08:08<24:23:36,  1.06it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 6604/100000 [2:08:20<48:23:48,  1.87s/it] 

Step 6600, Eval acc: 0.1573750376701355, Eval final acc: 0.15733352303504944


  7%|▋         | 6700/100000 [2:09:49<24:00:32,  1.08it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 6704/100000 [2:10:01<47:52:55,  1.85s/it] 

Step 6700, Eval acc: 0.17338204383850098, Eval final acc: 0.17347732186317444


  7%|▋         | 6800/100000 [2:11:31<23:47:41,  1.09it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 6804/100000 [2:11:43<47:23:36,  1.83s/it] 

Step 6800, Eval acc: 0.15345676243305206, Eval final acc: 0.15339821577072144


  7%|▋         | 6900/100000 [2:13:10<25:04:47,  1.03it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 6904/100000 [2:13:23<50:26:12,  1.95s/it] 

Step 6900, Eval acc: 0.16161048412322998, Eval final acc: 0.1615002453327179


  7%|▋         | 7000/100000 [2:14:54<23:58:27,  1.08it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 7004/100000 [2:15:05<47:29:50,  1.84s/it] 

Step 7000, Eval acc: 0.15072350203990936, Eval final acc: 0.15077510476112366


  7%|▋         | 7100/100000 [2:16:40<30:09:04,  1.17s/it]

Chomsky Task:  modular_arithmetic


  7%|▋         | 7104/100000 [2:16:53<52:49:39,  2.05s/it] 

Step 7100, Eval acc: 0.15600775182247162, Eval final acc: 0.15600356459617615


  7%|▋         | 7200/100000 [2:18:33<27:22:47,  1.06s/it]

Chomsky Task:  modular_arithmetic


  7%|▋         | 7204/100000 [2:18:46<55:22:51,  2.15s/it] 

Step 7200, Eval acc: 0.16102595627307892, Eval final acc: 0.16101260483264923


  7%|▋         | 7300/100000 [2:20:17<24:50:26,  1.04it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 7304/100000 [2:20:30<51:04:19,  1.98s/it] 

Step 7300, Eval acc: 0.16531100869178772, Eval final acc: 0.1653333455324173


  7%|▋         | 7400/100000 [2:22:01<24:56:48,  1.03it/s]

Chomsky Task:  modular_arithmetic


  7%|▋         | 7404/100000 [2:22:14<48:49:57,  1.90s/it] 

Step 7400, Eval acc: 0.17398540675640106, Eval final acc: 0.1740311086177826


  8%|▊         | 7500/100000 [2:23:44<24:45:24,  1.04it/s]

Chomsky Task:  modular_arithmetic


  8%|▊         | 7504/100000 [2:23:59<57:44:40,  2.25s/it] 

Step 7500, Eval acc: 0.15517720580101013, Eval final acc: 0.1551615446805954


  8%|▊         | 7600/100000 [2:25:32<23:55:41,  1.07it/s]

Chomsky Task:  modular_arithmetic


  8%|▊         | 7604/100000 [2:25:44<48:35:02,  1.89s/it] 

Step 7600, Eval acc: 0.1594967395067215, Eval final acc: 0.15954826772212982


  8%|▊         | 7633/100000 [2:26:11<25:14:31,  1.02it/s]

In [6]:
# Plotting eval accs and eval final accs
# plt.plot(eval_accs, label="avg", color="red")
# plt.plot(eval_final_accs, label="final", color="blue")
# plt.legend()
# plt.show()

In [7]:
# For testing only. Chomsky evaluation is done in run_chomsky_experiments.ipynb
# regret, total_accuracy, total_final_accuracy = evaluate_transformer_decoder(
#     chomsky_generator, params, utm_generator, num_batches=10, size=MODEL_SIZE
# )
# print(total_accuracy, total_final_accuracy)