In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import functools
import numpy as np
import optax
import tqdm
import tree
import pandas as pd
import matplotlib.pyplot as plt

from absl import logging
from typing import Any

from data import utm_data_generator as utm_dg_lib
from data import chomsky_data_generator as chomsky_sampler_lib
from helpers import make_chomsky_generator, utm_data_generator, make_model, init_params, save_params, evaluate_transformer_decoder, CHOMSKY_ALPHABET_SIZE

In [2]:
seed = 1

# Follows the paper's parameters
TRAINING_STEPS = 100000
EXECUTION_STEPS = 1000
USE_DELIMITERS = True
MEMORY_SIZE = 200
BATCH_SIZE = 32

# Use Markov = [True, False]
USE_MARKOV = True
# Model Size = ["small", "medium", "large"]
MODEL_SIZE = "large"

In [3]:
def _make_loss_fn(model: hk.Transformed) -> Any:
  """Returns the loss function for update_parameters."""

  def loss_fn(
      params: hk.Params,
      sequences: jax.Array,
      mask: jax.Array,
  ) -> jnp.float32:
    """Returns the loss for the model and the last state.

    Args:
      params: The parameters of the model, usually a neural network.
      sequences: The input of sequences to evaluate. See neural_predictors.py.
      mask: A binary array, True (1's) denote where to skip computing the loss.
    """
    # This code computes the loss for a transformer decoder model:
    # 1. Apply the model to get log probabilities (conditionals) for each token
    conditionals = model.apply(
        params=params,
        targets=sequences,
        rng=None,
    )
    # 2. Extract the log probabilities of the actual tokens that appeared in the sequence
    # by using take_along_axis to select the probability corresponding to each token
    true_conditionals = jnp.take_along_axis(
        conditionals, sequences[..., None], axis=-1
    )[..., 0]
    # 3. Apply the mask to zero out log probabilities where we should skip computing loss (e.g., for padding tokens)
    true_conditionals = jnp.where(mask, 0.0, true_conditionals)
    # 4. Sum the log probabilities across the sequence dimension to get log likelihood per batch
    marginals = jnp.sum(true_conditionals, axis=1)  # Shape (B,).
    # 5. Return the negative mean log likelihood as the loss (for minimization)
    return -jnp.mean(marginals)

  return loss_fn


@functools.partial(
    jax.jit, static_argnames=('optimizer', 'grad_fn', 'normalize_gradients')
)
def _update_parameters(
    params: hk.Params,
    opt_state: optax.OptState,
    sequences: jax.Array,
    mask: jax.Array,
    grad_fn: Any,
    optimizer: optax.GradientTransformation,
    normalize_gradients: bool = True,
) -> tuple[hk.Params, optax.OptState, dict[str, Any]]:
  """Returns updated params and extra logs (like loss, last state etc).

  Backpropagation is done on the whole sequence. The whole function is jitted.

  Args:
    params: The current parameters of the network.
    opt_state: The optimizer state.
    sequences: The input of sequences to evaluate. See base_predictor.py.
    mask: A binary array, True (1's) denote where to skip computing the loss.
    grad_fn: A gradient function, which takes some parameters, a random seed,
      the data to compute the gradient on, and an initial state for the
      predictor. It returns the gradient of the parameters for this batch of
      data, and extra values.
    optimizer: An optax optimizer.
    normalize_gradients: Whether to divide the gradients by the length of the
      sequences, or keep them as is. Using this option guarantees to have the
      same scale across various sequence lengths, and therefore tasks.
  """
  loss, grad = grad_fn(params, sequences, mask)
  if normalize_gradients:
    length_sequence = float(sequences.shape[1])
    grad = tree.map_structure(lambda x: x / length_sequence, grad)
  updates, new_opt_state = optimizer.update(grad, opt_state)
  new_params = optax.apply_updates(params, updates)

  log_dict = {
      'loss': loss,
      'grad_norm_unclipped': optax.global_norm(grad),
  }

  return new_params, new_opt_state, log_dict

