In [11]:
import torch
from chronos import BaseChronosPipeline

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import os
import json

import rrt_utils as rrt
import attn_lens as attn

In [3]:
def generate_rrt_induction_data(
    vocab_range=(1911, 2187),
    batch_size=100,
    num_unique_sequences=1,
    repeat_factor=4,
    extension=0,
    sub_extension=1,
    sequence_length=10
):
    """
    Generate data for RRT induction experiments.
    
    Args:
        vocab_range: Tuple of (min, max) vocab indices to use
        batch_size: Number of sequences in the batch
        num_unique_sequences: Number of unique sequences in the batch
        repeat_factor: Number of times to repeat the sequences
        extension: After repeating everything, repeat this many sequences again
        sub_extension: Some additional tokens in the next sequence
        sequence_length: Length of each sequence
        
    Returns:
        Tuple of (token_ids, attention_mask, decoder_input_ids)
    """
    # Define the vocab as all the tokens within the specified range
    vocab = torch.tensor([i for i in range(4096) if i >= vocab_range[0] and i <= vocab_range[1]])
    
    # Generate random token sequences
    tokens = [rrt.generate_random_token_ids(vocab, sequence_length, batch_size=batch_size, include_eos=False) 
              for _ in range(num_unique_sequences)]
    
    # Stack sequences with repetition pattern
    token_ids = rrt.stack_sequences(tokens * repeat_factor + tokens[:extension] + [tokens[extension][:,:sub_extension]])
    
    # Create attention mask
    attention_mask = torch.ones_like(token_ids, dtype=torch.bool)
    
    # Create decoder input ids
    decoder_input_ids = torch.cat([
        torch.zeros((batch_size, 1), dtype=torch.long), 
        tokens[extension][:,sub_extension:sub_extension+1]
    ], dim=1)
    
    return token_ids, attention_mask, decoder_input_ids


In [4]:
# Define different values to experiment with
repeat_factors = [2, 4, 6, 8, 10]
sequence_lengths = [2, 4, 6, 8, 10]

model_names = ["amazon/chronos-t5-mini", "amazon/chronos-t5-small", "amazon/chronos-t5-base", "amazon/chronos-t5-large"]

In [15]:
all_results = {}

for repeat_factor in repeat_factors:
    for sequence_length in sequence_lengths:
        print(f"Processing repeat_factor={repeat_factor}, sequence_length={sequence_length}")
        
        num_unique_sequences = 1  # number of unique sequences in the batch
        extension = 0  # after repeating everything, repeat this many sequences again
        sub_extension = 1  # some additional tokens in the next sequence
        
        # Generate data with current configuration
        token_ids, attention_mask, decoder_input_ids = generate_rrt_induction_data(
            num_unique_sequences=num_unique_sequences, 
            repeat_factor=repeat_factor, 
            extension=extension, 
            sub_extension=sub_extension, 
            sequence_length=sequence_length
        )
        
        # Display shapes and sample values
        print(f"token_ids.shape: {token_ids.shape}, decoder_input_ids.shape: {decoder_input_ids.shape}")
        
        # Store results for each model
        config_key = f"rf{repeat_factor}_sl{sequence_length}"
        all_results[config_key] = {
            "center_scores": {},
            "right_scores": {}
        }
        
        for model_name in model_names:
            print(f"Processing {model_name} for {config_key}...")
            
            pipeline = BaseChronosPipeline.from_pretrained(
                model_name,  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
                device_map="cpu",  # use "cpu" for CPU inference
                torch_dtype=torch.bfloat16,
            )

            if "bolt" in model_name:
                t5_model = pipeline.model
            else:
                t5_model = pipeline.model.model

            outputs = t5_model.generate(
                input_ids=token_ids,
                attention_mask=attention_mask,
                max_new_tokens=1,
                decoder_input_ids=decoder_input_ids,
                num_return_sequences=1,
                do_sample=False,
                use_cache=False,
                output_attentions=True,
                output_scores=True,
                output_hidden_states=True,
                return_dict_in_generate=True
            )
            
            # Extract cross attention probabilities
            # cross_attentions is a list of length layers, each with shape [batch, heads, dec_length, enc_length]
            cross_attn_probs = outputs.cross_attentions
            
            t_idx = 1
            s_idx = sequence_length * ((repeat_factor-1)*num_unique_sequences + extension) + sub_extension

            layers, heads = t5_model.config.num_decoder_layers, t5_model.config.num_heads
            mosaic_center = np.zeros((layers, heads)).tolist()
            mosaic_right = np.zeros((layers, heads)).tolist()

            for layer in range(layers):
                for head in range(heads):
                    # mean over the batch
                    mosaic_center[layer][head] = float(cross_attn_probs[0][layer][:, head, t_idx, s_idx].mean())
                    mosaic_right[layer][head] = float(cross_attn_probs[0][layer][:, head, t_idx, s_idx+1].mean())
            
            # Store the scores for this model and configuration
            all_results[config_key]["center_scores"][model_name] = mosaic_center
            all_results[config_key]["right_scores"][model_name] = mosaic_right

