In [202]:
import torch
from chronos import BaseChronosPipeline

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import attn_lens as attn

In [238]:
pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
    device_map="cpu",  # use "cpu" for CPU inference
    torch_dtype=torch.bfloat16,
)

tokenizer, t5_model = pipeline.tokenizer, pipeline.model.model

In [266]:
def sinusoidal(k, t_1=1, steps=100, amp=1, phase_shift=0):
    """
    Generate a sinusoidal function with frequency k.
    
    Args:
        k: Frequency of the sinusoid in Hz or array of frequencies
        t_1: Total time period
        steps: Number of time steps
        amp: Amplitude of the sinusoid or array of amplitudes
        phase_shift: Phase shift in radians
        
    Returns:
        Array of sinusoidal values
    """
    t = np.linspace(0, t_1, steps, endpoint=False)
    
    # Convert k to numpy array if it isn't already
    k_array = np.atleast_1d(k)
    
    # Handle amp: if k is array and amp is scalar, use the same amp for all frequencies
    if isinstance(amp, (int, float)) and len(k_array) > 1:
        amp_array = np.full_like(k_array, amp, dtype=float)
    else:
        amp_array = np.atleast_1d(amp)
        # Check that k and amp have the same length if amp is also an array
        if len(k_array) != len(amp_array):
            raise ValueError("k and amp must have the same length when passed as arrays")
    
    # Initialize y with zeros
    y = np.zeros(steps)
    
    # Sum the sinusoidal functions
    for k_i, amp_i in zip(k_array, amp_array):
        y += amp_i * np.sin(2 * np.pi * k_i * t + phase_shift)
    
    return t, y

def tokenize_series(series, tokenizer, num_decoder=0):
    assert len(series.shape) == 2
    enc_ids, attention_mask, scale = tokenizer.context_input_transform(series)
    dec_ids = torch.zeros((enc_ids.shape[0],1), dtype=int)
    
    if num_decoder != 0:
        dec_ids = torch.cat([dec_ids, enc_ids[:,-(num_decoder+1):-1]], dim=-1)
        enc_ids = torch.cat([enc_ids[:,:-(num_decoder+1)], torch.ones((enc_ids.shape[0],1), dtype=int)], dim=-1)

    attention_mask = torch.ones_like(enc_ids)

    return enc_ids, dec_ids, attention_mask, scale

_, series = sinusoidal(k=[5,10,20,40], amp=[1,0.75,0.5,0.5], t_1=3, steps=301)
series = torch.tensor(series).unsqueeze(0)
enc_ids, dec_ids, attention_mask, scale = tokenize_series(series, tokenizer, num_decoder=0)
enc_ids.shape, dec_ids.shape

(torch.Size([1, 302]), torch.Size([1, 1]))

In [250]:
# df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
# series = torch.tensor(df["#Passengers"].values).unsqueeze(0)

# enc_ids, attention_mask, scale = tokenizer.context_input_transform(series)
# dec_ids = None

In [267]:
outputs = t5_model.generate(input_ids=enc_ids, 
                            attention_mask=attention_mask, 
                            max_new_tokens=1,
                            decoder_input_ids=dec_ids,
                            num_return_sequences=1,
                            do_sample=False,
                            use_cache=False,
                            output_attentions=True,
                            output_scores=True,
                            output_hidden_states=True,
                            return_dict_in_generate=True
                            )
preds = tokenizer.output_transform(outputs.sequences[...,1:], scale)

In [261]:
outputs.cross_attentions[0][0].shape

torch.Size([1, 8, 2, 301])

In [268]:
num_layers, num_heads = t5_model.config.num_decoder_layers, t5_model.config.num_heads
target_token = 0

attns = {}
for layer in range(num_layers):
    attns[layer] = outputs.cross_attentions[0][layer][0, :, target_token, :-1]

attns[0].shape

torch.Size([8, 301])

