# Thunnini Demo

This notebook briefly showcases Thunnini's main features. The notebook pretrains a network (LSTM or Transformer) on a distribution over coins with random uniform bias (using the `DirichletCategorical` data generator). This network is then fine tuned on a mixture of two coins with bias 0.2 and 0.8 via soft prefix tuning (i.e., by tuning the embeddings of the first 6 tokens). The tuned prefix is then used during evaluation on the same mixture of two coins, and finally performance of the tuned predictor is compared against the Bayes predictor for the two-coin mixture, the pretrained network, and the untrained network.

The main aim of this notebook is to showcase how easy it is to set up predictors, data generators and pretraining, tuning, and evaluation with Thunnini.
See `ThunniniExperiment.ipynb` for a much more comprehensive notebook that features most of Thunnini's functionality and wraps it into an easily configurable interface.

# Imports

In [None]:
# @title Global imports

import dataclasses

# Utils
from matplotlib import pyplot as plt

# NNs / Linear algebra
import numpy as np
import jax

jax.config.update("jax_debug_nans", False)
%matplotlib inline

In [None]:
#@title Install Thunnini and its dependencies when running on Colab
try:  # When on Google Colab, clone the repository and install dependencies.
    import google.colab
    repo_path = 'thunnini'
    !git -C $repo_path pull origin || git clone https://github.com/google-deepmind/thunnini $repo_path
    !cd $repo_path
    !export PYTHONPATH=$(pwd)/..
    !pip install -r $repo_path/requirements.txt
except:
    repo_path = '.'  # Use the local path if not on Google Colab

In [None]:
# @title Thunnini imports
from thunnini.src import builders
from thunnini.src import config as config_lib
from thunnini.src import evaluation
from thunnini.src import plot_utils
from thunnini.src import training
from thunnini.src import tuning

# Experiment Configurations

In [None]:
#@title Predictor configuration

embedding_dim = 16
torso_type = "LSTM" #"LSTM", "Transformer"
hidden_sizes = [64, 32]


predictor_config = config_lib.PredictorConfig(
    token_dimensionality=2,  # binary tokens
    embedding_dimensionality=embedding_dim,
)


if torso_type == "LSTM":
  torso_config = config_lib.LSTMTorsoConfig(
    is_trainable=True,
    hidden_sizes=hidden_sizes,
    return_hidden_states=False
)
else:
  torso_config = config_lib.TransformerTorsoConfig(
    is_trainable=True,
    hidden_sizes=hidden_sizes,
    num_attention_heads=4,
    positional_encoding = 'SinCos',
    return_hidden_states=False,
    use_bias=False,
    widening_factor=4,
    normalize_qk=True,
    use_lora=True,
    reduced_rank=4,
)

In [None]:
#@title Training configuration

training_data_config = config_lib.DirichletCategoricalGeneratorConfig(
    batch_size=128,
    sequence_length=50,
    vocab_size=2,
    alphas=np.array([1, 1]),
)

training_config = config_lib.TrainingConfig(
    num_training_steps=1000,
    learning_rate=5e-3,
    max_grad_norm=1.0,
    data_gen_seed=0,
    predictor_init_seed=0,
)

In [None]:
#@title Tuning configuration

tuning_data_config = config_lib.MixtureOfCategoricalsGeneratorConfig(
    batch_size=128,
    sequence_length=50,
    vocab_size=2,
    biases=np.array([[0.2, 0.8], [0.8, 0.2]]),
    mixing_weights=np.array([0.25, 0.75]),
)

tuning_config = config_lib.TuningConfig(
    num_tuning_steps=1000,
    learning_rate=5e-3,
    max_grad_norm=1.0,
    data_gen_seed=10,
    prefix_init_seed=10,
    tuning_method="prefix_soft",
    prefix_length=6,
    prefix_init_method="one_hot",
)

In [None]:
#@title Evaluation configuration

# Evaluation will be on a single batch of this generator, so we choose
# a large batch.
eval_data_config = config_lib.MixtureOfCategoricalsGeneratorConfig(
    batch_size=1024,
    sequence_length=100,
    vocab_size=2,
    biases=np.array([[0.2, 0.8], [0.8, 0.2]]),
    mixing_weights=np.array([0.25, 0.75]),
)

# Pretraining

In [None]:
#@title Pretrain predictor
trained_params, train_results = training.train(
    training_config=training_config,
    predictor_config=predictor_config,
    torso_config=torso_config,
    data_config=training_data_config,
)

In [None]:
# Plot training loss curve
ax = plot_utils.plot_performance_metric(
    {torso_type: [train_results['loss']]},
    'Training loss',
    aggregate_fn_only = True,  # No variability band needed, single repetition.
    show_gridlines = True,
)
ax.set_xlabel('Training Step')
ax.set_title('Pretraining on ' + training_data_config.generator_type)

In [None]:
# @title Manually using the pretrained predictor

# The code below demonstrates how to manually use the a predictor and
# evaluate it on some sequences. Thunnini also has convenience functions that
# encapsulate this, see, e.g., `evaluation.evaluate_predictor_from_datagen`.

# Instantiate data generator and sample a batch.
datagen_tmp = builders.build_datagen(training_data_config)
batch_tmp = datagen_tmp.generate(
    rng_key=jax.random.PRNGKey(1337),
    return_ground_truth_log_probs=False
    )

# We'll also change the torso config to return the hidden states.
torso_config_tmp = dataclasses.replace(torso_config, return_hidden_states=True)
# The predictor is stateless - simply build a new instance.
predictor_tmp = builders.build_predictor(predictor_config, torso_config_tmp)
# Run a forward pass.
logits, hidden_states, prefix_logits, prefix_hidden = predictor_tmp.apply(
      trained_params, sequences=batch_tmp, prefix_type='None', prefix=None
      )