In [4]:
def train_transformer_decoder(
    data_generator: utm_dg_lib.UTMDataGenerator,
    training_steps: int,
    log_every: int,
    batch_size: int,
    use_tqdm: bool = True,
    with_markov: bool = False,
    size: str = "large",
    eval_data_generator: chomsky_sampler_lib.ChomskyDataGenerator = None,
) -> tuple[hk.Params, float, list[float], list[float], list[float]]:
    """Trains a neural network on some synthetic data.

    We train a decoder-only transformer on batches, minimizing the log-loss
    objective. The exact architecture can be modified using the TransformerConfig
    object (defined in models/transformer.py)

    Args:
      data_generator: Used to generate batches of data to train on.
      training_steps: Number of batches to train on.
      log_every: How often to log the loss. If negative or 0, no log at all.
      batch_size: The number of sequences in a batch.
      use_tqdm: Whether to use a progress bar or not.

    Returns:
      The final loss, and final parameters.
    """
    print("Vocab Size:", data_generator.feature_size)
    print("Model Size:", size)
    print("Batch Size:", batch_size)
    print("With Markov:", with_markov)
    model = make_model(data_generator, size)

    params = init_params(model, data_generator, batch_size)

    # Make gradient function.
    loss_fn = _make_loss_fn(model)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=False)

    # Make optimizer, to apply the gradients.
    optimizer = optax.adam(learning_rate=1e-4)
    opt_state = optimizer.init(params)

    logging.info("Initialization done, starting training...")
    
    last_loss = 0.0
    default_mask = lambda x: np.ones(x.shape[:2], dtype=bool)
    eval_losses = []
    eval_accs = []
    eval_final_accs = []

    for step in tqdm.trange(training_steps, disable=not use_tqdm):
        batch, log_dict = data_generator.sample(with_markov=with_markov)
        # Transform one-hots to integer tokens.
        batch = np.argmax(batch, axis=-1)
        if "loss_mask" in log_dict:
            loss_mask = log_dict["loss_mask"]
        else:
            loss_mask = default_mask(batch)

        params, opt_state, logs = _update_parameters(
            params=params,
            opt_state=opt_state,
            sequences=batch,
            grad_fn=grad_fn,
            optimizer=optimizer,
            mask=loss_mask,
        )
        
        if log_every > 0 and step % log_every == 0:
            logging.info(
                "Step %d, Loss (avg cumulative nats) %f, Grad norm %f",
                step,
                logs["loss"],
                logs["grad_norm_unclipped"],
            )
        
        if step % (log_every * 10) == 0 and eval_data_generator is not None:
            last_loss = logs["loss"]
            eval_loss, eval_acc, eval_final_acc = evaluate_transformer_decoder(
                eval_data_generator, params, data_generator, size=size
            )
            eval_losses.append(eval_loss)
            eval_accs.append(eval_acc)
            eval_final_accs.append(eval_final_acc)
            print(
                f"Step {step}, Eval acc: {eval_acc}, Eval final acc: {eval_final_acc}"
            )

    return params, last_loss, eval_losses, eval_accs, eval_final_accs

In [11]:
rng = np.random.default_rng(seed=1)

utm_generator = utm_data_generator(
    rng,
    maximum_steps=EXECUTION_STEPS,
    maximum_program_length=100,
    memory_size=MEMORY_SIZE,
    alphabet_size=CHOMSKY_ALPHABET_SIZE,
    batch_size=BATCH_SIZE,
)

chomsky_generator = make_chomsky_generator(rng, use_delimiters=USE_DELIMITERS, task_str="modular_arithmetic")