In [269]:
# Create a single figure to overlay all plots
fig = go.Figure()

# Add original series
fig.add_trace(go.Scatter(
    x=np.arange(len(series[0])),
    y=series[0].numpy(),
    mode='lines',
    name='Full Series',
    line=dict(color='royalblue')
))

# Add predictions part
fig.add_trace(go.Scatter(
    x=np.arange(len(series[0])-1, len(series[0]) + len(preds[0,0])),
    y=np.concatenate([[series[0].numpy()[-1]], preds[0,0].numpy()]),
    mode='lines',
    name='Predictions',
    line=dict(color='firebrick', width=2)
))

# Add a vertical line to mark the transition point
fig.add_vline(x=len(series[0])-1, line_dash="dash", line_color="gray",
              annotation_text="Prediction Start", annotation_position="top right")

# Add attention scores for each layer and head
colors = px.colors.qualitative.Plotly
color_idx = 0
for layer in range(num_layers):
    for head in range(num_heads):
        attn_scores = attns[layer][head].cpu().float().numpy()
        
        fig.add_trace(go.Scatter(
            x=np.arange(len(attn_scores)),
            y=attn_scores,
            mode='lines',
            name=f'Layer {layer+1}, Head {head+1}',
            line=dict(color=colors[color_idx % len(colors)]),
            opacity=0.5,
            yaxis='y2'
        ))
        color_idx += 1

# Set up the layout with two y-axes
fig.update_layout(
    title="",
    xaxis=dict(
        title="Token Position",
        domain=[0, 0.94]
    ),
    yaxis=dict(
        title="Amplitude",
        color="black",
        tickfont=dict(color="black")
    ),
    yaxis2=dict(
        title="Attention Score",
        color="grey",
        tickfont=dict(color="grey"),
        anchor="x",
        overlaying="y",
        side="right",
        range=[0, 1]  # Set the range from 0 to 1
    ),
    template="plotly_white",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    ),
    height=700,
    width=1000
)

# Display the plot
fig.show()


In [272]:
attns[0][0].shape, series.shape

(torch.Size([301]), torch.Size([1, 301]))

In [273]:
# Perform Discrete Fourier Transform on the series data
fft_result = torch.fft.fft(series)
fft_magnitude = torch.abs(fft_result)

# Convert to numpy for plotting
frequencies = np.fft.fftfreq(len(series[0]), d=3/series.shape[1])
fft_magnitude_np = fft_magnitude[0].numpy()

# Create a new figure for plotting all FFTs
fig = go.Figure()

