# MegaContext Research Console

Interactively prepare datasets, train GistNet with PyTorch Lightning, and capture experiment artifacts. Key docs: [GistNet](https://brandf.github.io/MegaContext/architecture/components/GistNet), [GistNet Training](https://brandf.github.io/MegaContext/architecture/components/GistNet%20Training), [Telemetry](https://brandf.github.io/MegaContext/ops/Telemetry), [Alternating Optimization](https://brandf.github.io/MegaContext/ops/Alternating%20Optimization).


## Quick Start

When running in Google Colab, execute the bootstrap cell below to clone the repo and install dependencies in the current runtime. Local Jupyter environments can skip it.


In [None]:
import importlib
import os
import sys
from pathlib import Path
from functools import lru_cache

COLAB = 'google.colab' in sys.modules
if COLAB:
    repo_url = 'https://github.com/brandf/MegaContext.git'
    workspace = Path('/content/MegaContext')
    if not workspace.exists():
        !git clone $repo_url $workspace
    else:
        !git -C $workspace pull --ff-only
    %cd /content/MegaContext
    print('Python executable:', sys.executable)
    cuda_index = os.environ.get('PYTORCH_CUDA_INDEX', 'https://download.pytorch.org/whl/cu121')

    from IPython import get_ipython

    shell = get_ipython()
    if shell is None:
        raise RuntimeError('IPython shell not found; cannot run pip installs.')

    def pip_install(args: str) -> None:
        shell.run_line_magic('pip', args)

    ready_flag = workspace / '.megacontext_bootstrapped'
    force_reinstall = os.environ.get('MEGACONTEXT_FORCE_REINSTALL', '').lower() in {'1', 'true', 'yes'}
    needs_install = force_reinstall or not ready_flag.exists()

    def install_dependencies() -> None:
        print('Installing MegaContext dependencies...')
        for module_name in list(sys.modules):
            if module_name.startswith('megacontext') or module_name.startswith('lightning'):
                sys.modules.pop(module_name, None)
        pip_install('install --upgrade pip setuptools wheel')
        pip_install(f'install torch torchvision torchaudio --index-url {cuda_index}')
        pip_install('install lightning')
        pip_install('install -e .[dev]')
        pip_install('install ipywidgets==7.7.1')
        ready_flag.write_text('ok', encoding='utf-8')

    if needs_install:
        install_dependencies()
    else:
        print(
            f'Using existing environment at {workspace} '
            '(set MEGACONTEXT_FORCE_REINSTALL=1 to rebuild).'
        )

    src_path = workspace / 'src'
    if str(src_path) not in sys.path:
        sys.path.append(str(src_path))
    importlib.invalidate_caches()
    try:
        import torch
        import megacontext  # noqa: F401
        import lightning  # noqa: F401
    except ModuleNotFoundError as exc:
        print(f'Missing dependency ({exc}); reinstalling...')
        install_dependencies()
        import torch
        import megacontext  # noqa: F401
        import lightning  # noqa: F401
    print('Torch version:', torch.__version__)
    print('CUDA available:', torch.cuda.is_available())
    print('Colab environment ready.')
else:
    print('Colab bootstrap skipped (not running in google.colab).')

try:
    import ipywidgets as widgets  # type: ignore
except ImportError:  # pragma: no cover
    widgets = None

@lru_cache(maxsize=1)
def ensure_widgets_ready() -> bool:
    'Enable widget infrastructure when available.'
    if widgets is None:
        return False
    if COLAB:
        try:
            from google.colab import output  # type: ignore
        except ImportError:  # pragma: no cover
            return True
        output.enable_custom_widget_manager()
    return True



## 0. Setup Console

Configure experiment settings, storage, logging, and reproducibility before running the pipeline.


In [None]:
from pathlib import Path
import os
import random

import numpy as np
import yaml
from IPython.display import Markdown, display

from megacontext.notebook import (
    MetricsTracker,
    build_logger,
    collect_environment_report,
    format_config_markdown,
    format_dataset_summary,
    format_training_config,
    format_training_summary,
    render_environment_report,
    save_experiment_summary,
)

try:
    import ipywidgets as widgets  # type: ignore
except ImportError:  # pragma: no cover
    widgets = None


def _resolve_path(value):
    if value in (None, ""):
        return None
    return Path(value).expanduser().resolve()


def launch_setup_console() -> None:
    global LOGGER_STATE
    if widgets is None:
        print('ipywidgets not available — falling back to defaults.')
        config_root = Path('configs')
        configs = sorted(config_root.glob('*.yaml'))
        if not configs:
            raise RuntimeError('No experiment configs found under `configs/`.')
        selected_config = configs[0]
        globals()['EXPERIMENT_CONFIG'] = selected_config
        globals()['experiment_cfg'] = yaml.safe_load(selected_config.read_text(encoding='utf-8'))
        env_report = collect_environment_report()
        print(render_environment_report(env_report))
        artifact_root = _resolve_path(os.environ.get('MEGACONTEXT_ARTIFACT_ROOT')) or (Path.cwd() / 'artifacts').resolve()
        artifact_root.mkdir(parents=True, exist_ok=True)
        globals()['ARTIFACT_ROOT'] = artifact_root
        os.environ['MEGACONTEXT_ARTIFACT_ROOT'] = str(artifact_root)
        LOGGER_STATE = {
            'selection': os.environ.get('MEGACONTEXT_LOGGER', 'none'),
            'project': os.environ.get('MEGACONTEXT_LOGGER_PROJECT', 'megacontext-poc'),
            'run_name': os.environ.get('MEGACONTEXT_LOGGER_RUN', ''),
        }
        seed = int(os.environ.get('MEGACONTEXT_SEED', 42))
        globals()['RUN_SEED'] = seed
        print(f'Selected config: {selected_config}')
        print(f'Artifact root: {artifact_root}')
        print(f'Seed: {seed}')
        return

    ensure_widgets_ready()

    config_root = Path('configs')
    configs = sorted(config_root.glob('*.yaml'))
    if not configs:
        raise RuntimeError('No experiment configs found under `configs/`.')

    selected_config = globals().get('EXPERIMENT_CONFIG', configs[0])
    if isinstance(selected_config, str):
        selected_config = Path(selected_config)
    if selected_config not in configs:
        selected_config = configs[0]

    artifact_default = _resolve_path(os.environ.get('MEGACONTEXT_ARTIFACT_ROOT')) or (Path.cwd() / 'artifacts').resolve()
    artifact_root = _resolve_path(globals().get('ARTIFACT_ROOT')) or artifact_default
    artifact_root.mkdir(parents=True, exist_ok=True)
    os.environ['MEGACONTEXT_ARTIFACT_ROOT'] = str(artifact_root)
    globals()['ARTIFACT_ROOT'] = artifact_root

    data_root_initial = _resolve_path(os.environ.get('MEGACONTEXT_DATA_ROOT'))

    LOGGER_STATE = globals().get('LOGGER_STATE', {
        'selection': 'none',
        'project': 'megacontext-poc',
        'run_name': '',
    })
    LOGGER_STATE = {
        'selection': LOGGER_STATE.get('selection', 'none'),
        'project': LOGGER_STATE.get('project', 'megacontext-poc'),
        'run_name': LOGGER_STATE.get('run_name', ''),
    }

    env_report = collect_environment_report()
    env_output = widgets.Output()
    with env_output:
        display(Markdown(render_environment_report(env_report)))

    seed_initial = int(globals().get('RUN_SEED', os.environ.get('MEGACONTEXT_SEED', 42)))
    state = {
        'artifact_root': str(artifact_root),
        'data_root': str(data_root_initial) if data_root_initial else '',
        'resume_checkpoint': globals().get('RESUME_CHECKPOINT'),
        'seed': seed_initial,
    }

    config_dropdown = widgets.Dropdown(
        options=[(cfg.stem, str(cfg)) for cfg in configs],
        value=str(selected_config),
        description='Experiment:',
        layout=widgets.Layout(width='60%'),
    )
    config_preview = widgets.Output()

    def _load_config(path: Path) -> None:
        cfg = yaml.safe_load(path.read_text(encoding='utf-8'))
        globals()['EXPERIMENT_CONFIG'] = path
        globals()['experiment_cfg'] = cfg
        config_preview.clear_output()
        with config_preview:
            display(Markdown(format_config_markdown(cfg)))

    _load_config(selected_config)

    def update_summary(status: str | None = None) -> None:
        summary_lines = [
            f"- **Config:** `{globals().get('EXPERIMENT_CONFIG')}`",
            f"- **Artifact root:** `{state['artifact_root']}`",
        ]
        if state['data_root']:
            summary_lines.append(f"- **Data root:** `{state['data_root']}`")
        else:
            summary_lines.append('- **Data root:** config-relative')
        summary_lines.extend([
            f"- **Logger:** {LOGGER_STATE.get('selection', 'none')}",
            f"- **Resume checkpoint:** `{state['resume_checkpoint']}`" if state['resume_checkpoint'] else '- **Resume checkpoint:** none',
            f"- **WANDB token:** {'set' if os.environ.get('WANDB_API_KEY') else 'not set'}",
            f"- **HF token:** {'set' if os.environ.get('HUGGINGFACE_HUB_TOKEN') else 'not set'}",
            f"- **Seed:** `{state['seed']}`",
        ])
        summary_output.clear_output()
        with summary_output:
            if status:
                display(Markdown(status))
            display(Markdown('\n'.join(summary_lines)))

    def _on_config_change(change):
        if change['name'] != 'value':
            return
        _load_config(Path(change['new']))
        update_summary('Experiment config updated.')

    config_dropdown.observe(_on_config_change, names='value')

    wandb_input = widgets.Text(
        value=os.environ.get('WANDB_API_KEY', ''),
        description='W&B token:',
        layout=widgets.Layout(width='70%'),
        placeholder='Optional',
    )
    hf_input = widgets.Text(
        value=os.environ.get('HUGGINGFACE_HUB_TOKEN', ''),
        description='HF token:',
        layout=widgets.Layout(width='70%'),
        placeholder='Optional',
    )

    def _sync_wandb(change):
        if change['name'] != 'value':
            return
        os.environ['WANDB_API_KEY'] = change['new'].strip()
        update_summary()

    def _sync_hf(change):
        if change['name'] != 'value':
            return
        os.environ['HUGGINGFACE_HUB_TOKEN'] = change['new'].strip()
        update_summary()

    wandb_input.observe(_sync_wandb, names='value')
    hf_input.observe(_sync_hf, names='value')

    artifact_input = widgets.Text(
        value=str(artifact_root),
        description='Artifact root:',
        layout=widgets.Layout(width='70%'),
    )
    artifact_apply = widgets.Button(description='Apply', button_style='info', icon='save')
    artifact_message = widgets.Output()

    data_input = widgets.Text(
        value=str(data_root_initial) if data_root_initial else '',
        description='Data root:',
        layout=widgets.Layout(width='70%'),
        placeholder='Optional — defaults to config-relative paths',
    )
    data_apply = widgets.Button(description='Apply', button_style='info', icon='save')
    data_message = widgets.Output()

    resume_status = widgets.HTML('Starting a new run.')
    resume_dropdown = widgets.Dropdown(
        description='Resume from:',
        layout=widgets.Layout(width='70%'),
    )

    summary_output = widgets.Output()

    def refresh_resume_options() -> None:
        runs_dir = Path(state['artifact_root']) / 'gistnet'
        runs_dir.mkdir(parents=True, exist_ok=True)
        globals()['RUNS_DIR'] = runs_dir
        checkpoints = sorted(
            runs_dir.glob('**/*.ckpt'),
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )
        options = [('Start fresh', '')]
        for path in checkpoints:
            try:
                label = str(path.relative_to(Path(state['artifact_root'])))
            except ValueError:
                label = str(path)
            options.append((label, str(path)))
        current = str(state['resume_checkpoint']) if state['resume_checkpoint'] else ''
        values = [value for _, value in options]
        if current not in values:
            current = ''
        resume_dropdown.options = options
        resume_dropdown.value = current
        if current:
            resume_status.value = f'Resuming from <code>{current}</code>'
        else:
            resume_status.value = 'Starting a new run.'

    def _apply_artifact_path(_=None) -> None:
        raw = artifact_input.value.strip()
        if not raw:
            artifact_input.value = state['artifact_root']
            update_summary('Artifact root cannot be empty; keeping current path.')
            return
        new_path = Path(raw).expanduser().resolve()
        if str(new_path) == state['artifact_root']:
            return
        new_path.mkdir(parents=True, exist_ok=True)
        state['artifact_root'] = str(new_path)
        globals()['ARTIFACT_ROOT'] = new_path
        os.environ['MEGACONTEXT_ARTIFACT_ROOT'] = str(new_path)
        with artifact_message:
            artifact_message.clear_output()
            display(Markdown(f'Artifacts stored under `{new_path}`.'))
        refresh_resume_options()
        update_summary('Artifact root updated.')

    artifact_apply.on_click(_apply_artifact_path)
    artifact_input.on_submit(lambda _: _apply_artifact_path())

    def _apply_data_path(_=None) -> None:
        raw = data_input.value.strip()
        if not raw:
            os.environ.pop('MEGACONTEXT_DATA_ROOT', None)
            state['data_root'] = ''
            with data_message:
                data_message.clear_output()
                display(Markdown('Using dataset paths relative to the experiment config.'))
            update_summary('Data root cleared.')
            return
        path = Path(raw).expanduser().resolve()
        if str(path) == state['data_root']:
            return
        path.mkdir(parents=True, exist_ok=True)
        os.environ['MEGACONTEXT_DATA_ROOT'] = str(path)
        state['data_root'] = str(path)
        with data_message:
            data_message.clear_output()
            display(Markdown(f'Dataset root override set to `{path}`.'))
        update_summary('Data root updated.')

    data_apply.on_click(_apply_data_path)
    data_input.on_submit(lambda _: _apply_data_path())

    def _on_resume_change(change):
        if change['name'] != 'value':
            return
        selected = change['new']
        state['resume_checkpoint'] = Path(selected) if selected else None
        globals()['RESUME_CHECKPOINT'] = state['resume_checkpoint']
        if selected:
            resume_status.value = f'Resuming from <code>{selected}</code>'
        else:
            resume_status.value = 'Starting a new run.'
        update_summary()

    resume_dropdown.observe(_on_resume_change, names='value')

    logger_dropdown = widgets.Dropdown(
        options=[('Disabled', 'none'), ('Weights & Biases', 'wandb')],
        value=LOGGER_STATE['selection'],
        description='Logger:',
        layout=widgets.Layout(width='50%'),
    )
    project_text = widgets.Text(
        value=LOGGER_STATE['project'],
        description='Project:',
        layout=widgets.Layout(width='50%'),
    )
    run_text = widgets.Text(
        value=LOGGER_STATE['run_name'],
        placeholder='auto',
        description='Run name:',
        layout=widgets.Layout(width='50%'),
    )
    logger_summary = widgets.HTML()

    def update_logger_state(_=None):
        LOGGER_STATE['selection'] = logger_dropdown.value
        LOGGER_STATE['project'] = project_text.value.strip()
        LOGGER_STATE['run_name'] = run_text.value.strip()
        logger_summary.value = (
            f"<b>Logger:</b> {LOGGER_STATE['selection']}<br>"
            f"<b>Project:</b> {LOGGER_STATE['project'] or '—'}<br>"
            f"<b>Run name:</b> {LOGGER_STATE['run_name'] or 'auto'}"
        )
        update_summary()

    for widget in (logger_dropdown, project_text, run_text):
        widget.observe(update_logger_state, names='value')
    update_logger_state()

    seed_input = widgets.IntText(
        value=seed_initial,
        description='Global seed:',
        layout=widgets.Layout(width='30%'),
    )
    apply_button = widgets.Button(description='Apply settings', button_style='success', icon='check')

    def apply_settings(_=None):
        try:
            seed = int(seed_input.value)
        except (TypeError, ValueError):
            seed_input.value = state['seed']
            update_summary('Seed must be an integer; reverting to the previous value.')
            return
        state['seed'] = seed
        globals()['RUN_SEED'] = seed
        os.environ['MEGACONTEXT_SEED'] = str(seed)
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        try:
            import torch
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
                try:
                    torch.backends.cudnn.deterministic = True
                    torch.backends.cudnn.benchmark = False
                except AttributeError:
                    pass
            try:
                torch.use_deterministic_algorithms(True)
            except (AttributeError, RuntimeError):
                pass
        except ImportError:
            pass
        update_summary(f'Seeded runtime with `{seed}`.')

    apply_button.on_click(apply_settings)

    refresh_resume_options()
    update_summary()

    overview_box = widgets.VBox([env_output, summary_output])
    experiment_box = widgets.VBox([config_dropdown, config_preview])
    storage_box = widgets.VBox([
        widgets.HBox([artifact_input, artifact_apply]),
        artifact_message,
        widgets.HBox([data_input, data_apply]),
        data_message,
        resume_dropdown,
        resume_status,
    ])
    logging_box = widgets.VBox([logger_dropdown, project_text, run_text, logger_summary])
    credentials_box = widgets.VBox([wandb_input, hf_input])
    reproducibility_box = widgets.VBox([seed_input, apply_button])

    tabs = widgets.Tab(children=[overview_box, experiment_box, storage_box, logging_box, credentials_box, reproducibility_box])
    titles = ['Overview', 'Experiment', 'Storage', 'Logging', 'Credentials', 'Reproducibility']
    for idx, title in enumerate(titles):
        tabs.set_title(idx, title)

    apply_settings()
    globals()['SETUP_CONSOLE'] = tabs
    return tabs


console_widget = launch_setup_console()
if console_widget is not None:
    display(console_widget)
    print('Setup console ready — adjust settings via the tabs above.')







## 1. Dataset Preparation

Runs `tools.prepare_dataset.prepare_dataset_from_config` with tqdm progress bars. Skip this if the shard already exists.


In [None]:
from pathlib import Path
import glob
import subprocess
import yaml
from tqdm.auto import tqdm


def ensure_dataset_assets(config_path: Path) -> None:
    raw_config = yaml.safe_load(config_path.read_text())
    dataset_cfg = raw_config.get('dataset', raw_config)
    base_dir = config_path.parent
    splits = dataset_cfg.get('splits', {})
    if isinstance(splits, dict):
        iterable = splits.items()
    else:
        iterable = enumerate(splits or [])
    missing_patterns = []
    for split_name, split in tqdm(iterable, desc='Checking dataset assets'):
        if not isinstance(split, dict):
            continue
        pattern = (base_dir / split['source']).expanduser()
        matches = glob.glob(str(pattern), recursive=True)
        if matches:
            continue
        if 'gutenberg' in str(pattern):
            print('Downloading Gutenberg subset (one-time download)...')
            script_path = Path('tools/download_gutenberg.sh')
            if not script_path.exists():
                raise FileNotFoundError(
                    'Expected tools/download_gutenberg.sh to exist. '
                    'Populate the Gutenberg corpus manually or restore the helper script.'
                )
            glob_str = str(pattern)
            stop_idx = len(glob_str)
            for token in ('*', '?', '['):
                idx = glob_str.find(token)
                if idx != -1:
                    stop_idx = min(stop_idx, idx)
            target_dir = Path(glob_str[:stop_idx]).resolve()
            if target_dir.suffix:
                target_dir = target_dir.parent
            target_dir.mkdir(parents=True, exist_ok=True)
            try:
                subprocess.check_call(['bash', str(script_path), str(target_dir)])
            except FileNotFoundError as err:
                raise RuntimeError('`bash` not found while running Gutenberg download script.') from err
            except subprocess.CalledProcessError as err:
                raise RuntimeError(
                    'Gutenberg download script failed. Inspect the output above or run "bash tools/download_gutenberg.sh" manually to diagnose.'
                ) from err
            matches = glob.glob(str(pattern), recursive=True)
            if matches:
                continue
        missing_patterns.append((split_name, pattern))
        print(f"No files matched pattern {pattern}")
    if missing_patterns:
        unresolved = '\n - '.join(f"{name}: {path}" for name, path in missing_patterns)
        raise FileNotFoundError(
            'Dataset assets missing. Patterns without matches:\n - ' + unresolved
        )


config_path_obj = EXPERIMENT_CONFIG if isinstance(EXPERIMENT_CONFIG, Path) else Path(EXPERIMENT_CONFIG)
ensure_dataset_assets(config_path_obj)









In [None]:
import os
from pathlib import Path

import yaml
from IPython.display import Markdown, display

from megacontext.notebook import format_dataset_summary
from tools.prepare_dataset import load_dataset_config, prepare_dataset_from_config

try:
    from tqdm.auto import tqdm
except ImportError:  # pragma: no cover
    tqdm = None

CONFIG_MODEL = load_dataset_config(EXPERIMENT_CONFIG)


def _resolve_dataset_path(split_name: str) -> Path:
    split_cfg = CONFIG_MODEL.splits[split_name]
    output_path_cfg = Path(split_cfg.output_path)
    if output_path_cfg.is_absolute():
        return output_path_cfg
    base_dir = EXPERIMENT_CONFIG.parent
    default_output = (base_dir / output_path_cfg).resolve()
    data_root_override = os.environ.get('MEGACONTEXT_DATA_ROOT')
    if data_root_override:
        resolved_base = base_dir.resolve()
        repo_root = (
            resolved_base.parents[1]
            if len(resolved_base.parents) >= 2
            else resolved_base
        )
        try:
            relative = default_output.relative_to(repo_root)
        except ValueError:
            relative = default_output.name
        return (Path(data_root_override).expanduser().resolve() / relative).resolve()
    return default_output


expected_outputs = {name: _resolve_dataset_path(name) for name in CONFIG_MODEL.splits}
metadata_path = CONFIG_MODEL.metadata_path()
if metadata_path.suffix not in {'.yaml', '.yml'}:
    metadata_path = metadata_path.with_suffix('.yaml')
metadata_path = metadata_path.resolve()
force_rebuild = os.environ.get('MEGACONTEXT_FORCE_DATA_REBUILD', '').lower() in {'1', 'true', 'yes'}
missing_outputs = [name for name, path in expected_outputs.items() if not path.exists()]
metadata_missing = not metadata_path.exists()

if force_rebuild:
    print('Forcing dataset rebuild because MEGACONTEXT_FORCE_DATA_REBUILD is set.')
if missing_outputs:
    print('Dataset outputs missing for splits: ' + ', '.join(missing_outputs))
if metadata_missing and not (force_rebuild or missing_outputs):
    print(f'Metadata not found at {metadata_path}; regenerating dataset summaries.')

run_prepare = force_rebuild or bool(missing_outputs) or metadata_missing

DATASET_RESULT = None
if run_prepare:
    progress = None
    if tqdm is not None:
        progress = tqdm(total=1, desc='Preparing dataset', leave=False, bar_format='{l_bar}{bar}| {elapsed}')
    else:
        print('Preparing dataset (tqdm unavailable)...')
    try:
        DATASET_RESULT = prepare_dataset_from_config(EXPERIMENT_CONFIG)
    finally:
        if tqdm is not None and progress is not None:
            progress.update(1)
            progress.close()
else:
    print('Dataset shards already present; skipping prepare_dataset_from_config.')
    metadata = yaml.safe_load(metadata_path.read_text(encoding='utf-8')) or {}
    DATASET_RESULT = {
        'config_path': str(EXPERIMENT_CONFIG),
        'metadata_path': str(metadata_path),
        'teacher_dtype': metadata.get('teacher_dtype'),
        'splits': metadata.get('splits', {}) or {},
    }

if not DATASET_RESULT['splits']:
    print('Warning: dataset summaries are empty. Consider rebuilding the dataset.')

display(Markdown(format_dataset_summary(DATASET_RESULT['splits'])))

available_splits = list(CONFIG_MODEL.splits.keys())
default_split = 'train' if 'train' in available_splits else available_splits[0]
SPLIT_NAME = globals().get('SPLIT_NAME', default_split)
if SPLIT_NAME not in available_splits:
    SPLIT_NAME = default_split


def _set_dataset_path(split_name: str) -> Path:
    path = _resolve_dataset_path(split_name)
    globals()['DATASET_PATH'] = path
    DATASET_RESULT['dataset_path'] = str(path)
    return path


DATASET_PATH = _set_dataset_path(SPLIT_NAME)

widgets_ready = False
if widgets is not None:
    try:
        widgets_ready = ensure_widgets_ready()
    except NameError:
        widgets_ready = True

if widgets_ready:
    split_dropdown = widgets.Dropdown(
        options=[(name, name) for name in available_splits],
        value=SPLIT_NAME,
        description='Split:',
        layout=widgets.Layout(width='40%'),
    )
    split_message = widgets.Output()

    def _sync_split(change):
        if change.get('name') != 'value':
            return
        selected = change['new']
        globals()['SPLIT_NAME'] = selected
        path = _set_dataset_path(selected)
        with split_message:
            split_message.clear_output()
            display(Markdown(f"Using split **{selected}** → `{path}`"))

    split_dropdown.observe(_sync_split, names='value')
    display(widgets.VBox([split_dropdown, split_message]))
    with split_message:
        split_message.clear_output()
        display(Markdown(f"Using split **{SPLIT_NAME}** → `{DATASET_PATH}`"))
else:
    print(f"Using split '{SPLIT_NAME}' at {DATASET_PATH}")





### Sample Example

Peek at the first prepared context to sanity-check tokens and horizons.


In [None]:
import pyarrow.ipc as pa_ipc
from pathlib import Path
from IPython.display import Markdown, display

if not Path(DATASET_PATH).exists():
    missing_msg = (
        f"Dataset shard not found at {DATASET_PATH}. "
        "Run the prep cell first."
    )
    raise FileNotFoundError(missing_msg)

with pa_ipc.open_file(DATASET_PATH) as reader:
    if reader.num_record_batches == 0:
        print('Dataset is empty.')
    else:
        batch = reader.get_batch(0)
        table_dict = batch.to_pydict()
        context_col = 'context_input_ids'
        future_col = 'future_input_ids'
        missing = [col for col in (context_col, future_col) if col not in table_dict]
        if missing:
            display(Markdown(
                f"Dataset preview skipped; missing columns: {', '.join(missing)}"
            ))
        else:
            context_tokens = table_dict[context_col][0][:32]
            future_tokens = table_dict[future_col][0][:16]
            summary_html = (
                f"Context tokens (first 32): {context_tokens}<br><br>"
                f"Future tokens (first 16): {future_tokens}"
            )
            display(Markdown(summary_html))



## 2. Configure GistNet Training

Hidden size defaults to the teacher embedding width reported during dataset prep (set `gistnet.model.hidden_size` explicitly to override).


In [None]:
import dataclasses
from copy import deepcopy

from megacontext.gistnet import BaseModelSettings, GistNetConfig, GistNetTrainingConfig

dataset_hidden_size = DATASET_RESULT['splits'][SPLIT_NAME]['teacher_hidden_size']
gistnet_cfg = experiment_cfg.get('gistnet', {})
model_dict = deepcopy(gistnet_cfg.get('model', {}))
if model_dict.get('hidden_size') == 'auto':
    if dataset_hidden_size:
        model_dict['hidden_size'] = dataset_hidden_size
    else:
        raise ValueError(
            'Dataset summary did not report a teacher hidden size; '
            'set gistnet.model.hidden_size explicitly.'
        )
MODEL_CONFIG = GistNetConfig(**model_dict)

training_dict = deepcopy(gistnet_cfg.get('training', {}))
TRAINING_CONFIG = GistNetTrainingConfig.from_dict(training_dict)
if TRAINING_CONFIG.base_model is None and 'base_model' in experiment_cfg:
    TRAINING_CONFIG = dataclasses.replace(
        TRAINING_CONFIG,
        base_model=BaseModelSettings.from_dict(experiment_cfg['base_model']),
    )
display(Markdown(format_training_config(TRAINING_CONFIG)))


### Optional Overrides

Adjust batch size or per-phase learning rates/steps without editing YAML.


In [None]:
if widgets is None:
    print(
        'ipywidgets not available; use dataclasses.replace(...) to '
        'override training settings manually.'
    )
else:
    batch_slider = widgets.IntSlider(
        value=TRAINING_CONFIG.batch_size,
        min=1,
        max=128,
        step=1,
        description='Batch size:',
    )
    log_every_slider = widgets.IntSlider(
        value=max(1, TRAINING_CONFIG.log_every_n_steps),
        min=1,
        max=max(1, TRAINING_CONFIG.log_every_n_steps * 2),
        step=1,
        description='Log every:',
    )
    phase_widgets = []
    for phase in TRAINING_CONFIG.phases:
        steps_slider = widgets.IntSlider(
            value=phase.max_steps,
            min=1,
            max=max(phase.max_steps, 100),
            step=1,
            description='Steps',
        )
        lr_slider = widgets.FloatLogSlider(
            value=phase.lr,
            base=10,
            min=-6,
            max=0,
            step=0.1,
            description='LR',
        )
        window_slider = widgets.IntSlider(
            value=phase.window_tokens,
            min=MODEL_CONFIG.block_size,
            max=max(phase.window_tokens, MODEL_CONFIG.block_size * 64),
            step=MODEL_CONFIG.block_size,
            description='Window',
        )
        phase_box = widgets.VBox([
            widgets.HTML(f'<b>{phase.name}</b> ({phase.objective})'),
            steps_slider,
            lr_slider,
            window_slider,
        ])
        phase_widgets.append((phase, phase_box))
    apply_button = widgets.Button(description='Apply overrides', button_style='success')
    overrides_output = widgets.Output()

    def _apply_overrides(_):
        global TRAINING_CONFIG
        phases = []
        for base_phase, phase_box in phase_widgets:
            steps_slider, lr_slider, window_slider = phase_box.children[1:]
            phases.append(
                dataclasses.replace(
                    base_phase,
                    max_steps=int(steps_slider.value),
                    lr=float(lr_slider.value),
                    window_tokens=int(window_slider.value),
                )
            )
        TRAINING_CONFIG = dataclasses.replace(
            TRAINING_CONFIG,
            batch_size=int(batch_slider.value),
            log_every_n_steps=int(log_every_slider.value),
            phases=tuple(phases),
        )
        overrides_output.clear_output()
        with overrides_output:
            display(
                Markdown(
                    format_training_config(
                        TRAINING_CONFIG, heading='Updated Training Config'
                    )
                )
            )

    apply_button.on_click(_apply_overrides)
    controls = widgets.VBox(
        [batch_slider, log_every_slider]
        + [box for _, box in phase_widgets]
        + [apply_button, overrides_output]
    )
    display(controls)


## 3. Build Lightning Components

Adds a rich progress bar and metrics tracker; enable WandB above to stream logs.


In [None]:
from datetime import datetime
from pathlib import Path

from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar  # type: ignore

from megacontext.gistnet import build_gistnet_experiment

metrics_callback = MetricsTracker(metric_keys=(
    'train/loss',
    'train/delta_loss',
    'train/gist_loss',
    'train/baseline_loss',
))

RUNS_DIR = globals().get('RUNS_DIR', (ARTIFACT_ROOT / 'gistnet').resolve())
RUNS_DIR.mkdir(parents=True, exist_ok=True)
run_name = (LOGGER_STATE.get('run_name') or '').strip()
if not run_name:
    run_name = f"{EXPERIMENT_CONFIG.stem}-{datetime.now():%Y%m%d-%H%M%S}"
run_name = run_name.replace(' ', '-')
RUN_NAME = run_name
RUN_DIR = (RUNS_DIR / run_name).resolve()
RUN_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR = (RUN_DIR / 'checkpoints').resolve()
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print(f'Artifacts for run `{RUN_NAME}` stored under {RUN_DIR}')

checkpoint_callback = ModelCheckpoint(
    dirpath=str(CHECKPOINT_DIR),
    filename='step-{step:06d}',
    monitor='train/loss',
    mode='min',
    save_last=True,
    save_top_k=3,
    every_n_train_steps=max(1, TRAINING_CONFIG.log_every_n_steps),
    auto_insert_metric_name=False,
)

callbacks = [metrics_callback, RichProgressBar(), checkpoint_callback]
LOGGER = None
try:
    LOGGER = build_logger(
        selection=LOGGER_STATE.get('selection', 'none'),
        project=LOGGER_STATE.get('project'),
        run_name=LOGGER_STATE.get('run_name'),
        config={'config_path': str(EXPERIMENT_CONFIG)},
    )
except RuntimeError as exc:
    print(exc)

trainer_kwargs = {
    'accelerator': 'auto',
    'devices': 1,
    'default_root_dir': str(RUN_DIR),
}

TRAINER, MODULE, DATA_MODULE = build_gistnet_experiment(
    dataset_path=DATASET_PATH,
    model_config=MODEL_CONFIG,
    training=TRAINING_CONFIG,
    callbacks=callbacks,
    logger=LOGGER,
    trainer_kwargs=trainer_kwargs,
)

globals()['RUN_DIR'] = RUN_DIR
globals()['CHECKPOINT_DIR'] = CHECKPOINT_DIR
globals()['RUN_NAME'] = RUN_NAME
TRAINER


## 4. Launch Training

Run this cell to start the Lightning loop. Progress appears below and (optionally) in WandB.


In [None]:
from pathlib import Path

resume_path = globals().get('RESUME_CHECKPOINT')
if isinstance(resume_path, str) and not resume_path:
    resume_path = None
if resume_path is not None and not isinstance(resume_path, Path):
    resume_path = Path(resume_path)
if resume_path is not None and not resume_path.exists():
    raise FileNotFoundError(f'Resume checkpoint not found: {resume_path}')

if resume_path:
    print(f'Resuming from checkpoint: {resume_path}')
else:
    print('Starting a new training run.')

TRAINER.fit(MODULE, DATA_MODULE, ckpt_path=str(resume_path) if resume_path else None)
best_model_path = getattr(TRAINER.checkpoint_callback, 'best_model_path', None) or ''
if best_model_path:
    best_path = Path(best_model_path)
else:
    best_path = None
last_model_path = getattr(TRAINER.checkpoint_callback, 'last_model_path', None) or ''
if last_model_path:
    last_path = Path(last_model_path)
else:
    last_path = None
globals()['LATEST_CHECKPOINT'] = best_path
globals()['LAST_CHECKPOINT'] = last_path





## 5. Visualise Metrics

Plots the captured metrics (requires matplotlib).


In [None]:
metrics_callback.plot(figsize=(7, 4))


## 6. Summarise & Save

Captures final metrics and writes a JSON summary under `artifacts/experiments/`.


In [None]:
from pathlib import Path
from IPython.display import Markdown

final_metrics = {k: float(v) for k, v in TRAINER.callback_metrics.items()}
if 'RUN_SEED' in globals():
    final_metrics.setdefault('seed', RUN_SEED)
display(Markdown(format_training_summary(final_metrics)))
seed_value = globals().get('RUN_SEED')
resume_path = globals().get('RESUME_CHECKPOINT')
run_dir = globals().get('RUN_DIR')
checkpoint_dir = globals().get('CHECKPOINT_DIR')
latest_checkpoint = globals().get('LATEST_CHECKPOINT')
last_checkpoint = globals().get('LAST_CHECKPOINT')
summary_root = (ARTIFACT_ROOT / 'experiments').resolve()
SUMMARY_PATH = save_experiment_summary(
    output_dir=summary_root,
    config_path=EXPERIMENT_CONFIG,
    dataset_summary=DATASET_RESULT['splits'],
    training_metrics=final_metrics,
    artifacts={
        'dataset_path': DATASET_RESULT.get('dataset_path'),
        'default_root_dir': str(run_dir) if run_dir else None,
        'checkpoint_dir': str(checkpoint_dir) if checkpoint_dir else None,
        'resume_from': str(resume_path) if resume_path else None,
        'latest_checkpoint': str(latest_checkpoint) if latest_checkpoint else None,
        'last_checkpoint': str(last_checkpoint) if last_checkpoint else None,
        'seed': seed_value,
    },
)
SUMMARY_PATH
