# Chronos-2 -> MASE Graph


In [None]:
import importlib.util
import shutil
import subprocess
import sys

REQUIRED_PIP_PACKAGES = {
    "chronos": "chronos-forecasting",
    "graphviz": "graphviz",
}

for module_name, package_name in REQUIRED_PIP_PACKAGES.items():
    if importlib.util.find_spec(module_name) is None:
        print(f"Installing {package_name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
    else:
        print(f"{package_name} already installed.")


In [1]:
from pathlib import Path
import shutil
import sys

import torch

SRC_DIR = Path.cwd() / "src"
if SRC_DIR.exists() and str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

from chop.ir.graph import MaseGraph
from chop.passes.graph.analysis.init_metadata import init_metadata_analysis_pass


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
MODEL_ID_CHOICES = ["amazon/chronos-2"]
LOADER_CHOICES = ["chronos", "transformers"]
RUN_MODE_CHOICES = ["export_only", "export_and_draw", "draw_only"]

MODEL_ID = "amazon/chronos-2"
LOADER = "transformers"
RUN_MODE = "export_and_draw"

OUTPUT_DIR = Path("artifacts")
GRAPH_NAME = "chronos2_mase_graph"
DEVICE = "cpu"
TRUST_REMOTE_CODE = False
HF_INPUT_NAMES = None  # For chronos models, None auto-resolves to ["context"].

assert MODEL_ID in MODEL_ID_CHOICES, f"MODEL_ID must be one of {MODEL_ID_CHOICES}"
assert LOADER in LOADER_CHOICES, f"LOADER must be one of {LOADER_CHOICES}"
assert RUN_MODE in RUN_MODE_CHOICES, f"RUN_MODE must be one of {RUN_MODE_CHOICES}"

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR


PosixPath('artifacts')

In [6]:
def _load_with_chronos(model_id: str, trust_remote_code: bool):
    import chronos  

    pipeline_cls = getattr(chronos, "Chronos2Pipeline", None) or getattr(
        chronos, "ChronosPipeline", None
    )
    if pipeline_cls is None:
        raise RuntimeError("`chronos` package does not expose Chronos2Pipeline/ChronosPipeline.")

    kwargs = {"trust_remote_code": True} if trust_remote_code else {}
    try:
        pipeline = pipeline_cls.from_pretrained(model_id, **kwargs)
    except TypeError:
        pipeline = pipeline_cls.from_pretrained(model_id)

    for attr in ("model", "module", "hf_model"):
        value = getattr(pipeline, attr, None)
        if isinstance(value, torch.nn.Module):
            return value, f"chronos::{pipeline_cls.__name__}.{attr}"

    if isinstance(pipeline, torch.nn.Module):
        return pipeline, f"chronos::{pipeline_cls.__name__}"

    raise RuntimeError("Could not extract a torch.nn.Module from the Chronos pipeline.")


def _load_with_transformers(model_id: str, trust_remote_code: bool):
    from transformers import AutoModel, AutoModelForSeq2SeqLM

    kwargs = {"trust_remote_code": True} if trust_remote_code else {}
    errors = []

    for loader in (AutoModelForSeq2SeqLM, AutoModel):
        try:
            model = loader.from_pretrained(model_id, **kwargs)
            return model, f"transformers::{loader.__name__}"
        except Exception as exc:
            errors.append(f"{loader.__name__}: {exc}")

    raise RuntimeError(" | ".join(errors))


def load_chronos_model(loader: str, model_id: str, trust_remote_code: bool):
    errors = []

    if loader in ("chronos"):
        try:
            return _load_with_chronos(model_id=model_id, trust_remote_code=trust_remote_code)
        except Exception as exc:
            errors.append(f"chronos loader failed: {exc}")
            if loader == "chronos":
                raise

    if loader in ("transformers"):
        try:
            return _load_with_transformers(model_id=model_id, trust_remote_code=trust_remote_code)
        except Exception as exc:
            errors.append(f"transformers loader failed: {exc}")
            if loader == "transformers":
                raise

    raise RuntimeError(" ; ".join(errors))


In [7]:
model, source = load_chronos_model(
    loader=LOADER,
    model_id=MODEL_ID,
    trust_remote_code=TRUST_REMOTE_CODE,
)
print(f"Loaded model using {source}")

model.eval()
if DEVICE:
    if DEVICE.startswith("cuda") and not torch.cuda.is_available():
        raise RuntimeError("CUDA requested but torch.cuda.is_available() is False.")
    model = model.to(DEVICE)

effective_hf_input_names = HF_INPUT_NAMES
if effective_hf_input_names is None and source.startswith("chronos::"):
    effective_hf_input_names = ["context"]
print(f"Using hf_input_names={effective_hf_input_names}")

mg = MaseGraph(model=model, hf_input_names=effective_hf_input_names)
mg, _ = init_metadata_analysis_pass(mg)
node_count = sum(1 for _ in mg.nodes)
print(f"Constructed MaseGraph with {node_count} FX nodes")

base_path = OUTPUT_DIR / GRAPH_NAME

if RUN_MODE in ("export_only", "export_and_draw"):
    mg.export(str(base_path))
    print(f"Exported: {base_path}.pt and {base_path}.mz")

if RUN_MODE in ("draw_only", "export_and_draw"):
    svg_path = OUTPUT_DIR / f"{GRAPH_NAME}.svg"
    if shutil.which("dot") is None:
        print("Skipped SVG draw: Graphviz binary `dot` not found in PATH.")
    else:
        try:
            mg.draw(str(svg_path))
            print(f"Rendered: {svg_path}")
        except (ImportError, OSError, FileNotFoundError) as exc:
            print(f"Skipped SVG draw: {exc}")


Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at amazon/chronos-2 and are newly initialized: ['decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.0.SelfAttention.q.we

Loaded model using transformers::AutoModelForSeq2SeqLM
Using hf_input_names=None
Constructed MaseGraph with 2131 FX nodes


[32mINFO    [0m [34mExporting GraphModule to artifacts/chronos2_mase_graph.pt[0m
[32mINFO    [0m [34mSaving full model format[0m
[32mINFO    [0m [34mExporting MaseMetadata to artifacts/chronos2_mase_graph.mz[0m


Exported: artifacts/chronos2_mase_graph.pt and artifacts/chronos2_mase_graph.mz
Rendered: artifacts/chronos2_mase_graph.svg


If we try to run with chronos we will get a not implemented error currently, when using transformers which is how I generated the graph it likely resolved to a supported class, think it did this:

- The checkpoint (amazon/chronos-2) contains weights + config.
- transformers loaded those weights into a generic Hugging Face model class it knows how to instantiate and trace (maybe a T5 class).
- MASE then traced that generic HF model class.

**Sooo** therefore we need to now implement support for mase --> not sure how to do that yet