# Customizable full experimentation pipeline

This notebook provides (easily configurable) infrastructure to:


1.  Pretrain a network (LSTM or Transformer) on data from one of the data generators. Pretraining is optional and can be skipped if you want to investigate weight tuning / prefix tuning of untrained networks.
2.  Tune the network using various tuning methods and to a tuning distribution (tuning data generator). This potentially creates a number of tuned models from the same base pretrained network.
3.  Evaluate sequential prediction performance of all tuned networks on potentially multiple evaluation data generators. Each evaluation can be repeated a number of times with different random seed (which will produce error bands in the comparison plots).
4.  Compare all tuned models on each evaluation data generator. The comparison also includes the non-tuned base model, the Bayes predictor for the respective evaluation data generator (the "Bayes optimal" baseline), the Bayes predictor for the pretraining distribution (and a prompt tuned version of it), and untuned (random) prefixes.

The full experiment can be customized via configuration dictionaries for each part. The standard settings can be customized via colab forms in the "Main User Settings" section below. Additional settings are accessible via the code in the configuration sections that follow.

> Comparing many tuning methods over many repetitions (with different seed) can quickly produce dozens of tuning runs. Keep models small and the number of tuning steps low to avoid very long run times.


## Tuning methods

The following tuning methods are available:
* Full weight tuning.
* Tuning of embedding or unembedding layer only, or both.
* LoRA (transformers only): tuning of additive low rank matrices for all linear layers of the inner transformer block.
* Gradient-based tuning of a prompt prefix:
  * Simplex: prefix is sequence of vectors in the simplex spanned by the one-hot tokens.
  * Real: prefix is sequence of real-valued vectors without constraints.
  * Soft: embeddings of prefix tokens are tuned instead of tokens themselves.
* Hard prefix: exhaustive search over all hard token sequences.

> Hard token prefix tuning is done via exhaustive search, which is only possible for short prefix length. Disable it for very long prefixes.


## Data generators

Thunnini currently implements three families of categorical distributions; this notebook restricts tokens to be binary (one-hot), leading to the following data generators:
*  Single coin: coin with fixed bias (Binomial distribution).
*  Mixture of two coins: two coins with different biases and a user definable mixture proportion (mixture of Binomials).
*  Random coins: Beta distribution over coin biases (typically set uniform).


## Performance metric (regret)

The main comparison metric shown in most plots is "regret", which is a model's excess prediction error, i.e., expected (cumulative) log loss, relative to the best possible prediction error given by an oracle (the data generator) that knows the emission probabilities in each step. A regret of zero thus does not mean no prediction error, but lowest theoretically achievable prediction error (with hidden knowledge).

Let the neural predictor with parameters $\theta$ be $\pi_\theta$. And let the data generator for sequences $x_{1:N}$ of length $N$ be a family of sources parameterized by $\tau$: $\xi(x_{1:N}\vert\tau)$. For instance, $\xi(\cdot \vert\tau)$ could be the family of coins with bias $\tau$. Given a distribution over $\tau$ (e.g., the uniform distribution), the data distribution is:  
$\xi(x) = \int \xi(x\vert \tau) p(\tau) d\tau$.

The expected log loss of the predictor over the data distribution is:  
$\mathcal{L}_\xi(\pi_\theta)=\mathbb{E}_{\xi} \left[ \sum_{i=1}^N - \log \pi_\theta(x_i \vert x_{<i}) \right]$.  
In practice, the expectation over $\xi$ is replaced with the average over a sample $\{x^{1}, \ldots, x^{D}\}$ with $x^{k}\sim \xi(\cdot \vert \tau^* \sim p(\tau))$, i.e., for each sequence first sample a value of $\tau$ and then sample the sequence from $\xi(\cdot \vert \tau)$.

The cumulative regret of the predictor is its log loss relative to the best possible (oracle) prediction:  
$\mathcal{R}_\xi(\pi_\theta)=1/D \sum_{k=1}^D \left[ \sum_{i=1}^N -\log \pi_\theta(x^k_i \vert x^k_{<i}) + \log \xi(x^k_i \vert x^k_{<i}, \tau=\tau^*) \right]$,    
where the term in the inner sum is the instantaneous regret per time step, and $\tau^*$ is the ground-truth (coin bias) that is only known to the data generator.

Without additional oracle knowledge, zero regret is generally not achievable in this setting. Instead, the best possible achievable regret is given by the Bayes predictor for the data generator:  
$\pi_{\text{Bayes}}(x) = \int \xi(x\vert \tau)  p(\tau) d\tau = \xi(x)$.

Through pretraining, neural predictors will (if the architecture has enough capacity and training converges properly) achieve Bayes optimality on their respective pretraining distribution. Their predictions and regrets become indistinguishable from the Bayes predictor (on the pretraining distribution). By fine tuning with data from a particular value of $\tau^*$, e.g., a coin with a particular bias, the fine tuned models can indirectly gain some or all of the "oracle info" (in their weights or via the tuned prefix) and may thus initially outperform the Bayes predictor for the pretraining distribution (not the tuning distribution) - but only for the particular downstream task $\tau^*$ that the models were tuned for. In an abstract sense, this is what fine tuning aims for: there is additional knowledge over the downstream task(s) that was not available at pretraining time (e.g., that the downstream task is answering questions of a high-school biology exam), and the goal is to make this information available to the tuned predictor (by tuning weights or prompts; or even just coming up with an ad hoc prompt). After a sufficient number of observations, the Bayes predictor will catch up in terms of instantaneous regret, but cumulative regret will keep track of the initial difference of course.


> By default, evaluation plots show the median over evaluation repetitions (different seeds for fine tuning runs, i.e., different samples from the tuning data generator and random initial prefixes for prefix tuning) and 25-75 quantile error bars or shaded areas (thus covering 50% of the repetitions). Thin lines show individual repetitions.


## Effect of prefixing on internal dynamics

The final section of this notebook allows to record the internal states of neural predictors during evaluation. These states are then projected via PCA to 2D, which usually reveals a quite structured representation - typically one dimension aligns with counts of heads/tails and the other dimension with the number of observations. By plotting individual trajectories with various prefixes into the same projection, the effect of the different prefix types can be seen.

> This analysis only makes sense for comparing prefix tuning methods, since it is unclear how one would project activations from models with different weights (due to weight tuning or LoRA) into the same lower-dimensional space. Intuitively, different prefixes allow to set different initial points in the same activation space, whereas weight tuning methods reshape the activation space to have better geometry for the downstream task.

# Imports

In [None]:
# @title Global imports
import collections
import dataclasses
import logging
import pathlib

# Set default logging level
logging.basicConfig(level=logging.WARNING,
                    force = True)