# Add the original series FFT
fig.add_trace(go.Scatter(
    x=frequencies[:len(frequencies)//2],  # Only plot positive frequencies
    y=fft_magnitude_np[:len(fft_magnitude_np)//2],  # Only plot positive frequencies
    mode='lines',
    name='Original Time Series',
    line=dict(width=3, color='black')
))

# Add FFTs for all attention scores using the attns variable
color_idx = 0
max_attn_fft = 0  # Track maximum attention FFT magnitude for scaling

# First pass to find the maximum attention FFT magnitude for scaling
for layer in attns.keys():
    layer_attns = attns[layer]  # Shape [8, 202]
    for head in range(layer_attns.shape[0]):
        # Get attention scores for this head
        attn_scores = layer_attns[head].cpu().float().numpy()
        
        # Perform FFT on attention scores
        attn_fft = np.fft.fft(attn_scores)
        attn_fft_magnitude = np.abs(attn_fft)
        
        # Update max value for scaling
        max_attn_fft = max(max_attn_fft, np.max(attn_fft_magnitude[:len(attn_fft_magnitude)//2]))

# Second pass to add the traces
for layer in attns.keys():
    layer_attns = attns[layer]  # Shape [8, 202]
    for head in range(layer_attns.shape[0]):
        # Get attention scores for this head
        attn_scores = layer_attns[head].cpu().float().numpy()
        
        # Perform FFT on attention scores
        attn_fft = np.fft.fft(attn_scores)
        attn_fft_magnitude = np.abs(attn_fft)
        
        # Get frequencies
        attn_frequencies = np.fft.fftfreq(len(attn_scores), d=3/series.shape[1])
        
        # Add to plot with secondary y-axis
        fig.add_trace(go.Scatter(
            x=attn_frequencies[:len(attn_frequencies)//2],  # Only positive frequencies
            y=attn_fft_magnitude[:len(attn_fft_magnitude)//2],
            mode='lines',
            name=f'Layer {layer+1}, Head {head+1}',
            line=dict(color=colors[color_idx % len(colors)]),
            opacity=0.1,
            yaxis='y2'  # Use secondary y-axis
        ))
        color_idx += 1

# Customize the layout with two y-axes
fig.update_layout(
    title="",
    xaxis=dict(
        title="Frequency",
        domain=[0, 0.94]
    ),
    yaxis=dict(
        title="Magnitude (Original Series)",
        color="black",
        tickfont=dict(color="black")
    ),
    yaxis2=dict(
        title="Magnitude (Attention Scores)",
        color="grey",
        tickfont=dict(color="grey"),
        anchor="x",
        overlaying="y",
        side="right",
        range=[0, max_attn_fft * 1.1]  # Scale appropriately based on max value
    ),
    template="plotly_white",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    ),
    height=700,
    width=1000
)

# Display the plot
fig.show()

# Create a new figure for plotting double FFTs (FFT of FFT)
double_fft_fig = go.Figure()

color_idx = 0
max_double_fft = 0  # Track maximum double FFT magnitude for scaling

# First pass to find the maximum double FFT magnitude for scaling
for layer in attns.keys():
    layer_attns = attns[layer]
    for head in range(layer_attns.shape[0]):
        # Get attention scores for this head
        attn_scores = layer_attns[head].cpu().float().numpy()
        
        # Perform first FFT on attention scores
        attn_fft = np.fft.fft(attn_scores)
        
        # Perform second FFT on the magnitude of the first FFT
        double_fft = np.fft.fft(np.abs(attn_fft))
        double_fft_magnitude = np.abs(double_fft)
        
        # Update max value for scaling
        max_double_fft = max(max_double_fft, np.max(double_fft_magnitude[:len(double_fft_magnitude)//2]))

# Second pass to add the traces for double FFT
for layer in attns.keys():
    layer_attns = attns[layer]
    for head in range(layer_attns.shape[0]):
        # Get attention scores for this head
        attn_scores = layer_attns[head].cpu().float().numpy()
        
        # Perform first FFT on attention scores
        attn_fft = np.fft.fft(attn_scores)
        
        # Perform second FFT on the magnitude of the first FFT
        double_fft = np.fft.fft(np.abs(attn_fft))
        double_fft_magnitude = np.abs(double_fft)
        
        # Get frequencies for double FFT
        double_fft_frequencies = np.fft.fftfreq(len(double_fft), d=1)
        
        # Filter for frequencies > 0
        positive_freq_indices = double_fft_frequencies > 0
        
        # Add to plot (only frequencies > 0)
        double_fft_fig.add_trace(go.Scatter(
            x=double_fft_frequencies[positive_freq_indices],
            y=double_fft_magnitude[positive_freq_indices],
            mode='lines',
            name=f'Layer {layer+1}, Head {head+1}',
            line=dict(color=colors[color_idx % len(colors)]),
            opacity=0.3  # Set opacity to 0.3 as requested
        ))
        color_idx += 1

# Customize the layout for double FFT plot
double_fft_fig.update_layout(
    title="FFT of FFT",
    xaxis=dict(
        title="Frequency",
    ),
    yaxis=dict(
        title="Magnitude",
    ),
    template="plotly_white",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    ),
    height=700,
    width=1000
)

# Display the double FFT plot
double_fft_fig.show()
