# LFM Segmentation Example Workflow
This notebook is an example workflow of doing binary segmentation on visual light (RGB) bands of Lunar data. 

**See the README in the [repo](https://github.com/nasa-nccs-hpda/lfm)** for more info on how to use this notebook, and more on the process of training the model. 

## Imports, Dino Repo Clone

In [1]:
import os
import sys

import torch

from lfm.tasks.segmentation.model import DINOSegmentation
from lfm.tasks.segmentation.dataset import get_dataloaders
from lfm.tasks.segmentation.driver import train_model
from lfm.tasks.segmentation.utils import install_termcolor_locally

%matplotlib inline

### Install termcolor package (required by DinoV3)

In [2]:
install_termcolor_locally()

Collecting termcolor
  Using cached termcolor-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Using cached termcolor-3.3.0-py3-none-any.whl (7.7 kB)
Installing collected packages: termcolor
Successfully installed termcolor-3.3.0
Termcolor installed to: /home/ajkerr1/.local/lib/python3.12/site-packages


[0m

In [3]:
# !git clone https://github.com/facebookresearch/dinov3.git

## Main workflow

### User Config

In [4]:
# Weights URL (received after registering for DINOV3)
weights_URL = (
    "https://dinov3.llamameta.net/dinov3_vitl16/"
    "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
    "?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoiNDloYXZtdThkZGh3eGw3aH"
    "JwNjQwa3E3IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXR"
    "cLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE"
    "3Njc5OTI2Njl9fX1dfQ__"
    "&Signature=neHREO7xc90azhmnF3r9qPztYJ5L2wO-EZkVKh6ECzR5H2YGzdK3dcF"
    "WQISNb6xYo3R5EO39FKJ7bwELXA%7EgoBqDbk-jm-9n9%7EVxtEOmWVx73usrILMwhyi"
    "cP5-448rbnUzOEM0lPkGS829mOBJkaSxxSsbkQ0VpVBcScNEFcpaNOZ--BeHxCHdTFV"
    "NGkhlEaCYPUbYyHYbTgDQntH2AsKYJTWw4NIEZJZLX9wjCOYKQ-YG86d0HJAvsdF79X"
    "vITPgJSA0U-4Z1CzIkQhZb0N-7-XnbZmnJJi42xnNS0DsB6CTedxq0FAfiYklBY77wT"
    "JrYLba%7Epkap23ymoUTxDXA__"
    "&Key-Pair-Id=K15QRJLYKIFSLZ"
    "&Download-Request-ID=1618342689192585"
)

# Data paths
INPUT_DIR = "/explore/nobackup/projects/lfm" 
IMAGE_DIR = f"{INPUT_DIR}/vis_chips/chips"
LABEL_DIR = f"{INPUT_DIR}/vis_chips/labels_npy"

# Output dir (this will be created automatically)
OUTPUT_DIR = "./outputs"  # Change this if you want a specific path
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

# Location of cloned dinov3 repo
REPO_DIR = "./dinov3"

# Dataset parameters
MAX_SAMPLES = 500  # Set to None to use all available samples, or an integer to limit
TRAIN_SPLIT = 0.8  # 80% train, 20% validation

# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
LOSS_TYPE = "combined"  # Combined Dice + Binary CE loss

# Model parameters
TARGET_SIZE = (304, 304)  # Input size for DINO model
FREEZE_ENCODER = True

# Visualization and saving
CHECKPOINT_EVERY = 10  # Save checkpoint every N epochs
VISUALIZE_EVERY = 10  # Create visualizations every N epochs

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Output directory: ./outputs
Using device: cuda


### Training code
1. Create dataloaders from files on /nobackup space.
2. Load DinoV3 encoder, create encoder/decoder finetuning model.
3. Train model, print training stats, and visualize results. 

In [None]:
# ============================================================================
# CREATE DATALOADERS
# ============================================================================

print("\n" + "="*60)
print("STEP 1: Creating dataloaders.")
print("="*60)

train_loader, val_loader, MEAN, STD = get_dataloaders(
    image_dir=IMAGE_DIR,
    label_dir=LABEL_DIR,
    batch_size=BATCH_SIZE,
    train_split=TRAIN_SPLIT,
    num_workers=NUM_WORKERS,
    target_size=TARGET_SIZE,
    max_samples=MAX_SAMPLES,
    seed=42,
    stats_save_dir=OUTPUT_DIR
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# ============================================================================
# LOAD ENCODER AND CREATE MODEL
# ============================================================================

print("\n" + "="*60)
print("STEP 2: Loading DINO encoder and creating model.")
print("="*60)

encoder = torch.hub.load(
    repo_or_dir='facebookresearch/dinov3',  # GitHub repo
    model='dinov3_vitl16',
    source='github',
    weights=weights_URL
).to(device)

print("Encoder loaded with pretrained weights.")

# Create model with DINO segmentation head, UNet decoder (see model.py)
print("Creating DINO segmentation model with UNet decoder...")
model = DINOSegmentation(
    encoder=encoder,  # Use DINOv3 head
    num_classes=2,  # Binary segmentaton (crater, not crater)
    img_size=TARGET_SIZE
).to(device)

# Unfreeze encoder for full fine-tuning if desired
if FREEZE_ENCODER:
    for param in model.encoder.parameters():
        param.requires_grad = False
    print("Encoder frozen (only decoder will be trained).")
else:
    print("Encoder unfrozen! Full model will be trained.")

# ============================================================================
# RUN TRAINING
# ============================================================================

print("\n" + "="*60)
print("Starting training.")
print("="*60)

train_losses, val_losses = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    mode="train",
    output_dir=OUTPUT_DIR,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    checkpoint_every=CHECKPOINT_EVERY,
    visualize_every=VISUALIZE_EVERY,
    loss_type=LOSS_TYPE,
    device=device
)


STEP 1: Creating dataloaders.
Computing dataset statistics...


Computing dataset statistics:   0%|          | 0/7530 [00:00<?, ?it/s]

First image shape: (3, 300, 300), dtype: float64
Value range: min=0.0, max=1.0


Computing dataset statistics:  17%|█▋        | 1286/7530 [00:11<00:55, 112.95it/s]