# Utils
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA

# NNs / Linear algebra
import numpy as np
import jax
import jax.numpy as jnp

jax.config.update("jax_debug_nans", False)
%matplotlib inline

In [None]:
#@title Install Thunnini and its dependencies when running on Colab
try:  # When on Google Colab, clone the repository and install dependencies.
    import google.colab
    repo_path = 'thunnini'
    !git -C $repo_path pull origin || git clone https://github.com/google-deepmind/thunnini $repo_path
    !cd $repo_path
    !export PYTHONPATH=$(pwd)/..
    !pip install -r $repo_path/requirements.txt
except:
    repo_path = '.'  # Use the local path if not on Google Colab

In [None]:
# @title Thunnini imports
from thunnini.src import builders
from thunnini.src import config as config_lib
from thunnini.src import evaluation
from thunnini.src import plot_utils
from thunnini.src import training
from thunnini.src import tuning
from thunnini.src import types

# Experiment Configuration

In [None]:
# @title Main user settings

# @markdown **Architecture**
architecture = "Transformer"  # @param ["LSTM","LSTM_untrained", "Transformer", "Transformer_untrained", "Linear", "Linear_untrained"]
embedding_dim = 128  # @param {"type":"integer"}
# @markdown ---


# @markdown **Pretraining**
pretraining_source = "Random Coins"  # @param ["Single Coin", "Two-Coin Mixture", "Random Coins"]
pretraining_sequence_length = 100  # @param {"type":"integer"}
pretraining_batch_size = 256  # @param {"type":"integer"}
pretraining_num_steps = 1000  # @param {"type":"integer"}
# @markdown ---

# @markdown **Tuning**
tuning_source = "Two-Coin Mixture"  # @param ["Single Coin", "Two-Coin Mixture", "Random Coins"]
tuning_sequence_length = 50  # @param {"type":"integer"}
tuning_batch_size = 256  # @param {"type":"integer"}
tuning_num_steps = 1000  # @param {"type":"integer"}
prefix_length = 6  # @param {"type":"integer"}
tuning_num_repetitions = 10  # @param {"type":"integer"}

# @markdown Prefix tuning
hard_token_tuning = True  # @param {type:"boolean"}
simplex_pf_tuning = True  # @param {type:"boolean"}
real_pf_tuning = True  # @param {type:"boolean"}
soft_pf_tuning = True  # @param {type:"boolean"}
# @markdown Fine tuning
full_fine_tuning = True  # @param {type:"boolean"}
lora_tuning = True  # @param {type:"boolean"}
# @markdown Embedding tuning
embedding_tuning = True  # @param {type:"boolean"}
unembedding_tuning = True  # @param {type:"boolean"}
un_embedding_tuning = True  # @param {type:"boolean"}
# @markdown Baselines
random_prefix = True  # @param {type:"boolean"}
pretrain_bayes = True  # @param {type:"boolean"}
pretrain_bayes_pt = True  # @param {type:"boolean"}

tuning_names = []
if simplex_pf_tuning:
  tuning_names.append("SimplexPT")
if real_pf_tuning:
  tuning_names.append("RealPT")
if soft_pf_tuning:
  tuning_names.append("SoftPT")
if full_fine_tuning:
  if architecture.endswith("_untrained"):
    full_fine_tuning = False
    print("Torso non-trainable. Full fine tuning was disabled.")
  else:
    tuning_names.append("FullWT")
if lora_tuning:
  if not architecture.startswith("Transformer"):
    raise ValueError("LoRA tuning is only supported for transformers.")
  tuning_names.append("LoRAWT")
if embedding_tuning:
  tuning_names.append("EmbedWT")
if unembedding_tuning:
  tuning_names.append("UnembedWT")
if un_embedding_tuning:
  tuning_names.append("Un+EmbedWT")
if hard_token_tuning:
  tuning_names.append("HardPT")
if random_prefix:
  tuning_names.append("RandomPF")
if pretrain_bayes and architecture.endswith("_untrained"):
  print("PreBayes is only supported for pretrined models.")
  pretrain_bayes = False
if pretrain_bayes_pt and architecture.endswith("_untrained"):
  print("PreBayesPT is only supported for pretrined models.")
  pretrain_bayes_pt = False
# @markdown ---

# @markdown **Evaluation**
eval_sequence_length = 200  # @param {"type":"integer"}
eval_num_sequences = 2048  # @param {"type":"integer"}
eval_single_coin = True  # @param {type:"boolean"}
eval_two_coin_mixture = True  # @param {type:"boolean"}
eval_random_coins = True  # @param {type:"boolean"}

eval_names = []
if eval_single_coin:
  eval_names.append("Single Coin")
if eval_two_coin_mixture:
  eval_names.append("Two-Coin Mixture")
if eval_random_coins:
  eval_names.append("Random Coins")

# @markdown ---
store_results = True  # @param {type:"boolean"}
store_path = "/tmp/thunnini_exp/" # @param {type:"string"}
if store_results:
    spath = pathlib.Path(store_path)
    spath.mkdir(parents=True, exist_ok=True)

In [None]:
# @title Data sources general configurations

all_data_sources = {
    "Single Coin": config_lib.CategoricalGeneratorConfig(
        batch_size=128,
        sequence_length=100,
        vocab_size=2,
        biases=np.array([0.2, 0.8]),
    ),
    "Two-Coin Mixture": config_lib.MixtureOfCategoricalsGeneratorConfig(
        batch_size=128,
        sequence_length=100,
        vocab_size=2,
        biases=np.array([[0.2, 0.8], [0.8, 0.2]]),
        mixing_weights=np.array([0.5, 0.5]),
    ),
    "Random Coins": config_lib.DirichletCategoricalGeneratorConfig(
        batch_size=128,
        sequence_length=100,
        vocab_size=2,
        alphas=np.array([1, 1]),
    ),
}

In [None]:
# @title Predictor configuration
predictor_config = config_lib.PredictorConfig(
    token_dimensionality=2,
    embedding_dimensionality=embedding_dim,
)

train = False if architecture.endswith("_untrained") else True
if architecture.startswith("LSTM"):
  torso_config = config_lib.LSTMTorsoConfig(
      is_trainable=train, hidden_sizes=[128], return_hidden_states=False
  )
elif architecture.startswith("Transformer"):
  torso_config = config_lib.TransformerTorsoConfig(
      is_trainable=train,
      hidden_sizes=[128],  # One layer per entry. Only width of MLP
      # block is affected though.
      num_attention_heads=4,
      positional_encoding="SinCos",
      return_hidden_states=False,
      use_bias=False,
      widening_factor=4,
      normalize_qk=True,
      use_lora=True,
      reduced_rank=4,
  )