# save the results
os.makedirs("variables", exist_ok=True)
with open("variables/results.json", "w") as f:
    json.dump(all_results, f)

Processing repeat_factor=2, sequence_length=2


token_ids.shape: torch.Size([100, 6]), decoder_input_ids.shape: torch.Size([100, 2])
Processing amazon/chronos-t5-mini for rf2_sl2...
Processing amazon/chronos-t5-small for rf2_sl2...
Processing amazon/chronos-t5-base for rf2_sl2...
Processing amazon/chronos-t5-large for rf2_sl2...
Processing repeat_factor=2, sequence_length=4
token_ids.shape: torch.Size([100, 10]), decoder_input_ids.shape: torch.Size([100, 2])
Processing amazon/chronos-t5-mini for rf2_sl4...
Processing amazon/chronos-t5-small for rf2_sl4...
Processing amazon/chronos-t5-base for rf2_sl4...
Processing amazon/chronos-t5-large for rf2_sl4...
Processing repeat_factor=2, sequence_length=6
token_ids.shape: torch.Size([100, 14]), decoder_input_ids.shape: torch.Size([100, 2])
Processing amazon/chronos-t5-mini for rf2_sl6...
Processing amazon/chronos-t5-small for rf2_sl6...
Processing amazon/chronos-t5-base for rf2_sl6...
Processing amazon/chronos-t5-large for rf2_sl6...
Processing repeat_factor=2, sequence_length=8
token_ids.s

In [7]:
# load the results
with open("variables/results.json", "r") as f:
    all_results = json.load(f)

In [9]:
rf = [2, 4, 6, 8, 10]
sl = [2, 4, 6, 8, 10]

import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display
import numpy as np

# Create widget for model selection
model_dropdown = widgets.Dropdown(
    options=model_names,
    value=model_names[0],
    description='Model:'
)

# Create ipywidgets for RF and SL selection
rf_dropdown = widgets.Dropdown(
    options=[(f"RF: {r}", r) for r in rf],
    value=4,  # Initial value
    description='RF:'
)

sl_dropdown = widgets.Dropdown(
    options=[(f"SL: {s}", s) for s in sl],
    value=10,  # Initial value
    description='SL:'
)

# Create output widget to display the plots
output = widgets.Output()

