In [1]:
import pandas as pd  # requires: pip install pandas
import torch
from chronos import BaseChronosPipeline
import plotly.express as px
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


[2025-05-09 12:59:20,048] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cpu (auto detect)


In [2]:
pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
    device_map="cpu",  # use "cpu" for CPU inference
    torch_dtype=torch.bfloat16,
)

layer = pipeline.model
layer_name = 'model.decoder.block.0.layer.0'
for name in layer_name.split('.'):
    # print("processing: ", name)
    layer = getattr(layer, name)
layer.SelfAttention.q

Linear(in_features=512, out_features=512, bias=False)

In [3]:
pipeline.model.get_submodule('model.decoder.block.0.layer.0.SelfAttention.q')

Linear(in_features=512, out_features=512, bias=False)

In [23]:
pipeline.model.model.decoder.block[0].layer[0].SelfAttention.n_heads

8

In [35]:
pipeline.model.get_submodule('model.decoder.block.0.layer.1.EncDecAttention')

T5Attention(
  (q): Linear(in_features=512, out_features=512, bias=False)
  (k): Linear(in_features=512, out_features=512, bias=False)
  (v): Linear(in_features=512, out_features=512, bias=False)
  (o): Linear(in_features=512, out_features=512, bias=False)
)

In [4]:
# get activations of layer 'model.encoder.embed_tokens' for the data frame input's token_ids
df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
context = torch.tensor(df["#Passengers"])
context_tensor = pipeline._prepare_and_validate_context(context=context) # does nothing
token_ids, attention_mask, scale = pipeline.tokenizer.context_input_transform(context_tensor)

model = pipeline.model
model.eval();

In [5]:
# Function to get activations of layers for some input on a model
def get_output_activations(layers, input, model, attention_mask=attention_mask, prediction_length=1):
    activations = []
    handles = []

    def hook_fn(module, input, output):
        activations.append(output)

    for layer in layers:
        handle = layer.register_forward_hook(hook_fn)
        handles.append(handle)
    outputs = model(input_ids=input, attention_mask=attention_mask, prediction_length=4)
    for handle in handles:
        handle.remove()
    
    return outputs, activations

def get_input_activations(layers, input, model, attention_mask=attention_mask, prediction_length=1):
    activations = []
    handles = []

    def hook_fn(module, input, output):
        activations.append(input)

    for layer in layers:
        handle = layer.register_forward_hook(hook_fn)
        handles.append(handle)

    # Run the model
    outputs = model(input_ids=input, attention_mask=attention_mask, prediction_length=prediction_length)
    for handle in handles:
        handle.remove()
    return outputs, activations

In [6]:
# get activations of layer 'model.encoder.embed_tokens' for the data frame input's token_ids
# outputs,activations = get_output_activations([model.model.encoder.embed_tokens], token_ids, model, attention_mask=attention_mask, prediction_length=1)

# get input activations of layer 'model.decoder.block' for the data frame input's token_ids
outputs, activations = get_input_activations([pipeline.model.get_submodule('model.decoder.block.0.layer.0.SelfAttention.q')], token_ids, model, attention_mask=attention_mask, prediction_length=4)

In [10]:
token_ids.shape

torch.Size([1, 145])

In [11]:
q_layer = model.get_submodule('model.decoder.block.0.layer.0.SelfAttention.q')

# 3) Prepare a place to record each decoder-input shape
shapes = []
def pre_hook(module, inputs):
    # inputs[0] is the tensor of token IDs fed into embed_tokens
    shapes.append(inputs)

# Register the hook on the decoder’s embed_tokens
# handle = q_layer.register_forward_pre_hook(pre_hook)
handle = model.model.decoder.embed_tokens.register_forward_pre_hook(pre_hook)

# 4) Generate with caching **disabled** and just **1** sequence path
#    (so we see the full prefix each time)
_ = model.model.generate(
    input_ids=token_ids,
    attention_mask=attention_mask,
    max_new_tokens=5,           # e.g. generate 5 steps
    num_return_sequences=1,     # just one path
    do_sample=False,            # deterministic for clarity
    use_cache=False,            # IMPORTANT: force full-prefix decoding
)

# 5) Remove the hook
handle.remove()

# 6) Inspect what we saw
print("Decoder‑input shapes at each step:")
for step, shape in enumerate(shapes):
    print(f"  step {step:>2}: {shape}")