elif architecture.startswith("Linear"):
  torso_config = config_lib.LinearTorsoConfig(
      is_trainable=train,
      hidden_sizes=[64, 32],
  )
else:
  raise ValueError(f"Unknown architecture: {architecture}")

In [None]:
# @title Pretraining configuration
training_data_config = dataclasses.replace(
    all_data_sources[pretraining_source],
    batch_size=pretraining_batch_size,
    sequence_length=pretraining_sequence_length,
)

training_config = config_lib.TrainingConfig(
    num_training_steps=pretraining_num_steps,
    learning_rate=1e-3,
    max_grad_norm=1.0,
    data_gen_seed=0,
    predictor_init_seed=0,
)

In [None]:
# @title Tuning configuration
tuning_data_config = dataclasses.replace(
    all_data_sources[tuning_source],
    batch_size=tuning_batch_size,
    sequence_length=tuning_sequence_length,
)

# Define all soft prefix tuning methods
tuning_config_base = config_lib.TuningConfig(
    num_tuning_steps=tuning_num_steps,
    learning_rate=5e-3,
    max_grad_norm=1.0,
    data_gen_seed=10,
    prefix_init_seed=11,
    tuning_method="prefix_real",
    num_tuning_repetitions=tuning_num_repetitions,
    prefix_length=prefix_length,
    prefix_init_method="one_hot",  # ["one_hot", "simplex", "zeros"]
    iterate_datagen_seed_over_repetitions=True,
)

tuning_configs = collections.OrderedDict()

for tuning_name in tuning_names:
  match tuning_name:
    case "SimplexPT":
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base, tuning_method="prefix_simplex"
      )
    case "RealPT":
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base, tuning_method="prefix_real"
      )
    case "SoftPT":
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base, tuning_method="prefix_soft"
      )
    case "FullWT":
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base,
          tuning_method="full_parameters",
          prefix_length=None,
          prefix_init_method=None,
      )
    case "LoRAWT":
      if not torso_config.use_lora:
        raise ValueError("Torso not set to support LoRA.")
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base,
          tuning_method="lora_finetune",
          prefix_length=None,
          prefix_init_method=None,
      )
    case "EmbedWT":
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base,
          tuning_method="embedding",
          prefix_length=None,
          prefix_init_method=None,
      )
    case "UnembedWT":
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base,
          tuning_method="unembedding",
          prefix_length=None,
          prefix_init_method=None,
      )
    case "Un+EmbedWT":
      tuning_configs[tuning_name] = dataclasses.replace(
          tuning_config_base,
          tuning_method="embedding_unembedding",
          prefix_length=None,
          prefix_init_method=None,
      )

In [None]:
# @title Evaluation configuration
eval_data_configs = {}
for eval_name in eval_names:
  eval_data_configs[eval_name] = dataclasses.replace(
      all_data_sources[eval_name],
      batch_size=eval_num_sequences,
      sequence_length=eval_sequence_length,
  )

In [None]:
# Assign a unique color to each tuning method (for consistent colors)
tuning_method_index = collections.OrderedDict([
    ("HardPT", 0),
    ("SimplexPT", 1),
    ("RealPT", 2),
    ("SoftPT", 3),
    ("FullWT", 4),
    ("LoRAWT", 5),
    ("EmbedWT", 6),
    ("UnembedWT", 7),
    ("Un+EmbedWT", 8),
    ("TargetBayes", 9), ("EvalBayes", 9),  # These are equivalent.
    ("PreBayes", 10),
    ("NoTuning", 11),
    ("PreBayesPT", 12),
    ("RandomPF", 13),
    ("ground_truth", 14),
])

default_color_cycler = plt.cycler(
    color=[
        "deepskyblue",  # 0
        "lightseagreen",  # 1
        "tab:blue",  # 2
        "navy",  # 3
        "goldenrod",  # 4
        "tab:orange",  # 5
        "sienna",  # 6
        "tab:red",  # 7
        "maroon",  # 8
        "black",  # 9
        "dimgray",  # 10
        "olivedrab",  # 11
        "darkgray",  # 12
        "limegreen",  # 13
        "palegreen"  # 14
    ]
)

# Set the default color cycle
plt.rcParams["axes.prop_cycle"] = default_color_cycler
# Default plot settings
rc_context = {
    "axes.facecolor": "whitesmoke",
    "grid.color": "gainsboro",
    "axes.edgecolor": "whitesmoke",
    "axes.labelcolor": "#242c2e",
    "text.color": "#242c2e",
    "ytick.color": "#242c2e",
    "xtick.color": "#242c2e",
    "legend.edgecolor": "none",
    "axes.labelsize": "xx-large",
    "xtick.labelsize": "x-large",
    "ytick.labelsize": "x-large",
    "axes.titlesize": "xx-large",
}

In [None]:
# @title Write configs to file

if store_results:
  with open(store_path + "configs.txt", "w") as f:
    print("--- Architecture configuration ---", file=f)
    print(predictor_config, file=f)
    print(torso_config, file=f)

    print("\n--- Training configuration ---", file=f)
    print(training_config, file=f)
    print(f"Training data generator: {pretraining_source}", file=f)
    print("\t", training_data_config, file=f)

    print("\n--- Tuning data configuration ---", file=f)
    print(f"Tuning data generator: {tuning_source}", file=f)
    print("\t", tuning_data_config, file=f)
    print("\n--- Tuning configurations ---", file=f)
    for tuning_name in tuning_names:
      if tuning_name in tuning_configs:
        print(f"\t{tuning_name}", file=f)
        print("\t", tuning_configs[tuning_name], file=f)
      else:
        print(f"\t{tuning_name}", file=f)

    print("\n--- Evaluation data configuration ---", file=f)
    for eval_name in eval_names:
      print(f"\t{eval_name}", file=f)
      print("\t", eval_data_configs[eval_name], file=f)
  print("Configs written to", store_path+"configs.txt")

# Pretraining

In [None]:
print("--- Architecture configuration ---")
print(predictor_config)
print(torso_config)
print("\n--- Training configuration ---")
print(training_config)
print(training_data_config)

In [None]:
# @title Pretrain predictor
trained_params, train_results = training.train(
    training_config=training_config,
    predictor_config=predictor_config,
    torso_config=torso_config,
    data_config=training_data_config,
)

In [None]:
# Plot loss curve
with plt.rc_context(rc_context):
  if train_results:
    ax = plot_utils.plot_performance_metric(
        {architecture: [train_results["loss"]]},
        "Training loss",
        aggregate_fn_only=True,  # No variability band, single repetition.
        show_gridlines=True,
    )
    ax.set_xlabel("Training Step")
    ax.set_title(f"Pretraining on {pretraining_source}.")
    if store_results:
      plt.savefig(store_path + "pretraining_loss_curve.pdf", bbox_inches="tight")
      print("Figure written to:", store_path + "pretraining_loss_curve.pdf")
  else:
    if torso_config.is_trainable:
      raise ValueError("Training failed but torso is trainable. Aborting.")
    print(
        "Predictor initialized but training skipped since torso is not"
        " trainable."
    )