def create_figure(current_rf, current_sl, current_model):
    # Generate config key
    config_key = f"rf{current_rf}_sl{current_sl}"
    
    # Check if config exists in results
    if config_key not in all_results:
        return go.Figure().update_layout(
            annotations=[dict(
                text=f"No data available for configuration: {config_key}",
                showarrow=False,
                xref="paper",
                yref="paper",
                x=0.5,
                y=0.5
            )]
        )
    
    # Get scores for current configuration
    all_center_scores = all_results[config_key]["center_scores"]
    all_right_scores = all_results[config_key]["right_scores"]
    
    # Check if model exists in scores
    if current_model not in all_center_scores or current_model not in all_right_scores:
        return go.Figure().update_layout(
            annotations=[dict(
                text=f"No data available for model {current_model} with configuration {config_key}",
                showarrow=False,
                xref="paper",
                yref="paper",
                x=0.5,
                y=0.5
            )]
        )
    
    # Get data for current model
    mosaic_center = all_center_scores[current_model]
    mosaic_right = all_right_scores[current_model]
    
    # Create a subplot with 2 side-by-side heatmaps
    fig = make_subplots(rows=1, cols=2, 
                        subplot_titles=("Current token", "Token to right of current"),
                        shared_yaxes=True)

    # Add heatmaps to the subplots
    fig.add_trace(
        go.Heatmap(z=mosaic_center, zmin=0, zmax=1, coloraxis="coloraxis"),
        row=1, col=1
    )
    fig.add_trace(
        go.Heatmap(z=mosaic_right, zmin=0, zmax=1, coloraxis="coloraxis"),
        row=1, col=2
    )

    # Update layout
    fig.update_layout(
        title_text=f"Attention Mosaics For Induction on RRTs - RF: {current_rf}, SL: {current_sl}, Model: {current_model.split('/')[-1]}",
        height=500,
        width=1000,
        coloraxis=dict(cmin=0, cmax=1, colorbar=dict(title="Attention Score"))
    )

    # Add axes labels with integer ticks
    fig.update_xaxes(title_text="Head", row=1, col=1, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_xaxes(title_text="Head", row=1, col=2, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_yaxes(title_text="Layer", row=1, col=1, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_yaxes(title_text="Layer", row=1, col=2, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    
    return fig

# Update function for all widget changes
def update_plot(*args):
    with output:
        output.clear_output(wait=True)
        current_model = model_dropdown.value
        current_rf = rf_dropdown.value
        current_sl = sl_dropdown.value
        
        fig = create_figure(current_rf, current_sl, current_model)
        fig.show()

# Register the update function with all dropdowns
model_dropdown.observe(update_plot, names='value')
rf_dropdown.observe(update_plot, names='value')
sl_dropdown.observe(update_plot, names='value')

# Create control panel
controls = widgets.HBox([model_dropdown, rf_dropdown, sl_dropdown])

# Display widgets and initial plot
display(controls)
display(output)

# Show initial plot
with output:
    initial_fig = create_figure(rf_dropdown.value, sl_dropdown.value, model_dropdown.value)
    initial_fig.show()

# Function to save current figure
def save_current_figure():
    current_model = model_dropdown.value
    current_rf = rf_dropdown.value
    current_sl = sl_dropdown.value
    
    with output:
        config_key = f"rf{current_rf}_sl{current_sl}"
        
        if config_key not in all_results:
            print(f"No data available for configuration: {config_key}")
            return
        
        all_center_scores = all_results[config_key]["center_scores"]
        all_right_scores = all_results[config_key]["right_scores"]
        
        if current_model not in all_center_scores or current_model not in all_right_scores:
            print(f"No data available for model {current_model} with configuration {config_key}")
            return
        
        mosaic_center = all_center_scores[current_model]
        mosaic_right = all_right_scores[current_model]
        
        fig = make_subplots(rows=1, cols=2, 
                          subplot_titles=("Current token", "Token to right of current"),
                          shared_yaxes=True)

        fig.add_trace(
            go.Heatmap(z=mosaic_center, zmin=0, zmax=1, coloraxis="coloraxis"),
            row=1, col=1
        )
        fig.add_trace(
            go.Heatmap(z=mosaic_right, zmin=0, zmax=1, coloraxis="coloraxis"),
            row=1, col=2
        )

        fig.update_layout(
            title_text=f"Attention Mosaics For Induction on RRTs - RF: {current_rf}, SL: {current_sl}, Model: {current_model.split('/')[-1]}",
            height=500,
            width=1000,
            coloraxis=dict(cmin=0, cmax=1, colorbar=dict(title="Attention Score"))
        )

        fig.update_xaxes(title_text="Head", row=1, col=1, title_font=dict(size=18), 
                       tickmode='linear', tick0=0, dtick=1, showticklabels=True)
        fig.update_xaxes(title_text="Head", row=1, col=2, title_font=dict(size=18), 
                       tickmode='linear', tick0=0, dtick=1, showticklabels=True)
        fig.update_yaxes(title_text="Layer", row=1, col=1, title_font=dict(size=18), 
                       tickmode='linear', tick0=0, dtick=1, showticklabels=True)
        fig.update_yaxes(title_text="Layer", row=1, col=2, title_font=dict(size=18), 
                       tickmode='linear', tick0=0, dtick=1, showticklabels=True)
        
        # manually convert any numpy arrays to lists
        for trace in fig.data:
            if isinstance(trace.z, np.ndarray):
                trace.z = trace.z.tolist()
            if hasattr(trace, "x") and isinstance(trace.x, np.ndarray):
                trace.x = trace.x.tolist()
            if hasattr(trace, "y") and isinstance(trace.y, np.ndarray):
                trace.y = trace.y.tolist()
        
        os.makedirs("plots/json", exist_ok=True)
        output_png_filename = f"plots/rf{current_rf}_sl{current_sl}_{current_model.split('/')[-1]}"
        output_json_filename = f"plots/json/rf{current_rf}_sl{current_sl}_{current_model.split('/')[-1]}"
        
        fig.write_image(f"{output_png_filename}.png")
        print(f"Saved figure to {output_png_filename}.png")

# Create save button
save_button = widgets.Button(
    description='Save Current Figure',
    button_style='success'
)
save_button.on_click(lambda b: save_current_figure())
display(save_button)

HBox(children=(Dropdown(description='Model:', options=('amazon/chronos-t5-mini', 'amazon/chronos-t5-small', 'a…

Output()

Button(button_style='success', description='Save Current Figure', style=ButtonStyle())