In [None]:
import torch
from chronos import BaseChronosPipeline

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import rrt_utils as rrt

In [2]:
pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
    device_map="cpu",  # use "cpu" for CPU inference
    torch_dtype=torch.bfloat16,
)

tokenizer, t5_model = pipeline.tokenizer, pipeline.model.model

In [90]:
def sinusoidal(k, t_1=1, steps=100, amp=1, phase_shift=0):
    """
    Generate a sinusoidal function with frequency k.
    
    Args:
        k: Frequency of the sinusoid in Hz or array of frequencies
        t_1: Total time period
        steps: Number of time steps
        amp: Amplitude of the sinusoid or array of amplitudes
        phase_shift: Phase shift in radians
        
    Returns:
        Array of sinusoidal values
    """
    t = np.linspace(0, t_1, steps, endpoint=False)
    
    # Convert k to numpy array if it isn't already
    k_array = np.atleast_1d(k)
    
    # Handle amp: if k is array and amp is scalar, use the same amp for all frequencies
    if isinstance(amp, (int, float)) and len(k_array) > 1:
        amp_array = np.full_like(k_array, amp, dtype=float)
    else:
        amp_array = np.atleast_1d(amp)
        # Check that k and amp have the same length if amp is also an array
        if len(k_array) != len(amp_array):
            raise ValueError("k and amp must have the same length when passed as arrays")
    
    # Initialize y with zeros
    y = np.zeros(steps)
    
    # Sum the sinusoidal functions
    for k_i, amp_i in zip(k_array, amp_array):
        y += amp_i * np.sin(2 * np.pi * k_i * t + phase_shift)
    
    return t, y

def get_period(k):
    """
    Get the period of a multisine function with frequencies k

    Args:
        k: Frequency of the sinusoid in Hz or array of frequencies

    Returns:
        Period of the multisine function
    """
    min_k = np.min(k)
    return 1 / min_k


def tokenize_series(series, tokenizer, num_decoder=0):
    """
    Tokenize a sinusoidal time series for use with Chronos models.
    
    Takes a continuous time series and converts it to the tokenized format
    expected by the Chronos model. Handles creation of encoder and decoder
    inputs with optional decoder autoregressive context.
    
    Args:
        series: Tensor of shape (batch_size, sequence_length) containing time series data
        tokenizer: Chronos tokenizer instance to use for tokenization
        num_decoder: Number of decoder tokens to include from the input sequence for 
                    autoregressive generation (default: 0)
    
    Returns:
        tuple of (encoder_input_ids, decoder_input_ids, attention_mask, scale) where:
            - encoder_input_ids: Tokenized input series for the encoder
            - decoder_input_ids: Initial tokens for the decoder (includes BOS token)
            - attention_mask: Attention mask for the encoder inputs
            - scale: Scaling factor applied during tokenization
    """
    assert len(series.shape) == 2
    enc_ids, attention_mask, scale = tokenizer.context_input_transform(series)
    dec_ids = torch.zeros((enc_ids.shape[0],1), dtype=int)
    
    if num_decoder != 0:
        dec_ids = torch.cat([dec_ids, enc_ids[:,-(num_decoder+1):-1]], dim=-1)
        enc_ids = torch.cat([enc_ids[:,:-(num_decoder+1)], torch.ones((enc_ids.shape[0],1), dtype=int)], dim=-1)

    attention_mask = torch.ones_like(enc_ids)

    return enc_ids, dec_ids, attention_mask, scale

def random_tokens(num_unique_sequences=1, repeats=10, extension=0, sub_extension=0, sequence_length=10, batch_size=1, decode=True):
    """
    Generate random repeating tokens for testing attention mechanisms.
    
    This function creates a batch of token sequences with repeating patterns.
    It generates random token sequences and repeats them a specified number of times,
    optionally with extensions and partial sequences.
    
    Args:
        num_unique_sequences: Number of unique sequences to generate
        repeats: Number of times to repeat each sequence
        extension: Number of additional sequences to repeat after the main repetitions
        sub_extension: Number of tokens to include from the next sequence
        sequence_length: Length of each unique sequence
        batch_size: number of batches
        
    Returns:
        The decoded time series as a tensor of shape (batch_size, T)
    """
    # define the vocab as all the tokens except the special tokens
    vocab = torch.tensor([i for i in range(4096) if i >= 1911 and i <= 2187])

    tokens = [rrt.generate_random_token_ids(vocab, sequence_length, batch_size=batch_size, include_eos=False) for _ in range(num_unique_sequences)]
    if decode:
        enc_ids = rrt.stack_sequences(tokens * repeats + tokens[:extension] + [tokens[extension][:,:sub_extension]], include_eos=False)
        series = tokenizer.output_transform(enc_ids, scale=torch.tensor(1))
    else:
        enc_ids = rrt.stack_sequences(tokens * repeats + tokens[:extension] + [tokens[extension][:,:sub_extension]], include_eos=True)
        series = enc_ids

    return series


