# GistNet Training Guide

This notebook consolidates the minimal data and training pipeline for the Phase 2 gist compressor. It replaces the standalone `docs/gistnet.md` so the dataset prep, trainer commands, and logging tips live alongside the executable steps.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brandf/MegaContext/blob/main/notebooks/gistnet.ipynb)


## Dataset Fields

`tools/prepare_dataset.py` emits Arrow shards with the following columns:

- `input_ids` — L0 token ids for the block being compressed (`block_size` tokens).
- `attention_mask` — mask for the block (currently all ones).
- `context_input_ids` — flattened horizon window (`horizon` tokens) used when the teacher model produced cached embeddings.
- `context_attention_mask` — mask for the horizon window.
- `teacher_hidden` — cached teacher embeddings with shape `[block_size, teacher_hidden_size]` stored using the configured dtype (auto → float16 on T4, bfloat16 on bf16-capable GPUs, otherwise float32).
- `gist_target` — pooled target vector (mean of the teacher hidden states) emitted in the same dtype as `teacher_hidden`.

The metadata stored in `data/<dataset>/metadata.yaml` records the tokenizer, block size, horizon, teacher model/dtype, and per-split statistics.


## Minimal sample dataset

If you just need a smoke test, generate the sample Arrow shard with:

```bash
uv run python -m tools.prepare_dataset --config configs/data/sample_text.yaml
```

This uses `sshleifer/tiny-gpt2` as the teacher and produces `data/sample_text/train.arrow`.


## 1. Clone the repository & install dependencies

Run these commands when starting from a fresh Colab runtime; skip them if the repository is already checked out locally.


In [None]:
!git clone https://github.com/brandf/MegaContext.git
%cd MegaContext


In [None]:
!git pull

In [None]:
!pip install -r requirements.txt
!pip install -e .[dev]

## 2. Download the Gutenberg subset

The helper script grabs a curated <1 GB slice of Project Gutenberg titles. Tweak `tools/download_gutenberg.sh` if you want a different reading list before running the cell below.


In [None]:
!bash tools/download_gutenberg.sh data/raw/gutenberg

## 3. Prepare the dataset shard

The `configs/data/gutenberg_sample.yaml` configuration reuses the same `sshleifer/tiny-gpt2` teacher with `block_size=32` and `horizon=64`. Adjust `teacher_device`, the horizon length, or dataset paths to match your hardware before executing the next cell.


In [None]:
!rm -f data/gutenberg_sample/train.arrow

In [None]:
%run tools/prepare_dataset.py --config configs/data/gutenberg_sample.yaml


### Larger corpus option

For more realistic experiments (still under 1 GB total), rerun the download and prep steps manually:

```bash
bash tools/download_gutenberg.sh data/raw/gutenberg
uv run python -m tools.prepare_dataset --config configs/data/gutenberg_sample.yaml
```

The Gutenberg subset feeds into the same pipeline and produces `data/gutenberg_sample/train.arrow` for training.


## 4. Train the GistNet model

`tools/train_gistnet.py` now supports multi-phase schedules. The default example runs two phases:

- `pooling-pretrain` uses the fast `pooling_mse` objective for 200 steps to match the mean teacher hidden state per block.
- `delta-finetune` switches to `delta_nll` for another 200 steps, measuring how well gists preserve the frozen base model’s loss.

You can add, remove, or reorder phases in `training.phases`, adjusting `max_steps`, learning rates, or window sizes to taste. If you omit `phases`, the script falls back to a single objective selected via the config or `--objective` flag. When switching machines or the virtualenv, run `uv sync --extra dev` to ensure `pytest`, `pydantic`, and friends install into `.venv` before invoking training or tests.

Run the trainer from the repository root:

```bash
uv run python -m tools.train_gistnet     --dataset data/sample_text/train.arrow     --config configs/runs/gistnet_example.yaml
```

The sample configuration enables sequential pooling + ΔNLL training, matching the Gutenberg shard defaults (`block_size=32`, hidden size 960) and targeting the MobileLLM teacher. Toggle individual phases or override the objective at the CLI (e.g., `--objective pooling_mse`) when you want a quicker baseline run. Throttle `max_steps` or `batch_size` if you need a faster smoke run on smaller GPUs. In this notebook we invoke the same script with extra flags so metrics and plots persist automatically.



In [None]:
%run tools/train_gistnet.py \
    --dataset data/gutenberg_sample/train.arrow \
    --config configs/runs/gistnet_example.yaml \
    --metrics-path artifacts/gistnet/metrics.json \
    --save-plot artifacts/gistnet/loss.png

### Logging & visualisation

- Progress defaults to a `tqdm` bar in notebook environments; disable it with `--no-tqdm` if you prefer plain logs.
- Add `--metrics-path artifacts/gistnet/metrics.json` to dump raw losses for custom plotting or notebooks.
- Use `--save-plot artifacts/gistnet/loss.png` to emit a ready-made curve (falls back gracefully if `matplotlib` is missing).
- CLI runs on infra like Novita can pass `--use-wandb --wandb-project <name>` to stream metrics without notebook changes.
- When the script detects a notebook runtime (Colab, JupyterLab), it renders the loss curve inline and still saves the PNG so headless runs can inspect it later.


## 5. Next steps

* Run the ΔNLL smoke eval (coming in Task 2.4).
* Push metrics & checkpoints to W&B or Novita storage with `--use-wandb`.
* Swap `configs/runs/gistnet_example.yaml` for a larger hidden size when you have a bigger teacher.
* Benchmark `pooling_mse` vs `delta_nll` objectives to understand substitutability trade-offs.

