In [18]:
import torch
from chronos import BaseChronosPipeline

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import pickle

In [2]:
def sinusoidal(k, t_1=1, steps=100, amp=1, phase_shift=0):
    """
    Generate a sinusoidal function with frequency k.

    Args:
        k: Frequency of the sinusoid in Hz or array of frequencies
        t_1: Total time period
        steps: Number of time steps
        amp: Amplitude of the sinusoid or array of amplitudes
        phase_shift: Phase shift in radians

    Returns:
        Array of sinusoidal values
    """
    t = np.linspace(0, t_1, steps, endpoint=False)

    # Convert k to numpy array if it isn't already
    k_array = np.atleast_1d(k)

    # Handle amp: if k is array and amp is scalar, use the same amp for all frequencies
    if isinstance(amp, (int, float)) and len(k_array) > 1:
        amp_array = np.full_like(k_array, amp, dtype=float)
    else:
        amp_array = np.atleast_1d(amp)
        # Check that k and amp have the same length if amp is also an array
        if len(k_array) != len(amp_array):
            raise ValueError("k and amp must have the same length when passed as arrays")

    # Initialize y with zeros
    y = np.zeros(steps)

    # Sum the sinusoidal functions
    for k_i, amp_i in zip(k_array, amp_array):
        y += amp_i * np.sin(2 * np.pi * k_i * t + phase_shift)

    return t, y

def tokenize_series(series, tokenizer, num_decoder=0):
    assert len(series.shape) == 2
    enc_ids, attention_mask, scale = tokenizer.context_input_transform(series)
    dec_ids = torch.zeros((enc_ids.shape[0],1), dtype=int)

    if num_decoder != 0:
        dec_ids = torch.cat([dec_ids, enc_ids[:,-(num_decoder+1):-1]], dim=-1)
        enc_ids = torch.cat([enc_ids[:,:-(num_decoder+1)], torch.ones((enc_ids.shape[0],1), dtype=int)], dim=-1)

    attention_mask = torch.ones_like(enc_ids)

    return enc_ids, dec_ids, attention_mask, scale

def get_model_outputs(model_name, series):
    """Load model and generate predictions with attention outputs"""
    pipeline = BaseChronosPipeline.from_pretrained(
        model_name,
        device_map="cpu",
        torch_dtype=torch.bfloat16,
    )
    tokenizer, t5_model = pipeline.tokenizer, pipeline.model.model

    enc_ids, dec_ids, attention_mask, scale = tokenize_series(series, tokenizer, num_decoder=0)

    outputs = t5_model.generate(input_ids=enc_ids,
                               attention_mask=attention_mask,
                               max_new_tokens=1,
                               decoder_input_ids=dec_ids,
                               num_return_sequences=1,
                               do_sample=False,
                               use_cache=False,
                               output_attentions=True,
                               output_scores=True,
                               output_hidden_states=True,
                               return_dict_in_generate=True)

    preds = tokenizer.output_transform(outputs.sequences[...,1:], scale)

    # Extract attention scores
    num_layers, num_heads = t5_model.config.num_decoder_layers, t5_model.config.num_heads
    target_token = 0

    attns = {}
    for layer in range(num_layers):
        attns[layer] = outputs.cross_attentions[0][layer][0, :, target_token, :-1]

    return preds, attns, num_layers, num_heads

In [61]:
with open('variables/mosaic_scores.pkl', 'rb') as file:
    attn_scores = pickle.load(file)

def get_attention_score(layer, head, model, rrt_config='rf2_sl10'):
    if 'amazon' not in model:
        model = 'amazon/' + model
    center_score = attn_scores[rrt_config]['center_scores'][model][layer][head]
    right_score = attn_scores[rrt_config]['right_scores'][model][layer][head]
    return min(max(center_score, right_score), 0.3)

def process_attention_data(attns, num_layers, num_heads, model_name):
    """Process attention data and prepare for plotting"""
    attn_traces = []
    for layer in range(num_layers):
        for head in range(num_heads):
            attn_traces.append({
                "layer": layer,
                "head": head,
                "scores": attns[layer][head].cpu().float().numpy(),
                "attention_score": get_attention_score(layer, head, model_name.split('/')[-1])
            })
    return attn_traces