def training_loop():
    model_sizes = [
        # "small", 
        # "medium", 
        "large"
    ]
    use_markovs = [
        True, 
        # False
    ]
    for use_markov in use_markovs:
        for model_size in model_sizes:
            params, loss, eval_losses, eval_accs, eval_final_accs = train_transformer_decoder(
                data_generator=utm_generator,
                training_steps=TRAINING_STEPS,
                batch_size=BATCH_SIZE,
                log_every=10,
                with_markov=use_markov,
                size=model_size,
                eval_data_generator=chomsky_generator,
            )
            logging.info("Final loss: %f", loss)

            SUFFIX = "markov" if use_markov else "original"

            file_name = f"params_{SUFFIX}_transformer_{model_size}.npz"
            save_params(params, file_name)

            logging.info(f"Parameters saved in file {file_name}")

            # Create a pandas DataFrame from the evaluation metrics
            eval_data = {
                "eval_losses": eval_losses,
                "eval_accs": eval_accs,
                "eval_final_accs": eval_final_accs,
            }
            eval_df = pd.DataFrame(eval_data)

            # Save the DataFrame to a CSV file
            metrics_name = f"metrics_{SUFFIX}_transformer_{model_size}.csv"
            eval_df.to_csv(metrics_name, index=False)

            logging.info(f"Evaluation metrics saved to {metrics_name}")

training_loop()

Vocab Size: 128
Model Size: small
Batch Size: 32
With Markov: True


  0%|          | 0/2000 [00:00<?, ?it/s]

Chomsky Task:  even_pairs


  0%|          | 4/2000 [00:06<40:50,  1.23s/it]  

Step 0, Eval acc: 0.007958902046084404, Eval final acc: 0.007958405651152134


  5%|▌         | 100/2000 [00:10<01:12, 26.17it/s]

Chomsky Task:  even_pairs


  5%|▌         | 106/2000 [00:13<08:00,  3.94it/s]

Step 100, Eval acc: 0.01636669784784317, Eval final acc: 0.016366617754101753


 10%|▉         | 199/2000 [00:17<01:08, 26.32it/s]

Chomsky Task:  even_pairs


 10%|█         | 205/2000 [00:20<07:18,  4.09it/s]

Step 200, Eval acc: 0.02849704399704933, Eval final acc: 0.028496259823441505


 15%|█▍        | 298/2000 [00:23<01:03, 26.62it/s]

Chomsky Task:  even_pairs


 15%|█▌        | 304/2000 [00:26<05:08,  5.49it/s]

Step 300, Eval acc: 0.040885526686906815, Eval final acc: 0.04088742285966873


 20%|██        | 400/2000 [00:29<00:59, 26.96it/s]

Chomsky Task:  even_pairs


 20%|██        | 406/2000 [00:32<05:07,  5.19it/s]

Step 400, Eval acc: 0.056422971189022064, Eval final acc: 0.05642212554812431


 25%|██▍       | 499/2000 [00:35<00:58, 25.48it/s]

Chomsky Task:  even_pairs


 25%|██▌       | 505/2000 [00:38<04:29,  5.54it/s]

Step 500, Eval acc: 0.07645846158266068, Eval final acc: 0.07645781338214874


 30%|██▉       | 598/2000 [00:41<00:59, 23.69it/s]

Chomsky Task:  even_pairs


 30%|███       | 604/2000 [00:43<03:37,  6.42it/s]

Step 600, Eval acc: 0.10110951960086823, Eval final acc: 0.10110883414745331


 35%|███▌      | 700/2000 [00:47<00:50, 25.92it/s]

Chomsky Task:  even_pairs


 35%|███▌      | 706/2000 [00:50<04:53,  4.42it/s]

Step 700, Eval acc: 0.13346965610980988, Eval final acc: 0.13349561393260956


 40%|███▉      | 799/2000 [00:54<00:47, 25.30it/s]

Chomsky Task:  even_pairs


 40%|████      | 805/2000 [00:56<03:13,  6.18it/s]

Step 800, Eval acc: 0.16318237781524658, Eval final acc: 0.16315510869026184


 45%|████▍     | 898/2000 [00:59<00:47, 22.97it/s]

Chomsky Task:  even_pairs


 45%|████▌     | 903/2000 [01:01<02:54,  6.29it/s]

Step 900, Eval acc: 0.1966998130083084, Eval final acc: 0.19666343927383423


 50%|████▉     | 999/2000 [01:05<00:38, 25.67it/s]

