In [1]:
import torch
from chronos import BaseChronosPipeline

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import rrt_utils as rrt
import attn_lens as attn

[2025-05-21 20:13:32,209] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cpu (auto detect)


In [2]:
# define the vocab as all the tokens except the special tokens
vocab = torch.tensor([i for i in range(4096) if i > 1])

batch_size = 100
num_unique_sequences = 2 # number of unique sequences in the batch
repeat_factor = 2 # number of times to repeat the sequences
extension = 0 # after repeating everything, repeat this many sequences again
sub_extension = 2 # some additional tokens in the next sequence
sequence_length = 6 # length of the sequences

tokens = [rrt.generate_random_token_ids(vocab, sequence_length, batch_size=batch_size, include_eos=False) for _ in range(num_unique_sequences)]
token_ids = rrt.stack_sequences(tokens * repeat_factor + tokens[:extension] + [tokens[extension][:,:sub_extension]])
# token_ids = rrt.stack_sequences(tokens * repeat_factor + tokens[:extension])
attention_mask = torch.ones_like(token_ids, dtype=torch.bool)

decoder_input_ids = torch.cat([torch.zeros((batch_size,1), dtype=torch.long), tokens[extension][:,sub_extension:sub_extension+1]], dim=1)

token_ids.shape, token_ids[0,:], decoder_input_ids.shape, decoder_input_ids[0,:]

(torch.Size([100, 27]),
 tensor([1940, 2750,  354, 2932, 2146, 1921, 3803, 2492, 2185, 1380, 2616, 1127,
         1940, 2750,  354, 2932, 2146, 1921, 3803, 2492, 2185, 1380, 2616, 1127,
         1940, 2750,    1]),
 torch.Size([100, 2]),
 tensor([  0, 354]))

In [3]:
model_names = ["amazon/chronos-t5-mini", "amazon/chronos-t5-small", "amazon/chronos-t5-base", "amazon/chronos-t5-large"]

for model_name in model_names:
    print(f"Processing {model_name}...")
    
    pipeline = BaseChronosPipeline.from_pretrained(
        model_name,  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
        device_map="cpu",  # use "cpu" for CPU inference
        torch_dtype=torch.bfloat16,
    )

    if "bolt" in model_name:
        t5_model = pipeline.model
    else:
        t5_model = pipeline.model.model

    output, cross_attn_probs = attn.attn_lens(token_ids, attention_mask, t5_model, max_new_tokens=1, decoder_input_ids=decoder_input_ids)

    t_idx = 1
    s_idx = sequence_length * ((repeat_factor-1)*num_unique_sequences + extension) + sub_extension

    layers, heads = t5_model.config.num_decoder_layers, t5_model.config.num_heads
    mosaic_left = np.zeros((layers, heads)).tolist()
    mosaic_center = np.zeros((layers, heads)).tolist()
    mosaic_right = np.zeros((layers, heads)).tolist()

    for layer in range(layers):
        for head in range(heads):
            # mean over the batch
            # mosaic_left[layer, head] = cross_attn_probs[0][layer][:, head, t_idx, s_idx-1].mean()
            mosaic_center[layer][head] = float(cross_attn_probs[t_idx][layer][:, head, t_idx, s_idx].mean())
            mosaic_right[layer][head] = float(cross_attn_probs[t_idx][layer][:, head, t_idx, s_idx+1].mean())

    # Extract model size from name for the title
    model_size = model_name.split('-')[-1]

    # Create a subplot with 2 side-by-side heatmaps
    fig = make_subplots(rows=1, cols=2, 
                        subplot_titles=("Current token", "Token to right of current"),
                        shared_yaxes=True)

    # Add heatmaps to the subplots
    # fig.add_trace(
    #     go.Heatmap(z=mosaic_left, zmin=0, zmax=1, coloraxis="coloraxis"),
    #     row=1, col=1
    # )
    fig.add_trace(
        go.Heatmap(z=mosaic_center, zmin=0, zmax=1, coloraxis="coloraxis"),
        row=1, col=1
    )
    fig.add_trace(
        go.Heatmap(z=mosaic_right, zmin=0, zmax=1, coloraxis="coloraxis"),
        row=1, col=2
    )

    # Update layout
    fig.update_layout(
        title_text=f"Attention Mosaics For Induction on RRTs for {model_name.split('/')[-1]}",
        height=500,
        width=1000,
        coloraxis=dict(cmin=0, cmax=1, colorbar=dict(title="Attention Score"))
    )

    # Add axes labels with integer ticks
    fig.update_xaxes(title_text="Head", row=1, col=1, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_xaxes(title_text="Head", row=1, col=2, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_yaxes(title_text="Layer", row=1, col=1, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)
    fig.update_yaxes(title_text="Layer", row=1, col=2, title_font=dict(size=18), 
                    tickmode='linear', tick0=0, dtick=1, showticklabels=True)

    # Show the figure
    fig.show()
    
    # Save the figure
    import os
    import plotly.io as pio
    import json

    os.makedirs("plots/json", exist_ok=True)
    output_png_filename = f"plots/{model_name.split('/')[-1]}"
    output_json_filename = f"plots/json/{model_name.split('/')[-1]}"
    fig.write_image(f"{output_png_filename}.png")
    # fig.write_html(f"{output_html_filename}.html")
    # pio.write_json(fig, f"{output_json_filename}.json", engine="json")
    # this returns a JSON string with array-of-arrays
    
    # manually convert any numpy arrays to lists
    for trace in fig.data:
        if isinstance(trace.z, np.ndarray):
            trace.z = trace.z.tolist()
        if hasattr(trace, "x") and isinstance(trace.x, np.ndarray):
            trace.x = trace.x.tolist()
        if hasattr(trace, "y") and isinstance(trace.y, np.ndarray):
            trace.y = trace.y.tolist()

    # now dump with the standard json library
    with open(f"{output_json_filename}.json", "w") as f:
        json.dump(fig.to_plotly_json(), f)

    print(f"Saved figure to {output_png_filename}.png")

Processing amazon/chronos-t5-mini...


Saved figure to plots/chronos-t5-mini.png
Processing amazon/chronos-t5-small...


Saved figure to plots/chronos-t5-small.png
Processing amazon/chronos-t5-base...


Saved figure to plots/chronos-t5-base.png
Processing amazon/chronos-t5-large...


Saved figure to plots/chronos-t5-large.png
