<a href="https://colab.research.google.com/github/gileshall/axonet/blob/main/notebooks/axonet_training_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Axonet End-to-End Training Tutorial

This notebook walks through the full axonet training pipeline:

1. **Install** axonet and dependencies
2. **Download neurons** from NeuroMorpho.Org
3. **Generate dataset** — render multi-view images with segmentation masks and depth maps
4. **Train VAE (Stage 1)** — train a SegVAE2D with variational skip connections
5. **Fine-tune CLIP (Stage 2)** — contrastive learning to align neuron images with text descriptions
6. **Evaluate CLIP** — retrieval metrics, zero-shot classification, and t-SNE visualization

**Runtime requirement:** Set your Colab runtime to **GPU** (T4 recommended):
`Runtime > Change runtime type > T4 GPU`

## 1. Configuration

All configurable variables are defined here. The defaults use a small demo scale (~50 neurons). Increase `N_NEURONS` and epoch counts for real training.

In [None]:
# ---- Scale ----
N_NEURONS = 50            # Number of neurons to download (demo scale)
N_VIEWS = 24              # Views per neuron (matches PCA default: 6 canonical + 12 biased + 6 random)
VAL_RATIO = 0.15          # Fraction held out for validation
IMAGE_SIZE = 512          # Rendered image resolution

# ---- VAE (Stage 1) ----
BATCH_SIZE_VAE = 8
MAX_EPOCHS_VAE = 10
LR_VAE = 1e-4
KLD_WEIGHT = 0.1

# ---- CLIP (Stage 2) ----
BATCH_SIZE_CLIP = 64
MAX_EPOCHS_CLIP = 10
LR_CLIP = 1e-4
CLIP_EMBED_DIM = 512
TEXT_ENCODER = "distilbert-base-uncased"
TEMPERATURE = 0.07

# ---- Paths ----
WORK_DIR       = "/content/axonet_tutorial"
NEURON_DIR     = f"{WORK_DIR}/neurons"
DATASET_DIR    = f"{WORK_DIR}/dataset"
STAGE1_CKPT_DIR = f"{WORK_DIR}/checkpoints/stage1"
STAGE2_CKPT_DIR = f"{WORK_DIR}/checkpoints/clip"
LOG_DIR        = f"{WORK_DIR}/logs"
EVAL_DIR       = f"{WORK_DIR}/eval_results"

## 2. Install

Install system libraries for headless OpenGL rendering (EGL) and the axonet package with CLIP dependencies.

In [None]:
# System deps for headless OpenGL (moderngl EGL backend)
!apt-get -qq install libegl1-mesa-dev libgles2-mesa-dev > /dev/null 2>&1

# Install axonet with CLIP extras
!pip install -q "axonet[clip] @ https://github.com/gileshall/axonet/archive/refs/heads/main.zip"

# Verify import
import axonet
print(f"axonet imported successfully")

In [None]:
import torch

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Training will be very slow.")
    print("Go to Runtime > Change runtime type > T4 GPU")

## 3. Download Neurons