# Tuning and Evaluation

In [None]:
print("--- Tuning data configuration ---")
print(tuning_data_config)

print("\n--- Tuning methods ---")
for tuning_name in tuning_names:
  if tuning_name in tuning_configs:
    print(f"{tuning_name} ({tuning_configs[tuning_name].tuning_method})")
  else:
    print(f"{tuning_name}")

In [None]:
# @title Run main tuning experiment and evaluations

if "FullWT" in tuning_configs and not torso_config.is_trainable:
  raise ValueError("Full weight tuning requires trainable torso.")

# Set logging level so we see tuning experiment progress in colab.
logging.basicConfig(level=logging.INFO,
                    force = True)

results, sequences = tuning.run_tuning_experiment(
  predictor_config=predictor_config,
  torso_config=torso_config,
  predictor_params=trained_params,
  tuning_configs=tuning_configs,
  tuning_data_config=tuning_data_config,
  eval_data_configs=eval_data_configs,
  eval_datgen_seed=0,
  eval_batching_batch_size=-1,
  evaluate_untuned_predictor=True,
  return_tuned_prefix=True,
)

# Reset logging level
logging.basicConfig(level=logging.WARNING,
                    force = True)

In [None]:
# @title Tune hard prefix (exhaustive search)

def int_to_binary_one_hot(number: int, binary_length: int) -> np.ndarray:
  """Returns one-hot binary representation of number as array."""
  binary_rep_list = list(np.binary_repr(number).zfill(binary_length))
  binary_array = np.array(binary_rep_list, dtype=np.uint8)
  one_hot = np.vstack([binary_array, 1 - binary_array]).transpose()
  return one_hot


def one_hot_binary_to_str(one_hot: np.ndarray) -> str:
  """Returns string representation of one-hot binary representation."""
  return "".join(one_hot[:, 0].astype(str))


if ("HardPT" in tuning_names) or pretrain_bayes_pt:
  hard_pf_tuning_batch_size = 2048

  # Exhaustively generate all hard prefixes
  all_hard_prefixes = [
      int_to_binary_one_hot(i, prefix_length) for i in range(2**prefix_length)
  ]

  # Draw sequences for tuning, make sure to hit the batch size as
  # tuning_data_config.batch_size may be < hard_pf_tuning_batch_size.
  hard_pf_tune_dg = builders.build_datagen(tuning_data_config)
  hard_pf_tune_sequences = jnp.array([])
  while len(hard_pf_tune_sequences) < hard_pf_tuning_batch_size:
    seqs = hard_pf_tune_dg.generate(
        rng_key=jax.random.PRNGKey(5), return_ground_truth_log_probs=False
    )
    if hard_pf_tune_sequences.size > 0:
      hard_pf_tune_sequences = jax.numpy.concatenate(
          [hard_pf_tune_sequences, seqs], axis=0
      )
    else:
      hard_pf_tune_sequences = seqs
  hard_pf_tune_sequences = hard_pf_tune_sequences[:hard_pf_tuning_batch_size]

  if "HardPT" in tuning_names:
    # Evaluate all hard prefixes on tuning "batch"
    hard_pf_logits, hard_pf_losses = evaluation.evaluate_prefix_list(
        prefix_list=all_hard_prefixes,
        prefix_type="prepend",
        predictor_config=predictor_config,
        torso_config=torso_config,
        predictor_params=trained_params,
        sequences=hard_pf_tune_sequences,
        batch_size=-1,  # Set to -1 to evaluate all sequences as single batch
    )
else:
  print("Prefix tuning skipped.")

In [None]:
# @title Process hard prefix tuning results

if "HardPT" in tuning_names:
  all_losses = np.array(hard_pf_losses)
  cum_losses = np.sum(all_losses, axis=-1)
  avg_cum_losses = np.mean(cum_losses, axis=-1)

  sort_inds = np.argsort(avg_cum_losses, axis=0)
  sorted_avg_cum_losses = avg_cum_losses[sort_inds]
  sorted_hard_prefixes = np.array(all_hard_prefixes)[sort_inds]

  best_hard_prefix = sorted_hard_prefixes[0]
  best_hard_prefix_str = one_hot_binary_to_str(best_hard_prefix)

  print(
      "Best mean cumulative loss:",
      sorted_avg_cum_losses[0],
      "(",
      best_hard_prefix_str,
      ")",
  )
  print(
      "Worst mean cumulative loss:",
      sorted_avg_cum_losses[-1],
      "(",
      one_hot_binary_to_str(sorted_hard_prefixes[-1]),
      ")",
  )
else:
  print("Hard prefix tuning skipped.")

In [None]:
# @title Plot hard prefix tuning results

if "HardPT" in tuning_names:
  with plt.rc_context(rc_context):
    fig = plt.figure(figsize=(15, 3.5))
    ax = fig.gca()
    xvec = range(len(all_hard_prefixes))
    ax.bar(xvec, sorted_avg_cum_losses, zorder=3)
    ax.set_xticks(xvec)
    ax.set_xticklabels(map(one_hot_binary_to_str, sorted_hard_prefixes))
    ax.tick_params(axis="x", labelrotation=90)
    ax.set_ylabel("Cumulative loss [nats]")
    ax.set_xlabel("Hard prefix")
    if torso_config.is_trainable:
      ax.set_title(
          f"Hard prefix tuning ({hard_pf_tuning_batch_size} trajectories)\n{architecture}: {pretraining_source} → {tuning_source}."
      )
    else:
      ax.set_title(
          f"Hard prefix tuning ({hard_pf_tuning_batch_size} trajectories)\n{architecture} → {tuning_source}."
      )
    ax.grid(True, axis="y", zorder=0)
  if store_results:
    plt.savefig(store_path + "hard_prefix_tuning.pdf", bbox_inches="tight")
    print("Figure written to:", store_path + "hard_prefix_tuning.pdf")
else:
  print("Hard prefix tuning skipped.")

In [None]:
# @title Evaluate best hard prefix and add to results

