# The Inductive Properties of Chronos

[Chronos](https://github.com/amazon-science/chronos-forecasting) is a family of time series forecasting models based on the T5 encoder-decoder architecture. Since the model is built upon the self-attention mechanism, we expect there to be induction heads, which are heads that attend very strongly to the most recent instance of the current token.

In this notebook, we will explore the inductive properties of the Chronos models by first finding evidence of induction heads and then studying interesting properties of inductive attention patterns.

## 0. Setup
Using the T4 GPU runtime is highly recommended. Besides that, the following cells will install and import the necessary libraries.

In [None]:
%pip install chronos-forecasting
%pip install kaleido

In [None]:
import torch
from chronos import BaseChronosPipeline

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
import plotly.offline as py
from IPython.display import display, HTML

from ipywidgets import Dropdown, Button, HBox, VBox, Output
from IPython.display import display, clear_output
from plotly.offline import init_notebook_mode
from plotly.io import to_html

from re import L

import os
import pickle

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## 1. Finding Induction Heads

### 1.1 Background
#### 1.1.1 Attention Heads
For some collection of $T$ tokens, $[S_1, S_2, \ldots, S_T]$, we embed each token into a vector of dimension $d$ and collect them into a matrix $X \in \mathbb{R}^{T \times d}$. The self-attention mechanism, for each layer $\ell$ and and head $h$, performs the following transformations to compute the attention head, $A^{(\ell, h)} \in \mathbb{R}^{1 \times d'}$:
\begin{align*}
q^{(\ell,h)}_T &= X^{(T)}W_Q^{(\ell,h)}, \\
K^{(\ell,h)} &= XW_K^{(\ell,h)}, \\
V^{(\ell,h)} &= XW_V^{(\ell,h)},
\end{align*}
$$A^{(\ell,h)} = \mathcal{S}\left(\frac{q^{(\ell, h)}_T\left(K^{(\ell, h)}\right)^\top}{\sqrt{d}}\right)V^{(\ell, h)}, $$
where $X^{(j)} \in \mathbb{R}^{1 \times d}$ is the embedding of the $j$-th token in the context (also the $j$-th row in $X$); the matrices $W_Q^{(\ell, h)}, W_K^{(\ell, h)}, W_V^{(\ell, h)} \in \mathbb{R}^{d \times d'}$ are the learnable parameters of the model; while $q^{(\ell, h)}_T \in \mathbb{R}^{1 \times d'} \text{ and } K^{(\ell, h)}, V^{(\ell, h)} \in \mathbb{R}^{T \times d'}$ are referred to as the query, key, and value matrices, respectively; and $\mathcal{S}$ is the row-wise softmax function.

The attention head can be interpreted as a weighted average over the rows of $V^{(\ell, h)}$, with the weight for the $j$-th token being equal to
\begin{equation*}
    \tilde{A}_{T,j} = \frac{\exp\left(q^{(\ell, h)}_T \left( k^{(\ell, h)}_j \right)^\top \right)}{\sum_{i=1}^T \exp\left(q^{(\ell, h)}_T \left( k_i^{(\ell, h)} \right)^\top \right)}
\end{equation*}
where $k^{(\ell, h)}_i$ is the $i$-th row in the $K$ matrix. With this formulation, we can explicity see that the attention head is computing a weighted sum of the the value vectors. The weight ascribed to each $j$-th token, $\tilde{A}_{T,j}$, is referred to as the 'attention score' and is how much the $T$-th token 'attends' to the $j$-th token.

#### 1.1.2 Induction Heads
Induction heads are attention heads that weight previous instances of the $S_T$ token much higher than usual, i.e. $\tilde{A}_{T,k} \gg 1/T$ where $k=\max \{ n < T | S_n = S_T \}$. Additionally, induction heads may also attend to the token *after* previous instances of the $S_T$ token: $\tilde{A}_{T,k'} \gg 1/T$ where $k' = \max \{n<T|S_{n-1}=S_T\}$.

### 1.2 Repeated Random Tokens (RRT)
The repeated random tokens method described by [Olsson et al.](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) is a standard method for empirically finding evidence of induction heads in models. Since the Chronos models are based on encoder-deocder archiectures, we will need to slightly adapt the procedure, using [Cynthia Chen's](https://chenxcynthia.github.io/projects/induction/) blog post as reference:

1. Create sequences of repeated random tokens. For example, let us create a sequence of length `sequence_length=3` and repeat it `repeat_factors=4` times: `[ABC][ABC][ABC][ABC]`. For our experiments, the random tokens will always be integers between 1911 and 2187, inclusive, since those token ids correspond to the tokens that the encoder always scales its data to be between.
2. Then, append one more token to continue the pattern and use that as the input token ids for the encoder. So, the encoder will recieve `ABCABCABCABCA<EOS>`, where `<EOS>` is a special token that marks the end of the encoder input sequence (note that there is an extra `A` token appended to the end of the sequence from step 1).
3. Then, pass in the proceeding token in the sequence to the decoder and do a forward pass of the model. Continuing our example, the decoder recieves `<DEC_START>B` since the `B` token would normally follow the sequence given to the encoder. `<DEC_START>` is a special token that marks the start of the decoder input sequence.
4. For each head in each layer of the decoder, find the average attention score between the current decoder token `B` and the most recent occurence of the `B` token in the encoder input ids. (In the chronos models, this is given through the cross attention scores). Additionally, also find the average attention score between the current decoder token `B` and the token *after* the most recent occurence of `B` in the encoder input ids.

Once we compute the average attention towards the most recent occurence of the current token as well as the token after the most most recent occurence, we isolate the heads that have an average attention score of at least $0.3$ as our induction heads.  

### 1.1 Helper Function to Generate RRT Data

In [4]:
def generate_rrt_induction_data(
    vocab_range=(1911, 2187),
    batch_size=100,
    num_unique_sequences=1,
    repeat_factor=4,
    extension=0,
    sub_extension=1,
    sequence_length=10,
    device="cpu"
):
    """
    Generate data for RRT induction experiments.

    Args:
        vocab_range: Tuple of (min, max) vocab indices to use
        batch_size: Number of sequences in the batch
        num_unique_sequences: Number of unique sequences in the batch
        repeat_factor: Number of times to repeat the sequences
        extension: After repeating everything, repeat this many sequences again
        sub_extension: Some additional tokens in the next sequence
        sequence_length: Length of each sequence

    Returns:
        Tuple of (token_ids, attention_mask, decoder_input_ids)
    """
    # Define constants
    EOS_TOKEN_ID = torch.tensor(1, dtype=torch.int)

    # Define the vocab as all the tokens within the specified range
    vocab = torch.tensor([i for i in range(4096) if i >= vocab_range[0] and i <= vocab_range[1]])

    # Generate random token sequences (inlined from rrt.generate_random_token_ids)
    tokens = []
    for _ in range(num_unique_sequences):
        # random choice of vocab indices
        indices = torch.randint(0, vocab.shape[0], (batch_size, sequence_length))
        # select the tokens
        token_sequence = vocab[indices]
        tokens.append(token_sequence)

    # Stack sequences with repetition pattern (inlined from rrt.stack_sequences)
    sequences_to_stack = tokens * repeat_factor + tokens[:extension] + [tokens[extension][:,:sub_extension]] + [torch.ones((batch_size, 1), dtype=torch.int) * EOS_TOKEN_ID]
    token_ids = torch.cat(sequences_to_stack, dim=-1)

    # Create attention mask
    attention_mask = torch.ones_like(token_ids, dtype=torch.bool)

    # Create decoder input ids
    decoder_input_ids = torch.cat([
        torch.zeros((batch_size, 1), dtype=torch.long),
        tokens[extension][:,sub_extension:sub_extension+1]
    ], dim=-1)

    token_ids = token_ids.to(device)
    attention_mask = attention_mask.to(device)
    decoder_input_ids = decoder_input_ids.to(device)

    return token_ids, attention_mask, decoder_input_ids


### 1.2 Computing RRT

We will conduct the RRT procedure over sequences of length `sequence_length=10` that will be repeated `repetitions=2` times in order to observe the effects of the input sequences on the induction heads that we detect.

In [5]:
# Define different values to experiment with
repeat_factors = [2]
sequence_lengths = [10]

model_names = ["amazon/chronos-t5-mini", "amazon/chronos-t5-small", "amazon/chronos-t5-base", "amazon/chronos-t5-large"]

In [None]:
from tqdm.notebook import tqdm

all_results = {}

total_iterations = len(repeat_factors) * len(sequence_lengths) * len(model_names)

with tqdm(total=total_iterations, desc="Processing configuations and models") as pbar:
  for repeat_factor in repeat_factors:
      for sequence_length in sequence_lengths:
          print(f"Processing repeat_factor={repeat_factor}, sequence_length={sequence_length}")

          num_unique_sequences = 1  # number of unique sequences in the batch
          extension = 0  # after repeating everything, repeat this many sequences again
          sub_extension = 1  # some additional tokens in the next sequence

          # Generate data with current configuration
          token_ids, attention_mask, decoder_input_ids = generate_rrt_induction_data(
              num_unique_sequences=num_unique_sequences,
              repeat_factor=repeat_factor,
              extension=extension,
              sub_extension=sub_extension,
              sequence_length=sequence_length,
              device=device
          )

          # Display shapes and sample values
          # print(f"token_ids.shape: {token_ids.shape}, decoder_input_ids.shape: {decoder_input_ids.shape}")

          # Store results for each model
          config_key = f"rf{repeat_factor}_sl{sequence_length}"
          all_results[config_key] = {
              "center_scores": {},
              "right_scores": {}
          }

          for model_name in model_names:
              pbar.set_description(f"Processing {model_name} for RF={repeat_factor}, SL={sequence_length}")

              pipeline = BaseChronosPipeline.from_pretrained(
                  model_name,
                  device_map=device,
                  torch_dtype=torch.bfloat16,
              )

              t5_model = pipeline.model.model

              outputs = t5_model.generate(
                  input_ids=token_ids,
                  attention_mask=attention_mask,
                  max_new_tokens=1,
                  decoder_input_ids=decoder_input_ids,
                  num_return_sequences=1,
                  do_sample=False,
                  use_cache=False,
                  output_attentions=True,
                  output_scores=True,
                  output_hidden_states=True,
                  return_dict_in_generate=True
              )

              # Extract cross attention probabilities
              # cross_attentions is a list of length layers, each with shape [batch, heads, dec_length, enc_length]
              cross_attn_probs = outputs.cross_attentions

              t_idx = 1
              s_idx = sequence_length * ((repeat_factor-1)*num_unique_sequences + extension) + sub_extension

              layers, heads = t5_model.config.num_decoder_layers, t5_model.config.num_heads
              mosaic_center = np.zeros((layers, heads)).tolist()
              mosaic_right = np.zeros((layers, heads)).tolist()

              for layer in range(layers):
                  for head in range(heads):
                      # mean over the batch
                      mosaic_center[layer][head] = float(cross_attn_probs[0][layer][:, head, t_idx, s_idx].to("cpu").mean())
                      mosaic_right[layer][head] = float(cross_attn_probs[0][layer][:, head, t_idx, s_idx+1].to("cpu").mean())

              # Store the scores for this model and configuration
              all_results[config_key]["center_scores"][model_name] = mosaic_center
              all_results[config_key]["right_scores"][model_name] = mosaic_right

              pbar.update(1)

# save the results in a pickle file
os.makedirs("variables", exist_ok=True)
with open("variables/results.pkl", "wb") as f:
    pickle.dump(all_results, f)

### 1.3 Visualizing Induction Heads in an Induction Mosaic

As a nice way to visualize the induction heads, we will plot the average attention scores to the previous tokens of interest as a heatmap where the y-axis represents the layer of the head and the x-axis represents different heads in each transformer layer.

While the layers are arranged in the order they are found in the model (for example, layer 1 precedes layer 2 in the model), the attention heads in each layer have no inherent order to them and are only given numeric labels for classification purposes.

#### 1.3.1 Helper Functions for Plotting the Induction Mosaic

In [7]:
def enable_plotly_in_cell():
  display(HTML("<script src=\"/static/components/requirejs/require.js\"></script>"))
  py.init_notebook_mode(connected=False)

def create_mosaic_figure(current_rf, current_sl, current_model):
    # Generate config key
    config_key = f"rf{current_rf}_sl{current_sl}"

    # Check if config exists in results
    if config_key not in all_results:
        return go.Figure().update_layout(
            annotations=[dict(
                text=f"No data available for configuration: {config_key}",
                showarrow=False,
                xref="paper",
                yref="paper",
                x=0.5,
                y=0.5
            )]
        )

    # Get scores for current configuration
    all_center_scores = all_results[config_key]["center_scores"]
    all_right_scores = all_results[config_key]["right_scores"]

    # Check if model exists in scores
    if current_model not in all_center_scores or current_model not in all_right_scores:
        return go.Figure().update_layout(
            annotations=[dict(
                text=f"No data available for model {current_model} with configuration {config_key}",
                showarrow=False,
                xref="paper",
                yref="paper",
                x=0.5,
                y=0.5
            )]
        )

    # Get data for current model
    mosaic_center = all_center_scores[current_model]
    mosaic_right = all_right_scores[current_model]

    # Create a subplot with 2 side-by-side heatmaps
    fig = make_subplots(rows=1, cols=2,
                        subplot_titles=("Current token", "Token to right of current"),
                        shared_yaxes=True)

    # Add heatmaps to the subplots
    fig.add_trace(
        go.Heatmap(z=mosaic_center, zmin=0, zmax=1, coloraxis="coloraxis"),
        row=1, col=1
    )
    fig.add_trace(
        go.Heatmap(z=mosaic_right, zmin=0, zmax=1, coloraxis="coloraxis"),
        row=1, col=2
    )

    # Update layout
    fig.update_layout(
        title_text=f"Attention Mosaics For Induction on RRTs - Repetitions: {current_rf}, Sequence Length: {current_sl}, Model: {current_model.split('/')[-1]}",
        height=500,
        width=1000,
        coloraxis=dict(cmin=0, cmax=1, colorbar=dict(title="Attention Score"))
    )

    # Add axes labels with integer ticks
    fig.update_xaxes(title_text="Head", row=1, col=1, title_font=dict(size=18),
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_xaxes(title_text="Head", row=1, col=2, title_font=dict(size=18),
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_yaxes(title_text="Layer", row=1, col=1, title_font=dict(size=18),
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_yaxes(title_text="Layer", row=1, col=2, title_font=dict(size=18),
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)

    return fig

#### 1.3.2 Induction Mosaic

In [None]:
enable_plotly_in_cell()
init_notebook_mode(connected=True)

mosaic_model_dd = widgets.Dropdown(options=model_names, description='Model:')
mosaic_rf_dd    = widgets.Dropdown(options=repeat_factors, description='Repetitions:')
mosaic_sl_dd    = widgets.Dropdown(options=sequence_lengths, description='Seq Length:')
mosaic_save_btn = widgets.Button(description='Save Figure')
mosaic_controls = widgets.HBox([mosaic_model_dd, mosaic_rf_dd, mosaic_sl_dd, mosaic_save_btn])

fig_out = widgets.Output()
msg_out = widgets.Output()

def update_plot(change=None):
    fig = create_mosaic_figure(mosaic_rf_dd.value, mosaic_sl_dd.value, mosaic_model_dd.value)
    html_snippet = to_html(fig, include_plotlyjs=False, full_html=False)
    
    fig_out.clear_output(wait=True)
    with fig_out:
        display(HTML(html_snippet))

for w in (mosaic_model_dd, mosaic_rf_dd, mosaic_sl_dd):
    w.observe(update_plot, names='value')

def mosaic_on_save(b):
    msg_out.clear_output()
    fig   = create_mosaic_figure(mosaic_rf_dd.value, mosaic_sl_dd.value, mosaic_model_dd.value)
    short = mosaic_model_dd.value.split('/')[-1]
    base  = f"mosaic_rf{mosaic_rf_dd.value}_sl{mosaic_sl_dd.value}_{short}"
    with msg_out:
        try:
            fig.write_image(f"{base}.png")
            print(f"✓ Saved image: {base}.png")
        except Exception as e:
            print("✗ Image save failed:", e)
        try:
            fig.write_json(f"{base}.json")
            print(f"✓ Saved JSON: {base}.json")
        except Exception as e:
            print("✗ JSON save failed:", e)

mosaic_save_btn.on_click(mosaic_on_save)

display(mosaic_controls, msg_out, fig_out)
update_plot()

There are three important take aways from the induction mosaics that we will explore more quantitatively:

1. Larger models have more induction heads that are identifiable with the RRT test.
2. Most of the induction heads attend to the token directly proceeding the last instance of the current token rather than the last instance itself.
3. Across all the models, the induction heads that attend to the most recent instance of the current token tend to be in earlier layers, whereas the heads that attend to the token proceeding it tend to be in later layers.

### 1.4 Model Size and Induction Heads Comparison

For the RRT procedure where `sequence_length=10` and `repetitions=2`, the number of induction heads in each model is given in the table below:

| Model                | Parameters | Current Token Induction Heads | Token to Right of Current Token Induction Heads | Total Induction Heads |
|----------------------|------------|-------------------------------|------------------------------------------------|-----------------------|
| chronos-t5-mini      | 20M        | 2 (6.2%)                      | 0 (0.0%)                                       | 2 (6.2%)              |
| chronos-t5-small     | 46M        | 3 (6.2%)                      | 8 (16.7%)                                      | 11 (22.9%)            |
| chronos-t5-base      | 200M       | 5 (3.5%)                      | 15 (10.4%)                                     | 20 (13.9%)            |
| 3chronos-t5-large    | 710M       | 15 (3.9%)                     | 33 (8.6%)                                      | 48 (12.5%)            |



#### 1.4.1 Helper Functions To Visualize Induction Heads Across Models

In [None]:
enable_plotly_in_cell()

# Function to get induction head counts for specific parameters
def get_induction_head_counts(model_names, all_results, repeat_factor=2, seq_length=10, threshold=0.3):
    results = []

    for model_name in model_names:
        # Extract model results for the specific repeat factor and sequence length
        rfsl_results = all_results[f"rf{repeat_factor}_sl{seq_length}"]
        center_scores = rfsl_results["center_scores"][model_name]
        right_scores = rfsl_results["right_scores"][model_name]

        # Count induction heads above threshold
        center_induction_heads = 0
        right_induction_heads = 0
        total_heads = 0

        layers = len(center_scores)
        heads = len(center_scores[0])

        for layer in range(layers):
            for head in range(heads):
                if center_scores[layer][head] >= threshold:
                    center_induction_heads += 1
                if right_scores[layer][head] >= threshold:
                    right_induction_heads += 1
                total_heads += 1

        center_proportion = center_induction_heads / total_heads
        right_proportion = right_induction_heads / total_heads
        total_proportion = (center_induction_heads + right_induction_heads) / total_heads

        # Calculate proportion
        results.append((model_name, center_induction_heads, right_induction_heads, total_heads, center_proportion, right_proportion, total_proportion))

    return results


# Define helper functions for plotting induction head data
def prepare_induction_data(induction_stats):
    """Extract and prepare data for plotting from induction statistics."""
    models = [model.replace('amazon/', '') for model, _, _, _, _, _, _ in induction_stats]
    center_induction_heads = [center for _, center, _, _, _, _, _ in induction_stats]
    right_induction_heads = [right for _, _, right, _, _, _, _ in induction_stats]
    center_proportions = [center_prop for _, _, _, _, center_prop, _, _ in induction_stats]
    right_proportions = [right_prop for _, _, _, _, _, right_prop, _ in induction_stats]
    return models, center_induction_heads, right_induction_heads, center_proportions, right_proportions

def create_bar_trace(x, y, name, text_format, color, position='outside', showlegend=True, customdata=None):
    """Create a bar trace with consistent styling."""
    text = [text_format.format(val) for val in y]

    # Create customdata for selective text display
    if customdata is not None:
        show_text = customdata
    else:
        show_text = [True] * len(y)

    # Create final text list with conditional display
    final_text = []
    for i, t in enumerate(text):
        if show_text[i]:
            final_text.append(t)
        else:
            final_text.append("")

    if '%' in text_format:
        hover_format = '<b>%{x}</b><br>' + name + ': %{y:.1%}<extra></extra>'
    else:
        hover_format = '<b>%{x}</b><br>' + name + ': %{y}<extra></extra>'

    return go.Bar(
        x=x, y=y,
        name=name,
        text=final_text,
        textposition=position,
        textfont=dict(size=14, color='rgba(50,50,50,0.8)'),
        marker=dict(
            color=color,
            line=dict(color=color.replace('0.7', '1.0'), width=2),
        ),
        hovertemplate=hover_format,
        showlegend=showlegend
    )

def plot_induction_head_stats(induction_stats, rf=2, sl=10, threshold=0.3):
    """Create and display a bar chart for induction head statistics."""
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

    models, center_induction_heads, right_induction_heads, center_proportions, right_proportions = prepare_induction_data(induction_stats)

    # Create list to control text display - hide for mini model
    show_right_text = [model != "chronos-t5-mini" for model in models]

    # Create subplots
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=('<b>Number of Induction Heads</b>', '<b>Proportion of Induction Heads</b>'),
        horizontal_spacing=0.15
    )

    # Add traces for count plot
    fig.add_trace(
        create_bar_trace(models, center_induction_heads, "Current Token", "{}", 'rgba(0, 123, 255, 0.7)'),
        row=1, col=1
    )

    fig.add_trace(
        create_bar_trace(models, right_induction_heads, "Token to Right", "{}", 'rgba(255, 123, 0, 0.7)',
                        customdata=show_right_text),
        row=1, col=1
    )

    # Add traces for proportion plot
    fig.add_trace(
        create_bar_trace(models, center_proportions, "Current Token", "{:.1%}", 'rgba(0, 123, 255, 0.7)',
                         showlegend=False),
        row=1, col=2
    )

    fig.add_trace(
        create_bar_trace(models, right_proportions, "Token to Right", "{:.1%}", 'rgba(255, 123, 0, 0.7)',
                         showlegend=False, customdata=show_right_text),
        row=1, col=2
    )

    # Update layout
    fig.update_layout(
        template='plotly_white',
        height=600, width=1100,
        barmode='stack',
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="center",
            x=0.5
        ),
        title=dict(
            text=f"<b>Induction Head Statistics</b> (RF={rf}, SL={sl}, threshold={threshold})",
            font=dict(size=20), x=0.5, y=0.95,
        ),
        margin=dict(t=100, b=80, l=80, r=80),
        font=dict(family="Arial, sans-serif", size=13),
        plot_bgcolor='white',
    )

    # Update axes
    fig.update_yaxes(
        title_text="<b>Number of Induction Heads</b>",
        title_font=dict(size=14),
        showgrid=True, gridwidth=1, gridcolor='rgba(220, 220, 220, 0.8)',
        row=1, col=1
    )
    fig.update_yaxes(
        title_text="<b>Proportion of Heads</b>",
        title_font=dict(size=14),
        showgrid=True, gridwidth=1, gridcolor='rgba(220, 220, 220, 0.8)',
        tickformat='.0%',
        row=1, col=2
    )
    fig.update_xaxes(
        tickangle=0,
        title_text="<b>Model</b>",
        title_font=dict(size=14),
        title_standoff=15,
        tickfont=dict(size=12),
    )

    return fig

#### 1.4.2 Induction Heads in Each Model

While the raw number of induction heads is increasing as the model gets bigger, the ratio of induction heads to the total number of heads seems to even out at about 13% in the larger models.

Additionally, we see that almost all of the induction heads attend to tokens directly after the most recent occurence of the current token (labeled as token to right) rather than the most recent occurence of the current token (labled as current).

In [None]:
enable_plotly_in_cell()

induction_stats = get_induction_head_counts(model_names, all_results, repeat_factor=2, seq_length=10)
plot_induction_head_stats(induction_stats).show()

### 1.5 Relative Layer Position of Induction Heads

In section 1.3, we saw that the induction mosaic had a clear demarcation on the larger models where the induction heads switched from attending to the most recent instance of the current token, to attending to the token to the right of it. So, let's explore that a bit more.

#### 1.5.1 Helper Functions for Visualizations

In [11]:
def get_induction_head_layer_positions(model_names, all_results, repeat_factor=2, seq_length=10, threshold=0.3):
    """Get normalized layer positions of induction heads for each model."""
    results = []

    for model_name in model_names:
        # Extract model results for the specific repeat factor and sequence length
        rfsl_results = all_results[f"rf{repeat_factor}_sl{seq_length}"]
        center_scores = rfsl_results["center_scores"][model_name]
        right_scores = rfsl_results["right_scores"][model_name]

        # Store layer positions as normalized values (0-1)
        center_layer_positions = []
        right_layer_positions = []

        layers = len(center_scores)
        heads = len(center_scores[0])

        for layer in range(layers):
            layer_norm = (layer + 1) / layers  # Normalize to 0-1 (add 1 to make 1-indexed)

            for head in range(heads):
                if center_scores[layer][head] >= threshold:
                    center_layer_positions.append(layer_norm)
                if right_scores[layer][head] >= threshold:
                    right_layer_positions.append(layer_norm)

        # Calculate statistics about these positions
        model_short_name = model_name.split('/')[-1]
        total_layers = layers

        results.append((model_name, model_short_name, total_layers, center_layer_positions, right_layer_positions))

    return results

def plot_induction_head_layer_distribution(model_names, all_results, repeat_factor=2, seq_length=10, threshold=0.3):
    """Create vertical violin plots showing the distribution of induction head layers."""
    # Get layer positions data
    layer_positions_data = get_induction_head_layer_positions(
        model_names, all_results, repeat_factor, seq_length, threshold
    )

    # Extract data for plotting
    model_short_names = [data[1] for data in layer_positions_data]

    # Create figure with proper dimensions
    fig = go.Figure()

    # Track if we've added each type of data for legend purposes
    added_current = False
    added_right = False

    # Add violin plots for Current Token heads (now on negative side)
    for i, model in enumerate(model_short_names):
        center_positions = layer_positions_data[i][3]
        if center_positions:
            # Add scatter points for the actual data points
            fig.add_trace(go.Scatter(
                x=[model] * len(center_positions),
                y=center_positions,
                mode='markers',
                marker=dict(
                    color='rgba(0, 123, 255, 0.7)',
                    size=8,
                    line=dict(width=1, color='rgba(0, 123, 255, 1.0)')
                ),
                name="Current Token",
                showlegend=not added_current,
                hovertemplate='<b>%{x}</b><br>Layer: %{y:.0%}<extra>Current Token</extra>',
                legendgroup='current'
            ))
            added_current = True

            # Add violin plot - now on negative side
            fig.add_trace(go.Violin(
                x=[model] * len(center_positions),
                y=center_positions,
                name="Current Token",
                side='negative',  # Now on negative side
                width=1.0,
                line_color='rgba(0, 123, 255, 1.0)',
                fillcolor='rgba(0, 123, 255, 0.3)',
                points=False,  # Hide points as we're showing them separately
                showlegend=False,
                legendgroup='current',
                hoverinfo='skip',
                meanline_visible=True,  # Keep mean line visible
                box_visible=False,  # Make box invisible
                spanmode='hard',  # Don't extend beyond data points
                orientation='v',
                offsetgroup=0
            ))

    # Add violin plots for Token to Right heads (now on positive side)
    for i, model in enumerate(model_short_names):
        right_positions = layer_positions_data[i][4]
        if right_positions:
            # Add scatter points for the actual data points
            fig.add_trace(go.Scatter(
                x=[model] * len(right_positions),
                y=right_positions,
                mode='markers',
                marker=dict(
                    color='rgba(255, 123, 0, 0.7)',
                    size=8,
                    line=dict(width=1, color='rgba(255, 123, 0, 1.0)')
                ),
                name="Token to Right",
                showlegend=not added_right,
                hovertemplate='<b>%{x}</b><br>Layer: %{y:.0%}<extra>Token to Right</extra>',
                legendgroup='right'
            ))
            added_right = True

            # Add violin plot - now on positive side
            fig.add_trace(go.Violin(
                x=[model] * len(right_positions),
                y=right_positions,
                name="Token to Right",
                side='positive',  # Now on positive side
                width=1.0,
                line_color='rgba(255, 123, 0, 1.0)',
                fillcolor='rgba(255, 123, 0, 0.3)',
                points=False,  # Hide points as we're showing them separately
                showlegend=False,
                legendgroup='right',
                hoverinfo='skip',
                meanline_visible=True,  # Keep mean line visible
                box_visible=False,  # Make box invisible
                spanmode='hard',  # Don't extend beyond data points
                orientation='v',
                offsetgroup=1
            ))

    # Add explicit legend entries if needed
    if not added_current:
        fig.add_trace(go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(color='rgba(0, 123, 255, 0.7)', size=8),
            name="Current Token",
            showlegend=True
        ))

    if not added_right:
        fig.add_trace(go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(color='rgba(255, 123, 0, 0.7)', size=8),
            name="Token to Right",
            showlegend=True
        ))

    # Update layout
    fig.update_layout(
        template='plotly_white',
        height=600,
        width=1100,
        violinmode='group',
        violingap=0.2,
        violingroupgap=0.1,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="center",
            x=0.5
        ),
        title=dict(
            text=f"<b>Induction Head Layer Distribution</b> (RF={repeat_factor}, SL={seq_length}, threshold={threshold})",
            font=dict(size=20), x=0.5, y=0.95,
        ),
        margin=dict(t=100, b=80, l=80, r=80),
        font=dict(family="Arial, sans-serif", size=13),
        plot_bgcolor='white',
    )

    # Update y-axis to show percentages
    fig.update_yaxes(
        title_text="<b>Relative Layer Position</b>",
        title_font=dict(size=14),
        showgrid=True,
        gridwidth=1,
        gridcolor='rgba(220, 220, 220, 0.8)',
        tickformat='.0%',
        range=[0, 1.05]  # Ensure 0-100% is shown with a little padding
    )

    # Update x-axis
    fig.update_xaxes(
        title_text="<b>Model</b>",
        title_font=dict(size=14),
        title_standoff=15,
        tickfont=dict(size=12),
    )

    return fig

#### 1.5.2 Violin Plots of the Relative Layer Position of Induction Heads

Here, we plot the distribution of the relative layer positions (close to 0% meaning closer to the first layer and 100% meaning the last layer) of induction heads that attend to the most recent prior occurence of the current token, as well as induction heads that attend to the toke to the right of the most recent prior occurence of the current token.

Especially in the larger models, there is a clear split where the earlier layers have been reserved for heads that attend to copies of the current token, and later layers have been reserved to heads that attend to the token following copies of the current token.

In [None]:
enable_plotly_in_cell()
plot_induction_head_layer_distribution(model_names, all_results)

## 2. Induction with Multisine Series Data

To observe induction in a setting that is closer to the true inference tasks given to the Chronos models, we can study a multisine series.

For some set of $n$ frquencies $\{k_1, k_2, \ldots, k_n\}$ where each $k_i$ is a frequency in $\text{Hz}$, $n$ amplitudues $\{a_1, a_2, \ldots, a_n\}$, and $n$ phase shifts $\{\phi_1, \phi_2, \ldots, \phi_n\}$ a multisine series is simply
$$f(t) = \sum_{i=1}^n a_i \sin(2 \pi k_i t + \phi_i)$$

Since a multisine series is periodic (we say it is period with a periodic equal to $T_\text{min} = \left(\min\{k_1, k_2, \ldots, k_n\}\right)^{-1}$), we can almost treat it like a sequence of repeated random tokens. So, any inductive heads we are looking for would have a period roughly equal to $T_\text{min}$ (and correspondingly a frequency of $k_\text{min} = \min\{k_1, k_2, \ldots, k_n\}$). Or, we might possibly see multiple sets of induction heads that each attend to the data with various periods corresponding to the frequencies $k_i$ with the highest amplitudes.

An easy way to verify this is to plot the attention scores of some multisine data and then compute the Fourier transforms of the data and the attention scores, and see if the attentions scores spike up near the spikes in the Fourier transforme of the original data.

### 2.1 Helper Functions

#### 2.1.1 Data Generation Helper Functions

In [13]:
def sinusoidal(k, t_1=1, steps=100, amp=1, phase_shift=0):
    """
    Generate a sinusoidal function with frequency k.

    Args:
        k: Frequency of the sinusoid in Hz or array of frequencies
        t_1: Total time period
        steps: Number of time steps
        amp: Amplitude of the sinusoid or array of amplitudes
        phase_shift: Phase shift in radians

    Returns:
        Array of sinusoidal values
    """
    t = np.linspace(0, t_1, steps, endpoint=False)

    # Convert k to numpy array if it isn't already
    k_array = np.atleast_1d(k)

    # Handle amp: if k is array and amp is scalar, use the same amp for all frequencies
    if isinstance(amp, (int, float)) and len(k_array) > 1:
        amp_array = np.full_like(k_array, amp, dtype=float)
    else:
        amp_array = np.atleast_1d(amp)
        # Check that k and amp have the same length if amp is also an array
        if len(k_array) != len(amp_array):
            raise ValueError("k and amp must have the same length when passed as arrays")

    # Initialize y with zeros
    y = np.zeros(steps)

    # Sum the sinusoidal functions
    for k_i, amp_i in zip(k_array, amp_array):
        y += amp_i * np.sin(2 * np.pi * k_i * t + phase_shift)

    return t, y

def tokenize_series(series, tokenizer, num_decoder=0):
    assert len(series.shape) == 2
    enc_ids, attention_mask, scale = tokenizer.context_input_transform(series)
    dec_ids = torch.zeros((enc_ids.shape[0],1), dtype=int)

    if num_decoder != 0:
        dec_ids = torch.cat([dec_ids, enc_ids[:,-(num_decoder+1):-1]], dim=-1)
        enc_ids = torch.cat([enc_ids[:,:-(num_decoder+1)], torch.ones((enc_ids.shape[0],1), dtype=int)], dim=-1)

    attention_mask = torch.ones_like(enc_ids)

    return enc_ids, dec_ids, attention_mask, scale

def get_model_outputs(model_name, series):
    """Load model and generate predictions with attention outputs"""
    pipeline = BaseChronosPipeline.from_pretrained(
        model_name,
        device_map="cpu",
        torch_dtype=torch.bfloat16,
    )
    tokenizer, t5_model = pipeline.tokenizer, pipeline.model.model

    enc_ids, dec_ids, attention_mask, scale = tokenize_series(series, tokenizer, num_decoder=0)

    outputs = t5_model.generate(input_ids=enc_ids,
                               attention_mask=attention_mask,
                               max_new_tokens=1,
                               decoder_input_ids=dec_ids,
                               num_return_sequences=1,
                               do_sample=False,
                               use_cache=False,
                               output_attentions=True,
                               output_scores=True,
                               output_hidden_states=True,
                               return_dict_in_generate=True)

    preds = tokenizer.output_transform(outputs.sequences[...,1:], scale)

    # Extract attention scores
    num_layers, num_heads = t5_model.config.num_decoder_layers, t5_model.config.num_heads
    target_token = 0

    attns = {}
    for layer in range(num_layers):
        attns[layer] = outputs.cross_attentions[0][layer][0, :, target_token, :-1]

    return preds, attns, num_layers, num_heads

#### 2.1.2 Data Visualizaion Helper Functions

In [14]:
def process_attention_data(attns, num_layers, num_heads, colors):
    """Process attention data and prepare for plotting"""
    attn_traces = []
    color_idx = 0

    for layer in range(num_layers):
        for head in range(num_heads):
            attn_scores = attns[layer][head].cpu().float().numpy()
            color = colors[color_idx % len(colors)]

            attn_traces.append({
                "layer": layer,
                "head": head,
                "scores": attn_scores,
                "color": color
            })

            color_idx += 1

    return attn_traces

def add_attention_traces(fig, attn_traces):
    """Add attention score traces to the figure"""
    for attn_data in attn_traces:
        layer, head = attn_data["layer"], attn_data["head"]
        attn_scores = attn_data["scores"]
        color = attn_data["color"]

        # Create trace for subplot 1 - attention scores
        trace = go.Scatter(
            x=np.arange(len(attn_scores)),
            y=attn_scores,
            mode='lines',
            line=dict(color=color),
            opacity=0.3,
            showlegend=False,
            hovertemplate=f"(l={layer+1},h={head+1})<br>Position: %{{x}}<br>Score: %{{y:.3f}}<extra></extra>"
        )
        fig.add_trace(trace, row=1, col=1, secondary_y=True)

def add_fft_traces(fig, attn_traces, series):
    """Add FFT traces for attention scores and compute max FFT magnitude"""
    max_attn_fft = 0

    for attn_data in attn_traces:
        layer, head = attn_data["layer"], attn_data["head"]
        attn_scores = attn_data["scores"]
        color = attn_data["color"]

        # Perform FFT on attention scores
        attn_fft = np.fft.fft(attn_scores)
        attn_fft_magnitude = np.abs(attn_fft)

        # Get frequencies
        attn_frequencies = np.fft.fftfreq(len(attn_scores), d=3/series.shape[1])

        # Add to FFT subplot with secondary y-axis
        fig.add_trace(go.Scatter(
            x=attn_frequencies[:len(attn_frequencies)//2],
            y=attn_fft_magnitude[:len(attn_fft_magnitude)//2],
            mode='lines',
            line=dict(color=color),
            opacity=0.3,
            showlegend=False,
            hovertemplate=f"(l={layer+1},h={head+1})<br>Frequency: %{{x:.3f}}<br>Magnitude: %{{y:.3f}}<extra></extra>"
        ), row=2, col=1, secondary_y=True)

        max_attn_fft = max(max_attn_fft, np.max(attn_fft_magnitude[:len(attn_fft_magnitude)//2]))

    return max_attn_fft

def add_original_and_prediction_traces(fig, series, preds):
    """Add original series and prediction traces"""
    original_color = 'royalblue'
    prediction_color = 'firebrick'

    # Add original series
    fig.add_trace(go.Scatter(
        x=np.arange(len(series[0])),
        y=series[0].numpy(),
        mode='lines',
        name='Original Series',
        line=dict(color=original_color, width=2),
        hoverinfo='y'
    ), row=1, col=1, secondary_y=False)

    # Add predictions
    fig.add_trace(go.Scatter(
        x=np.arange(len(series[0])-1, len(series[0]) + len(preds[0,0])),
        y=np.concatenate([[series[0].numpy()[-1]], preds[0,0].numpy()]),
        mode='lines',
        name='Prediction',
        line=dict(color=prediction_color, width=3),
        hoverinfo='y'
    ), row=1, col=1, secondary_y=False)

    return original_color

def add_original_fft_trace(fig, series, original_color):
    """Add FFT of the original series"""
    # Perform FFT on the original series
    fft_result = torch.fft.fft(series)
    fft_magnitude = torch.abs(fft_result)

    # Convert to numpy for plotting
    frequencies = np.fft.fftfreq(len(series[0]), d=3/series.shape[1])
    fft_magnitude_np = fft_magnitude[0].numpy()

    # Add the original series FFT
    fig.add_trace(go.Scatter(
        x=frequencies[:len(frequencies)//2],
        y=fft_magnitude_np[:len(fft_magnitude_np)//2],
        mode='lines',
        name='Original Series FFT',
        line=dict(width=3, color=original_color),
        hoverinfo='y'
    ), row=2, col=1, secondary_y=False)

def create_fft_figure(model_name, series):
    """Create the 2x1 grid visualization for a model"""
    # Get model outputs
    preds, attns, num_layers, num_heads = get_model_outputs(model_name, series)

    # Create 2x1 subplot figure
    fig = make_subplots(
        rows=2,
        cols=1,
        subplot_titles=["Data, Prediction, and Attention Scores", "FFT of Data and Attention Scores"],
        vertical_spacing=0.22,
        specs=[[{"secondary_y": True}], [{"secondary_y": True}]]
    )

    # Colors for consistent coloring across plots
    colors = px.colors.qualitative.Plotly

    # Process attention data
    attn_traces = process_attention_data(attns, num_layers, num_heads, colors)

    # Add attention score traces to first subplot
    add_attention_traces(fig, attn_traces)

    # Add FFT traces and get max FFT magnitude
    max_attn_fft = add_fft_traces(fig, attn_traces, series)

    # Add original series and prediction traces
    original_color = add_original_and_prediction_traces(fig, series, preds)

    # Add original series FFT
    add_original_fft_trace(fig, series, original_color)

    # Update layout for the figure
    fig.update_layout(
        title=f"Model: {model_name}",
        height=900,
        width=1000,
        template="plotly_white",
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    # Set x-axis range for first plot from 0 to last prediction point
    last_x_value = len(series[0]) + len(preds[0,0]) - 1
    fig.update_xaxes(title_text="Token Position", row=1, col=1, range=[0, last_x_value])
    fig.update_xaxes(title_text="Frequency", row=2, col=1)

    # Update y-axis titles
    fig.update_yaxes(title_text="Amplitude", secondary_y=False, row=1, col=1)
    fig.update_yaxes(title_text="Attention Score", secondary_y=True, row=1, col=1, range=[0, 1])

    fig.update_yaxes(title_text="FFT Magnitude (Series)", secondary_y=False, row=2, col=1)
    fig.update_yaxes(title_text="FFT Magnitude (Attention)", secondary_y=True, row=2, col=1, range=[0, max_attn_fft * 1.1])

    return fig


### 2.2 Attention Scores Throughout a Multisine Series

Note: switching between models is quite slow, so you may need to wait a minute or two for the plot to reflect the changes in the dropdown menus.

In [None]:
enable_plotly_in_cell()
init_notebook_mode(connected=True)

# Define the model names
model_names = ["amazon/chronos-t5-mini", "amazon/chronos-t5-small", "amazon/chronos-t5-base", "amazon/chronos-t5-large"]

# Generate the sinusoidal data
_, series = sinusoidal(k=[5,10,20,40], amp=[1,0.75,0.5,0.5], t_1=3, steps=301)
series = torch.tensor(series).unsqueeze(0)

fft_model_dd = widgets.Dropdown(
    options=model_names,
    description='Model:',
    layout=widgets.Layout(width='50%')
)

# One Output area for the FFT plot
fft_out = widgets.Output()

def update_fft_plot(change=None):
    # build FFT figure
    fig = create_fft_figure(fft_model_dd.value, series)
    # extract only the <div>+<script> snippet
    html_snip = to_html(fig, include_plotlyjs=False, full_html=False)
    
    # clear old, display new
    fft_out.clear_output(wait=True)
    with fft_out:
        display(HTML(html_snip))

fft_model_dd.observe(update_fft_plot, names='value')

# Trigger initial draw
display(fft_model_dd, fft_out)
update_fft_plot()

### 2.2.1 Interpreting the FFT of the Attention Scores

As expected, we do see quite a few heads spike up around the frequencies of the original data. However, there is also something even more interesting at play here with the periodicity of the FFT of the attention scores.

To be more specific notice how the period of the FFT of many of the attention heads coincides with the smallest difference in the frequencies of the original data. For example, if the original data had frequencies `[5,10,20,40]`, then the we will find that the FFT of the attention scores of many heads will be periodic with a period of `5 Hz` since the two closest frequncies, `5 Hz` and `10 Hz` have a difference of `5 Hz`. This pattern holds true across all the models and many different multisine series.