In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd drive/MyDrive/NLP-Project/chronos-forecasting

/content/drive/MyDrive/NLP-Project/chronos-forecasting


In [3]:
import os, sys
sys.path.insert(0, os.path.abspath("/content/drive/MyDrive/NLP-Project/chronos-forecasting/src"))

In [16]:
import math
import sys
import torch
# /content/drive/MyDrive/NLP-Project/chronos-forecasting/
def _import_chronos_bits():
    """
    Adjust these imports if your package/module path is different.
    """
    try:
        from chronos.chronos2.config import Chronos2CoreConfig
        from chronos.chronos2.layers import TimeSelfAttention
        return Chronos2CoreConfig, TimeSelfAttention
    except Exception:
        # fallback: if running from repo root with local package name
        from chronos.chronos2.config import Chronos2CoreConfig
        from chronos.chronos2.layers import TimeSelfAttention
        return Chronos2CoreConfig, TimeSelfAttention

Chronos2CoreConfig, TimeSelfAttention = _import_chronos_bits()

def make_pattern_full_mask(
    pad_mask_2d: torch.Tensor,  # [B, S], 1=keep, 0=pad
    num_heads: int,
    num_output_patches: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Build additive 4D mask [B, H, Q, K] that matches the sparse semantics:
      - context queries (q < future_start): cannot attend to future keys (k >= future_start)
      - future queries (q >= future_start): can attend to all keys
      - padding keys masked everywhere
    """
    B, S = pad_mask_2d.shape
    finfo_min = torch.finfo(dtype).min

    future_start = S - num_output_patches

    # base key padding mask: [B, 1, 1, S]
    keep = pad_mask_2d.to(torch.bool)
    base = torch.zeros((B, 1, 1, S), device=pad_mask_2d.device, dtype=dtype)
    base = base.masked_fill(~keep[:, None, None, :], finfo_min)

    # pattern mask: [1, 1, Q, K]
    pattern = torch.zeros((1, 1, S, S), device=pad_mask_2d.device, dtype=dtype)
    if future_start > 0:
        pattern[:, :, :future_start, future_start:] = finfo_min

    # expand to heads
    mask = (base + pattern).expand(B, num_heads, S, S).contiguous()
    return mask

@torch.no_grad()
def run_all():
    torch.manual_seed(0)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32  # keep float32 for strict comparisons

    # Small-ish config for fast testing
    d_model = 64
    d_kv = 16
    n_heads = 4

    cfg = Chronos2CoreConfig(
        d_model=d_model,
        d_kv=d_kv,
        num_heads=n_heads,
        num_layers=1,
        dropout_rate=0.0,
        attn_implementation="sdpa",
    )

    B = 2
    S = 64
    num_output_patches = 8
    future_start = S - num_output_patches

    # Place REG just before future (matches "context + REG + future")
    reg_token_index = future_start - 1

    # Inputs
    x = torch.randn(B, S, d_model, device=device, dtype=dtype)
    pos = torch.arange(S, device=device).unsqueeze(0).repeat(B, 1)
    pad = torch.ones(B, S, device=device, dtype=dtype)
    # Add a little padding at the end
    pad[:, -3:] = 0.0

    # ---------- 1) Sparse mode: output_attentions must raise ----------
    cfg.time_attention_type = "windowed_future_global"
    cfg.time_local_radius = 4
    cfg.time_attention_chunk_size = 8
    cfg.time_reg_is_global = False

    attn = TimeSelfAttention(cfg).to(device).eval()

    try:
        _ = attn(x, attention_mask=pad, position_ids=pos,
                 num_output_patches=num_output_patches,
                 reg_token_index=reg_token_index,
                 output_attentions=True)
        raise RuntimeError("Expected sparse mode to raise with output_attentions=True, but it did not.")
    except ValueError:
        pass

    # ---------- 2) No context->future leakage ----------
    out1 = attn(x, attention_mask=pad, position_ids=pos,
               num_output_patches=num_output_patches,
               reg_token_index=reg_token_index,
               output_attentions=False).hidden_states

    x2 = x.clone()
    x2[:, future_start:, :] += torch.randn_like(x2[:, future_start:, :]) * 3.0  # perturb only future tokens

    out2 = attn(x2, attention_mask=pad, position_ids=pos,
               num_output_patches=num_output_patches,
               reg_token_index=reg_token_index,
               output_attentions=False).hidden_states

    ctx_diff = (out1[:, :future_start, :] - out2[:, :future_start, :]).abs().max().item()
    fut_diff = (out1[:, future_start:, :] - out2[:, future_start:, :]).abs().max().item()
    assert ctx_diff < 1e-5, f"Context changed when only future tokens changed (leakage). max diff={ctx_diff}"
    assert fut_diff > 1e-4, f"Future did not change when future tokens changed (unexpected). max diff={fut_diff}"

    # ---------- 3) Window locality for context queries ----------
    cfg.time_local_radius = 2
    cfg.time_attention_chunk_size = 7
    attn = TimeSelfAttention(cfg).to(device).eval()

    outA = attn(x, attention_mask=pad, position_ids=pos,
                num_output_patches=num_output_patches,
                reg_token_index=reg_token_index).hidden_states

    i = 10
    j_far = i + 10  # outside radius=2
    x3 = x.clone()
    x3[:, j_far, :] += torch.randn_like(x3[:, j_far, :]) * 5.0

    outB = attn(x3, attention_mask=pad, position_ids=pos,
                num_output_patches=num_output_patches,
                reg_token_index=reg_token_index).hidden_states

    local_diff = (outA[:, i, :] - outB[:, i, :]).abs().max().item()
    assert local_diff < 1e-5, f"Locality broken: position {i} changed due to far token {j_far}. diff={local_diff}"

    # ---------- 4) Future queries are global ----------
    # Change the earliest context token, future outputs should change.
    x4 = x.clone()
    x4[:, 0, :] += torch.randn_like(x4[:, 0, :]) * 5.0

    outC = attn(x4, attention_mask=pad, position_ids=pos,
                num_output_patches=num_output_patches,
                reg_token_index=reg_token_index).hidden_states

    fut_global_diff = (outA[:, future_start:, :] - outC[:, future_start:, :]).abs().max().item()
    assert fut_global_diff > 1e-4, f"Future outputs did not react to far context change (not global?). diff={fut_global_diff}"

    # ---------- 5) Padding respected ----------
    # Modify padded tokens only; earlier outputs should stay unchanged.
    x5 = x.clone()
    x5[:, -3:, :] += torch.randn_like(x5[:, -3:, :]) * 10.0  # these are padded keys
    outD = attn(x5, attention_mask=pad, position_ids=pos,
                num_output_patches=num_output_patches,
                reg_token_index=reg_token_index).hidden_states

    pad_respect_diff = (outA[:, :-3, :] - outD[:, :-3, :]).abs().max().item()
    assert pad_respect_diff < 1e-5, f"Padding not respected: non-pad outputs changed. diff={pad_respect_diff}"

    # ---------- 6) Chunk-size invariance ----------
    cfg.time_attention_chunk_size = 1
    attn_cs = TimeSelfAttention(cfg).to(device).eval()
    o1 = attn_cs(
        x, attention_mask=pad, position_ids=pos,
        num_output_patches=num_output_patches,
        reg_token_index=reg_token_index
    ).hidden_states

    # change chunk size on the same module/config
    attn_cs.config.time_attention_chunk_size = 16
    o2 = attn_cs(
        x, attention_mask=pad, position_ids=pos,
        num_output_patches=num_output_patches,
        reg_token_index=reg_token_index
    ).hidden_states

    chunk_diff = (o1 - o2).abs().max().item()
    assert chunk_diff < 1e-5, f"Chunk size changed outputs too much. diff={chunk_diff}"


    # ---------- 7) Optional dense equivalence with large radius ----------
    # Build a full 4D mask with the same semantics and compare.
    cfg_full = Chronos2CoreConfig(
        d_model=d_model, d_kv=d_kv, num_heads=n_heads, num_layers=1,
        dropout_rate=0.0, attn_implementation="sdpa",
    )
    cfg_full.time_attention_type = "full"

    cfg_sparse = Chronos2CoreConfig(
        d_model=d_model, d_kv=d_kv, num_heads=n_heads, num_layers=1,
        dropout_rate=0.0, attn_implementation="sdpa",
    )
    cfg_sparse.time_attention_type = "windowed_future_global"
    cfg_sparse.time_local_radius = S  # big enough to include all context keys
    cfg_sparse.time_attention_chunk_size = 8
    cfg_sparse.time_reg_is_global = False

    dense = TimeSelfAttention(cfg_full).to(device).eval()
    sparse = TimeSelfAttention(cfg_sparse).to(device).eval()
    sparse.load_state_dict(dense.state_dict())

    full_mask = make_pattern_full_mask(pad, n_heads, num_output_patches, dtype=dtype)
    od = dense(x, attention_mask=full_mask, position_ids=pos,
               num_output_patches=num_output_patches,
               reg_token_index=reg_token_index).hidden_states
    os = sparse(x, attention_mask=pad, position_ids=pos,
                num_output_patches=num_output_patches,
                reg_token_index=reg_token_index).hidden_states

    eq_diff = (od - os).abs().max().item()
    assert eq_diff < 1e-4, f"Dense vs sparse (large radius) mismatch. diff={eq_diff}"

    print("✅ All sparse time-attention sanity tests passed.")

def grad_test():
    # quick backward test
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32
    torch.manual_seed(0)

    cfg = Chronos2CoreConfig(d_model=64, d_kv=16, num_heads=4, num_layers=1, dropout_rate=0.0, attn_implementation="sdpa")
    cfg.time_attention_type = "windowed_future_global"
    cfg.time_local_radius = 2
    cfg.time_attention_chunk_size = 8
    cfg.time_reg_is_global = False

    attn = TimeSelfAttention(cfg).to(device).train()

    B, S = 2, 64
    num_output_patches = 8
    x = torch.randn(B, S, 64, device=device, dtype=dtype, requires_grad=True)
    pos = torch.arange(S, device=device).unsqueeze(0).repeat(B, 1)
    pad = torch.ones(B, S, device=device, dtype=dtype)

    out = attn(x, attention_mask=pad, position_ids=pos,
              num_output_patches=num_output_patches,
              reg_token_index=(S - num_output_patches - 1)).hidden_states
    loss = out.mean()
    loss.backward()

    # Ensure at least one parameter has gradients
    grads = [p.grad for p in attn.parameters() if p.requires_grad]
    assert any(g is not None and torch.isfinite(g).all() for g in grads), "No valid gradients found."
    print("✅ Backward/grad test passed.")

if __name__ == "__main__":
    run_all()
    grad_test()


✅ All sparse time-attention sanity tests passed.
✅ Backward/grad test passed.


In [18]:
import math
import sys
import torch
# /content/drive/MyDrive/NLP-Project/chronos-forecasting/
def _import_chronos_bits():
    """
    Adjust these imports if your package/module path is different.
    """
    try:
        from chronos.chronos2.config import Chronos2CoreConfig
        from chronos.chronos2.layers import TimeSelfAttention
        return Chronos2CoreConfig, TimeSelfAttention
    except Exception:
        # fallback: if running from repo root with local package name
        from chronos.chronos2.config import Chronos2CoreConfig
        from chronos.chronos2.layers import TimeSelfAttention
        return Chronos2CoreConfig, TimeSelfAttention

Chronos2CoreConfig, TimeSelfAttention = _import_chronos_bits()

In [7]:
import torch

def _import_chronos():
    """
    Adjust these imports if your paths differ.
    The fallback imports assume you're running from the module directory.
    """
    try:
        from chronos.chronos2.model import Chronos2Encoder
        from chronos.chronos2.config import Chronos2CoreConfig
        return Chronos2Encoder, Chronos2CoreConfig
    except Exception:
        from model import Chronos2Encoder
        from config import Chronos2CoreConfig
        return Chronos2Encoder, Chronos2CoreConfig


Chronos2Encoder, Chronos2CoreConfig = _import_chronos()


@torch.no_grad()
def test_sparse_encoder_does_not_build_4d_mask():
    """
    Ensures Chronos2Encoder.forward() does NOT call the dense mask builder
    when time_attention_type == "windowed_future_global".
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32

    cfg = Chronos2CoreConfig(
        d_model=64,
        d_kv=16,
        num_heads=4,
        num_layers=2,
        dropout_rate=0.0,
        attn_implementation="sdpa",
    )
    cfg.time_attention_type = "windowed_future_global"
    cfg.time_local_radius = 4
    cfg.time_attention_chunk_size = 8
    cfg.time_reg_is_global = False

    enc = Chronos2Encoder(cfg).to(device).eval()

    # Monkeypatch the dense mask builder: if it's called, we fail.
    called = {"flag": False}
    orig = enc._expand_and_invert_time_attention_mask

    def _trap(*args, **kwargs):
        called["flag"] = True
        raise AssertionError(
            "_expand_and_invert_time_attention_mask was called in sparse mode "
            "(should not happen)."
        )

    enc._expand_and_invert_time_attention_mask = _trap  # patch

    B, S = 2, 64
    H = 8  # num_output_patches
    x = torch.randn(B, S, cfg.d_model, device=device, dtype=dtype)
    pad = torch.ones(B, S, device=device, dtype=dtype)  # 2D padding mask
    pos = torch.arange(S, device=device).unsqueeze(0).repeat(B, 1)
    group_ids = torch.arange(B, device=device, dtype=torch.long)

    # group_time_mask shape depends on your implementation; in your encoder it’s used by GroupSelfAttention
    # This is the common shape used in your repo: [B, S, S] boolean-ish mask.
    group_time_mask = torch.ones(B, S, S, device=device, dtype=torch.bool)

    _ = enc(
        inputs_embeds=x,
        group_ids=group_ids,
        attention_mask=pad,  # MUST be 2D in sparse mode
        position_ids=pos,
        num_output_patches=H,
        reg_token_index=None,
        output_attentions=False,
    )

    assert not called["flag"], "Dense mask builder was called in sparse mode."
    enc._expand_and_invert_time_attention_mask = orig  # restore (optional)
    print("✅ Test 1 passed: sparse encoder did NOT build a 4D time mask.")


@torch.no_grad()
def test_reg_global_no_context_to_future_leak():
    """
    With REG enabled and time_reg_is_global=True, verify:
      - modifying ONLY future tokens does not change context outputs (including REG)
    This tests that context queries still cannot attend to future keys.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32

    cfg = Chronos2CoreConfig(
        d_model=64,
        d_kv=16,
        num_heads=4,
        num_layers=2,
        dropout_rate=0.0,
        attn_implementation="sdpa",
        use_reg_token=True,  # IMPORTANT
    )
    cfg.time_attention_type = "windowed_future_global"
    cfg.time_local_radius = 4
    cfg.time_attention_chunk_size = 8
    cfg.time_reg_is_global = True  # IMPORTANT

    enc = Chronos2Encoder(cfg).to(device).eval()

    B = 2
    S_ctx = 48
    H_fut = 8
    # Sequence layout: [context tokens] + [REG] + [future tokens]
    S = S_ctx + 1 + H_fut
    reg_token_index = S_ctx
    future_start = reg_token_index + 1

    x = torch.randn(B, S, cfg.d_model, device=device, dtype=dtype)
    pad = torch.ones(B, S, device=device, dtype=dtype)
    pos = torch.arange(S, device=device).unsqueeze(0).repeat(B, 1)
    group_ids = torch.arange(B, device=device, dtype=torch.long)
    group_time_mask = torch.ones(B, S, S, device=device, dtype=torch.bool)

    out1 = enc(
        inputs_embeds=x,
        group_ids=group_ids,
        attention_mask=pad,
        position_ids=pos,
        num_output_patches=H_fut,
        reg_token_index=reg_token_index,
        output_attentions=False,
    ).last_hidden_state  # adjust field name if yours differs

    x2 = x.clone()
    x2[:, future_start:, :] += torch.randn_like(x2[:, future_start:, :]) * 3.0  # perturb ONLY future tokens

    out2 = enc(
        inputs_embeds=x2,
        group_ids=group_ids,
        attention_mask=pad,
        position_ids=pos,
        num_output_patches=H_fut,
        reg_token_index=reg_token_index,
        output_attentions=False,
    ).last_hidden_state

    # Context includes [0 .. reg_token_index] (context + REG)
    ctx1 = out1[:, :future_start, :]
    ctx2 = out2[:, :future_start, :]
    diff = (ctx1 - ctx2).abs().max().item()

    assert diff < 1e-5, f"Context/REG changed when only future tokens changed (leak). diff={diff}"
    print("✅ Test 2 passed: REG-global still has no context→future leakage.")


if __name__ == "__main__":
    test_sparse_encoder_does_not_build_4d_mask()
    test_reg_global_no_context_to_future_leak()


✅ Test 1 passed: sparse encoder did NOT build a 4D time mask.
✅ Test 2 passed: REG-global still has no context→future leakage.


In [14]:
import argparse
import time
import torch


def _import_chronos():
    """
    Adjust these imports if your paths differ.
    Fallback assumes running from the module folder.
    """
    try:
        from chronos.chronos2.model import Chronos2Encoder
        from chronos.chronos2.config import Chronos2CoreConfig
        return Chronos2Encoder, Chronos2CoreConfig
    except Exception:
        from chronos.chronos2.model import Chronos2Encoder
        from chronos.chronos2.config import Chronos2CoreConfig
        return Chronos2Encoder, Chronos2CoreConfig


Chronos2Encoder, Chronos2CoreConfig = _import_chronos()


def _cuda_mem_str(x: int) -> str:
    return f"{x/1024**2:.1f} MiB"


@torch.no_grad()
def long_context_forward_benchmark(seq_len: int, num_output_patches: int, radius: int, chunk: int):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32

    cfg = Chronos2CoreConfig(
        d_model=512,          # closer to realistic
        d_kv=64,
        num_heads=8,
        num_layers=4,         # keep modest for sanity test
        dropout_rate=0.0,
        attn_implementation="sdpa",
        use_reg_token=False,
    )
    cfg.time_attention_type = "windowed_future_global"
    cfg.time_local_radius = radius
    cfg.time_attention_chunk_size = chunk
    cfg.time_reg_is_global = False

    enc = Chronos2Encoder(cfg).to(device).eval()

    B = 1
    x = torch.randn(B, seq_len, cfg.d_model, device=device, dtype=dtype)
    pad = torch.ones(B, seq_len, device=device, dtype=dtype)
    pos = torch.arange(seq_len, device=device).unsqueeze(0)
    group_ids = torch.zeros(B, device=device, dtype=torch.long)

    # Warmup
    _ = enc(
        inputs_embeds=x,
        group_ids=group_ids,
        attention_mask=pad,
        position_ids=pos,
        num_output_patches=num_output_patches,
        reg_token_index=None,
        output_attentions=False,
    )

    if device == "cuda":
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()

    t0 = time.perf_counter()
    out = enc(
        inputs_embeds=x,
        group_ids=group_ids,
        attention_mask=pad,
        position_ids=pos,
        num_output_patches=num_output_patches,
        reg_token_index=None,
        output_attentions=False,
    )
    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()

    # Try to access output tensor robustly
    y = getattr(out, "last_hidden_state", None)
    if y is None:
        y = getattr(out, "hidden_states", None)
    if y is None:
        # fallback: common naming in some codebases
        y = getattr(out, "final_hidden_state", None)
    if y is None:
        raise RuntimeError("Could not find encoder output tensor on Chronos2EncoderOutput")

    peak = None
    if device == "cuda":
        peak = torch.cuda.max_memory_allocated()

    print(f"✅ Forward OK on device={device}, dtype={dtype}")
    print(f"   seq_len={seq_len}, num_output_patches={num_output_patches}, radius={radius}, chunk={chunk}")
    print(f"   output shape: {tuple(y.shape)}")
    print(f"   time: {(t1 - t0):.3f}s")
    if peak is not None:
        print(f"   peak CUDA allocated: {_cuda_mem_str(peak)}")


def long_context_backward_benchmark(seq_len: int, num_output_patches: int, radius: int, chunk: int):
    """
    Backward test is the real memory killer; run it once at batch=1.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32

    cfg = Chronos2CoreConfig(
        d_model=512,
        d_kv=64,
        num_heads=8,
        num_layers=4,
        dropout_rate=0.0,
        attn_implementation="sdpa",
        use_reg_token=False,
    )
    cfg.time_attention_type = "windowed_future_global"
    cfg.time_local_radius = radius
    cfg.time_attention_chunk_size = chunk
    cfg.time_reg_is_global = False

    enc = Chronos2Encoder(cfg).to(device).train()

    B = 1
    x = torch.randn(B, seq_len, cfg.d_model, device=device, dtype=dtype, requires_grad=True)
    pad = torch.ones(B, seq_len, device=device, dtype=dtype)
    pos = torch.arange(seq_len, device=device).unsqueeze(0)
    group_ids = torch.zeros(B, device=device, dtype=torch.long)

    if device == "cuda":
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()

    t0 = time.perf_counter()
    out = enc(
        inputs_embeds=x,
        group_ids=group_ids,
        attention_mask=pad,
        position_ids=pos,
        num_output_patches=num_output_patches,
        reg_token_index=None,
        output_attentions=False,
    )
    y = getattr(out, "last_hidden_state", None) or getattr(out, "hidden_states", None) or getattr(out, "final_hidden_state", None)
    if y is None:
        raise RuntimeError("Could not find encoder output tensor on Chronos2EncoderOutput")

    loss = y.float().mean()
    loss.backward()

    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()

    peak = None
    if device == "cuda":
        peak = torch.cuda.max_memory_allocated()

    print(f"✅ Backward OK on device={device}, dtype={dtype}")
    print(f"   seq_len={seq_len}, num_output_patches={num_output_patches}, radius={radius}, chunk={chunk}")
    print(f"   time (fwd+bwd): {(t1 - t0):.3f}s")
    if peak is not None:
        print(f"   peak CUDA allocated: {_cuda_mem_str(peak)}")


def config_propagation_sanity(training_step_fn, model, expected: str = "windowed_future_global", steps: int = 3):
    """
    Generic config propagation check: call this inside any training loop.
    It asserts the config doesn't silently revert to full.
    """
    for step in range(steps):
        tat = None
        # Common locations depending on how you store it
        if hasattr(model, "chronos_config") and hasattr(model.chronos_config, "time_attention_type"):
            tat = model.chronos_config.time_attention_type
        elif hasattr(model, "config") and hasattr(model.config, "chronos_config"):
            # HF-style dict storage
            cc = model.config.chronos_config
            if isinstance(cc, dict):
                tat = cc.get("time_attention_type", None)
            else:
                tat = getattr(cc, "time_attention_type", None)

        print(f"[step {step}] time_attention_type = {tat}")
        assert tat == expected, f"time_attention_type changed/reverted (got {tat}, expected {expected})"

        training_step_fn(step)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--seq", type=int, default=8192, help="Sequence length (e.g., 8192 or 16384)")
    ap.add_argument("--out", type=int, default=64, help="num_output_patches (future tokens)")
    ap.add_argument("--radius", type=int, default=128, help="time_local_radius")
    ap.add_argument("--chunk", type=int, default=32, help="time_attention_chunk_size")
    ap.add_argument("--no-backward", action="store_true", help="Skip backward test")
    args = ap.parse_args()

    long_context_forward_benchmark(args.seq, args.out, args.radius, args.chunk)
    if not args.no_backward:
        long_context_backward_benchmark(args.seq, args.out, args.radius, args.chunk)

    print("\nConfig propagation sanity: integrate this into your training loop.")
    print("Example usage is printed below.\n")

    print(
        "Example:\n"
        "  # inside your training script\n"
        "  def one_step(step):\n"
        "      ...  # run one optimizer step\n"
        "  config_propagation_sanity(one_step, model, expected='windowed_future_global', steps=3)\n"
    )


if __name__ == "__main__":
    main()


usage: colab_kernel_launcher.py [-h] [--seq SEQ] [--out OUT] [--radius RADIUS]
                                [--chunk CHUNK] [--no-backward]
colab_kernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-a1294821-efcf-45f8-b2e7-006d0031e89d.json


SystemExit: 2

In [None]:
import inspect, chronos.chronos2.layers as L
print("torch.gather(k_ctx) in live code?",
      "torch.gather(k_ctx" in inspect.getsource(L.TimeSelfAttention._windowed_future_global_attention))


In [None]:
import importlib
import chronos.chronos2.layers as L
importlib.reload(L)

# then re-import the class you test
from chronos.chronos2.layers import TimeSelfAttention

In [4]:
!python test.py --seq 8192 --out 64

2025-12-15 10:15:44.381607: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765793744.401036    1670 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765793744.406919    1670 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765793744.421897    1670 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765793744.421923    1670 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765793744.421927    1670 computation_placer.cc:177] computation placer alr