if "HardPT" in tuning_names:
  results["HardPT"] = collections.OrderedDict()
  results["HardPT"]["tuned_prefix"] = best_hard_prefix
  for eval_name, eval_sequences in sequences.items():
    hard_pf_eval_logits, hard_pf_eval_losses = evaluation.evaluate_prefix_list(
        prefix_list=[best_hard_prefix],
        prefix_type="prepend",
        predictor_config=predictor_config,
        torso_config=torso_config,
        predictor_params=trained_params,
        sequences=eval_sequences,
        batch_size=-1,  # Set to -1 to evaluate all sequences as single batch
    )
    results["HardPT"][eval_name] = {}
    results["HardPT"][eval_name]["logits"] = hard_pf_eval_logits
    results["HardPT"][eval_name]["losses"] = hard_pf_eval_losses
else:
  print("Hard prefix tuning skipped.")

In [None]:
# @title Evaluate untuned real prefixes and add to results (as a control)

if "RandomPF" in tuning_names:
  # Collect all initial prefixes from tuning
  all_init_prefixes = results["RealPT"]["initial_prefix"]

  results["RandomPF"] = collections.OrderedDict()
  results["RandomPF"]["initial_prefix"] = all_init_prefixes
  for eval_name, eval_sequences in sequences.items():
    init_pf_logits, init_pf_losses = evaluation.evaluate_prefix_list(
        prefix_list=all_init_prefixes,
        prefix_type="prepend",
        predictor_config=predictor_config,
        torso_config=torso_config,
        predictor_params=trained_params,
        sequences=eval_sequences,
        batch_size=-1,  # Set to -1 to evaluate all sequences as single batch
    )
    results["RandomPF"][eval_name] = {}
    results["RandomPF"][eval_name]["logits"] = init_pf_logits
    results["RandomPF"][eval_name]["losses"] = init_pf_losses

In [None]:
# @title Evaluate Bayesian predictor for pretraining distr. on eval distr.

if pretrain_bayes or pretrain_bayes_pt:
  pretraining_dg = builders.build_datagen(training_data_config)
  if pretrain_bayes:
    results["PreBayes"] = collections.OrderedDict()
    for eval_name, eval_sequences in sequences.items():
      # Evaluate Pretraining Bayes
      preb_logits, preb_losses = pretraining_dg.solve(eval_sequences)
      results["PreBayes"][eval_name] = {}
      results["PreBayes"][eval_name]["logits"] = preb_logits
      results["PreBayes"][eval_name]["losses"] = preb_losses

  if pretrain_bayes_pt:
    results["PreBayesPT"] = collections.OrderedDict()
    for eval_name, eval_sequences in sequences.items():
      # Prefix tune Pretraining Bayes
      best_prefix = None
      best_prefix_cum_loss = np.inf
      for prefix in all_hard_prefixes:
        prefixes = jnp.tile(prefix, (hard_pf_tune_sequences.shape[0], 1, 1))
        _, prebpt_losses = pretraining_dg.solve(
            jnp.concatenate((prefixes, hard_pf_tune_sequences), axis=1)
        )
        prebpt_losses = prebpt_losses[:, prefix_length:]
        prebt_cum_loss = jnp.mean(jnp.sum(prebpt_losses, axis=-1), axis=-1)
        if prebt_cum_loss < best_prefix_cum_loss:
          best_prefix = prefix
          best_prefix_cum_loss = prebt_cum_loss

      print("Best prefix (PreBayesPT):", one_hot_binary_to_str(best_prefix))
      print("Best prefix cumulative loss (PreBayesPT):", best_prefix_cum_loss, "nats.")
      prebpt_logits, prebpt_losses = pretraining_dg.solve(
          jnp.concatenate((jnp.tile(best_prefix, (eval_sequences.shape[0], 1, 1)), eval_sequences), axis=1)
      )
      results["PreBayesPT"][eval_name] = {}
      results["PreBayesPT"][eval_name]["logits"] = prebpt_logits[:, prefix_length:, :]
      results["PreBayesPT"][eval_name]["losses"] = prebpt_losses[:, prefix_length:]

# Visualization of all tuning results

In [None]:
# @title Postprocess all results (compute regrets and prepare for plotting)

# Compute regrets
(gt_losses, instant_regrets, cumulative_regrets, tuning_loss_curves) = (
    plot_utils.postprocess_tuning_experiment_results(
        results, tuning_names, eval_names
    )
)


# Rename 'Bayes' to 'TargetBayes' or 'EvalBayes' and move HardPT and NoTuning
# to group prefix tuning methods and baselines.
for eval_name in eval_names:
  if 'Bayes' in results:  # this will also move to end
    if eval_name == tuning_source:
      instant_regrets[eval_name]['TargetBayes'] = instant_regrets[eval_name].pop('Bayes')
      cumulative_regrets[eval_name]['TargetBayes'] = cumulative_regrets[eval_name].pop('Bayes')
    else:
      instant_regrets[eval_name]['EvalBayes'] = instant_regrets[eval_name].pop('Bayes')
      cumulative_regrets[eval_name]['EvalBayes'] = cumulative_regrets[eval_name].pop('Bayes')
  if pretrain_bayes:
    # Manually process PreBayes results (not covered by postprocess above)
    preb_losses = results["PreBayes"][eval_name]["losses"]
    instant_regret_preb = jnp.mean(preb_losses-gt_losses[eval_name], axis=0)
    instant_regrets[eval_name]["PreBayes"] = [instant_regret_preb]
    cumulative_regrets[eval_name]["PreBayes"] = [np.cumsum(instant_regret_preb)]
  if 'HardPT' in results:
    instant_regrets[eval_name].move_to_end('HardPT', last=False)
    cumulative_regrets[eval_name].move_to_end('HardPT', last=False)
  if 'NoTuning' in results:
    instant_regrets[eval_name].move_to_end('NoTuning')
    cumulative_regrets[eval_name].move_to_end('NoTuning')
  if pretrain_bayes_pt:
    # Manually process PreBayesPT results (not covered by postprocess above)
    prebpt_losses = results["PreBayesPT"][eval_name]["losses"]
    instant_regret_prebpt = jnp.mean(prebpt_losses-gt_losses[eval_name], axis=0)
    instant_regrets[eval_name]["PreBayesPT"] = [instant_regret_prebpt]
    cumulative_regrets[eval_name]["PreBayesPT"] = [np.cumsum(instant_regret_prebpt)]
  if 'RandomPF' in results:
    instant_regrets[eval_name].move_to_end('RandomPF')
    cumulative_regrets[eval_name].move_to_end('RandomPF')

In [None]:
# @title Construct legends grouped per method

pt_methods = []
wt_methods = []
baselines = []
for mname in tuning_method_index:
  if mname in tuning_names:
    if mname.endswith("PT"):
      pt_methods.append(mname)
    elif mname.endswith("WT"):
      wt_methods.append(mname)
  else:
    if mname == "TargetBayes":
      baselines.append("TargetBayes")
    elif mname == "PreBayes" and pretrain_bayes:
      baselines.append("PreBayes")
    elif mname == "NoTuning":
      baselines.append("NoTuning")
    elif mname == "PreBayesPT" and pretrain_bayes_pt:
      baselines.append("PreBayesPT")

