# MegaContext Research Console

Interactively prepare datasets, train GistNet with PyTorch Lightning, and capture experiment artifacts. Key docs: [GistNet](https://brandf.github.io/MegaContext/architecture/components/GistNet), [GistNet Training](https://brandf.github.io/MegaContext/architecture/components/GistNet%20Training), [Telemetry](https://brandf.github.io/MegaContext/ops/Telemetry), [Alternating Optimization](https://brandf.github.io/MegaContext/ops/Alternating%20Optimization).


## Quick Start

When running in Google Colab, execute the bootstrap cell below to clone the repo and install dependencies in the current runtime. Local Jupyter environments can skip it.


In [None]:
import importlib
import os
import sys
from pathlib import Path
from functools import lru_cache

COLAB = 'google.colab' in sys.modules
if COLAB:
    repo_url = 'https://github.com/brandf/MegaContext.git'
    workspace = Path('/content/MegaContext')
    if not workspace.exists():
        !git clone $repo_url $workspace
    else:
        !git -C $workspace pull --ff-only
    %cd /content/MegaContext
    print('Python executable:', sys.executable)
    cuda_index = os.environ.get('PYTORCH_CUDA_INDEX', 'https://download.pytorch.org/whl/cu121')
    %pip install --upgrade pip setuptools wheel
    %pip install torch torchvision torchaudio --index-url $cuda_index
    %pip install -e .[dev]
    %pip install lightning
    %pip install ipywidgets==7.7.1
    src_path = workspace / 'src'
    if str(src_path) not in sys.path:
        sys.path.append(str(src_path))
    importlib.invalidate_caches()
    try:
        import torch
        import megacontext  # noqa: F401
        import lightning  # noqa: F401
    except Exception as exc:
        print('Import check failed:', exc)
        raise
    else:
        print('Torch version:', torch.__version__)
        print('CUDA available:', torch.cuda.is_available())
        print('Colab environment ready.')
else:
    print('Colab bootstrap skipped (not running in google.colab).')

try:
    import ipywidgets as widgets  # type: ignore
except ImportError:  # pragma: no cover
    widgets = None

@lru_cache(maxsize=1)
def ensure_widgets_ready() -> bool:
    'Enable widget infrastructure when available.'
    if widgets is None:
        return False
    if COLAB:
        try:
            from google.colab import output  # type: ignore
        except ImportError:  # pragma: no cover
            return True
        output.enable_custom_widget_manager()
    return True


## 0. Environment Snapshot

Verify your runtime (GPU, dependencies, disk space) before launching long jobs.


In [None]:
from pathlib import Path

import yaml
from IPython.display import Markdown, display

from megacontext.notebook import (
    MetricsTracker,
    build_logger,
    collect_environment_report,
    format_config_markdown,
    format_dataset_summary,
    format_training_config,
    format_training_summary,
    render_environment_report,
    save_experiment_summary,
)

ENV_REPORT = collect_environment_report()
display(Markdown(render_environment_report(ENV_REPORT)))


In [None]:
import os
from IPython.display import display

try:
    import ipywidgets as widgets  # type: ignore
except ImportError:
    widgets = None

widgets_ready = False
if widgets is not None:
    try:
        widgets_ready = ensure_widgets_ready()
    except NameError:
        widgets_ready = True

if widgets_ready:
    wandb_input = widgets.Text(
        value=os.environ.get('WANDB_API_KEY', ''),
        description='W&B API key:',
        layout=widgets.Layout(width='70%'),
        placeholder='Optional: paste your Weights & Biases API key'
    )
    hf_input = widgets.Text(
        value=os.environ.get('HUGGINGFACE_HUB_TOKEN', ''),
        description='HF token:',
        layout=widgets.Layout(width='70%'),
        placeholder='Optional: paste your Hugging Face token'
    )

    def _sync_wandb(change):
        os.environ['WANDB_API_KEY'] = change['new'].strip()

    def _sync_hf(change):
        os.environ['HUGGINGFACE_HUB_TOKEN'] = change['new'].strip()

    wandb_input.observe(_sync_wandb, names='value')
    hf_input.observe(_sync_hf, names='value')

    display(widgets.VBox([wandb_input, hf_input]))
else:
    print('Set WANDB_API_KEY and HUGGINGFACE_HUB_TOKEN manually if required.')


In [ ]:
from pathlib import Path

import pandas as pd
from IPython.display import Markdown, display

try:
    import ipywidgets as widgets  # type: ignore
except ImportError:
    widgets = None

CONFIG_ROOT = Path('configs')
AVAILABLE_CONFIGS = sorted(CONFIG_ROOT.glob('*.yaml'))
if not AVAILABLE_CONFIGS:
    raise RuntimeError('No experiment configs found under `configs/`.')

display(pd.DataFrame({'name': [cfg.stem for cfg in AVAILABLE_CONFIGS], 'path': [str(cfg) for cfg in AVAILABLE_CONFIGS]}))

if 'EXPERIMENT_CONFIG' in globals():
    current_config = Path(EXPERIMENT_CONFIG)
    if current_config in AVAILABLE_CONFIGS:
        default_config = current_config
    else:
        default_config = AVAILABLE_CONFIGS[0]
else:
    default_config = AVAILABLE_CONFIGS[0]

EXPERIMENT_CONFIG = default_config

widgets_ready = False
if widgets is not None:
    try:
        widgets_ready = ensure_widgets_ready()
    except NameError:
        widgets_ready = True

if widgets_ready:
    config_dropdown = widgets.Dropdown(
        options=[(cfg.stem, str(cfg)) for cfg in AVAILABLE_CONFIGS],
        value=str(default_config),
        description='Config:',
        layout=widgets.Layout(width='60%'),
    )
    message_box = widgets.Output()

    def _select_config(change):
        if change['name'] != 'value':
            return
        globals()['EXPERIMENT_CONFIG'] = Path(change['new'])
        with message_box:
            message_box.clear_output()
            display(Markdown(
                f"Selected **{EXPERIMENT_CONFIG.stem}** (`{EXPERIMENT_CONFIG}`)"
            ))

    config_dropdown.observe(_select_config, names='value')
    display(widgets.VBox([config_dropdown, message_box]))
    with message_box:
        display(Markdown(
            f"Selected **{default_config.stem}** (`{default_config}`)"
        ))
    globals()['EXPERIMENT_CONFIG'] = Path(config_dropdown.value)
else:
    display(Markdown(
        f"Widgets unavailable; using **{EXPERIMENT_CONFIG.stem}** (`{EXPERIMENT_CONFIG}`)"
    ))


In [None]:
from pathlib import Path
from IPython.display import display, Markdown

try:
    import ipywidgets as widgets  # type: ignore
except ImportError:
    widgets = None

if 'EXPERIMENT_CONFIG' not in globals():
    CONFIG_ROOT = Path('configs')
    configs = sorted(CONFIG_ROOT.glob('*.yaml'))
    if not configs:
        raise RuntimeError('No experiment configs found under `configs/`.')
    EXPERIMENT_CONFIG = configs[0]

widgets_ready = False
if widgets is not None:
    try:
        widgets_ready = ensure_widgets_ready()
    except NameError:
        widgets_ready = True

LOGGER_STATE = {
    'selection': 'none',
    'project': 'megacontext-poc',
    'run_name': '',
}

if not widgets_ready:
    print('ipywidgets not available; logging defaults to disabled.')
    display(Markdown(
        f"Logger: **disabled**<br>Config: **{EXPERIMENT_CONFIG.stem}** "
        f"(`{EXPERIMENT_CONFIG}`)"))
else:
    summary = widgets.HTML()
    logger_dropdown = widgets.Dropdown(
        options=[('Disabled', 'none'), ('Weights & Biases', 'wandb')],
        value=LOGGER_STATE['selection'],
        description='Logger:',
        layout=widgets.Layout(width='50%'),
    )
    project_text = widgets.Text(
        value=LOGGER_STATE['project'],
        description='Project:',
        layout=widgets.Layout(width='50%'),
    )
    run_text = widgets.Text(
        value=LOGGER_STATE['run_name'],
        placeholder='auto',
        description='Run name:',
        layout=widgets.Layout(width='50%'),
    )
    wandb_box = widgets.VBox([project_text, run_text])
    wandb_box.layout.display = 'none'

    def refresh_summary() -> None:
        summary.value = (
            f"<b>Logger:</b> {LOGGER_STATE['selection']}<br>"
            f"<b>Config:</b> {EXPERIMENT_CONFIG.stem} "
            f"(<code>{EXPERIMENT_CONFIG}</code>)"
        )

    def _update_logger_state(change):
        LOGGER_STATE['selection'] = logger_dropdown.value
        LOGGER_STATE['project'] = project_text.value
        LOGGER_STATE['run_name'] = run_text.value
        wandb_box.layout.display = (
            'flex' if LOGGER_STATE['selection'] == 'wandb' else 'none'
        )
        refresh_summary()

    logger_dropdown.observe(_update_logger_state, names='value')
    project_text.observe(_update_logger_state, names='value')
    run_text.observe(_update_logger_state, names='value')

    refresh_summary()
    display(summary)
    display(logger_dropdown)
    display(wandb_box)


## 1. Configure Storage

Select where artifacts, checkpoints, and logs are written.

In [None]:
from pathlib import Path
import os

from IPython.display import Markdown, display

DEFAULT_ARTIFACT_ROOT = Path(
    os.environ.get('MEGACONTEXT_ARTIFACT_ROOT', Path.cwd() / 'artifacts')
).expanduser().resolve()

ARTIFACT_ROOT = Path(
    globals().get('ARTIFACT_ROOT', DEFAULT_ARTIFACT_ROOT)
).expanduser().resolve()
ARTIFACT_ROOT.mkdir(parents=True, exist_ok=True)

widgets_ready = False
if 'widgets' in globals() and widgets is not None:
    try:
        widgets_ready = ensure_widgets_ready()
    except NameError:
        widgets_ready = True

if widgets_ready:
    artifact_input = widgets.Text(
        value=str(ARTIFACT_ROOT),
        description='Artifact root:',
        layout=widgets.Layout(width='80%'),
    )
    artifact_status = widgets.Output()

    def _sync_artifact(change):
        if change.get('name') != 'value':
            return
        chosen = Path(change['new']).expanduser().resolve()
        chosen.mkdir(parents=True, exist_ok=True)
        globals()['ARTIFACT_ROOT'] = chosen
        with artifact_status:
            artifact_status.clear_output()
            display(Markdown(f'Using artifact root `{chosen}`'))

    artifact_input.observe(_sync_artifact, names='value')
    display(widgets.VBox([artifact_input, artifact_status]))
    with artifact_status:
        artifact_status.clear_output()
        display(Markdown(f'Using artifact root `{ARTIFACT_ROOT}`'))
else:
    print(f'Artifact root: {ARTIFACT_ROOT}')

globals()['ARTIFACT_ROOT'] = ARTIFACT_ROOT


## 2. Preview Configuration

Review and, if needed, edit fields before running stages.


In [None]:
experiment_cfg = yaml.safe_load(EXPERIMENT_CONFIG.read_text(encoding='utf-8'))
display(Markdown(format_config_markdown(experiment_cfg)))


## 3. Dataset Preparation

Runs `tools.prepare_dataset.prepare_dataset_from_config` with tqdm progress bars. Skip this if the shard already exists.


In [None]:
from pathlib import Path
import glob
import subprocess
import yaml
from tqdm.auto import tqdm


def ensure_dataset_assets(config_path: Path) -> None:
    raw_config = yaml.safe_load(config_path.read_text())
    dataset_cfg = raw_config.get('dataset', raw_config)
    base_dir = config_path.parent
    splits = dataset_cfg.get('splits', {})
    if isinstance(splits, dict):
        iterable = splits.items()
    else:
        iterable = enumerate(splits or [])
    missing_patterns = []
    for split_name, split in tqdm(iterable, desc='Checking dataset assets'):
        if not isinstance(split, dict):
            continue
        pattern = (base_dir / split['source']).expanduser()
        matches = glob.glob(str(pattern), recursive=True)
        if matches:
            continue
        missing_patterns.append((split_name, pattern))
        if 'gutenberg' in str(pattern):
            print('Downloading Gutenberg subset (one-time download)...')
            subprocess.check_call(['bash', 'scripts/download_gutenberg.sh'])
            matches = glob.glob(str(pattern), recursive=True)
            if matches:
                continue
        print(f"No files matched pattern {pattern}")
    if missing_patterns:
        unresolved = '\n - '.join(f"{name}: {path}" for name, path in missing_patterns)
        raise FileNotFoundError(
            'Dataset assets missing. Patterns without matches:\n - ' + unresolved
        )


config_path_obj = EXPERIMENT_CONFIG if isinstance(EXPERIMENT_CONFIG, Path) else Path(EXPERIMENT_CONFIG)
ensure_dataset_assets(config_path_obj)




In [None]:
import os
from pathlib import Path

from IPython.display import Markdown, display

from tools.prepare_dataset import load_dataset_config, prepare_dataset_from_config

try:
    from tqdm.auto import tqdm
except ImportError:  # pragma: no cover
    tqdm = None

progress = None
if tqdm is not None:
    progress = tqdm(total=1, desc='Preparing dataset', leave=False, bar_format='{l_bar}{bar}| {elapsed}')
else:
    print('Preparing dataset (tqdm unavailable)...')
try:
    DATASET_RESULT = prepare_dataset_from_config(EXPERIMENT_CONFIG)
finally:
    if progress is not None:
        progress.update(1)
        progress.close()

display(Markdown(format_dataset_summary(DATASET_RESULT['splits'])))

CONFIG_MODEL = load_dataset_config(EXPERIMENT_CONFIG)
available_splits = list(CONFIG_MODEL.splits.keys())
default_split = 'train' if 'train' in available_splits else available_splits[0]
SPLIT_NAME = globals().get('SPLIT_NAME', default_split)
if SPLIT_NAME not in available_splits:
    SPLIT_NAME = default_split


def _resolve_dataset_path(split_name: str) -> Path:
    split_cfg = CONFIG_MODEL.splits[split_name]
    output_path_cfg = Path(split_cfg.output_path)
    if output_path_cfg.is_absolute():
        return output_path_cfg
    base_dir = EXPERIMENT_CONFIG.parent
    default_output = (base_dir / output_path_cfg).resolve()
    data_root_override = (
        Path(os.environ['MEGACONTEXT_DATA_ROOT']).expanduser().resolve()
        if 'MEGACONTEXT_DATA_ROOT' in os.environ
        else None
    )
    if data_root_override is not None:
        resolved_base = base_dir.resolve()
        repo_root = (
            resolved_base.parents[1]
            if len(resolved_base.parents) >= 2
            else resolved_base
        )
        try:
            relative = default_output.relative_to(repo_root)
        except ValueError:
            relative = default_output.name
        return (data_root_override / relative).resolve()
    return default_output


def _set_dataset_path(split_name: str) -> Path:
    path = _resolve_dataset_path(split_name)
    globals()['DATASET_PATH'] = path
    DATASET_RESULT['dataset_path'] = str(path)
    return path


DATASET_PATH = _set_dataset_path(SPLIT_NAME)

widgets_ready = False
if widgets is not None:
    try:
        widgets_ready = ensure_widgets_ready()
    except NameError:
        widgets_ready = True

if widgets_ready:
    split_dropdown = widgets.Dropdown(
        options=[(name, name) for name in available_splits],
        value=SPLIT_NAME,
        description='Split:',
        layout=widgets.Layout(width='40%'),
    )
    split_message = widgets.Output()

    def _sync_split(change):
        if change.get('name') != 'value':
            return
        selected = change['new']
        globals()['SPLIT_NAME'] = selected
        path = _set_dataset_path(selected)
        with split_message:
            split_message.clear_output()
            display(Markdown(f"Using split **{selected}** → `{path}`"))

    split_dropdown.observe(_sync_split, names='value')
    display(widgets.VBox([split_dropdown, split_message]))
    with split_message:
        split_message.clear_output()
        display(Markdown(f"Using split **{SPLIT_NAME}** → `{DATASET_PATH}`"))
else:
    print(f"Using split '{SPLIT_NAME}' at {DATASET_PATH}")



### Sample Example

Peek at the first prepared context to sanity-check tokens and horizons.


In [None]:
import pyarrow.ipc as pa_ipc
from pathlib import Path
from IPython.display import Markdown, display

if not Path(DATASET_PATH).exists():
    missing_msg = (
        f"Dataset shard not found at {DATASET_PATH}. "
        "Run the prep cell first."
    )
    raise FileNotFoundError(missing_msg)

with pa_ipc.open_file(DATASET_PATH) as reader:
    if reader.num_record_batches == 0:
        print('Dataset is empty.')
    else:
        batch = reader.get_batch(0)
        table_dict = batch.to_pydict()
        context_col = 'context_input_ids'
        future_col = 'future_input_ids'
        missing = [col for col in (context_col, future_col) if col not in table_dict]
        if missing:
            display(Markdown(
                f"Dataset preview skipped; missing columns: {', '.join(missing)}"
            ))
        else:
            context_tokens = table_dict[context_col][0][:32]
            future_tokens = table_dict[future_col][0][:16]
            summary_html = (
                f"Context tokens (first 32): {context_tokens}<br><br>"
                f"Future tokens (first 16): {future_tokens}"
            )
            display(Markdown(summary_html))



## 4. Configure GistNet Training

Hidden size defaults to the teacher embedding width reported during dataset prep (set `gistnet.model.hidden_size` explicitly to override).


In [None]:
import dataclasses
from copy import deepcopy

from megacontext.gistnet import BaseModelSettings, GistNetConfig, GistNetTrainingConfig

dataset_hidden_size = DATASET_RESULT['splits'][SPLIT_NAME]['teacher_hidden_size']
gistnet_cfg = experiment_cfg.get('gistnet', {})
model_dict = deepcopy(gistnet_cfg.get('model', {}))
if model_dict.get('hidden_size') == 'auto':
    if dataset_hidden_size:
        model_dict['hidden_size'] = dataset_hidden_size
    else:
        raise ValueError(
            'Dataset summary did not report a teacher hidden size; '
            'set gistnet.model.hidden_size explicitly.'
        )
MODEL_CONFIG = GistNetConfig(**model_dict)

training_dict = deepcopy(gistnet_cfg.get('training', {}))
TRAINING_CONFIG = GistNetTrainingConfig.from_dict(training_dict)
if TRAINING_CONFIG.base_model is None and 'base_model' in experiment_cfg:
    TRAINING_CONFIG = dataclasses.replace(
        TRAINING_CONFIG,
        base_model=BaseModelSettings.from_dict(experiment_cfg['base_model']),
    )
display(Markdown(format_training_config(TRAINING_CONFIG)))


### Optional Overrides

Adjust batch size or per-phase learning rates/steps without editing YAML.


In [None]:
if widgets is None:
    print(
        'ipywidgets not available; use dataclasses.replace(...) to '
        'override training settings manually.'
    )
else:
    batch_slider = widgets.IntSlider(
        value=TRAINING_CONFIG.batch_size,
        min=1,
        max=128,
        step=1,
        description='Batch size:',
    )
    log_every_slider = widgets.IntSlider(
        value=max(1, TRAINING_CONFIG.log_every_n_steps),
        min=1,
        max=max(1, TRAINING_CONFIG.log_every_n_steps * 2),
        step=1,
        description='Log every:',
    )
    phase_widgets = []
    for phase in TRAINING_CONFIG.phases:
        steps_slider = widgets.IntSlider(
            value=phase.max_steps,
            min=1,
            max=max(phase.max_steps, 100),
            step=1,
            description='Steps',
        )
        lr_slider = widgets.FloatLogSlider(
            value=phase.lr,
            base=10,
            min=-6,
            max=0,
            step=0.1,
            description='LR',
        )
        window_slider = widgets.IntSlider(
            value=phase.window_tokens,
            min=MODEL_CONFIG.block_size,
            max=max(phase.window_tokens, MODEL_CONFIG.block_size * 64),
            step=MODEL_CONFIG.block_size,
            description='Window',
        )
        phase_box = widgets.VBox([
            widgets.HTML(f'<b>{phase.name}</b> ({phase.objective})'),
            steps_slider,
            lr_slider,
            window_slider,
        ])
        phase_widgets.append((phase, phase_box))
    apply_button = widgets.Button(description='Apply overrides', button_style='success')
    overrides_output = widgets.Output()

    def _apply_overrides(_):
        global TRAINING_CONFIG
        phases = []
        for base_phase, phase_box in phase_widgets:
            steps_slider, lr_slider, window_slider = phase_box.children[1:]
            phases.append(
                dataclasses.replace(
                    base_phase,
                    max_steps=int(steps_slider.value),
                    lr=float(lr_slider.value),
                    window_tokens=int(window_slider.value),
                )
            )
        TRAINING_CONFIG = dataclasses.replace(
            TRAINING_CONFIG,
            batch_size=int(batch_slider.value),
            log_every_n_steps=int(log_every_slider.value),
            phases=tuple(phases),
        )
        overrides_output.clear_output()
        with overrides_output:
            display(
                Markdown(
                    format_training_config(
                        TRAINING_CONFIG, heading='Updated Training Config'
                    )
                )
            )

    apply_button.on_click(_apply_overrides)
    controls = widgets.VBox(
        [batch_slider, log_every_slider]
        + [box for _, box in phase_widgets]
        + [apply_button, overrides_output]
    )
    display(controls)


### Reproducibility & Seeding

Ensures deterministic behaviour where possible and records the seed in the run summary.


In [None]:
import os
import random

import numpy as np
import torch
from IPython.display import Markdown, display

if 'RUN_SEED' not in globals():
    RUN_SEED = int(os.environ.get('MEGACONTEXT_SEED', 42))

random.seed(RUN_SEED)
os.environ['PYTHONHASHSEED'] = str(RUN_SEED)
np.random.seed(RUN_SEED)
torch.manual_seed(RUN_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RUN_SEED)
    try:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except AttributeError:
        pass
try:
    torch.use_deterministic_algorithms(True)
except (AttributeError, RuntimeError):
    pass

display(Markdown(f'Seeding complete with seed `{RUN_SEED}`.'))


### Checkpoints & Resume

Choose whether to restart or continue from an existing checkpoint.

In [None]:
from pathlib import Path
from IPython.display import Markdown, display

RUNS_DIR = (ARTIFACT_ROOT / 'gistnet').expanduser().resolve()
RUNS_DIR.mkdir(parents=True, exist_ok=True)
globals()['RUNS_DIR'] = RUNS_DIR


def _discover_checkpoints(root: Path) -> list[Path]:
    if not root.exists():
        return []
    return sorted(
        root.glob('**/*.ckpt'),
        key=lambda path: path.stat().st_mtime,
        reverse=True,
    )


available_checkpoints = _discover_checkpoints(RUNS_DIR)
default_resume = available_checkpoints[0] if available_checkpoints else None
existing_resume = globals().get('RESUME_CHECKPOINT')
if isinstance(existing_resume, str) and existing_resume:
    existing_resume = Path(existing_resume)
RESUME_CHECKPOINT = existing_resume or default_resume

widgets_ready = False
if 'widgets' in globals() and widgets is not None:
    try:
        widgets_ready = ensure_widgets_ready()
    except NameError:
        widgets_ready = True


def _checkpoint_label(path: Path) -> str:
    try:
        relative = path.relative_to(RUNS_DIR)
    except ValueError:
        relative = path.name
    return str(relative)


if widgets_ready and available_checkpoints:
    options = [('Do not resume', '')]
    options.extend(((_checkpoint_label(path), str(path)) for path in available_checkpoints))
    resume_dropdown = widgets.Dropdown(
        options=options,
        value=str(RESUME_CHECKPOINT) if RESUME_CHECKPOINT else '',
        description='Resume from:',
        layout=widgets.Layout(width='70%'),
    )
    resume_status = widgets.Output()

    def _sync_resume(change):
        if change.get('name') != 'value':
            return
        selected = change['new']
        globals()['RESUME_CHECKPOINT'] = Path(selected) if selected else None
        with resume_status:
            resume_status.clear_output()
            if selected:
                display(Markdown(f'Resuming from `{_checkpoint_label(Path(selected))}`'))
            else:
                display(Markdown('Starting a new run.'))

    resume_dropdown.observe(_sync_resume, names='value')
    display(widgets.VBox([resume_dropdown, resume_status]))
    with resume_status:
        resume_status.clear_output()
        if RESUME_CHECKPOINT:
            display(Markdown(f'Resuming from `{_checkpoint_label(Path(RESUME_CHECKPOINT))}`'))
        else:
            display(Markdown('Starting a new run.'))
else:
    if RESUME_CHECKPOINT:
        print(f'Resume checkpoint: {_checkpoint_label(Path(RESUME_CHECKPOINT))}')
    elif available_checkpoints:
        fallback = _checkpoint_label(available_checkpoints[0])
        print(f'Checkpoints available (default `{fallback}`) — set RESUME_CHECKPOINT to resume.')
    else:
        print('No checkpoints found; starting a new run.')

globals()['RESUME_CHECKPOINT'] = RESUME_CHECKPOINT


## 5. Build Lightning Components

Adds a rich progress bar and metrics tracker; enable WandB above to stream logs.


In [None]:
from datetime import datetime
from pathlib import Path

from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar  # type: ignore

from megacontext.gistnet import build_gistnet_experiment

metrics_callback = MetricsTracker(metric_keys=(
    'train/loss',
    'train/delta_loss',
    'train/gist_loss',
    'train/baseline_loss',
))

RUNS_DIR = globals().get('RUNS_DIR', (ARTIFACT_ROOT / 'gistnet').resolve())
RUNS_DIR.mkdir(parents=True, exist_ok=True)
run_name = (LOGGER_STATE.get('run_name') or '').strip()
if not run_name:
    run_name = f"{EXPERIMENT_CONFIG.stem}-{datetime.now():%Y%m%d-%H%M%S}"
run_name = run_name.replace(' ', '-')
RUN_NAME = run_name
RUN_DIR = (RUNS_DIR / run_name).resolve()
RUN_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR = (RUN_DIR / 'checkpoints').resolve()
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print(f'Artifacts for run `{RUN_NAME}` stored under {RUN_DIR}')

checkpoint_callback = ModelCheckpoint(
    dirpath=str(CHECKPOINT_DIR),
    filename='step-{step:06d}',
    monitor='train/loss',
    mode='min',
    save_last=True,
    save_top_k=3,
    every_n_train_steps=max(1, TRAINING_CONFIG.log_every_n_steps),
    auto_insert_metric_name=False,
)

callbacks = [metrics_callback, RichProgressBar(), checkpoint_callback]
LOGGER = None
try:
    LOGGER = build_logger(
        selection=LOGGER_STATE.get('selection', 'none'),
        project=LOGGER_STATE.get('project'),
        run_name=LOGGER_STATE.get('run_name'),
        config={'config_path': str(EXPERIMENT_CONFIG)},
    )
except RuntimeError as exc:
    print(exc)

trainer_kwargs = {
    'accelerator': 'auto',
    'devices': 1,
    'default_root_dir': str(RUN_DIR),
}

TRAINER, MODULE, DATA_MODULE = build_gistnet_experiment(
    dataset_path=DATASET_PATH,
    model_config=MODEL_CONFIG,
    training=TRAINING_CONFIG,
    callbacks=callbacks,
    logger=LOGGER,
    trainer_kwargs=trainer_kwargs,
)

globals()['RUN_DIR'] = RUN_DIR
globals()['CHECKPOINT_DIR'] = CHECKPOINT_DIR
globals()['RUN_NAME'] = RUN_NAME
TRAINER


## 6. Launch Training

Run this cell to start the Lightning loop. Progress appears below and (optionally) in WandB.


In [None]:
from pathlib import Path

resume_path = globals().get('RESUME_CHECKPOINT')
if isinstance(resume_path, str) and not resume_path:
    resume_path = None
if resume_path is not None and not isinstance(resume_path, Path):
    resume_path = Path(resume_path)
if resume_path is not None and not resume_path.exists():
    raise FileNotFoundError(f'Resume checkpoint not found: {resume_path}')

if resume_path:
    print(f'Resuming from checkpoint: {resume_path}')
else:
    print('Starting a new training run.')

TRAINER.fit(MODULE, DATA_MODULE, ckpt_path=str(resume_path) if resume_path else None)
best_path = Path(TRAINER.checkpoint_callback.best_model_path) if TRAINER.checkpoint_callback.best_model_path else None
last_path = Path(TRAINER.checkpoint_callback.last_model_path) if getattr(TRAINER.checkpoint_callback, 'last_model_path', None) else None
globals()['LATEST_CHECKPOINT'] = best_path
globals()['LAST_CHECKPOINT'] = last_path




## 7. Visualise Metrics

Plots the captured metrics (requires matplotlib).


In [None]:
metrics_callback.plot(figsize=(7, 4))


## 8. Summarise & Save

Captures final metrics and writes a JSON summary under `artifacts/experiments/`.


In [None]:
from pathlib import Path
from IPython.display import Markdown

final_metrics = {k: float(v) for k, v in TRAINER.callback_metrics.items()}
if 'RUN_SEED' in globals():
    final_metrics.setdefault('seed', RUN_SEED)
display(Markdown(format_training_summary(final_metrics)))
seed_value = globals().get('RUN_SEED')
resume_path = globals().get('RESUME_CHECKPOINT')
run_dir = globals().get('RUN_DIR')
checkpoint_dir = globals().get('CHECKPOINT_DIR')
latest_checkpoint = globals().get('LATEST_CHECKPOINT')
last_checkpoint = globals().get('LAST_CHECKPOINT')
summary_root = (ARTIFACT_ROOT / 'experiments').resolve()
SUMMARY_PATH = save_experiment_summary(
    output_dir=summary_root,
    config_path=EXPERIMENT_CONFIG,
    dataset_summary=DATASET_RESULT['splits'],
    training_metrics=final_metrics,
    artifacts={
        'dataset_path': DATASET_RESULT.get('dataset_path'),
        'default_root_dir': str(run_dir) if run_dir else None,
        'checkpoint_dir': str(checkpoint_dir) if checkpoint_dir else None,
        'resume_from': str(resume_path) if resume_path else None,
        'latest_checkpoint': str(latest_checkpoint) if latest_checkpoint else None,
        'last_checkpoint': str(last_checkpoint) if last_checkpoint else None,
        'seed': seed_value,
    },
)
SUMMARY_PATH