Chomsky Task:  even_pairs


 50%|█████     | 1005/2000 [01:07<02:13,  7.47it/s]

Step 1000, Eval acc: 0.22849054634571075, Eval final acc: 0.22846587002277374


 55%|█████▍    | 1098/2000 [01:10<00:39, 22.86it/s]

Chomsky Task:  even_pairs


 55%|█████▌    | 1104/2000 [01:13<02:53,  5.17it/s]

Step 1100, Eval acc: 0.2546849846839905, Eval final acc: 0.2546969950199127


 60%|██████    | 1200/2000 [01:16<00:32, 24.81it/s]

Chomsky Task:  even_pairs


 60%|██████    | 1206/2000 [01:18<01:51,  7.10it/s]

Step 1200, Eval acc: 0.27089476585388184, Eval final acc: 0.27090197801589966


 65%|██████▍   | 1299/2000 [01:22<00:32, 21.38it/s]

Chomsky Task:  even_pairs


 65%|██████▌   | 1305/2000 [01:23<01:12,  9.62it/s]

Step 1300, Eval acc: 0.2901690900325775, Eval final acc: 0.29018282890319824


 70%|██████▉   | 1398/2000 [01:27<00:22, 26.28it/s]

Chomsky Task:  even_pairs


 70%|███████   | 1404/2000 [01:28<01:05,  9.15it/s]

Step 1400, Eval acc: 0.3093836009502411, Eval final acc: 0.31167200207710266


 75%|███████▌  | 1500/2000 [01:32<00:19, 25.41it/s]

Chomsky Task:  even_pairs


 75%|███████▌  | 1503/2000 [01:33<01:04,  7.66it/s]

Step 1500, Eval acc: 0.3167768716812134, Eval final acc: 0.31671327352523804


 80%|███████▉  | 1598/2000 [01:36<00:16, 25.02it/s]

Chomsky Task:  even_pairs


 80%|████████  | 1604/2000 [01:38<00:42,  9.34it/s]

Step 1600, Eval acc: 0.33109691739082336, Eval final acc: 0.3310949206352234


 85%|████████▌ | 1700/2000 [01:41<00:11, 25.31it/s]

Chomsky Task:  even_pairs


 85%|████████▌ | 1706/2000 [01:43<00:38,  7.60it/s]

Step 1700, Eval acc: 0.335883766412735, Eval final acc: 0.3358166813850403


 90%|████████▉ | 1799/2000 [01:47<00:08, 23.01it/s]

Chomsky Task:  even_pairs


 90%|█████████ | 1805/2000 [01:48<00:20,  9.30it/s]

Step 1800, Eval acc: 0.35400864481925964, Eval final acc: 0.3539075553417206


 95%|█████████▍| 1898/2000 [01:52<00:04, 22.28it/s]

Chomsky Task:  even_pairs


 95%|█████████▌| 1904/2000 [01:54<00:14,  6.79it/s]

Step 1900, Eval acc: 0.3603024482727051, Eval final acc: 0.36032363772392273


100%|██████████| 2000/2000 [01:57<00:00, 16.95it/s]


Vocab Size: 128
Model Size: medium
Batch Size: 32
With Markov: True


  0%|          | 0/2000 [00:00<?, ?it/s]

Chomsky Task:  even_pairs


  0%|          | 3/2000 [00:06<53:33,  1.61s/it]  

Step 0, Eval acc: 0.01046818122267723, Eval final acc: 0.01043485477566719


  5%|▌         | 100/2000 [00:28<07:51,  4.03it/s]

Chomsky Task:  even_pairs


  5%|▌         | 104/2000 [00:32<18:44,  1.69it/s]

Step 100, Eval acc: 0.22034886479377747, Eval final acc: 0.218496173620224


 10%|█         | 200/2000 [00:54<06:43,  4.46it/s]

Chomsky Task:  even_pairs


 10%|█         | 204/2000 [00:58<17:09,  1.75it/s]