if "RandomPF" in tuning_names:
    baselines.append("RandomPF")

n_pt_methods = len(pt_methods)
n_wt_methods = len(wt_methods)
n_baselines = len(baselines)

In [None]:
# @title Tuning loss curves

colors = [f"C{tuning_method_index[name]}" for name in tuning_loss_curves]

with plt.rc_context(rc_context):
  ax = plot_utils.plot_performance_metric(
      tuning_loss_curves,
      "Tuning loss",
      aggregate_fn_only=False,
      show_gridlines=True,
      show_individual_lines=True,
      colors=colors,
  )
  if torso_config.is_trainable:
    ax.set_title(
        f"{architecture}: {pretraining_source} → {tuning_source}"
    )
  else:
    ax.set_title(f"{architecture}; untrained → {tuning_source}")
  ax.set_xlabel("Tuning step")
  ax.get_legend().set_title("Tuning method")

if store_results:
  plt.savefig(store_path + "tuning_loss_curves.pdf", bbox_inches="tight")
  print("Figure written to:", store_path + "tuning_loss_curves.pdf")

In [None]:
# @title Evaluation performance of tuned models

with plt.rc_context(rc_context):
  fig, axes = plt.subplots(
      nrows=2, ncols=len(eval_names), figsize=(6.5 * len(eval_names), 8)
  )
  for i, eval_name in enumerate(eval_names):
    if len(axes.shape) == 1:
      ax = axes[0]
      ax2 = axes[1]
    else:
      ax = axes[0, i]
      ax2 = axes[1, i]
    colors = [
        f"C{tuning_method_index[name]}"
        for name in cumulative_regrets[eval_name]
    ]
    plot_utils.plot_performance_metric(
        instant_regrets[eval_name],
        "Instant regret [nats]",
        axis=ax,
        aggregate_fn_only=False,
        show_gridlines=True,
        show_individual_lines=True,
        colors=colors,
    )
    ax.set_title("Evaluation on " + eval_name)
    ax.set_xlabel("")
    ax.get_legend().remove()

    plot_utils.plot_performance_metric(
        cumulative_regrets[eval_name],
        "Cumulative regret [nats]",
        axis=ax2,
        aggregate_fn_only=False,
        show_gridlines=True,
        show_individual_lines=True,
        colors=colors,
    )
    if i > 0:
      ax2.get_legend().set_visible(False)
      ax.set_ylabel("")
      ax2.set_ylabel("")
    else:
      leg = ax2.legend(
          #title="Tuning method",
          ncols=np.ceil(len(colors)/(4-len(eval_names))),
          loc="lower center",
          bbox_to_anchor=(0.5, -0.05),
          bbox_transform=fig.transFigure,
          edgecolor="dimgray",
      )
    ax2.set_xlabel("Step $n$")

  # Add markers for tuning- and pretraining-length
  for ax in axes.flatten():
    ylims = ax.get_ylim()
    ax.vlines(
        [
            tuning_data_config.sequence_length - 1,
            training_data_config.sequence_length - 1,
            eval_sequence_length - 1,
        ],
        ymin=ylims[0],
        ymax=ylims[1],
        colors=["black"],
        linewidth=1.5,
        linestyle=":",
    )
    ax.text(
        x=tuning_data_config.sequence_length,
        y=ylims[1],
        s="$N_\\text{tune}$",
        rotation=0,
        ha="center",
        va="top",
        fontsize=11,
        backgroundcolor="gainsboro",
        color="#242c2e",
    )
    ax.text(
        x=training_data_config.sequence_length,
        y=ylims[1],
        s="$N_\\text{train}$",
        rotation=0,
        ha="center",
        va="top",
        fontsize=11,
        backgroundcolor="gainsboro",
        color="#242c2e",
    )
    tb = ax.text(
        x=eval_sequence_length,
        y=ylims[1],
        s="$N_\\text{eval}$",
        rotation=0,
        ha="center",
        va="top",
        fontsize=11,
        backgroundcolor="gainsboro",
        color="#242c2e",
    )

  if torso_config.is_trainable:
    fig.suptitle(
        f"{architecture}: {pretraining_source} → {tuning_source}",
        fontsize=18,
        y=0.95,
    )
  else:
    fig.suptitle(
      f"{architecture} → {tuning_source}",
      fontsize=18,
      y=0.95,
    )

if store_results:
  plt.savefig(store_path + "evaluation_results.pdf", bbox_inches="tight")
  print("Figure written to:", store_path + "evaluation_results.pdf")

In [None]:
#@title Quantitative performance at important locations

share_y_limits = True  # Share y-axis limits across subplots?
n_timesteps = 3  # Same as bar_locations and bar_labels

with plt.rc_context(rc_context):
  for eval_name in eval_names:
    fig, axes = plt.subplots(
        ncols=n_timesteps,
        nrows=1,
        figsize=(5.5 * n_timesteps, 4),
        sharey=share_y_limits,
    )
    bar_locations = [
        tuning_data_config.sequence_length - 1,
        training_data_config.sequence_length - 1,
        eval_data_configs[eval_name].sequence_length - 1,
    ]
    bar_labels = [
        "$N_\\text{tune}$",
        "$N_\\text{train}$",
        "$N_\\text{eval}$"
    ]

    colors = [
        f"C{tuning_method_index[name]}"
        for name in cumulative_regrets[eval_name]
    ]
    for j, bar_location in enumerate(bar_locations):
      ax = axes[j]
      plot_utils.plot_performance_metric(
          cumulative_regrets[eval_name],
          "Cumulative regret [nats]",
          axis=ax,
          bar_plot=True,
          bar_plot_line_index=bar_location,
          aggregate_fn_only=False,
          show_gridlines=True,
          show_individual_lines=False,
          colors=colors,
      )
      for i, label in enumerate(ax.get_xticklabels()):
        label.set_color(colors[i])

      ax.set_title(f"$n={bar_location}$ (" + bar_labels[j] + ")")
      if j > 0:
        ax.set_ylabel("")

    if torso_config.is_trainable:
      fig.suptitle(
          f"{architecture}: {pretraining_source} → {tuning_source}\nEvaluation on {eval_name}",
          y=1.10,
          fontsize=18,
      )
    else:
      fig.suptitle(
          f"{architecture} → {tuning_source}\nEvaluation on {eval_name}",
          y=1.10,
          fontsize=18,
      )

    # Add markers for separating tuning method types / baselines.
    ylim_top = 0
    for ax in axes:
      yl_top = ax.get_ylim()[1]
      if yl_top > ylim_top:
        ylim_top = yl_top
    for ax in axes:
      margins = ax.margins()
      marg_increase = 0.15
      ax.margins(margins[0], margins[1]+marg_increase)  # make vertical space by increasing margins
      if not share_y_limits:
        ylim_top = ax.get_ylim()[1]
      ylim_marg = ylim_top * (1+margins[1]+marg_increase)
      ax.margins(0.05, 0.05)
      separator_locations = np.array([
              0,
              n_pt_methods,
              n_wt_methods + n_pt_methods,
              n_wt_methods + n_pt_methods + n_baselines,
          ])-0.5
      separator_strings = ["Prefix T.", "Weight T.", "Baselines"]
      ax.vlines(
          separator_locations[1:],
          ymin=ax.get_ylim()[0],
          ymax=ylim_marg,
          colors=["black"],
          linewidth=1.5,
          linestyle=":",
      )
      sep_widths = (separator_locations[1:] - separator_locations[:-1])/2.0
      sep_centers = separator_locations[:-1] + sep_widths
      for i, location in enumerate(sep_centers):
        tb = ax.text(
            x=location,
            y=ylim_marg,
            s=separator_strings[i],
            rotation=0,
            ha="center",
            va="top",
            fontsize=11,
            backgroundcolor="gainsboro",
            color="#242c2e",
            fontweight="heavy",
        )
        tb.set_bbox(dict(color="gainsboro", alpha=0.45))

    if store_results:
      path = store_path + "evaluation_results_detail_" + eval_name.replace(" ", "_") + ".pdf"
      plt.savefig(path, bbox_inches="tight")
      print("Figure written to:", path)

