# 01 — Train

Fine-tune DeepForest on local data.

Inputs
- `data/tiles/` — GeoTIFF tiles (RGB).  
- `data/labels/df_labels_train.csv`  
- `data/labels/df_labels_valid.csv`

Outputs
- `models/deepforest_ft.pt` (fine-tuned model)
- Printed internal metrics (IoU/mAP) after validation
- Optional preview plot of predictions

Steps
1) Load pretrained DeepForest.
2) Set training/validation config (paths, epochs, lr, batch).
3) Train with PyTorch Lightning Trainer.
4) Validate and print metrics.
5) (Optional) Plot predictions on a sample tile.
6) Save model to .pt file.

In [None]:
# Root and dependencies

import sys
from pathlib import Path

REPO_ROOT = Path.cwd().parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import pandas as pd
from deepforest import main
from pytorch_lightning import Trainer

from scripts.vis_utils import plot_predictions

# paths
DATA = REPO_ROOT / "data"
TILES = DATA / "tiles"
LABELS = DATA / "labels"
MODELS = REPO_ROOT / "models"
TRAIN_CSV = LABELS / "df_labels_train.csv"
VALID_CSV = LABELS / "df_labels_valid.csv"

print("REPO_ROOT:", REPO_ROOT)

### LOAD BASE MODEL AND RE-TRAIN WITH LOCAL DATA

In [None]:
# Load pretrained model

m = main.deepforest()
m.load_model(model_name="weecology/deepforest-tree", revision="main")

In [None]:
# Training config

m.config["train"]["csv_file"] = str(TRAIN_CSV)
m.config["train"]["root_dir"] = str(TILES)
m.config["validation"]["csv_file"] = str(VALID_CSV)
m.config["validation"]["root_dir"] = str(TILES)
m.config["train"]["epochs"] = 3
m.config["train"]["lr"] = 1e-4
m.config["batch_size"] = 2
m.config["workers"] = 0  # Windows

In [None]:
# Train

m.trainer = Trainer(
    max_epochs=m.config["train"]["epochs"],
    accelerator="auto",
    devices=1
)
m.trainer.fit(m, ckpt_path=None)

### QUICK ASSESSMENT WITH VALIDATION DATA (proper evaluation done in 02_evaluate.ipynb)

In [None]:
# Internal metrics

m.config["validation"]["csv_file"] = str(VALID_CSV)
m.config["validation"]["root_dir"] = str(TILES)
results = m.trainer.validate(m)
print(results[0] if results else {})

In [None]:
# Plot image example

im = Image.open(TILES / "tile03.tif").convert("RGB")  # insert tile name
arr = np.array(im)
pred = m.predict_image(arr)       # uses default thresholds, that is modified in evaluation
print(pred.head() if pred is not None else "No detections")

plot_predictions(arr, pred, title="Predictions (train defaults)")

In [None]:
# Save re-trained model

m.save_model(REPO_ROOT/"models"/"deepforest_ft.pt")