predictor_log_losses = datagen_tmp.instant_log_loss_from_logits(
      logits, batch_tmp
      )

print('Instant log loss shape:', predictor_log_losses.shape)
print('Hidden states dict keys:', hidden_states.keys())

print('Prefix logits:', prefix_logits)  # Will be None - no prefix used.
print('Prefix hidden states:', prefix_hidden)  # Will be None - no prefix used.

# Tuning

In [None]:
#@title Tune pretrained predictor
tuned_params, tuned_prefix, tuning_results = tuning.tune(
    tuning_config=tuning_config,
    predictor_config=predictor_config,
    torso_config=torso_config,
    predictor_params=trained_params,
    data_config=tuning_data_config,
)

In [None]:
# Plot tuning loss curve
ax = plot_utils.plot_performance_metric(
    {torso_type: [tuning_results['loss']]},
    'Tuning loss',
    aggregate_fn_only = True,  # No variability band needed, single repetition.
    show_gridlines = True,
)
ax.set_xlabel('Tuning Step')
ax.set_title(tuning_config.tuning_method + ' tuning on ' + tuning_data_config.generator_type)

# Evaluation

In [None]:
#@title Evaluate pretrained predictor

eval_results_trained = evaluation.evaluate_predictor_from_datagen(
    predictor_config=predictor_config,
    torso_config=torso_config,
    predictor_params=trained_params,
    datagen_config=eval_data_config,
    datagen_seed=1337,
    datagen_num_batches=1,
    return_gt_and_optimal_results=True,
)
sequences, trained_logits, trained_log_losses, bo_log_probs, bo_losses, gt_log_probs, gt_losses = eval_results_trained

In [None]:
#@title Evaluate tuned predictor on the same sequences

tuned_logits, tuned_log_losses = evaluation.evaluate_predictor_from_sequences(
    predictor_config=predictor_config,
    torso_config=torso_config,
    predictor_params=tuned_params,
    prefix_type = "embedding",
    prefix = tuned_prefix,
    sequences=sequences,
    batch_size=-1,
)

In [None]:
#@title Construct and evaluate untrained predictor on the same sequences
untrained_predictor = builders.build_predictor(predictor_config, torso_config)
untrained_params = untrained_predictor.init(
    rngs=jax.random.PRNGKey(815),
    sequences=sequences[0:10],  # Take some sequences as dummy sequences
)

untrained_logits, untrained_log_losses = evaluation.evaluate_predictor_from_sequences(
    predictor_config=predictor_config,
    torso_config=torso_config,
    predictor_params=untrained_params,
    sequences=sequences,
    batch_size = -1,
)

In [None]:
#@title Compute regrets
instant_regret = {
    'Bayes-optimal (' + eval_data_config.generator_type + ')': [np.mean(bo_losses - gt_losses, axis=0)],
    'Pretrained ' + torso_type + ' (' + training_data_config.generator_type + ')': [np.mean(trained_log_losses - gt_losses, axis=0)],
    'Tuned ' + torso_type + ' (' + tuning_data_config.generator_type + ')': [np.mean(tuned_log_losses - gt_losses, axis=0)],
    'Untrained ' + torso_type: [np.mean(untrained_log_losses - gt_losses, axis=0)],
    }

cumulative_regret = {}
for model, regret in instant_regret.items():
  cumulative_regret[model] = [np.cumsum(regret)]

In [None]:
# Plot evaluation results
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(9, 8))

plot_utils.plot_performance_metric(
    instant_regret,
    'Instant regret [nats]',
    axis=axes[0],
    aggregate_fn_only=True,  # No variability band needed, single repetition.
    show_gridlines=True,
    )
axes[0].set_title('Evaluation on ' + eval_data_config.generator_type)
axes[0].set_xlabel('')
axes[0].get_legend().remove()

plot_utils.plot_performance_metric(
    cumulative_regret,
    'Cumulative regret [nats]',
    axis=axes[1],
    aggregate_fn_only=True,  # No variability band needed, single repetition.
    show_gridlines=True,
    )

In [None]:
# @title Show one trajectory

fig = plt.figure(figsize=(9, 4))
seq_len = sequences.shape[1]
xvec = np.arange(seq_len)

# Plot observations from the first eval sequence
plt.plot(xvec, sequences[0, :, 0], '.', label='Observations')
# Plot the ground-truth probability (gt_log_probs are the same for each timestep)
plt.hlines(np.exp(gt_log_probs[0,0,0]), xmin=0, xmax=seq_len, label='Ground truth', color='goldenrod', linewidth=4)

# Plot predictions
plt.plot(xvec, jax.nn.softmax(trained_logits[0, :, :])[:,0], label='Pretrained ' + torso_type, linewidth=2)
plt.plot(xvec, jax.nn.softmax(tuned_logits[0, :, :])[:,0], label='Tuned' + torso_type, linewidth=2)
plt.plot(xvec, jax.nn.softmax(untrained_logits[0, :, :])[:,0], label='Untrained ' + torso_type, linewidth=2)
plt.plot(xvec, jax.nn.softmax(bo_log_probs[0, :, :])[:,0], label='Bayes-optimal', color='C0', linewidth=2)

plt.legend()
plt.xlabel('Step')
plt.yticks([0, 0.5, 1])
plt.grid('on')
plt.ylabel('Probability')
plt.title('Single trajectory predictions')