# Internal dynamics

In [None]:
# Collect states of untuned network on sequences from training distribution
torso_config_w_states = dataclasses.replace(
    torso_config, return_hidden_states=True
)
predictor = builders.build_predictor(predictor_config, torso_config_w_states)

training_datagen = builders.build_datagen(training_data_config)
training_sequences, training_log_probs = training_datagen.generate(
    rng_key=jax.random.PRNGKey(5), return_ground_truth_log_probs=True
)
training_probs = np.exp(training_log_probs)

_, training_states, _, training_prefix_states = (
    evaluation.predictions_and_states_from_sequences(
        predictor_config=None,
        torso_config=None,
        predictor_params=trained_params,
        sequences=training_sequences,
        predictor_instance=predictor,
        prefix_type="none",
        prefix=None,
    )
)

In [None]:
# Helper function to concatenate states across layers and reduce datapoints
def concat_states(
    states: types.Hidden,
    prefix_states: types.PrefixHidden,
    state_string: str,
    dim_red_no_sequences: int,
    dim_red_sequence_length: int,
) -> np.ndarray | None:
  """Concatenate states across layers and selects subset."""
  states_concat = None
  if states is None:
    return None

  for name in states:
    if name.endswith(state_string):
      # Concatenate prefix (if given) with states on time axis
      if prefix_states is None:
        state_sequence = states[name]
      else:
        state_sequence = np.concat([prefix_states[name], states[name]], axis=1)

      # Concatenate across layers on feature axis
      if states_concat is None:
        states_concat = state_sequence[
            :dim_red_no_sequences, :dim_red_sequence_length
        ]
      else:
        states_concat = np.concatenate(
            [
                states_concat,
                state_sequence[:dim_red_no_sequences, :dim_red_sequence_length],
            ],
            axis=-1,
        )

  return states_concat

In [None]:
# Dimensionality reduction (PCA) of internal state
dim_red_no_sequences = 100
dim_red_sequence_length = 50

state_string = None
if isinstance(torso_config, config_lib.LSTMTorsoConfig):
  state_string = "cell"  # "cell" or "hidden"
elif isinstance(torso_config, config_lib.TransformerTorsoConfig):
  state_string = "attention_out"
else:
  raise ValueError("Invalid torso type for state analysis.")

# Concatenate states across layers and reduce datapoints
trainging_states_concat = concat_states(
    training_states,
    training_prefix_states,
    state_string,
    dim_red_no_sequences,
    dim_red_sequence_length,
)
# Flatten non-feature dimension
training_states_flat = trainging_states_concat.reshape(
    -1, trainging_states_concat.shape[-1]
)

# Perform PCA
model = PCA(n_components=2, whiten=True)
training_states_projected = model.fit_transform(training_states_flat)

# Unflatten non-feature dimensions
training_states_projected = training_states_projected.reshape(
    dim_red_no_sequences, dim_red_sequence_length, -1
)

In [None]:
# Plot PC projection of non-prefixed net with training sequences

dark_plot = True  # Dark background? Disable for print-friendly version
if dark_plot:
  rc_context_internal = {
      "figure.facecolor": "#1b1e21",
      "text.color": "lightgray",
      "legend.facecolor": "dimgray",
      "legend.edgecolor": "none",
      "ytick.color": "lightgray",
      "axes.labelcolor": "lightgray",
  }
else:
  rc_context_internal = rc_context

with plt.rc_context(rc_context_internal):
  # Two subplots, one that colors step index and one that colors gt probability
  fig, axes = plt.subplots(
      nrows=1, ncols=2, figsize=(10, 4), sharex=True, sharey=True
  )
  for j, ax in enumerate(axes):
    for i in range(dim_red_no_sequences):
      # Individual sequences as gray lines and colored scatter points
      ax.plot(
          training_states_projected[i, :, 0],
          training_states_projected[i, :, 1],
          color="grey",
          marker=None,
          alpha=0.1,
          linewidth=1,
          zorder=-1,
      )
      if j == 0:
        # Color by step index.
        sc_color = np.linspace(0, 1, dim_red_sequence_length)
        sc_cmap = plt.cm.plasma
      else:
        # Color by ground-truth probability
        sc_color = training_probs[i, :dim_red_sequence_length, 0]
        sc_cmap = plt.cm.viridis
      sc = ax.scatter(
          training_states_projected[i, :, 0],
          training_states_projected[i, :, 1],
          c=sc_color,
          marker=".",
          alpha=0.5,
          zorder=3,
          cmap=sc_cmap,
          vmin=0,
          vmax=1,
          s=12,
      )
      if i == dim_red_no_sequences - 1:
        if j == 0:
          cbar = plt.colorbar(sc, label="Step $n$", ticks=[0, 1], aspect=35)
          cbar.ax.set_yticklabels(["0", f"{dim_red_sequence_length}"])
        else:
          cbar = plt.colorbar(
              sc,
              label="Ground-truth probability $\\tau$",
              ticks=[0, 0.5, 1],
              aspect=35,
          )
    ax.axis("off")
    cbar.ax.yaxis.set_ticks_position("left")
    cbar.solids.set(alpha=1)
    cbar.outline.set_visible(False)
    cbar.ax.invert_yaxis()

  if torso_config.is_trainable:
    fig.suptitle(
        f"{architecture} pretrained on {pretraining_source}\nState ({state_string}) on {dim_red_no_sequences} {pretraining_source} sequences (length: {dim_red_sequence_length})",
        fontsize=18,
    )
  else:
    fig.suptitle(
        f"{architecture}\nState ({state_string}) on {dim_red_no_sequences} {pretraining_source} sequences (length: {dim_red_sequence_length})",
        fontsize=18,
    )
  fig.tight_layout()

