# 3. Model Training

This notebook trains Detectron2 object detection models for microbial colony detection.

**Supports all experimental configurations:**
- Part 1: AGAR dataset (5 subsets Ã— 4 models = 20 runs)
- Part 2: Curated/Roboflow dataset (4 models)
- Transfer learning from pre-trained AGAR models

**Configuration:** Change the parameters in Section 3.1 to select the dataset, model, and training mode.

**Prerequisites:** Run `1_setup.ipynb` and `2_data_exploration.ipynb` first.

## 3.1 Configuration

In [None]:
import os
import json
import datetime

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2.data import DatasetCatalog, MetadataCatalog

import config
from utils.training import MyTrainer

In [None]:
# ===================== CONFIGURE YOUR EXPERIMENT =====================

# --- Dataset ---
DATASET_SOURCE = "agar"       # 'agar' or 'roboflow'
SUBSET = "total"              # For AGAR: 'total', 'bright', 'dark', 'vague', 'lowres'

# --- Model ---
MODEL_KEY = "faster_rcnn_R50"  # See config.MODELS for options:
                               # 'faster_rcnn_R50', 'faster_rcnn_R101',
                               # 'retinanet_R50', 'retinanet_R101',
                               # 'mask_rcnn_R50', 'mask_rcnn_R101'

# --- Training mode ---
USE_TRANSFER_LEARNING = False  # If True, set TRANSFER_WEIGHTS below
TRANSFER_WEIGHTS = None        # Path to .pth file, e.g.:
# TRANSFER_WEIGHTS = config.get_model_weights("total_faster_rcnn_R101", "agar")
# TRANSFER_WEIGHTS = config.get_model_weights("total_retinanet_R50", "agar")

# --- Hyperparameters ---
BATCH_SIZE = 8
NUM_EPOCHS = 10                # AGAR: 10, Roboflow: 100
BASE_LR = 0.005
NUM_CLASSES = 3                # AGAR: 3, Roboflow: 4 (3 + background)

# --- Reproducibility ---
RANDOM_SEED = 42               # Fixed seed for deterministic training

# ====================================================================

print(f"Experiment: {DATASET_SOURCE}/{SUBSET} | Model: {MODEL_KEY}")
print(f"Transfer learning: {USE_TRANSFER_LEARNING}")
print(f"Epochs: {NUM_EPOCHS}, Batch: {BATCH_SIZE}, LR: {BASE_LR}")
print(f"Random seed: {RANDOM_SEED}")

## 3.2 Dataset Setup

In [None]:
if DATASET_SOURCE == "agar":
    dataset = config.AGAR_DATASETS[SUBSET]
    train_path = dataset["train"]
    val_path = dataset["val"]
    test_path = dataset["test"]
    img_dir_train = img_dir_val = img_dir_test = config.AGAR_IMG_DIR
    train_name = f"{SUBSET}_train"
    val_name = f"{SUBSET}_val"
    test_name = f"{SUBSET}_test"
    dataset_label = f"{SUBSET}_100"

elif DATASET_SOURCE == "roboflow":
    dataset = config.ROBOFLOW_DATASETS["curated"]
    train_path = dataset["train"]
    val_path = dataset["val"]
    test_path = dataset["test"]
    img_dir_train = dataset["train_dir"]
    img_dir_val = dataset["val_dir"]
    img_dir_test = dataset["test_dir"]
    train_name = "robo_train"
    val_name = "robo_val"
    test_name = "robo_test"
    dataset_label = "final_noaugm"

# Register datasets (safe to re-run)
for name, ann_path, img_dir in [
    (train_name, train_path, img_dir_train),
    (val_name, val_path, img_dir_val),
    (test_name, test_path, img_dir_test),
]:
    if name in DatasetCatalog.list():
        DatasetCatalog.remove(name)
        MetadataCatalog.remove(name)
    register_coco_instances(name, {}, ann_path, img_dir)