Download SWC morphology files from [NeuroMorpho.Org](https://neuromorpho.org). We query for mouse neurons and fetch both the standardized SWC files and morphometry measurements.

In [None]:
!python -m axonet.utils.neuromorpho_bulk \
    --query 'species:mouse' \
    --out {NEURON_DIR} \
    --max-pages 1 \
    --page-size {N_NEURONS} \
    --find \
    --fetch-morphometry \
    --insecure

In [None]:
import json
from pathlib import Path
from collections import Counter

swc_dir = Path(NEURON_DIR) / "swc"
swc_files = sorted(swc_dir.glob("*.swc")) + sorted(swc_dir.glob("*.SWC"))
print(f"Downloaded {len(swc_files)} SWC files")

# Parse metadata for distribution info
metadata_path = Path(NEURON_DIR) / "metadata.jsonl"
cell_types, regions = [], []
if metadata_path.exists():
    with open(metadata_path) as f:
        for line in f:
            m = json.loads(line)
            cell_types.append(m.get("cell_type", "unknown"))
            regions.append(m.get("brain_region", ["unknown"])[0] if isinstance(m.get("brain_region"), list) else m.get("brain_region", "unknown"))

    print(f"\nCell type distribution (top 10):")
    for ct, n in Counter(cell_types).most_common(10):
        print(f"  {ct}: {n}")

    print(f"\nBrain region distribution (top 10):")
    for br, n in Counter(regions).most_common(10):
        print(f"  {br}: {n}")

In [None]:
# Preview first SWC file
if swc_files:
    sample = swc_files[0]
    print(f"Sample: {sample.name}")
    print("-" * 60)
    with open(sample) as f:
        for i, line in enumerate(f):
            if i >= 15:
                print("...")
                break
            print(line.rstrip())

## 4. Generate Dataset

Render multi-view images of each neuron using PCA-guided camera placement:
- **6 canonical views** along principal component axes
- **12 biased views** concentrated near the PC1-PC2 plane (largest projected area)
- **6 random views** for diversity

Each view produces four outputs:
- `mask_bw` — binary silhouette (VAE input)
- `mask` — class-ID segmentation map (VAE target)
- `mask_color` — colorized segmentation (visualization)
- `depth` — depth map (VAE target)

Data is split into train/val at the neuron level.

In [None]:
!python -m axonet.training.dataset_generator \
    --swc-dir {NEURON_DIR}/swc \
    --out {DATASET_DIR} \
    --views {N_VIEWS} \
    --sampling pca \
    --adaptive-framing \
    --width {IMAGE_SIZE} \
    --height {IMAGE_SIZE} \
    --val-ratio {VAL_RATIO} \
    --margin 0.40 \
    --supersample-factor 2

In [None]:
import json
from pathlib import Path

def count_manifest(path):
    """Count samples and unique neurons in a manifest."""
    ids = set()
    n = 0
    with open(path) as f:
        for line in f:
            entry = json.loads(line)
            n += 1
            ids.add(entry.get("swc", ""))
    return n, len(ids)

train_manifest = Path(DATASET_DIR) / "manifest_train.jsonl"
val_manifest = Path(DATASET_DIR) / "manifest_val.jsonl"

if train_manifest.exists():
    n_train, u_train = count_manifest(train_manifest)
    print(f"Train: {n_train} samples from {u_train} neurons")

if val_manifest.exists():
    n_val, u_val = count_manifest(val_manifest)
    print(f"Val:   {n_val} samples from {u_val} neurons")
    print(f"Total: {n_train + n_val} samples from {u_train + u_val} neurons")

In [None]:
import json
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from pathlib import Path

# Show a grid: 3 neurons x 3 columns (mask_bw, mask_color, depth)
manifest_path = Path(DATASET_DIR) / "manifest_train.jsonl"
entries = []
with open(manifest_path) as f:
    for line in f:
        entries.append(json.loads(line))

# Pick one view from 3 different neurons
seen_swc = set()
selected = []
for e in entries:
    swc_name = e.get("swc", "")
    if swc_name not in seen_swc:
        seen_swc.add(swc_name)
        selected.append(e)
    if len(selected) >= 3:
        break

fig, axes = plt.subplots(len(selected), 3, figsize=(12, 4 * len(selected)))
if len(selected) == 1:
    axes = [axes]

columns = ["mask_bw", "mask_color", "depth"]
titles = ["Binary Mask (input)", "Segmentation (color)", "Depth Map"]

for row, entry in enumerate(selected):
    for col, (key, title) in enumerate(zip(columns, titles)):
        img_path = Path(DATASET_DIR) / entry[key]
        img = mpimg.imread(str(img_path))
        ax = axes[row][col]
        cmap = "gray" if key != "mask_color" else None
        ax.imshow(img, cmap=cmap)
        if row == 0:
            ax.set_title(title, fontsize=12)
        ax.set_ylabel(Path(entry["swc"]).stem[:30], fontsize=9)
        ax.set_xticks([])
        ax.set_yticks([])

plt.tight_layout()
plt.show()

## 5. Train VAE — Stage 1

Train the **SegVAE2D** model, a variational U-Net with:
- A global variational bottleneck
- Variational skip connections at each encoder level (preventing information bypass)
- Dual-head output: semantic segmentation + depth prediction

The encoder learned here becomes the image backbone for CLIP Stage 2.

In [None]:
!python -m axonet.training.trainer \
    --data-dir {DATASET_DIR} \
    --batch-size {BATCH_SIZE_VAE} \
    --lr {LR_VAE} \
    --max-epochs {MAX_EPOCHS_VAE} \
    --kld-weight {KLD_WEIGHT} \
    --skip-mode variational \
    --base-channels 64 \
    --latent-channels 128 \
    --num-classes 6 \
    --precision 32 \
    --save-dir {STAGE1_CKPT_DIR} \
    --log-dir {LOG_DIR}/stage1 \
    --early-stopping \
    --seed 42

In [None]:
import glob

# Find best checkpoint (prefer filename containing 'best', fall back to last)
ckpt_files = sorted(glob.glob(f"{STAGE1_CKPT_DIR}/*.ckpt"))
STAGE1_BEST = None
for f in ckpt_files:
    if "best" in f.lower():
        STAGE1_BEST = f
        break
if STAGE1_BEST is None and ckpt_files:
    STAGE1_BEST = ckpt_files[-1]  # last checkpoint

print(f"Stage 1 checkpoint: {STAGE1_BEST}")
print(f"All checkpoints: {ckpt_files}")

In [None]:
%load_ext tensorboard
%tensorboard --logdir {LOG_DIR}/stage1

In [None]:
import torch
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from pathlib import Path
from axonet.models.d3_swc_vae import load_model

device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(Path(STAGE1_BEST), device=device)
model.eval()

# Load a few val samples
val_manifest = Path(DATASET_DIR) / "manifest_val.jsonl"
val_entries = []
with open(val_manifest) as f:
    for line in f:
        val_entries.append(json.loads(line))

n_show = min(3, len(val_entries))
fig, axes = plt.subplots(n_show, 4, figsize=(16, 4 * n_show))
if n_show == 1:
    axes = [axes]

col_titles = ["Input (mask_bw)", "GT Segmentation", "Predicted Seg", "Predicted Depth"]

for row in range(n_show):
    entry = val_entries[row]

    # Load input
    input_path = Path(DATASET_DIR) / entry["mask_bw"]
    input_img = mpimg.imread(str(input_path))
    if input_img.ndim == 3:
        input_img = np.mean(input_img, axis=2)
    input_tensor = torch.from_numpy(input_img.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(device)

    # Load GT segmentation
    gt_seg_path = Path(DATASET_DIR) / entry["mask"]
    gt_seg = mpimg.imread(str(gt_seg_path))

    # Run inference
    with torch.no_grad():
        out = model(input_tensor)

    pred_seg = out["seg_logits"].argmax(dim=1).squeeze().cpu().numpy()
    pred_depth = out["depth"].squeeze().cpu().numpy()

    # Plot
    axes[row][0].imshow(input_img, cmap="gray")
    axes[row][1].imshow(gt_seg, cmap="tab10", vmin=0, vmax=5)
    axes[row][2].imshow(pred_seg, cmap="tab10", vmin=0, vmax=5)
    axes[row][3].imshow(pred_depth, cmap="magma")

    for col in range(4):
        if row == 0:
            axes[row][col].set_title(col_titles[col], fontsize=11)
        axes[row][col].set_xticks([])
        axes[row][col].set_yticks([])

plt.tight_layout()
plt.show()
del model
torch.cuda.empty_cache()

## 6. Fine-tune CLIP — Stage 2

Train a CLIP-style model that aligns the frozen VAE encoder's image embeddings with text descriptions generated from neuron metadata.

- The VAE encoder is frozen; only projection heads are trained
- Text encoder: DistilBERT (from `sentence-transformers`)
- Loss: InfoNCE contrastive loss with learnable temperature
- Text descriptions are auto-generated from metadata (cell type, brain region, species, morphometry)

In [None]:
!python -m axonet.training.clip_trainer \
    --stage1-checkpoint {STAGE1_BEST} \
    --data-dir {DATASET_DIR} \
    --metadata {NEURON_DIR}/metadata.jsonl \
    --source neuromorpho \
    --batch-size {BATCH_SIZE_CLIP} \
    --clip-embed-dim {CLIP_EMBED_DIM} \
    --temperature {TEMPERATURE} \
    --learnable-temperature \
    --text-encoder {TEXT_ENCODER} \
    --max-epochs {MAX_EPOCHS_CLIP} \
    --lr {LR_CLIP} \
    --save-dir {STAGE2_CKPT_DIR} \
    --log-dir {LOG_DIR}/clip \
    --early-stopping \
    --seed 42

In [None]:
import glob

clip_ckpts = sorted(glob.glob(f"{STAGE2_CKPT_DIR}/*.ckpt"))
CLIP_BEST = None
for f in clip_ckpts:
    if "best" in f.lower():
        CLIP_BEST = f
        break
if CLIP_BEST is None and clip_ckpts:
    CLIP_BEST = clip_ckpts[-1]

print(f"CLIP checkpoint: {CLIP_BEST}")
print(f"All checkpoints: {clip_ckpts}")

In [None]:
%tensorboard --logdir {LOG_DIR}/clip

## 7. Evaluate CLIP

Run a comprehensive evaluation:
- **Retrieval R@k**: how often the correct text/image is in the top-k results
- **Zero-shot classification**: classify neurons by cell type and brain region using text prompts
- **Novel query retrieval**: test with unseen text queries
- **t-SNE visualization**: 2D embedding space colored by cell type and brain region

Multi-pose images are aggregated to per-neuron embeddings via mean pooling.

In [None]:
!python -m axonet.training.clip_evaluator \
    --checkpoint {CLIP_BEST} \
    --data-dir {DATASET_DIR} \
    --metadata {NEURON_DIR}/metadata.jsonl \
    --source neuromorpho \
    --output-dir {EVAL_DIR} \
    --pooling mean

In [None]:
# Print the evaluation report
from pathlib import Path

report_path = Path(EVAL_DIR) / "eval_report.txt"
if report_path.exists():
    print(report_path.read_text())
else:
    print("eval_report.txt not found")

In [None]:
import json
from pathlib import Path

metrics_path = Path(EVAL_DIR) / "metrics.json"
if metrics_path.exists():
    metrics = json.loads(metrics_path.read_text())

    print("=== Retrieval Metrics ===")
    for k, v in metrics.get("retrieval", {}).items():
        print(f"  {k}: {v:.1f}" if isinstance(v, float) else f"  {k}: {v}")

    print("\n=== Zero-Shot Classification ===")
    for k, v in metrics.get("zero_shot", {}).items():
        print(f"  {k}: {v:.1f}" if isinstance(v, float) else f"  {k}: {v}")

    print("\n=== Novel Queries (top-10 precision) ===")
    for q, v in metrics.get("novel_queries", {}).items():
        prec = v.get("top_10_precision", 0)
        print(f"  \"{q}\": {prec:.0f}%")
else:
    print("metrics.json not found")

In [None]:
from pathlib import Path
from IPython.display import Image, display

tsne_cell = Path(EVAL_DIR) / "tsne_cell_type.png"
tsne_region = Path(EVAL_DIR) / "tsne_region.png"

if tsne_cell.exists():
    print("t-SNE by Cell Type:")
    display(Image(filename=str(tsne_cell), width=700))

if tsne_region.exists():
    print("\nt-SNE by Brain Region:")
    display(Image(filename=str(tsne_region), width=700))

if not tsne_cell.exists() and not tsne_region.exists():
    print("No t-SNE plots found. They may have been skipped if too few neurons.")

## 8. Interactive Query

Use the trained CLIP model for text-to-image retrieval: type a natural language description and retrieve the most similar neuron renderings.

In [None]:
import json
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from pathlib import Path
from axonet.training.clip_evaluator import load_clip_model

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = load_clip_model(Path(CLIP_BEST), device=device)

# Load val manifest and pre-compute image embeddings
val_manifest = Path(DATASET_DIR) / "manifest_val.jsonl"
val_entries = []
with open(val_manifest) as f:
    for line in f:
        val_entries.append(json.loads(line))

print(f"Pre-computing embeddings for {len(val_entries)} images...")
all_embeddings = []
with torch.no_grad():
    for entry in val_entries:
        input_path = Path(DATASET_DIR) / entry["mask_bw"]
        img = mpimg.imread(str(input_path))
        if img.ndim == 3:
            img = img.mean(axis=2)
        tensor = torch.from_numpy(img.astype("float32")).unsqueeze(0).unsqueeze(0).to(device)
        emb = clip_model.image_encoder.encode_for_clip(tensor)
        all_embeddings.append(emb.cpu())

all_embeddings = torch.cat(all_embeddings, dim=0)
all_embeddings = F.normalize(all_embeddings, p=2, dim=-1)
print(f"Done. Embedding shape: {all_embeddings.shape}")


def query_neurons(text: str, top_k: int = 5):
    """Retrieve top-k neuron images matching a text query."""
    with torch.no_grad():
        text_emb = clip_model.text_encoder([text])
        text_emb = F.normalize(text_emb.to(device), p=2, dim=-1).cpu()

    sims = (all_embeddings @ text_emb.T).squeeze()
    top_idx = torch.argsort(sims, descending=True)[:top_k].numpy()

    fig, axes = plt.subplots(1, top_k, figsize=(4 * top_k, 4))
    fig.suptitle(f'Query: "{text}"', fontsize=13)
    for i, idx in enumerate(top_idx):
        entry = val_entries[idx]
        # Show mask_color if available, else mask_bw
        img_key = "mask_color" if "mask_color" in entry else "mask_bw"
        img_path = Path(DATASET_DIR) / entry[img_key]
        img = mpimg.imread(str(img_path))
        ax = axes[i] if top_k > 1 else axes
        ax.imshow(img, cmap="gray" if img_key == "mask_bw" else None)
        ax.set_title(f"sim={sims[idx]:.3f}\n{Path(entry['swc']).stem[:25]}", fontsize=9)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    plt.show()


# Run example queries
query_neurons("pyramidal neuron from hippocampus")
query_neurons("interneuron from neocortex")

## 9. Summary

| Step | What | Output |
|------|------|--------|
| Download | SWC morphologies from NeuroMorpho.Org | `neurons/swc/*.swc`, `metadata.jsonl` |
| Dataset | PCA-guided multi-view rendering | `dataset/` with train/val manifests |
| Stage 1 | SegVAE2D training (seg + depth) | Encoder checkpoint |
| Stage 2 | CLIP fine-tuning (image-text alignment) | CLIP checkpoint |
| Eval | Retrieval, zero-shot, t-SNE | `eval_results/` |

**Scaling up for real training:**
- Increase `N_NEURONS` to 500–5000+ (remove `--max-pages 1` for full query)
- Set `MAX_EPOCHS_VAE=100`, `MAX_EPOCHS_CLIP=100`
- Use larger batch sizes if VRAM allows
- Consider multi-GPU with `--devices` flag

In [None]:
# Optional: download all results as a zip
import shutil
from pathlib import Path

archive_path = shutil.make_archive("/content/axonet_results", "zip", WORK_DIR)
print(f"Archive created: {archive_path}")

try:
    from google.colab import files
    files.download(archive_path)
except ImportError:
    print("Not running in Colab; download manually from the file browser.")