if store_results:
  plt.savefig(store_path + "internal_states.pdf", bbox_inches="tight")
  print("Figure written to:", store_path + "internal_states.pdf")

In [None]:
# Evaluate various prefixes (on all eval datagens) and plot projected state.
no_sequences_to_show = (
    20  # spread evenly across repetitions; at least 1 per repetition.
)
pca_results = {}

# Only compare prefix methods that use same weights as above
method_exclude_list = [
    "TargetBayes",
    "EvalBayes",
    "Bayes",
    "PreBayes",
    "PreBayesPT",
    "ground_truth",
    "FullWT",
    "LoRAWT",
    "EmbedWT",
    "UnembedWT",
    "Un+EmbedWT",
]

if dark_plot:
  init_color = "white"
else:
  init_color = "tomato"

with plt.rc_context(rc_context_internal):
  for eval_name in eval_names:
    # draw sequences from eval generator
    eval_datagen = builders.build_datagen(eval_data_configs[eval_name])
    eval_sequences = eval_datagen.generate(
        rng_key=jax.random.PRNGKey(5), return_ground_truth_log_probs=False
    )

    for i, method in enumerate(results):
      if method in method_exclude_list:
        continue

      # Figure out prefix-type and prefix
      match method:
        case "NoTuning":
          prefix_type = "none"
          prefixes = [None]
        case "RandomPF":
          prefix_type = "prepend"
          prefixes = results[method]["initial_prefix"]
        case "HardPT":
          prefix_type = "prepend"
          prefixes = [results[method]["tuned_prefix"]]
        case "SimplexPT":
          prefix_type = "simplex"
          prefixes = results[method]["tuned_prefix"]
        case "RealPT":
          prefix_type = "prepend"
          prefixes = results[method]["tuned_prefix"]
        case "SoftPT":
          prefix_type = "embedding"
          prefixes = results[method]["tuned_prefix"]
        case _:
          raise ValueError(f"Unknown or unsupported tuning method: {method}.")

      # Evaluate prefix on sequences, preprocess, and project to PCs
      eval_projected_states = []
      no_sequences_per_repetition = int(
          np.ceil(no_sequences_to_show / len(prefixes))
      )
      print(
          f"Processing: {method} ({len(prefixes)} repetitions,"
          f" {no_sequences_per_repetition} sequences per repetition)"
      )
      for prefix in prefixes:
        _, eval_states, _, eval_prefix_states = (
            evaluation.predictions_and_states_from_sequences(
                predictor_config=None,
                torso_config=None,
                predictor_params=trained_params,
                sequences=eval_sequences,
                predictor_instance=predictor,  # Reuse predictor instance from above.
                prefix_type=prefix_type,
                prefix=prefix,
            )
        )
        # Concatenate states across layers and reduce datapoints
        eval_states_concat = concat_states(
            eval_states,
            eval_prefix_states,
            state_string,
            no_sequences_per_repetition,
            dim_red_sequence_length,
        )
        # Flatten non-feature dimension
        eval_states_flat = eval_states_concat.reshape(
            -1, eval_states_concat.shape[-1]
        )
        # Perform PCA
        eval_states_projected = model.transform(eval_states_flat)
        # Unflatten non-feature dimensions
        eval_states_projected = eval_states_projected.reshape(
            no_sequences_per_repetition, dim_red_sequence_length, -1
        )
        eval_projected_states.append(eval_states_projected)
      pca_results[method] = np.concatenate(eval_projected_states, axis=0)

    # Plot all projections
    n_subplots = len(pca_results)
    fig, axes = plt.subplots(
        nrows=1,
        ncols=n_subplots,
        figsize=(2.8 * n_subplots, 4),
        sharex=True,
        sharey=True,
    )
    for j, method in enumerate(pca_results):
      ax = axes[j]
      ax.set_title(f"{method}")

      for k in range(dim_red_no_sequences):
        # Plot grid from training distribution
        (pretrain_lines,) = ax.plot(
            training_states_projected[k, :, 0],
            training_states_projected[k, :, 1],
            color="grey",
            marker=None,
            alpha=0.1,
            linewidth=1,
            zorder=-1,
        )

      color = plot_utils.get_color_by_index(tuning_method_index[method])
      for seq in pca_results[method]:
        (eval_lines,) = ax.plot(
            seq[:, 0],
            seq[:, 1],
            color=color,
            marker=".",
            alpha=0.6,
            linewidth=1,
            zorder=-1,
        )
        # Initial state
        if method == "NoTuning":
          (init_marker,) = ax.plot(
              seq[0, 0],
              seq[0, 1],
              marker="o",
              ls="",
              color=init_color,
              fillstyle="none",
          )
        else:
          (init_marker,) = ax.plot(
              seq[prefix_length, 0],
              seq[prefix_length, 1],
              marker="o",
              ls="",
              color=init_color,
              fillstyle="none",
          )
      ax.axis("off")

    init_marker.set_label("$n=L_{\\text{prefix}}"+f"={prefix_length}"+"$")
    eval_lines.set_label(f"{eval_name}")
    pretrain_lines.set_label(f"{pretraining_source} (no prefix)")
    legend = fig.legend(frameon=True, fontsize=14, ncols=3, loc="lower right")
    for lh in legend.legend_handles:
      lh.set_alpha(1)
    if torso_config.is_trainable:
      fig.suptitle(
          f"{architecture}: {pretraining_source} → {tuning_source}\nState ({state_string}) on {no_sequences_to_show} {eval_name} sequences (length: {dim_red_sequence_length})",
          fontsize=18,
      )
    else:
      fig.suptitle(
          f"{architecture} → {tuning_source}\nState ({state_string}) on {no_sequences_to_show} {eval_name} sequences (length: {dim_red_sequence_length})",
          fontsize=18,
      )
    fig.tight_layout()

if store_results:
  plt.savefig(store_path + "internal_states_prefix_trajectories.pdf", bbox_inches="tight")
  print("Figure written to:", store_path + "internal_states_prefix_trajectories.pdf")