print(f"Datasets registered: {train_name}, {val_name}, {test_name}")

## 3.3 Compute Training Schedule

In [None]:
with open(train_path, 'r') as f:
    data = json.load(f)

num_train_images = len(data['images'])
iterations = NUM_EPOCHS * num_train_images // BATCH_SIZE
checkpoint_period = num_train_images // BATCH_SIZE  # once per epoch
lr_decay_step = int(0.7 * iterations)

print(f"Training images: {num_train_images}")
print(f"Total iterations: {iterations}")
print(f"Checkpoint every: {checkpoint_period} iterations (= 1 epoch)")
print(f"LR decay at iteration: {lr_decay_step}")

## 3.4 Build Detectron2 Config

In [None]:
config_file = config.MODELS[MODEL_KEY]
model_name = os.path.splitext(os.path.basename(config_file))[0]

# Output directory with timestamp
current_time = datetime.datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
transfer_tag = "_transferlearn" if USE_TRANSFER_LEARNING else ""
output_dir = os.path.join(
    config.OUTPUTS_DIR,
    f"{dataset_label}{transfer_tag}_{model_name}_{current_time}"
)
os.makedirs(output_dir, exist_ok=True)

# Build config
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(config_file))

cfg.DATASETS.TRAIN = (train_name,)
cfg.DATASETS.TEST = (val_name,)
cfg.DATALOADER.NUM_WORKERS = 4
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False

# Reproducibility: set random seed
cfg.SEED = RANDOM_SEED

# Weights: transfer learning or model zoo
if USE_TRANSFER_LEARNING and TRANSFER_WEIGHTS:
    cfg.MODEL.WEIGHTS = TRANSFER_WEIGHTS
    print(f"Transfer learning from: {TRANSFER_WEIGHTS}")
else:
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)
    print(f"Initializing from model zoo: {config_file}")

# Optimizer
cfg.SOLVER.IMS_PER_BATCH = BATCH_SIZE
cfg.SOLVER.BASE_LR = BASE_LR
cfg.SOLVER.MOMENTUM = 0.9
cfg.SOLVER.WEIGHT_DECAY = 0.0005
cfg.SOLVER.MAX_ITER = iterations
cfg.SOLVER.STEPS = (lr_decay_step,)

# Warmup
cfg.SOLVER.WARMUP_FACTOR = 1.0 / 1000
cfg.SOLVER.WARMUP_ITERS = min(1000, checkpoint_period)
cfg.SOLVER.WARMUP_METHOD = "linear"

# Detection head
if config.is_retinanet(MODEL_KEY):
    cfg.MODEL.RETINANET.NUM_CLASSES = NUM_CLASSES
else:
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = NUM_CLASSES
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512

# Checkpointing & evaluation
cfg.SOLVER.CHECKPOINT_PERIOD = checkpoint_period
cfg.TEST.EVAL_PERIOD = checkpoint_period
cfg.OUTPUT_DIR = output_dir

# Save full resolved config for reproducibility
config_yaml_path = os.path.join(output_dir, "full_config.yaml")
with open(config_yaml_path, "w") as f:
    f.write(cfg.dump())
print(f"Full config saved to: {config_yaml_path}")

print(f"\nOutput directory: {output_dir}")
print(f"Config file: {config_file}")
print(f"Num classes: {NUM_CLASSES}")
print(f"Seed: {RANDOM_SEED}")

## 3.5 Train

In [None]:
trainer = MyTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

## 3.6 Training Curves

In [None]:
from utils.evaluation import plot_training_curves

metrics_path = os.path.join(output_dir, "metrics.json")
plot_save_path = os.path.join(output_dir, "training_curves.png")

plot_training_curves(metrics_path, save_path=plot_save_path)

## 3.7 TensorBoard (Optional)

In [None]:
%load_ext tensorboard
%tensorboard --logdir {output_dir}