Step 200, Eval acc: 0.31026580929756165, Eval final acc: 0.31105977296829224


 15%|█▌        | 300/2000 [01:20<06:40,  4.24it/s]

Chomsky Task:  even_pairs


 15%|█▌        | 304/2000 [01:24<15:20,  1.84it/s]

Step 300, Eval acc: 0.3564973771572113, Eval final acc: 0.3564954996109009


 20%|██        | 400/2000 [01:45<06:09,  4.33it/s]

Chomsky Task:  even_pairs


 20%|██        | 404/2000 [01:49<14:30,  1.83it/s]

Step 400, Eval acc: 0.3968147933483124, Eval final acc: 0.39681917428970337


 25%|██▌       | 500/2000 [02:11<05:49,  4.30it/s]

Chomsky Task:  even_pairs


 25%|██▌       | 504/2000 [02:15<14:59,  1.66it/s]

Step 500, Eval acc: 0.41329818964004517, Eval final acc: 0.4132329523563385


 30%|███       | 600/2000 [02:37<06:10,  3.78it/s]

Chomsky Task:  even_pairs


 30%|███       | 604/2000 [02:41<13:59,  1.66it/s]

Step 600, Eval acc: 0.4275030195713043, Eval final acc: 0.42751413583755493


 35%|███▌      | 700/2000 [03:03<05:14,  4.13it/s]

Chomsky Task:  even_pairs


 35%|███▌      | 704/2000 [03:07<12:11,  1.77it/s]

Step 700, Eval acc: 0.4358794093132019, Eval final acc: 0.4358961582183838


 40%|████      | 800/2000 [03:29<04:27,  4.49it/s]

Chomsky Task:  even_pairs


 40%|████      | 803/2000 [03:32<13:13,  1.51it/s]

Step 800, Eval acc: 0.44898638129234314, Eval final acc: 0.44901242852211


 45%|████▌     | 900/2000 [03:54<04:17,  4.26it/s]

Chomsky Task:  even_pairs


 45%|████▌     | 903/2000 [03:58<12:27,  1.47it/s]

Step 900, Eval acc: 0.4449065327644348, Eval final acc: 0.4448942244052887


 50%|█████     | 1000/2000 [04:20<04:07,  4.04it/s]

Chomsky Task:  even_pairs


 50%|█████     | 1004/2000 [04:39<43:23,  2.61s/it]  

Step 1000, Eval acc: 0.45243018865585327, Eval final acc: 0.4524269998073578


 55%|█████▌    | 1100/2000 [05:02<03:24,  4.41it/s]

Chomsky Task:  even_pairs


 55%|█████▌    | 1104/2000 [05:05<07:35,  1.97it/s]

Step 1100, Eval acc: 0.45805415511131287, Eval final acc: 0.4580756723880768


 60%|██████    | 1200/2000 [05:26<03:05,  4.30it/s]

Chomsky Task:  even_pairs


 60%|██████    | 1204/2000 [05:30<07:23,  1.79it/s]

Step 1200, Eval acc: 0.46260109543800354, Eval final acc: 0.46267908811569214


 65%|██████▌   | 1300/2000 [05:51<02:36,  4.48it/s]

Chomsky Task:  even_pairs


 65%|██████▌   | 1303/2000 [05:54<07:13,  1.61it/s]

Step 1300, Eval acc: 0.46420541405677795, Eval final acc: 0.46420708298683167


 70%|███████   | 1400/2000 [06:16<02:19,  4.30it/s]

Chomsky Task:  even_pairs


 70%|███████   | 1404/2000 [06:20<05:26,  1.82it/s]

Step 1400, Eval acc: 0.46998634934425354, Eval final acc: 0.469992071390152


 75%|███████▌  | 1500/2000 [06:41<01:56,  4.30it/s]

Chomsky Task:  even_pairs


 75%|███████▌  | 1504/2000 [06:45<04:32,  1.82it/s]

Step 1500, Eval acc: 0.46396350860595703, Eval final acc: 0.46398186683654785


 80%|████████  | 1600/2000 [07:07<01:38,  4.07it/s]