Decoder‑input shapes at each step:
  step  0: (tensor([[2104, 2106, 2113, 2112, 2108, 2115, 2121, 2121, 2115, 2107, 2100, 2106,
         2105, 2110, 2118, 2115, 2110, 2122, 2132, 2132, 2126, 2114, 2104, 2117,
         2120, 2122, 2136, 2128, 2133, 2136, 2146, 2146, 2139, 2128, 2120, 2130,
         2132, 2137, 2143, 2137, 2138, 2155, 2161, 2167, 2151, 2142, 2133, 2143,
         2144, 2144, 2164, 2163, 2160, 2167, 2177, 2181, 2164, 2152, 2137, 2147,
         2148, 2140, 2163, 2159, 2163, 2177, 2196, 2192, 2175, 2160, 2148, 2160,
         2167, 2162, 2179, 2180, 2180, 2202, 2226, 2218, 2201, 2182, 2164, 2184,
         2187, 2184, 2203, 2201, 2204, 2231, 2250, 2246, 2222, 2198, 2181, 2198,
         2202, 2195, 2222, 2218, 2222, 2254, 2275, 2276, 2246, 2218, 2197, 2213,
         2214, 2204, 2225, 2218, 2226, 2261, 2288, 2295, 2246, 2224, 2200, 2213,
         2224, 2215, 2247, 2242, 2253, 2279, 2316, 2321, 2274, 2247, 2225, 2246,
         2252, 2239, 2253, 2273, 2279, 2309, 2352, 2344, 2296,

In [36]:
encoder_outputs = model.model.encoder(token_ids, attention_mask)
encoder_outputs.last_hidden_state.shape

torch.Size([1, 145, 512])

In [9]:
decoder_input_ids = torch.zeros(size=(20,1), dtype=torch.int)
enc_h = encoder_outputs.last_hidden_state.expand(20, -1, -1).contiguous()
decoder_outputs = model.model.decoder(input_ids=decoder_input_ids, encoder_hidden_states=enc_h,
                                      encoder_attention_mask=attention_mask, past_key_values=None, use_cache=False, return_dict=True)

In [10]:
lm_logits = model.model.lm_head(decoder_outputs.last_hidden_state)
# get the logits for the next token
next_token_logits = lm_logits[:, -1, :]
next_token_logits[0], next_token_logits[19]

(tensor([ -80.5000,  -16.7500, -246.0000,  ..., -184.0000, -200.0000,
           45.7500], dtype=torch.bfloat16, grad_fn=<SelectBackward0>),
 tensor([ -80.5000,  -16.7500, -246.0000,  ..., -184.0000, -200.0000,
           45.7500], dtype=torch.bfloat16, grad_fn=<SelectBackward0>))

In [11]:
lm_logits.shape

torch.Size([20, 1, 4096])

In [16]:
# When the model is given some input, write a function which will return the input token_ids that have been fed into the embedding layer
def get_input_token_ids(model, input, attention_mask):
    activations = []
    handles = []

    def hook_fn(module, input, output):
        activations.append(input)

    handle = model.model.shared.register_forward_hook(hook_fn)
    handles.append(handle)

    outputs = model(input_ids=input, attention_mask=attention_mask, prediction_length=4)
    for handle in handles:
        handle.remove()
    
    return activations

In [17]:
get_input_token_ids(model, token_ids, attention_mask)

[(tensor([[2104, 2106, 2113, 2112, 2108, 2115, 2121, 2121, 2115, 2107, 2100, 2106,
           2105, 2110, 2118, 2115, 2110, 2122, 2132, 2132, 2126, 2114, 2104, 2117,
           2120, 2122, 2136, 2128, 2133, 2136, 2146, 2146, 2139, 2128, 2120, 2130,
           2132, 2137, 2143, 2137, 2138, 2155, 2161, 2167, 2151, 2142, 2133, 2143,
           2144, 2144, 2164, 2163, 2160, 2167, 2177, 2181, 2164, 2152, 2137, 2147,
           2148, 2140, 2163, 2159, 2163, 2177, 2196, 2192, 2175, 2160, 2148, 2160,
           2167, 2162, 2179, 2180, 2180, 2202, 2226, 2218, 2201, 2182, 2164, 2184,
           2187, 2184, 2203, 2201, 2204, 2231, 2250, 2246, 2222, 2198, 2181, 2198,
           2202, 2195, 2222, 2218, 2222, 2254, 2275, 2276, 2246, 2218, 2197, 2213,
           2214, 2204, 2225, 2218, 2226, 2261, 2288, 2295, 2246, 2224, 2200, 2213,
           2224, 2215, 2247, 2242, 2253, 2279, 2316, 2321, 2274, 2247, 2225, 2246,
           2252, 2239, 2253, 2273, 2279, 2309, 2352, 2344, 2296, 2273, 2239, 2259,
    