In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import functools
import numpy as np
import optax
import tqdm
import tree
import pandas as pd
import matplotlib.pyplot as plt

from absl import logging
from typing import Any

from data import utm_data_generator as utm_dg_lib
from data import chomsky_data_generator as chomsky_sampler_lib
from helpers import make_chomsky_generator, utm_data_generator, make_model, init_params, save_params, evaluate_transformer_decoder, CHOMSKY_ALPHABET_SIZE

In [2]:
seed = 1

class SAMPLE_TYPE:
    ORIGINAL = 0
    MARKOV = 1
    RANDOM = 2

# Follows the paper's parameters
TRAINING_STEPS = 2000 # not really used in training loop
EXECUTION_STEPS = 1000
USE_DELIMITERS = True
MEMORY_SIZE = 200
BATCH_SIZE = 32

In [3]:
def _make_loss_fn(model: hk.Transformed) -> Any:
  """Returns the loss function for update_parameters."""

  def loss_fn(
      params: hk.Params,
      sequences: jax.Array,
      mask: jax.Array,
  ) -> jnp.float32:
    """Returns the loss for the model and the last state.

    Args:
      params: The parameters of the model, usually a neural network.
      sequences: The input of sequences to evaluate. See neural_predictors.py.
      mask: A binary array, True (1's) denote where to skip computing the loss.
    """
    # This code computes the loss for a transformer decoder model:
    # 1. Apply the model to get log probabilities (conditionals) for each token
    conditionals = model.apply(
        params=params,
        targets=sequences,
        rng=None,
    )
    # 2. Extract the log probabilities of the actual tokens that appeared in the sequence
    # by using take_along_axis to select the probability corresponding to each token
    true_conditionals = jnp.take_along_axis(
        conditionals, sequences[..., None], axis=-1
    )[..., 0]
    # 3. Apply the mask to zero out log probabilities where we should skip computing loss (e.g., for padding tokens)
    true_conditionals = jnp.where(mask, 0.0, true_conditionals)
    # 4. Sum the log probabilities across the sequence dimension to get log likelihood per batch
    marginals = jnp.sum(true_conditionals, axis=1)  # Shape (B,).
    # 5. Return the negative mean log likelihood as the loss (for minimization)
    return -jnp.mean(marginals)

  return loss_fn


@functools.partial(
    jax.jit, static_argnames=('optimizer', 'grad_fn', 'normalize_gradients')
)
def _update_parameters(
    params: hk.Params,
    opt_state: optax.OptState,
    sequences: jax.Array,
    mask: jax.Array,
    grad_fn: Any,
    optimizer: optax.GradientTransformation,
    normalize_gradients: bool = True,
) -> tuple[hk.Params, optax.OptState, dict[str, Any]]:
  """Returns updated params and extra logs (like loss, last state etc).

  Backpropagation is done on the whole sequence. The whole function is jitted.

  Args:
    params: The current parameters of the network.
    opt_state: The optimizer state.
    sequences: The input of sequences to evaluate. See base_predictor.py.
    mask: A binary array, True (1's) denote where to skip computing the loss.
    grad_fn: A gradient function, which takes some parameters, a random seed,
      the data to compute the gradient on, and an initial state for the
      predictor. It returns the gradient of the parameters for this batch of
      data, and extra values.
    optimizer: An optax optimizer.
    normalize_gradients: Whether to divide the gradients by the length of the
      sequences, or keep them as is. Using this option guarantees to have the
      same scale across various sequence lengths, and therefore tasks.
  """
  loss, grad = grad_fn(params, sequences, mask)
  if normalize_gradients:
    length_sequence = float(sequences.shape[1])
    grad = tree.map_structure(lambda x: x / length_sequence, grad)
  updates, new_opt_state = optimizer.update(grad, opt_state)
  new_params = optax.apply_updates(params, updates)

  log_dict = {
      'loss': loss,
      'grad_norm_unclipped': optax.global_norm(grad),
  }

  return new_params, new_opt_state, log_dict