In [113]:
t = 1
steps = 400
k = [5,10,20,40]
amp = [1,0.75,0.5,0.5]
period = int(get_period(k) * steps / t)

_, sine_series = sinusoidal(k=[5,10,20,40], amp=[1,0.75,0.5,0.5], t_1=1, steps=400)
sine_series = torch.tensor(sine_series).unsqueeze(0)

rrt_series = random_tokens(num_unique_sequences=1, repeats=5, sequence_length=period, decode=True) * torch.abs(torch.max(sine_series))
# enc_ids, dec_ids, attention_mask, scale = tokenize_series(series, tokenizer, num_decoder=0)
# enc_ids.shape, dec_ids.shape

In [114]:
# Create a plot of the time series using plotly
fig = go.Figure()
fig.add_trace(go.Scatter(
    y=sine_series[0],  # Use the first batch item
    mode='lines',
    name='Time Series'
))

fig.add_trace(go.Scatter(
    y=rrt_series[0],  # Use the first batch item
    mode='lines',
    name='Time Series'
))

fig.update_layout(
    title='Time Series Plot',
    xaxis_title='Time Step',
    yaxis_title='Value',
    template='plotly_white'
)

fig.show()


In [124]:
# Function to create mixed series with probability p
def create_mixed_series(sine_series, rrt_series, p):
    """
    Create a mixed series where elements from sine_series are replaced with 
    elements from rrt_series with probability p
    
    Args:
        sine_series: Original sine series
        rrt_series: Random token series
        p: Probability of replacement (0.0 to 1.0)
        
    Returns:
        Mixed series
    """
    # Convert rrt_series to match dtype of sine_series
    rrt_series = rrt_series.to(dtype=sine_series.dtype)
    
    mask = torch.rand_like(sine_series) < p
    mixed_series = sine_series.clone()
    mixed_series[mask] = rrt_series[mask]
    return mixed_series

# Function to compute FFT
def compute_fft(series, t_1=1, steps=400):
    """
    Compute the FFT of a time series
    
    Args:
        series: Time series data
        t_1: Total time period
        steps: Number of time steps
        
    Returns:
        Frequencies and magnitudes
    """
    # Remove batch dimension
    series = series.squeeze().numpy()
    
    # Compute FFT
    fft_result = np.fft.rfft(series)
    fft_freq = np.fft.rfftfreq(len(series), d=t_1/steps)
    fft_mag = np.abs(fft_result)
    
    return fft_freq, fft_mag

# Function to compute FFT of attention matrices
def compute_attention_fft(attention_matrix):
    """
    Compute the FFT of attention matrix
    
    Args:
        attention_matrix: Attention matrix from model output
        
    Returns:
        FFT magnitudes of the attention pattern
    """
    # Compute FFT along sequence dimension
    fft_result = np.fft.rfft(attention_matrix, axis=-1)
    fft_mag = np.abs(fft_result)
    return fft_mag


# Create mixed series for different p values
p_values = np.linspace(0.0, 1.0, 11)  # 0.0, 0.1, 0.2, ..., 1.0
mixed_series_dict = {}
tokenized_series_dict = {}
attention_scores_dict = {}

for p in p_values:
    # Create mixed series
    mixed_series = create_mixed_series(sine_series, rrt_series, p)
    mixed_series_dict[p] = mixed_series
    
    # Tokenize the series
    enc_ids, dec_ids, attention_mask, scale = tokenize_series(mixed_series, tokenizer, num_decoder=1)
    tokenized_series_dict[p] = (enc_ids, dec_ids, attention_mask, scale)

# Generate one token for each mixed series and record attention
for p in p_values:
    enc_ids, dec_ids, attention_mask, scale = tokenized_series_dict[p]
    
    # Generate one new token with deterministic sampling
    outputs = t5_model.generate(
        input_ids=enc_ids,
        attention_mask=attention_mask,
        max_new_tokens=1,
        decoder_input_ids=dec_ids,
        num_return_sequences=1,
        do_sample=False,
        use_cache=False,
        output_attentions=True,
        output_scores=True,
        output_hidden_states=True,
        return_dict_in_generate=True
    )
    
    # Store attention scores
    num_layers = t5_model.config.num_decoder_layers
    num_heads = t5_model.config.num_heads
    target_token = 1  # Index for the generated token
    
    attention_scores_dict[p] = []
    for layer in range(num_layers):
        # Get cross-attention scores for all heads at this layer
        # This will be a tensor of shape [num_heads, src_tokens]
        layer_attention = outputs.cross_attentions[0][layer][0, :, target_token, :-1]
        attention_scores_dict[p].append(layer_attention)

In [128]:
# Compute FFT for sine and RRT series
sine_freq, sine_mag = compute_fft(sine_series)
rrt_freq, rrt_mag = compute_fft(rrt_series)