Chomsky Task:  even_pairs


 80%|████████  | 1604/2000 [07:10<03:44,  1.77it/s]

Step 1600, Eval acc: 0.45913347601890564, Eval final acc: 0.4590676724910736


 85%|████████▌ | 1700/2000 [07:33<01:08,  4.38it/s]

Chomsky Task:  even_pairs


 85%|████████▌ | 1704/2000 [07:36<02:35,  1.90it/s]

Step 1700, Eval acc: 0.4729768633842468, Eval final acc: 0.4729904234409332


 90%|█████████ | 1800/2000 [07:57<00:45,  4.44it/s]

Chomsky Task:  even_pairs


 90%|█████████ | 1804/2000 [08:00<01:39,  1.96it/s]

Step 1800, Eval acc: 0.4673359990119934, Eval final acc: 0.4673217833042145


 95%|█████████▌| 1900/2000 [08:22<00:22,  4.38it/s]

Chomsky Task:  even_pairs


 95%|█████████▌| 1904/2000 [08:25<00:49,  1.96it/s]

Step 1900, Eval acc: 0.4664842486381531, Eval final acc: 0.4664286971092224


100%|██████████| 2000/2000 [08:47<00:00,  3.79it/s]


Vocab Size: 128
Model Size: large
Batch Size: 32
With Markov: True


  0%|          | 0/2000 [00:00<?, ?it/s]

Chomsky Task:  even_pairs


  0%|          | 4/2000 [00:10<1:04:35,  1.94s/it]

Step 0, Eval acc: 0.038628801703453064, Eval final acc: 0.038620077073574066


  5%|▌         | 100/2000 [01:41<30:27,  1.04it/s]

Chomsky Task:  even_pairs


  5%|▌         | 104/2000 [01:53<1:00:45,  1.92s/it]

Step 100, Eval acc: 0.45471954345703125, Eval final acc: 0.4547271728515625


 10%|█         | 200/2000 [04:07<37:22,  1.25s/it]  

Chomsky Task:  even_pairs


 10%|█         | 204/2000 [04:24<1:18:55,  2.64s/it]

Step 200, Eval acc: 0.4707615375518799, Eval final acc: 0.4707808494567871


 15%|█▌        | 300/2000 [05:59<26:33,  1.07it/s]  

Chomsky Task:  even_pairs


 15%|█▌        | 304/2000 [06:11<53:01,  1.88s/it]  

Step 300, Eval acc: 0.47670355439186096, Eval final acc: 0.476762056350708


 20%|██        | 400/2000 [07:53<24:15,  1.10it/s]  

Chomsky Task:  even_pairs


 20%|██        | 404/2000 [08:05<47:45,  1.80s/it]  

Step 400, Eval acc: 0.47783446311950684, Eval final acc: 0.4777982831001282


 25%|██▌       | 500/2000 [09:33<30:27,  1.22s/it]

Chomsky Task:  even_pairs


 25%|██▌       | 503/2000 [09:46<1:02:04,  2.49s/it]

Step 500, Eval acc: 0.47912994027137756, Eval final acc: 0.4791753888130188


 30%|███       | 600/2000 [11:17<22:18,  1.05it/s]  

Chomsky Task:  even_pairs


 30%|███       | 604/2000 [11:30<45:21,  1.95s/it]  

Step 600, Eval acc: 0.48145729303359985, Eval final acc: 0.48144450783729553


 35%|███▌      | 700/2000 [13:00<20:53,  1.04it/s]

Chomsky Task:  even_pairs


 35%|███▌      | 704/2000 [13:12<40:53,  1.89s/it]  

Step 700, Eval acc: 0.4858711361885071, Eval final acc: 0.4859001040458679


 40%|████      | 800/2000 [14:42<19:12,  1.04it/s]

Chomsky Task:  even_pairs


 40%|████      | 804/2000 [14:54<38:03,  1.91s/it]  

Step 800, Eval acc: 0.48320847749710083, Eval final acc: 0.4831902086734772


 45%|████▌     | 900/2000 [16:23<17:47,  1.03it/s]