In [4]:
def train_transformer_decoder(
    data_generator: utm_dg_lib.UTMDataGenerator,
    training_steps: int,
    log_every: int,
    batch_size: int,
    use_tqdm: bool = True,
    with_markov: bool = False,
    size: str = "large",
    eval_data_generator: chomsky_sampler_lib.ChomskyDataGenerator = None,
) -> tuple[hk.Params, float, list[float], list[float], list[float]]:
    """Trains a neural network on some synthetic data.

    We train a decoder-only transformer on batches, minimizing the log-loss
    objective. The exact architecture can be modified using the TransformerConfig
    object (defined in models/transformer.py)

    Args:
      data_generator: Used to generate batches of data to train on.
      training_steps: Number of batches to train on.
      log_every: How often to log the loss. If negative or 0, no log at all.
      batch_size: The number of sequences in a batch.
      use_tqdm: Whether to use a progress bar or not.

    Returns:
      The final loss, and final parameters.
    """
    print("Vocab Size:", data_generator.feature_size)
    print("Model Size:", size)
    print("Batch Size:", batch_size)
    print("With Markov:", with_markov)
    model = make_model(data_generator, size)

    params = init_params(model, data_generator, batch_size)

    # Make gradient function.
    loss_fn = _make_loss_fn(model)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=False)

    # Make optimizer, to apply the gradients.
    optimizer = optax.adam(learning_rate=1e-4)
    opt_state = optimizer.init(params)

    logging.info("Initialization done, starting training...")

    last_loss = 0.0
    default_mask = lambda x: np.ones(x.shape[:2], dtype=bool)
    eval_losses = []
    eval_accs = []
    eval_final_accs = []

    for step in tqdm.trange(training_steps, disable=not use_tqdm):
        batch, log_dict = data_generator.sample(with_markov=with_markov)
        # Transform one-hots to integer tokens.
        batch = np.argmax(batch, axis=-1)
        if "loss_mask" in log_dict:
            loss_mask = log_dict["loss_mask"]
        else:
            loss_mask = default_mask(batch)

        params, opt_state, logs = _update_parameters(
            params=params,
            opt_state=opt_state,
            sequences=batch,
            grad_fn=grad_fn,
            optimizer=optimizer,
            mask=loss_mask,
        )

        if log_every > 0 and step % log_every == 0:
            logging.info(
                "Step %d, Loss (avg cumulative nats) %f, Grad norm %f",
                step,
                logs["loss"],
                logs["grad_norm_unclipped"],
            )

        if step % (log_every * 10) == 0 and eval_data_generator is not None:
            last_loss = logs["loss"]
            eval_loss, eval_acc, eval_final_acc = evaluate_transformer_decoder(
                eval_data_generator, params, data_generator, size=size
            )
            eval_losses.append(eval_loss)
            eval_accs.append(eval_acc)
            eval_final_accs.append(eval_final_acc)
            print(
                f"Step {step}, Eval acc: {eval_acc}, Eval final acc: {eval_final_acc}"
            )

    return params, last_loss, eval_losses, eval_accs, eval_final_accs

In [5]:
rng = np.random.default_rng(seed=1)

utm_generator = utm_data_generator(
    rng,
    maximum_steps=EXECUTION_STEPS,
    maximum_program_length=100,
    memory_size=MEMORY_SIZE,
    alphabet_size=CHOMSKY_ALPHABET_SIZE,
    batch_size=BATCH_SIZE,
)

chomsky_generator = make_chomsky_generator(rng, use_delimiters=USE_DELIMITERS)

def training_loop():
    model_sizes = [
        "small",
        "medium",
        "large"
    ]

    sampling_types = [
        SAMPLE_TYPE.RANDOM, 
        SAMPLE_TYPE.ORIGINAL,
        SAMPLE_TYPE.MARKOV
    ]

    training_steps_map = {
        "small": 20000,
        "medium": 2000,
        "large": 2000
    }

    for sampling_type in sampling_types:
        for model_size in model_sizes:
            training_steps = training_steps_map[model_size]
            if sampling_type == SAMPLE_TYPE.RANDOM:
                training_steps = 0
            
            params, loss, eval_losses, eval_accs, eval_final_accs = train_transformer_decoder(
                data_generator=utm_generator,
                training_steps=training_steps,
                batch_size=BATCH_SIZE,
                log_every=10,
                with_markov=(sampling_type == SAMPLE_TYPE.MARKOV),
                size=model_size,
                eval_data_generator=chomsky_generator,
            )
            logging.info("Final loss: %f", loss)

            if sampling_type == SAMPLE_TYPE.ORIGINAL:
                SUFFIX = 'original'
            elif sampling_type == SAMPLE_TYPE.MARKOV:
                SUFFIX = 'markov'
            else:
                SUFFIX = 'random'

            file_name = f"params_{SUFFIX}_transformer_{model_size}.npz"
            save_params(params, file_name)

            logging.info(f"Parameters saved in file {file_name}")

            # Create a pandas DataFrame from the evaluation metrics
            eval_data = {
                "eval_losses": eval_losses,
                "eval_accs": eval_accs,
                "eval_final_accs": eval_final_accs,
            }
            eval_df = pd.DataFrame(eval_data)

            # Save the DataFrame to a CSV file
            metrics_name = f"metrics_{SUFFIX}_transformer_{model_size}.csv"
            if sampling_type != SAMPLE_TYPE.RANDOM:
                eval_df.to_csv(metrics_name, index=False)

            logging.info(f"Evaluation metrics saved to {metrics_name}")

training_loop()

Vocab Size: 128
Model Size: small
Batch Size: 32
With Markov: False


0it [00:00, ?it/s]


Vocab Size: 128
Model Size: medium
Batch Size: 32
With Markov: False


0it [00:00, ?it/s]


Vocab Size: 128
Model Size: large
Batch Size: 32
With Markov: False


0it [00:00, ?it/s]


Vocab Size: 128
Model Size: small
Batch Size: 32
With Markov: False


  0%|          | 0/20000 [00:00<?, ?it/s]

Chomsky Task:  even_pairs


  0%|          | 10/20000 [00:08<3:17:36,  1.69it/s]

Step 0, Eval acc: 0.008060423657298088, Eval final acc: 0.008061978965997696


  0%|          | 100/20000 [00:11<13:14, 25.05it/s] 

Chomsky Task:  even_pairs


  1%|          | 110/20000 [00:15<57:23,  5.78it/s]  

Step 100, Eval acc: 0.015095042996108532, Eval final acc: 0.015102459117770195


  1%|          | 198/20000 [00:18<13:32, 24.38it/s]

Chomsky Task:  even_pairs


  1%|          | 200/20000 [00:19<32:06, 10.28it/s]


KeyboardInterrupt: 