# Compute FFT for each mixed series
fft_results = {}
for p, series in mixed_series_dict.items():
    freq, mag = compute_fft(series)
    fft_results[p] = (freq, mag)

# Visualize the mixed series
fig_series = make_subplots(
    rows=4, cols=3,
    subplot_titles=[f"p = {p:.1f}" for p in p_values],
    specs=[[{}, {}, {}], [{}, {}, {}], [{}, {}, {}], [{}, {}, {}]]
)

# Add time series data
row, col = 1, 1
for p in p_values:
    if row <= 4 and col <= 3:
        # Add mixed time series
        fig_series.add_trace(
            go.Scatter(
                y=mixed_series_dict[p].squeeze().numpy(),
                mode='lines',
                name=f'p = {p:.1f}',
                showlegend=False
            ),
            row=row, col=col
        )
        
        # Move to next subplot
        col += 1
        if col > 3:
            col = 1
            row += 1

# Update layout
fig_series.update_layout(
    height=800,
    width=1200,
    title_text="Mixed Time Series (p = probability of using RRT values)"
)

# Show the figure
fig_series.show()

# Visualize FFT results
fig_fft = make_subplots(
    rows=4, cols=3,
    subplot_titles=[f"FFT (p = {p:.1f})" for p in p_values],
    specs=[[{}, {}, {}], [{}, {}, {}], [{}, {}, {}], [{}, {}, {}]]
)

# Add FFT data
row, col = 1, 1
for p in p_values:
    if row <= 4 and col <= 3:
        freq, mag = fft_results[p]
        
        # Add FFT magnitude
        fig_fft.add_trace(
            go.Scatter(
                x=freq[:50],  # Show only first 50 frequencies
                y=mag[:50],
                mode='lines',
                name=f'p = {p:.1f}',
                showlegend=False
            ),
            row=row, col=col
        )
        
        # Move to next subplot
        col += 1
        if col > 3:
            col = 1
            row += 1

# Update layout
fig_fft.update_layout(
    height=800,
    width=1200,
    title_text="FFT Analysis of Mixed Time Series"
)

# Show the figure
fig_fft.show()

# Visualize FFT of attention patterns
fig_attn_fft = make_subplots(
    rows=4, cols=3,
    subplot_titles=[f"Attention FFT (p = {p:.1f})" for p in p_values],
    specs=[[{}, {}, {}], [{}, {}, {}], [{}, {}, {}], [{}, {}, {}]]
)

# Plot FFT of attention matrices for each layer and head
row, col = 1, 1
for p in p_values:
    if row <= 4 and col <= 3:
        # Get attention data for this p value
        layer_attentions = attention_scores_dict[p]
        
        # Extract and process all attention patterns
        all_fft_mags = []
        
        # Handle different data structures
        if isinstance(layer_attentions, dict):
            for layer_idx, layer_attn in layer_attentions.items():
                for head_idx in range(layer_attn.shape[0]):
                    try:
                        head_attn = layer_attn[head_idx].cpu().float().numpy()
                        if not np.isnan(head_attn).any() and np.size(head_attn) > 0:
                            fft_mag = compute_attention_fft(head_attn)
                            all_fft_mags.append(fft_mag)
                    except:
                        continue
        elif isinstance(layer_attentions, list):
            for layer_idx, layer_attn in enumerate(layer_attentions):
                try:
                    if isinstance(layer_attn, torch.Tensor) and len(layer_attn.shape) >= 2:
                        for head_idx in range(layer_attn.shape[0]):
                            head_attn = layer_attn[head_idx].cpu().float().numpy()
                            if not np.isnan(head_attn).any() and np.size(head_attn) > 0:
                                fft_mag = compute_attention_fft(head_attn)
                                all_fft_mags.append(fft_mag)
                    else:
                        attn_data = layer_attn.cpu().float().numpy()
                        fft_mag = compute_attention_fft(attn_data)
                        all_fft_mags.append(fft_mag)
                except:
                    continue
        else:
            try:
                attn_data = layer_attentions.cpu().float().numpy()
                fft_mag = compute_attention_fft(attn_data)
                all_fft_mags.append(fft_mag)
            except:
                print(f"Could not process attention data for p={p}")
        
        # Plot each FFT without individual labels
        for fft_mag in all_fft_mags:
            # Get frequency axis (normalized)
            freq_axis = np.arange(fft_mag.shape[0]) / fft_mag.shape[0]
            
            # Plot FFT magnitude without legend entry
            fig_attn_fft.add_trace(
                go.Scatter(
                    x=freq_axis,
                    y=fft_mag,
                    mode='lines',
                    opacity=0.7,
                    showlegend=False,
                    line=dict(width=1)
                ),
                row=row, col=col
            )
        
        # Move to next subplot
        col += 1
        if col > 3:
            col = 1
            row += 1

# Update layout
fig_attn_fft.update_layout(
    height=800,
    width=1200,
    title_text="FFT of Attention Patterns Across Layers and Heads"
)

# Show the figure
fig_attn_fft.show()