# MegaContext Research Console

Interactively prepare datasets, train GistNet with PyTorch Lightning, and capture experiment artifacts. Key docs: [GistNet](https://brandf.github.io/MegaContext/architecture/components/GistNet), [GistNet Training](https://brandf.github.io/MegaContext/architecture/components/GistNet%20Training), [Telemetry](https://brandf.github.io/MegaContext/ops/Telemetry), [Alternating Optimization](https://brandf.github.io/MegaContext/ops/Alternating%20Optimization).


## Quick Start

When running in Google Colab, execute the bootstrap cell below to clone the repo and install dependencies in the current runtime. Local Jupyter environments can skip it.


In [None]:
import importlib
import os
import sys
from pathlib import Path

COLAB = 'google.colab' in sys.modules
if COLAB:
    repo_url = 'https://github.com/brandf/MegaContext.git'
    workspace = Path('/content/MegaContext')
    if not workspace.exists():
        !git clone $repo_url $workspace
    else:
        !git -C $workspace pull --ff-only
    %cd /content/MegaContext
    print('Python executable:', sys.executable)
    %pip install --upgrade pip
    %pip install -r requirements.txt
    %pip install -e .
    %pip install pytorch-lightning lightning
    !python -m pip show pytorch-lightning lightning
    src_path = workspace / 'src'
    if str(src_path) not in sys.path:
        sys.path.append(str(src_path))
    importlib.invalidate_caches()
    try:
        import megacontext  # noqa: F401
        import lightning  # noqa: F401
    except Exception as exc:
        print('Import check failed:', exc)
        raise
    else:
        print('Colab environment ready.')
else:
    print('Colab bootstrap skipped (not running in google.colab).')



## 0. Environment Snapshot

Verify your runtime (GPU, dependencies, disk space) before launching long jobs.


In [None]:
from pathlib import Path

import yaml
from IPython.display import Markdown, display

from megacontext.notebook import (
    MetricsTracker,
    build_logger,
    collect_environment_report,
    format_config_markdown,
    format_dataset_summary,
    format_training_config,
    format_training_summary,
    render_environment_report,
    save_experiment_summary,
)

ENV_REPORT = collect_environment_report()
display(Markdown(render_environment_report(ENV_REPORT)))


## 1. Choose Experiment Config

Use the dropdown to select a combined config (dataset + base model + training).


In [None]:
from IPython.display import display, Markdown
import pandas as pd

try:
    import ipywidgets as widgets
except ImportError:
    widgets = None

CONFIG_ROOT = Path('configs')
available_configs = sorted(CONFIG_ROOT.glob('*.yaml'))
if not available_configs:
    raise RuntimeError('No experiment configs found under `configs/`.')

config_table = pd.DataFrame(
    [{'name': path.stem, 'path': str(path)} for path in available_configs]
)
display(config_table)

if widgets is None:
    print('ipywidgets not available; defaulting to the first config.')
    EXPERIMENT_CONFIG = available_configs[0]
else:
    buttons = []
    message_box = widgets.Output()

    def make_handler(path: Path):
        def _handler(_):
            globals()['EXPERIMENT_CONFIG'] = path
            with message_box:
                message_box.clear_output()
                display(Markdown(f"Selected **{path.stem}** (`{path}`)"))
        return _handler

    for cfg_path in available_configs:
        button = widgets.Button(
            description=f"Use {cfg_path.stem}",
            layout=widgets.Layout(width='180px')
        )
        button.on_click(make_handler(cfg_path))
        buttons.append(button)

    button_row = widgets.HBox(buttons)
    display(button_row, message_box)
    globals()['EXPERIMENT_CONFIG'] = available_configs[0]
    with message_box:
        message_box.clear_output()
        display(Markdown(f"Defaulting to **{available_configs[0].stem}** (`{available_configs[0]}`)"))



In [None]:
from IPython.display import display

try:
    import ipywidgets as widgets
except ImportError:
    widgets = None

LOGGER_STATE = {'selection': 'none', 'project': 'megacontext-poc', 'run_name': ''}

if 'EXPERIMENT_CONFIG' not in globals():
    from pathlib import Path
    CONFIG_ROOT = Path('configs')
    configs = sorted(CONFIG_ROOT.glob('*.yaml'))
    if not configs:
        raise RuntimeError('No experiment configs found under `configs/`.')
    EXPERIMENT_CONFIG = configs[0]

if widgets is None:
    print('ipywidgets not available; logging defaults to disabled.')
else:
    logger_dropdown = widgets.Dropdown(
        options=[('Disabled', 'none'), ('Weights & Biases', 'wandb')],
        value='none',
        description='Logger:',
        layout=widgets.Layout(width='50%'),
    )
    project_text = widgets.Text(
        value=LOGGER_STATE['project'],
        description='Project:',
        layout=widgets.Layout(width='50%'),
    )
    run_text = widgets.Text(
        value=LOGGER_STATE['run_name'],
        placeholder='auto',
        description='Run name:',
        layout=widgets.Layout(width='50%'),
    )
    wandb_box = widgets.VBox([project_text, run_text])
    wandb_box.layout.display = 'none'

    def _update_logger_state(_):
        LOGGER_STATE['selection'] = logger_dropdown.value
        LOGGER_STATE['project'] = project_text.value
        LOGGER_STATE['run_name'] = run_text.value
        wandb_box.layout.display = (
            'flex' if LOGGER_STATE['selection'] == 'wandb' else 'none'
        )

    for widget in (logger_dropdown, project_text, run_text):
        widget.observe(_update_logger_state, names='value')
    _update_logger_state(None)

    summary = widgets.HTML(
        value=f"<b>Config:</b> {EXPERIMENT_CONFIG.stem} (<code>{EXPERIMENT_CONFIG}</code>)"
    )
    display(summary, widgets.VBox([logger_dropdown, wandb_box]))



## 2. Preview Configuration

Review and, if needed, edit fields before running stages.


In [None]:
experiment_cfg = yaml.safe_load(EXPERIMENT_CONFIG.read_text(encoding='utf-8'))
display(Markdown(format_config_markdown(experiment_cfg)))


## 3. Dataset Preparation

Runs `tools.prepare_dataset.prepare_dataset_from_config` with tqdm progress bars. Skip this if the shard already exists.


In [None]:
import os

from tools.prepare_dataset import load_dataset_config, prepare_dataset_from_config

DATASET_RESULT = prepare_dataset_from_config(EXPERIMENT_CONFIG)
display(Markdown(format_dataset_summary(DATASET_RESULT['splits'])))
SPLIT_NAME = 'train'
dataset_cfg_model = load_dataset_config(EXPERIMENT_CONFIG)
split_cfg = dataset_cfg_model.splits[SPLIT_NAME]
output_path_cfg = Path(split_cfg.output_path)
if output_path_cfg.is_absolute():
    DATASET_PATH = output_path_cfg
else:
    base_dir = EXPERIMENT_CONFIG.parent
    default_output = (base_dir / output_path_cfg).resolve()
    data_root_override = (
        Path(os.environ['MEGACONTEXT_DATA_ROOT']).expanduser().resolve()
        if 'MEGACONTEXT_DATA_ROOT' in os.environ
        else None
    )
    if data_root_override is not None:
        resolved_base = base_dir.resolve()
        repo_root = (
            resolved_base.parents[1]
            if len(resolved_base.parents) >= 2
            else resolved_base
        )
        try:
            relative = default_output.relative_to(repo_root)
        except ValueError:
            relative = default_output.name
        DATASET_PATH = (data_root_override / relative).resolve()
    else:
        DATASET_PATH = default_output
DATASET_RESULT['dataset_path'] = str(DATASET_PATH)
DATASET_RESULT


### Sample Example

Peek at the first prepared context to sanity-check tokens and horizons.


In [None]:
import pyarrow.ipc as pa_ipc

if not Path(DATASET_PATH).exists():
    missing_msg = (
        f"Dataset shard not found at {DATASET_PATH}. "
        "Run the prep cell first."
    )
    raise FileNotFoundError(missing_msg)
reader = pa_ipc.open_file(DATASET_PATH)
if reader.num_record_batches == 0:
    print('Dataset is empty.')
else:
    batch = reader.get_batch(0)
    row = batch.to_pydict()
    context_tokens = row['context_input_ids'][0][:32]
    future_tokens = row['future_input_ids'][0][:16]
    summary_html = (
        f"Context tokens (first 32): {context_tokens}<br><br>"
        f"Future tokens (first 16): {future_tokens}"
    )
    display(Markdown(summary_html))



## 4. Configure GistNet Training

Hidden size defaults to the teacher embedding width reported during dataset prep (set `gistnet.model.hidden_size` explicitly to override).


In [None]:
import dataclasses
from copy import deepcopy

from megacontext.gistnet import BaseModelSettings, GistNetConfig, GistNetTrainingConfig

dataset_hidden_size = DATASET_RESULT['splits'][SPLIT_NAME]['teacher_hidden_size']
gistnet_cfg = experiment_cfg.get('gistnet', {})
model_dict = deepcopy(gistnet_cfg.get('model', {}))
if model_dict.get('hidden_size') == 'auto':
    if dataset_hidden_size:
        model_dict['hidden_size'] = dataset_hidden_size
    else:
        raise ValueError(
            'Dataset summary did not report a teacher hidden size; '
            'set gistnet.model.hidden_size explicitly.'
        )
MODEL_CONFIG = GistNetConfig(**model_dict)

training_dict = deepcopy(gistnet_cfg.get('training', {}))
TRAINING_CONFIG = GistNetTrainingConfig.from_dict(training_dict)
if TRAINING_CONFIG.base_model is None and 'base_model' in experiment_cfg:
    TRAINING_CONFIG = dataclasses.replace(
        TRAINING_CONFIG,
        base_model=BaseModelSettings.from_dict(experiment_cfg['base_model']),
    )
display(Markdown(format_training_config(TRAINING_CONFIG)))


### Optional Overrides

Adjust batch size or per-phase learning rates/steps without editing YAML.


In [None]:
if widgets is None:
    print(
        'ipywidgets not available; use dataclasses.replace(...) to '
        'override training settings manually.'
    )
else:
    batch_slider = widgets.IntSlider(
        value=TRAINING_CONFIG.batch_size,
        min=1,
        max=128,
        step=1,
        description='Batch size:',
    )
    log_every_slider = widgets.IntSlider(
        value=max(1, TRAINING_CONFIG.log_every_n_steps),
        min=1,
        max=max(1, TRAINING_CONFIG.log_every_n_steps * 2),
        step=1,
        description='Log every:',
    )
    phase_widgets = []
    for phase in TRAINING_CONFIG.phases:
        steps_slider = widgets.IntSlider(
            value=phase.max_steps,
            min=1,
            max=max(phase.max_steps, 100),
            step=1,
            description='Steps',
        )
        lr_slider = widgets.FloatLogSlider(
            value=phase.lr,
            base=10,
            min=-6,
            max=0,
            step=0.1,
            description='LR',
        )
        window_slider = widgets.IntSlider(
            value=phase.window_tokens,
            min=MODEL_CONFIG.block_size,
            max=max(phase.window_tokens, MODEL_CONFIG.block_size * 64),
            step=MODEL_CONFIG.block_size,
            description='Window',
        )
        phase_box = widgets.VBox([
            widgets.HTML(f'<b>{phase.name}</b> ({phase.objective})'),
            steps_slider,
            lr_slider,
            window_slider,
        ])
        phase_widgets.append((phase, phase_box))
    apply_button = widgets.Button(description='Apply overrides', button_style='success')
    overrides_output = widgets.Output()

    def _apply_overrides(_):
        global TRAINING_CONFIG
        phases = []
        for base_phase, phase_box in phase_widgets:
            steps_slider, lr_slider, window_slider = phase_box.children[1:]
            phases.append(
                dataclasses.replace(
                    base_phase,
                    max_steps=int(steps_slider.value),
                    lr=float(lr_slider.value),
                    window_tokens=int(window_slider.value),
                )
            )
        TRAINING_CONFIG = dataclasses.replace(
            TRAINING_CONFIG,
            batch_size=int(batch_slider.value),
            log_every_n_steps=int(log_every_slider.value),
            phases=tuple(phases),
        )
        overrides_output.clear_output()
        with overrides_output:
            display(
                Markdown(
                    format_training_config(
                        TRAINING_CONFIG, heading='Updated Training Config'
                    )
                )
            )

    apply_button.on_click(_apply_overrides)
    controls = widgets.VBox(
        [batch_slider, log_every_slider]
        + [box for _, box in phase_widgets]
        + [apply_button, overrides_output]
    )
    display(controls)


## 5. Build Lightning Components

Adds a rich progress bar and metrics tracker; enable WandB above to stream logs.


In [None]:
from lightning.pytorch.callbacks import RichProgressBar  # type: ignore

from megacontext.gistnet import build_gistnet_experiment

metrics_callback = MetricsTracker(metric_keys=(
    'train/loss',
    'train/delta_loss',
    'train/gist_loss',
    'train/baseline_loss',
))
callbacks = [metrics_callback, RichProgressBar()]
LOGGER = None
try:
    LOGGER = build_logger(
        selection=LOGGER_STATE.get('selection', 'none'),
        project=LOGGER_STATE.get('project'),
        run_name=LOGGER_STATE.get('run_name'),
        config={'config_path': str(EXPERIMENT_CONFIG)},
    )
except RuntimeError as exc:
    print(exc)

TRAINER, MODULE, DATA_MODULE = build_gistnet_experiment(
    dataset_path=DATASET_PATH,
    model_config=MODEL_CONFIG,
    training=TRAINING_CONFIG,
    callbacks=callbacks,
    logger=LOGGER,
    trainer_kwargs={
        'accelerator': 'auto',
        'devices': 1,
        'default_root_dir': './artifacts/gistnet',
    },
)
TRAINER


## 6. Launch Training

Run this cell to start the Lightning loop. Progress appears below and (optionally) in WandB.


In [None]:
TRAINER.fit(MODULE, DATA_MODULE)


## 7. Visualise Metrics

Plots the captured metrics (requires matplotlib).


In [None]:
metrics_callback.plot(figsize=(7, 4))


## 8. Summarise & Save

Captures final metrics and writes a JSON summary under `artifacts/experiments/`.


In [None]:
from IPython.display import Markdown

final_metrics = {k: float(v) for k, v in TRAINER.callback_metrics.items()}
display(Markdown(format_training_summary(final_metrics)))
SUMMARY_PATH = save_experiment_summary(
    output_dir=Path('artifacts/experiments'),
    config_path=EXPERIMENT_CONFIG,
    dataset_summary=DATASET_RESULT['splits'],
    training_metrics=final_metrics,
    artifacts={
        'dataset_path': DATASET_RESULT.get('dataset_path'),
        'default_root_dir': TRAINER.default_root_dir,
    },
)
SUMMARY_PATH
