In [1]:
import torch
from chronos import BaseChronosPipeline
import plotly.express as px
import numpy as np
import pandas as pd

import rrt_utils as rrt
import attn_lens as attn

[2025-05-11 10:26:29,966] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cpu (auto detect)


In [2]:
pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
    device_map="cpu",  # use "cpu" for CPU inference
    torch_dtype=torch.bfloat16,
)

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
series = torch.tensor(df["#Passengers"].values).unsqueeze(0)

tokenizer, t5_model = pipeline.tokenizer, pipeline.model.model
input_ids, attention_mask = attn.tokenize_series(series, tokenizer)

In [4]:
outputs, cross_attention_probs = attn.attn_lens(input_ids, attention_mask, t5_model, max_new_tokens=1)

In [7]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create figure with 2x3 grid of subplots, each with secondary y-axis
fig = make_subplots(
    rows=2, cols=3,
    specs=[[{"secondary_y": True}, {"secondary_y": True}, {"secondary_y": True}],
           [{"secondary_y": True}, {"secondary_y": True}, {"secondary_y": True}]],
    subplot_titles=[f"Layer {i}" for i in range(6)],
    horizontal_spacing=0.03,  # Reduce horizontal spacing between subplots
    vertical_spacing=0.1      # Reduce vertical spacing between subplots
)

t_idx = 0 # decoder position

# Define consistent colors using a professional color palette
context_color = '#1f77b4'  # A nice blue from the Plotly default palette
# Using the default plotly color palette
head_colors = ['#ff7f0e', '#2ca02c', '#d62728', '#9467bd', 
               '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22']

# For each layer
for layer_idx in range(6):
    row = layer_idx // 3 + 1
    col = layer_idx % 3 + 1
    
    # Add passenger count trace with consistent color
    fig.add_trace(
        go.Scatter(
            y=df["#Passengers"],
            mode='lines',
            name='Context' if layer_idx == 0 else 'Context_' + str(layer_idx),
            showlegend=True if layer_idx == 0 else False,
            line=dict(color=context_color, width=2)
        ),
        row=row, col=col,
        secondary_y=False
    )
    
    # Add attention probabilities for each head with consistent colors
    for head_idx in range(8):
        attn_probs = cross_attention_probs[0][layer_idx][0][head_idx, t_idx, :].float().cpu().numpy()
        fig.add_trace(
            go.Scatter(
                y=attn_probs,
                mode='lines',
                name=f'Head {head_idx}' if layer_idx == 0 else f'Head_{head_idx}_L{layer_idx}',
                showlegend=True if layer_idx == 0 else False,
                line=dict(color=head_colors[head_idx], width=1.5)
            ),
            row=row, col=col,
            secondary_y=True
        )

# Update layout
fig.update_layout(
    height=800,
    width=1800,
    title=dict(
        text="Passenger Count vs Cross Attention Probabilities Across Layers",
        x=0.5,
        y=0.965,
        xref="paper",
        yanchor="top",
        font=dict(
            size=24
        )
    ),
    showlegend=True,
    legend=dict(
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="right",
        x=1.03,
    ),
    margin=dict(l=50, r=50, t=80, b=50)  # Reduce margins to maximize plot area
)

# Update x-axes only for plots on the bottom edge (row=2)
for col in range(1, 4):  # Columns 1, 2, 3
    fig.update_xaxes(title_text="Time", tickangle=45, tickfont=dict(size=10), row=2, col=col)
    fig.update_xaxes(title_text="", tickangle=45, tickfont=dict(size=10), row=1, col=col)

    if col == 1:
        fig.update_yaxes(title_text="Passenger Count", secondary_y=False, row=1, col=col)
        fig.update_yaxes(title_text="Passenger Count", secondary_y=False, row=2, col=col)
    else:
        fig.update_yaxes(title_text="", secondary_y=False, row=1, col=col)
        fig.update_yaxes(title_text="", secondary_y=False, row=2, col=col)
    if col == 3:
        fig.update_yaxes(title_text="Cross Attention Probabilities", secondary_y=True, range=[0, 1], row=1, col=col)
        fig.update_yaxes(title_text="Cross Attention Probabilities", secondary_y=True, range=[0, 1], row=2, col=col)
    else:
        fig.update_yaxes(title_text="", secondary_y=True, range=[0, 1], row=1, col=col)
        fig.update_yaxes(title_text="", secondary_y=True, range=[0, 1], row=2, col=col)

fig.show()