# LFM Toy Model
This notebook is used to fine-tune a DinoV3 toy model on Lunar data. 

# Model specifications
The SAT-493M ViT-L/16 distilled DinoV3 encoder was used (trained on Satellite data). All encoder parameters were unfrozen for fine-tuning. See the [DinoV3 repo](https://github.com/facebookresearch/dinov3) for more info. 

## Input data specifications
The vis data, (hosted at /explore/nobackup/projects/lfm/rawdata/Lunar/LowRes_MLDataset_v1_bilinear), was preprocessed by extracting the following bands and normalizing values to [0,1] range: [643, 566, 415]. Data was saved in (3, 300, 300) shape .npy files under the LFM project space (explore/nobackup/projects/lfm/vis_chips). 

## Label specifications
Labels were processed from the annotations JSON file. Annotations were sorted by corresponding filename, then all labels for a given filename were saved single composite (300, 300) shape .npy images under the LFM project space (explore/nobackup/projects/lfm/vis_chips). 

## Input/label matching
Labels and inputs were matched by asset ID, as well as tile row/column ID. 

## Training specifications
Model was trained on 500 input/label pairs for 50 epochs, using a PRISM JupyterHub job on 4 V100 GPUs (1 V100 will also work, but will be slower). The parameters used were: "combined" loss function (Dice loss + Binary CE), 1e-4 LR, AdamW optimizer, and Cosine Annealing LR scheduling. A train/val split of 80/20% was used as well.

## Imports, Dino Repo Clone

In [None]:
import os
from pathlib import Path
from glob import glob

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tqdm.notebook import tqdm

from model import DINOSegmentation
from dataset import get_dataloaders
from driver import main

# Notebook settings
%matplotlib inline

In [None]:
# Notebook requires access to models on GitHub (see their README)
!git clone https://github.com/facebookresearch/dinov3.git

## Main workflow

### User Config

In [None]:
# Weights URL (received after registering for DINOV3)
weights_URL = (
    "https://dinov3.llamameta.net/dinov3_vitl16/"
    "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
    "?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoiNDloYXZtdThkZGh3eGw3aH"
    "JwNjQwa3E3IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXR"
    "cLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE"
    "3Njc5OTI2Njl9fX1dfQ__"
    "&Signature=neHREO7xc90azhmnF3r9qPztYJ5L2wO-EZkVKh6ECzR5H2YGzdK3dcF"
    "WQISNb6xYo3R5EO39FKJ7bwELXA%7EgoBqDbk-jm-9n9%7EVxtEOmWVx73usrILMwhyi"
    "cP5-448rbnUzOEM0lPkGS829mOBJkaSxxSsbkQ0VpVBcScNEFcpaNOZ--BeHxCHdTFV"
    "NGkhlEaCYPUbYyHYbTgDQntH2AsKYJTWw4NIEZJZLX9wjCOYKQ-YG86d0HJAvsdF79X"
    "vITPgJSA0U-4Z1CzIkQhZb0N-7-XnbZmnJJi42xnNS0DsB6CTedxq0FAfiYklBY77wT"
    "JrYLba%7Epkap23ymoUTxDXA__"
    "&Key-Pair-Id=K15QRJLYKIFSLZ"
    "&Download-Request-ID=1618342689192585"
)

# Data paths
INPUT_DIR = "/explore/nobackup/people/ajkerr1/Lunar_FM"
IMAGE_DIR = f"{INPUT_DIR}/vis_chips/chips"
LABEL_DIR = f"{INPUT_DIR}/vis_chips/labels_npy"

# Output dir (create this if not already created)
OUTPUT_DIR = ""
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

# Location of cloned dinov3 repo
REPO_DIR = "./dinov3"

# Dataset parameters
MAX_SAMPLES = 500  # Set to None to use all available samples, or an integer to limit
TRAIN_SPLIT = 0.8  # 80% train, 20% validation

# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
LOSS_TYPE = "combined"  # Combined Dice + Binary CE loss

# Model parameters
TARGET_SIZE = (304, 304)  # Input size for DINO model
N_CLASSES = 2  # Binary segmentation (background, crater)
FREEZE_ENCODER = True

# Visualization and saving
CHECKPOINT_EVERY = 10  # Save checkpoint every N epochs
VISUALIZE_EVERY = 10  # Create visualizations every N epochs

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

### Training code
1. Create dataloaders from files on /nobackup space.
2. Load DinoV3 encoder, create encoder/decoder finetuning model.
3. Train model.
4. Print post-training QA stats.

In [None]:
# ============================================================================
# CREATE DATALOADERS
# ============================================================================

print("\n" + "="*60)
print("STEP 1: Creating dataloaders.")
print("="*60)

train_loader, val_loader, MEAN, STD = get_dataloaders(
    image_dir=IMAGE_DIR,
    label_dir=LABEL_DIR,
    batch_size=BATCH_SIZE,
    train_split=TRAIN_SPLIT,
    num_workers=NUM_WORKERS,
    target_size=TARGET_SIZE,
    max_samples=MAX_SAMPLES,
    seed=42,
    stats_save_dir=OUTPUT_DIR
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# ============================================================================
# LOAD ENCODER AND CREATE MODEL
# ============================================================================

print("\n" + "="*60)
print("STEP 2: Loading DINO encoder and creating model.")
print("="*60)

encoder = torch.hub.load(
    REPO_DIR,
    'dinov3_vitl16',
    source='local',
    weights=weights_URL
).to(device)

print("Encoder loaded with pretrained weights.")

# Create model with DINO segmentation head, UNet decoder (see model.py)
print("Creating DINO segmentation model with UNet decoder...")
model = DINOSegmentation(
    encoder=encoder,
    num_classes=N_CLASSES,
    img_size=TARGET_SIZE
).to(device)

# Unfreeze encoder for full fine-tuning
if FREEZE_ENCODER:
    for param in model.encoder.parameters():
        param.requires_grad = False
    print("Encoder frozen (only decoder will be trained).")
else:
    print("Encoder unfrozen! Full model will be trained.")

# ============================================================================
# RUN TRAINING
# ============================================================================

print("\n" + "="*60)
print("Starting training.")
print("="*60)

train_losses, val_losses = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    mode="train",
    output_dir=OUTPUT_DIR,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    checkpoint_every=CHECKPOINT_EVERY,
    visualize_every=VISUALIZE_EVERY,
    loss_type=LOSS_TYPE,
    device=device
)