# LFM Segmentation Example Workflow
This notebook is an example workflow of doing binary segmentation on visual light (RGB) bands of Lunar data. 

You can get started with this notebook by downloading it with:

```bash
wget https://raw.githubusercontent.com/nasa-nccs-hpda/lfm/refs/heads/main/notebooks/finetune_dinov3.ipynb
```

**See the README in the [repo](https://github.com/nasa-nccs-hpda/lfm)** for more info on how to use this notebook, and more on the process of training the model. 

## Imports, Dino Repo Clone

In [None]:
import os
import sys
import torch
import subprocess
from glob import glob

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

%matplotlib inline

In [None]:
repo_dir = "lfm"

if not os.path.exists(repo_dir):
    subprocess.run(["git", "clone", f"https://github.com/nasa-nccs-hpda/{repo_dir}"])
else:
    subprocess.run(["git", "-C", repo_dir, "pull"])

In [None]:
sys.path.append("lfm")

from lfm.tasks.segmentation.model import DINOSegmentation, load_dinov3_encoder
from lfm.tasks.segmentation.dataset import get_dataloaders
from lfm.tasks.segmentation.driver import train_model
from lfm.tasks.segmentation.utils import install_termcolor_locally

### Install termcolor package (required by DinoV3)

In [None]:
install_termcolor_locally()

## Main workflow

1. Define user-configured variables
2. Create dataloaders from files on /explore/nobackup/.
3. Load DinoV3 encoder, create encoder/decoder finetuning model.
4. Train model, print training stats, and visualize results. 

### User Config

In [None]:
# Weights URL (received after registering for DINOV3)
weights_URL = (
    "https://dinov3.llamameta.net/dinov3_vitl16/"
    "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
    "?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoiNDloYXZtdThkZGh3eGw3aH"
    "JwNjQwa3E3IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXR"
    "cLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE"
    "3Njc5OTI2Njl9fX1dfQ__"
    "&Signature=neHREO7xc90azhmnF3r9qPztYJ5L2wO-EZkVKh6ECzR5H2YGzdK3dcF"
    "WQISNb6xYo3R5EO39FKJ7bwELXA%7EgoBqDbk-jm-9n9%7EVxtEOmWVx73usrILMwhyi"
    "cP5-448rbnUzOEM0lPkGS829mOBJkaSxxSsbkQ0VpVBcScNEFcpaNOZ--BeHxCHdTFV"
    "NGkhlEaCYPUbYyHYbTgDQntH2AsKYJTWw4NIEZJZLX9wjCOYKQ-YG86d0HJAvsdF79X"
    "vITPgJSA0U-4Z1CzIkQhZb0N-7-XnbZmnJJi42xnNS0DsB6CTedxq0FAfiYklBY77wT"
    "JrYLba%7Epkap23ymoUTxDXA__"
    "&Key-Pair-Id=K15QRJLYKIFSLZ"
    "&Download-Request-ID=1618342689192585"
)
weights_local_checkpoint = '/explore/nobackup/projects/ilab/models/dinov3/dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth'

# Data paths
INPUT_DIR = "/explore/nobackup/projects/lfm"
IMAGE_DIR = f"{INPUT_DIR}/vis_chips/chips"
LABEL_DIR = f"{INPUT_DIR}/vis_chips/labels_npy"

# Output dir (this will be created automatically)
OUTPUT_DIR = "./outputs"  # Change this if you want a specific path
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

# Location of cloned dinov3 repo
REPO_DIR = "./dinov3"

# Dataset parameters
MAX_SAMPLES = 500  # Set to None to use all available samples, or an integer to limit
TRAIN_SPLIT = 0.8  # 80% train, 20% validation

# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
LOSS_TYPE = "combined"  # Combined Dice + Binary CE loss

# Model parameters
TARGET_SIZE = (304, 304)  # Input size for DINO model
FREEZE_ENCODER = False

# Visualization and saving
CHECKPOINT_EVERY = 10  # Save checkpoint every N epochs
VISUALIZE_EVERY = 10  # Create visualizations every N epochs

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

### Create dataloaders

In [None]:
# ============================================================================
# CREATE DATALOADERS
# ============================================================================

print("\n" + "="*60)
print("STEP 1: Creating dataloaders.")
print("="*60)

train_loader, val_loader, MEAN, STD = get_dataloaders(
    image_dir=IMAGE_DIR,
    label_dir=LABEL_DIR,
    batch_size=BATCH_SIZE,
    train_split=TRAIN_SPLIT,
    num_workers=NUM_WORKERS,
    target_size=TARGET_SIZE,
    max_samples=MAX_SAMPLES,
    seed=42,
    stats_save_dir=OUTPUT_DIR
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

### Load Encoder and Create Model

In [None]:
print("\n" + "="*60)
print("STEP 2: Loading DINO encoder and creating model.")
print("="*60)

encoder = load_dinov3_encoder(weights_local_checkpoint, device)

In [None]:
# Create model with DINO segmentation head, UNet decoder (see model.py)
print("Creating DINO segmentation model with UNet decoder...")
model = DINOSegmentation(
    encoder=encoder,  # Use DINOv3 head
    num_classes=2,  # Binary segmentaton (crater, not crater)
    img_size=TARGET_SIZE
).to(device)

# Unfreeze encoder for full fine-tuning if desired
if FREEZE_ENCODER:
    for param in model.encoder.parameters():
        param.requires_grad = False
    print("Encoder frozen (only decoder will be trained).")
else:
    print("Encoder unfrozen! Full model will be trained.")

### Run Training

In [None]:
print("\n" + "="*60)
print("Starting training.")
print("="*60)

train_losses, val_losses = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    mode="train",
    output_dir=OUTPUT_DIR,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    checkpoint_every=CHECKPOINT_EVERY,
    visualize_every=VISUALIZE_EVERY,
    loss_type=LOSS_TYPE,
    device=device
)

## Display some of the output visualizations

The training of the model is already producing some visualizations every N epochs.
Here we open some of the visualizations to look at them from the notebook.

In [None]:
visualization_dir = os.path.join(OUTPUT_DIR, 'visualizations')
visualization_filenames = sorted(glob(os.path.join(visualization_dir, '*.png')))

In [None]:
for vis_filename in visualization_filenames:

    img = mpimg.imread(vis_filename)
    
    plt.figure(figsize=(16, 14))
    plt.imshow(img)
    plt.axis("off")
    plt.show()