# LFM Toy Model
This notebook is used to fine-tune a DinoV3 toy model on Lunar data. 

## Prerequisites

**Ensure that you are using the ilab-pytorch kernel (look in the top right -- it should say "Python [...]". Click it and select the one that says Python [conda env: ilab-pytorch].**

**You should be signed into GitHub**. The dinov3/ repo is required to be able to load the dino model, **if you aren't signed into GitHub this step will fail.**

## Imports, Dino Repo Clone

In [1]:
import os
import sys

import torch

from model import DINOSegmentation
from dataset import get_dataloaders
from driver import train_model
from utils import install_termcolor_locally

# Notebook settings
%matplotlib inline

### Install termcolor package (required by DinoV3)

In [2]:
install_termcolor_locally()

Collecting termcolor
  Using cached termcolor-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Using cached termcolor-3.3.0-py3-none-any.whl (7.7 kB)
Installing collected packages: termcolor
Successfully installed termcolor-3.3.0
Termcolor installed to: /home/ajkerr1/.local/lib/python3.12/site-packages


[0m

In [None]:
# !git clone https://github.com/facebookresearch/dinov3.git

## Main workflow

### User Config

In [3]:
# Weights URL (received after registering for DINOV3)
weights_URL = (
    "https://dinov3.llamameta.net/dinov3_vitl16/"
    "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
    "?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoiNDloYXZtdThkZGh3eGw3aH"
    "JwNjQwa3E3IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXR"
    "cLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE"
    "3Njc5OTI2Njl9fX1dfQ__"
    "&Signature=neHREO7xc90azhmnF3r9qPztYJ5L2wO-EZkVKh6ECzR5H2YGzdK3dcF"
    "WQISNb6xYo3R5EO39FKJ7bwELXA%7EgoBqDbk-jm-9n9%7EVxtEOmWVx73usrILMwhyi"
    "cP5-448rbnUzOEM0lPkGS829mOBJkaSxxSsbkQ0VpVBcScNEFcpaNOZ--BeHxCHdTFV"
    "NGkhlEaCYPUbYyHYbTgDQntH2AsKYJTWw4NIEZJZLX9wjCOYKQ-YG86d0HJAvsdF79X"
    "vITPgJSA0U-4Z1CzIkQhZb0N-7-XnbZmnJJi42xnNS0DsB6CTedxq0FAfiYklBY77wT"
    "JrYLba%7Epkap23ymoUTxDXA__"
    "&Key-Pair-Id=K15QRJLYKIFSLZ"
    "&Download-Request-ID=1618342689192585"
)

# Data paths
INPUT_DIR = "/explore/nobackup/projects/lfm"
IMAGE_DIR = f"{INPUT_DIR}/vis_chips/chips"
LABEL_DIR = f"{INPUT_DIR}/vis_chips/labels_npy"

# Output dir (create this if not already created)
OUTPUT_DIR = "./outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

# Location of cloned dinov3 repo
REPO_DIR = "./dinov3"

# Dataset parameters
MAX_SAMPLES = 500  # Set to None to use all available samples, or an integer to limit
TRAIN_SPLIT = 0.8  # 80% train, 20% validation

# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
LOSS_TYPE = "combined"  # Combined Dice + Binary CE loss

# Model parameters
TARGET_SIZE = (304, 304)  # Input size for DINO model
FREEZE_ENCODER = True

# Visualization and saving
CHECKPOINT_EVERY = 10  # Save checkpoint every N epochs
VISUALIZE_EVERY = 10  # Create visualizations every N epochs

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Output directory: ./outputs
Using device: cuda


### Training code
1. Create dataloaders from files on /nobackup space.
2. Load DinoV3 encoder, create encoder/decoder finetuning model.
3. Train model and print training stats. 

In [4]:
# ============================================================================
# CREATE DATALOADERS
# ============================================================================

print("\n" + "="*60)
print("STEP 1: Creating dataloaders.")
print("="*60)

train_loader, val_loader, MEAN, STD = get_dataloaders(
    image_dir=IMAGE_DIR,
    label_dir=LABEL_DIR,
    batch_size=BATCH_SIZE,
    train_split=TRAIN_SPLIT,
    num_workers=NUM_WORKERS,
    target_size=TARGET_SIZE,
    max_samples=MAX_SAMPLES,
    seed=42,
    stats_save_dir=OUTPUT_DIR
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# ============================================================================
# LOAD ENCODER AND CREATE MODEL
# ============================================================================

print("\n" + "="*60)
print("STEP 2: Loading DINO encoder and creating model.")
print("="*60)

encoder = torch.hub.load(
    repo_or_dir='facebookresearch/dinov3',  # GitHub repo
    model='dinov3_vitl16',
    source='github',
    weights=weights_URL
).to(device)

print("Encoder loaded with pretrained weights.")

# Create model with DINO segmentation head, UNet decoder (see model.py)
print("Creating DINO segmentation model with UNet decoder...")
model = DINOSegmentation(
    encoder=encoder,  # Use DINOv3 head
    num_classes=2,  # Binary segmentaton (crater, not crater)
    img_size=TARGET_SIZE
).to(device)

# Unfreeze encoder for full fine-tuning if desired
if FREEZE_ENCODER:
    for param in model.encoder.parameters():
        param.requires_grad = False
    print("Encoder frozen (only decoder will be trained).")
else:
    print("Encoder unfrozen! Full model will be trained.")

# ============================================================================
# RUN TRAINING
# ============================================================================

print("\n" + "="*60)
print("Starting training.")
print("="*60)

train_losses, val_losses = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    mode="train",
    output_dir=OUTPUT_DIR,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    checkpoint_every=CHECKPOINT_EVERY,
    visualize_every=VISUALIZE_EVERY,
    loss_type=LOSS_TYPE,
    device=device
)


STEP 1: Creating dataloaders.
Loading existing dataset statistics...
Mean (RGB): [0.32744208 0.32249309 0.302985  ]
Std (RGB): [0.15045737 0.15014801 0.14386101]
Limited to 500 samples
Found 500 matched image-label pairs
Train samples: 400
Val samples: 100
Train batches: 25
Val batches: 7

STEP 2: Loading DINO encoder and creating model.


Using cache found in /home/ajkerr1/.cache/torch/hub/facebookresearch_dinov3_main


Encoder loaded with pretrained weights.
Creating DINO segmentation model with UNet decoder...
Encoder frozen (only decoder will be trained).

Starting training.
Using device: cuda
Using loss function: combined
Loss function: CombinedLoss

Starting training for 50 epochs...
Starting from epoch: 1
Checkpoints will be saved every 10 epochs to: ./outputs/checkpoints
Visualizations will be saved every 10 epochs to: ./outputs/visualizations

MODEL PARAMETER SUMMARY
Encoder:
  Trainable: 0 / 303,156,224 (0.00%)

Decoder:
  Trainable: 5,924,610 / 5,924,610 (100.00%)

Combined Model:
  Trainable: 5,924,610 / 309,080,834 (1.92%)


Epoch 1/50


Epoch 1/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 1 Summary:
  Train Loss: 0.6082
  Val Loss:   0.6661
  LR:         0.000100
  Time:       16.01s
Saved checkpoint to: ./outputs/checkpoints/best_model.pt
  Saved best model (val_loss: 0.6661)

Epoch 2/50


Epoch 2/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 2 Summary:
  Train Loss: 0.5326
  Val Loss:   0.5452
  LR:         0.000100
  Time:       14.69s
Saved checkpoint to: ./outputs/checkpoints/best_model.pt
  Saved best model (val_loss: 0.5452)

Epoch 3/50


Epoch 3/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 3 Summary:
  Train Loss: 0.4950
  Val Loss:   0.5241
  LR:         0.000099
  Time:       14.69s
Saved checkpoint to: ./outputs/checkpoints/best_model.pt
  Saved best model (val_loss: 0.5241)

Epoch 4/50


Epoch 4/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 4 Summary:
  Train Loss: 0.4716
  Val Loss:   0.5069
  LR:         0.000098
  Time:       14.77s
Saved checkpoint to: ./outputs/checkpoints/best_model.pt
  Saved best model (val_loss: 0.5069)

Epoch 5/50


Epoch 5/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 5 Summary:
  Train Loss: 0.4486
  Val Loss:   0.5159
  LR:         0.000098
  Time:       14.80s

Epoch 6/50


Epoch 6/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 6/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 6 Summary:
  Train Loss: 0.4352
  Val Loss:   0.5121
  LR:         0.000096
  Time:       14.75s

Epoch 7/50


Epoch 7/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 7/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 7 Summary:
  Train Loss: 0.4176
  Val Loss:   0.5116
  LR:         0.000095
  Time:       14.79s

Epoch 8/50


Epoch 8/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 8/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 8 Summary:
  Train Loss: 0.4015
  Val Loss:   0.5183
  LR:         0.000094
  Time:       14.82s

Epoch 9/50


Epoch 9/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 9/50 [Val]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 9 Summary:
  Train Loss: 0.3752
  Val Loss:   0.5052
  LR:         0.000092
  Time:       14.85s
Saved checkpoint to: ./outputs/checkpoints/best_model.pt
  Saved best model (val_loss: 0.5052)

Epoch 10/50


Epoch 10/50 [Train]:   0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 