Chomsky Task:  even_pairs


 45%|████▌     | 904/2000 [16:35<34:37,  1.90s/it]  

Step 900, Eval acc: 0.4858960211277008, Eval final acc: 0.4860699772834778


 50%|█████     | 1000/2000 [18:07<16:01,  1.04it/s]

Chomsky Task:  even_pairs


 50%|█████     | 1004/2000 [18:19<30:30,  1.84s/it]  

Step 1000, Eval acc: 0.48407354950904846, Eval final acc: 0.48409825563430786


 55%|█████▌    | 1100/2000 [19:50<14:08,  1.06it/s]

Chomsky Task:  even_pairs


 55%|█████▌    | 1104/2000 [20:03<28:26,  1.90s/it]  

Step 1100, Eval acc: 0.4635477066040039, Eval final acc: 0.46357792615890503


 60%|██████    | 1200/2000 [21:51<16:57,  1.27s/it]

Chomsky Task:  even_pairs


 60%|██████    | 1204/2000 [22:08<33:47,  2.55s/it]  

Step 1200, Eval acc: 0.4770578444004059, Eval final acc: 0.4770139157772064


 65%|██████▌   | 1300/2000 [23:59<14:18,  1.23s/it]

Chomsky Task:  even_pairs


 65%|██████▌   | 1304/2000 [24:13<26:01,  2.24s/it]

Step 1300, Eval acc: 0.47410717606544495, Eval final acc: 0.4741422235965729


 70%|███████   | 1400/2000 [25:53<09:40,  1.03it/s]

Chomsky Task:  even_pairs


 70%|███████   | 1404/2000 [26:05<18:04,  1.82s/it]

Step 1400, Eval acc: 0.47567200660705566, Eval final acc: 0.4756760597229004


 75%|███████▌  | 1500/2000 [27:30<07:36,  1.09it/s]

Chomsky Task:  even_pairs


 75%|███████▌  | 1504/2000 [27:42<14:42,  1.78s/it]

Step 1500, Eval acc: 0.482139527797699, Eval final acc: 0.48211774230003357


 80%|████████  | 1600/2000 [29:07<06:00,  1.11it/s]

Chomsky Task:  even_pairs


 80%|████████  | 1604/2000 [29:18<11:56,  1.81s/it]

Step 1600, Eval acc: 0.47045761346817017, Eval final acc: 0.4735780656337738


 85%|████████▌ | 1700/2000 [30:44<04:31,  1.11it/s]

Chomsky Task:  even_pairs


 85%|████████▌ | 1704/2000 [30:55<08:46,  1.78s/it]

Step 1700, Eval acc: 0.4767988324165344, Eval final acc: 0.47688227891921997


 90%|█████████ | 1800/2000 [32:20<03:01,  1.10it/s]

Chomsky Task:  even_pairs


 90%|█████████ | 1804/2000 [32:31<05:49,  1.78s/it]

Step 1800, Eval acc: 0.48402658104896545, Eval final acc: 0.4840807020664215


 95%|█████████▌| 1900/2000 [33:56<01:29,  1.12it/s]

Chomsky Task:  even_pairs


 95%|█████████▌| 1904/2000 [34:07<02:49,  1.77s/it]

Step 1900, Eval acc: 0.4705963730812073, Eval final acc: 0.4704767167568207


100%|██████████| 2000/2000 [35:32<00:00,  1.07s/it]


In [6]:
# Plotting eval accs and eval final accs
# plt.plot(eval_accs, label="avg", color="red")
# plt.plot(eval_final_accs, label="final", color="blue")
# plt.legend()
# plt.show()

In [7]:
# For testing only. Chomsky evaluation is done in run_chomsky_experiments.ipynb
# regret, total_accuracy, total_final_accuracy = evaluate_transformer_decoder(
#     chomsky_generator, params, utm_generator, num_batches=10, size=MODEL_SIZE
# )
# print(total_accuracy, total_final_accuracy)