def create_fft_figure(model_name, series, attn_colorscale=px.colors.sequential.Bluered):
    """Create the 3x1 grid visualization for a model"""
    # Get model outputs
    preds, attns, num_layers, num_heads = get_model_outputs(model_name, series)
    attention_opacity = 0.2
    
    # Create 3x1 subplot figure
    fig = make_subplots(
        rows=3, cols=1,
        subplot_titles=["Data, Prediction, and Attention Scores", 
                        "FFT of Data and Attention Scores", 
                        "FFT of FFT of Attention Scores"],
        vertical_spacing=0.15,
        specs=[[{"secondary_y": True}], [{"secondary_y": True}], [{}]]
    )

    # Process attention data
    attn_traces= process_attention_data(attns, num_layers, num_heads, model_name)
    
    # Add original series and prediction
    original_color = 'royalblue'
    prediction_color = 'gray'
    
    # Add original series
    fig.add_trace(go.Scatter(
        x=np.arange(len(series[0])),
        y=series[0].numpy(),
        mode='lines',
        name='Original Series',
        line=dict(color=original_color, width=2),
        hoverinfo='y'
    ), row=1, col=1, secondary_y=False)

    # Add predictions
    fig.add_trace(go.Scatter(
        x=np.arange(len(series[0])-1, len(series[0]) + len(preds[0,0])),
        y=np.concatenate([[series[0].numpy()[-1]], preds[0,0].numpy()]),
        mode='lines',
        name='Prediction',
        line=dict(color=prediction_color, width=3),
        hoverinfo='y'
    ), row=1, col=1, secondary_y=False)
    
    # Add original series FFT
    fft_result = torch.fft.fft(series)
    fft_magnitude = torch.abs(fft_result)
    frequencies = np.fft.fftfreq(len(series[0]), d=3/series.shape[1])
    fft_magnitude_np = fft_magnitude[0].numpy()
    
    fig.add_trace(go.Scatter(
        x=frequencies[:len(frequencies)//2],
        y=fft_magnitude_np[:len(fft_magnitude_np)//2],
        mode='lines',
        name='Original Series FFT',
        line=dict(width=3, color=original_color),
        hoverinfo='y'
    ), row=2, col=1, secondary_y=False)
    
    # Process attention traces
    max_attn_fft = 0
    fft_results = []
    
    # Helper to get color from score
    def get_color(score):
        score = score/0.3
        color_idx = int(score * (len(attn_colorscale) - 1))
        return attn_colorscale[color_idx]
    
    # Add attention traces
    for attn_data in attn_traces:
        layer, head = attn_data["layer"], attn_data["head"]
        attn_scores = attn_data["scores"]
        score = attn_data["attention_score"]
        color = get_color(score)
        
        # Add attention scores to first subplot
        fig.add_trace(go.Scatter(
            x=np.arange(len(attn_scores)),
            y=attn_scores,
            mode='lines',
            line=dict(color=color, width=1.5),
            opacity=attention_opacity,
            showlegend=False,
            hovertemplate=f"(l={layer},h={head})<br>Position: %{{x}}<br>Score: %{{y:.3f}}<br>Attention Score: {score:.3f}<extra></extra>"
        ), row=1, col=1, secondary_y=True)
        
        # Compute FFT for attention scores
        attn_fft = np.fft.fft(attn_scores)
        attn_fft_magnitude = np.abs(attn_fft)
        attn_frequencies = np.fft.fftfreq(len(attn_scores), d=3/series.shape[1])
        
        # Store for FFT of FFT
        fft_results.append({
            "layer": layer, "head": head,
            "fft_magnitude": attn_fft_magnitude,
            "attention_score": score
        })
        
        # Add FFT to second subplot
        fig.add_trace(go.Scatter(
            x=attn_frequencies[:len(attn_frequencies)//2],
            y=attn_fft_magnitude[:len(attn_fft_magnitude)//2],
            mode='lines',
            line=dict(color=color, width=1.5),
            opacity=attention_opacity,
            showlegend=False,
            hovertemplate=f"(l={layer+1},h={head+1})<br>Frequency: %{{x:.3f}}<br>Magnitude: %{{y:.3f}}<br>Attention Score: {score:.3f}<extra></extra>"
        ), row=2, col=1, secondary_y=True)
        
        max_attn_fft = max(max_attn_fft, np.max(attn_fft_magnitude[:len(attn_fft_magnitude)//2]))
    
    # Add FFT of FFT traces
    max_fft_of_fft = 0
    for fft_data in fft_results:
        layer, head = fft_data["layer"], fft_data["head"]
        fft_magnitude = fft_data["fft_magnitude"][:len(fft_data["fft_magnitude"])//2]
        score = fft_data["attention_score"]
        color = get_color(score)
        
        # Compute FFT of FFT
        fft_of_fft = np.fft.fft(fft_magnitude)
        fft_of_fft_magnitude = np.abs(fft_of_fft)
        fft_of_fft_frequencies = np.fft.fftfreq(len(fft_magnitude))
        
        # Add to third subplot
        fig.add_trace(go.Scatter(
            x=fft_of_fft_frequencies[:len(fft_of_fft_frequencies)//2],
            y=fft_of_fft_magnitude[:len(fft_of_fft_magnitude)//2],
            mode='lines',
            line=dict(color=color, width=1.5),
            opacity=attention_opacity,
            showlegend=False,
            hovertemplate=f"(l={layer+1},h={head+1})<br>Frequency: %{{x:.3f}}<br>Magnitude: %{{y:.3f}}<br>Attention Score: {score:.3f}<extra></extra>"
        ), row=3, col=1)
        
        max_fft_of_fft = max(max_fft_of_fft, np.max(fft_of_fft_magnitude[:len(fft_of_fft_magnitude)//2]))
    
    # Add colorbar
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode='markers',
            marker=dict(
                colorscale=attn_colorscale,
                showscale=True,
                cmin=0, cmax=0.3,
                colorbar=dict(
                    title='Attention Score',
                    thickness=15,
                    len=0.5,
                    y=0.5,
                    yanchor='middle'
                )
            ),
            showlegend=False,
            hoverinfo='none'
        )
    )

    # Update layout
    fig.update_layout(
        title=f"Model: {model_name}",
        height=1200,
        width=1000,
        template="plotly_white",
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    # Update axes
    last_x_value = len(series[0]) + len(preds[0,0]) - 1
    fig.update_xaxes(title_text="Token Position", row=1, col=1, range=[0, last_x_value])
    fig.update_xaxes(title_text="Frequency", row=2, col=1)
    fig.update_xaxes(title_text="Frequency", row=3, col=1)

    fig.update_yaxes(title_text="Amplitude", secondary_y=False, row=1, col=1)
    fig.update_yaxes(title_text="Attention Score", secondary_y=True, row=1, col=1, range=[0, 1])
    fig.update_yaxes(title_text="FFT Magnitude (Series)", secondary_y=False, row=2, col=1)
    fig.update_yaxes(title_text="FFT Magnitude (Attention)", secondary_y=True, row=2, col=1, range=[0, max_attn_fft * 1.1])
    fig.update_yaxes(title_text="FFT of FFT Magnitude", row=3, col=1, range=[0, max_fft_of_fft * 1.1])

    return fig


In [62]:
model_names = ["amazon/chronos-t5-mini", "amazon/chronos-t5-small", "amazon/chronos-t5-base", "amazon/chronos-t5-large"]

_, series = sinusoidal(k=[5,10,20,40], amp=[1,0.75,0.5,0.5], t_1=3, steps=301)
series = torch.tensor(series).unsqueeze(0)

# Pre-build one figure per model
figs = [create_fft_figure(m, series, attn_colorscale=px.colors.diverging.RdYlGn) for m in model_names]

group_sizes = [len(f.data) for f in figs]
offsets     = [sum(group_sizes[:i]) for i in range(len(figs))]
total_traces = sum(group_sizes)

fig = go.Figure()
for f in figs:
    for tr in f.data:
        fig.add_trace(tr)

fig.update_layout(figs[0].layout)

# Initially show only the *first* model’s traces
init_mask = [
    offsets[0] <= i < offsets[0] + group_sizes[0]
    for i in range(total_traces)
]
for trace, vis in zip(fig.data, init_mask):
    trace.visible = vis


buttons = []
for idx, full_name in enumerate(model_names):
    label = full_name.split("/")[-1]  # strip "amazon/"
    mask = [
        offsets[idx] <= i < offsets[idx] + group_sizes[idx]
        for i in range(total_traces)
    ]
    buttons.append(dict(
        label=label,
        method="update",
        args=[
            {"visible": mask},
            {"title.text": f"Model: {label}"}
        ]
    ))


fig.update_layout(
    updatemenus=[dict(
        type="dropdown",
        x=1.02,        # just outside the right edge of subplot 1
        xanchor="left",
        y=1.0,
        yanchor="top",
        active=0,      # default to the first model
        showactive=True,
        buttons=buttons
    )],
    margin=dict(r=150)  # make space for the menu
)

fig.layout.yaxis2.update(showticklabels=True)
fig.show()

In [19]:
attn_scores['rf2_sl10'].keys()

dict_keys(['center_scores', 'right_scores'])