In [5]:
import os
from pathlib import Path
import random
import time
import datetime
import logging
import numpy as np
import tensorflow as tf
import torch
import cv2
import matplotlib.pyplot as plt
from tensorboard.plugins.hparams import api as hp
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from detectron2.engine import DefaultTrainer, HookBase
from detectron2.evaluation import COCOEvaluator, inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm

# Initialize Detectron2 configuration
cfg = get_cfg()

In [11]:
def register_datasets():
    """
    Register train and validation datasets using COCO format.
    """
    # Register train dataset
    register_coco_instances("train_crowd", {}, "/scratch/mlainer/hail_master/crowdsource/cnn/train/annotations/instances_default.json", "/scratch/mlainer/hail_master/crowdsource/cnn/train/images")
    dataset_dicts_train = DatasetCatalog.get("train_crowd")
    hail_metadata_train = MetadataCatalog.get("train_crowd")

    # Register validation dataset
    register_coco_instances("val_crowd", {}, "/scratch/mlainer/hail_master/crowdsource/cnn/val/annotations/instances_default.json", "/scratch/mlainer/hail_master/crowdsource/cnn/val/images")
    dataset_dicts_val = DatasetCatalog.get("val_crowd")
    hail_metadata_val = MetadataCatalog.get("val_crowd")

In [10]:
def setup_hyperparameters():
    """
    Set up hyperparameters for tuning.
    """
    HP_BASE_LR = hp.HParam('base_lr', hp.Discrete([0.001]))
    HP_GAMMA = hp.HParam('gamma', hp.Discrete([0.5]))
    HP_BATCH_SIZE_PER_IMAGE = hp.HParam('batch_size_per_image', hp.Discrete([256]))

    return HP_BASE_LR, HP_GAMMA, HP_BATCH_SIZE_PER_IMAGE

def configure_tensorboard():
    """
    Configure TensorBoard for logging hyperparameter tuning metrics.
    """
    with tf.summary.create_file_writer('output/logs/hparam_tuning').as_default():
        hp.hparams_config(
            hparams=[HP_BASE_LR, HP_GAMMA, HP_BATCH_SIZE_PER_IMAGE],
            metrics=[
                hp.Metric(AP_BBOX, display_name='AP_bbox'),
                hp.Metric(AP50_BBOX, display_name='AP50_bbox'),
                hp.Metric(AP75_BBOX, display_name='AP75_bbox'),
                hp.Metric(AP_SEGM, display_name='AP_segm'),
                hp.Metric(AP50_SEGM, display_name='AP50_segm'),
                hp.Metric(AP75_SEGM, display_name='AP75_segm')
            ],
        )

def train_hail_model(run_dir, base_lr, gamma, batch_size):
    """
    Train the Hail detection model using the specified hyperparameters.

    Args:
    - run_dir (str): Directory to save the training output.
    - base_lr (float): Learning rate.
    - gamma (float): Learning rate decay factor.
    - batch_size (int): Batch size per image.

    Returns:
    - tuple: Metrics (AP_bbox, AP50_bbox, AP75_bbox, AP_segm, AP50_segm, AP75_segm).
    """
    cfg.merge_from_file("/scratch/mlainer/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    cfg.INPUT.MIN_SIZE_TRAIN = (500,)
    cfg.OUTPUT_DIR = run_dir
    cfg.DATASETS.TRAIN = ("train_crowd",)
    cfg.DATASETS.TEST = ("val_crowd",)
    cfg.DATALOADER.NUM_WORKERS = 1

    #cfg.INPUT.RANDOM_FLIP = "horizontal"
    #cfg.SOLVER.CHECKPOINT_PERIOD = 2000

    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = base_lr
    cfg.SOLVER.WARMUP_ITERS = 300
    cfg.SOLVER.MAX_ITER = 3000 #adjust up if val AP is still rising, adjust down if overfit
    cfg.SOLVER.STEPS = (2400, 2700)
    cfg.SOLVER.GAMMA = gamma

    # Test
    cfg.TEST.EVAL_PERIOD = 100
    cfg.TEST.DETECTIONS_PER_IMAGE = 80

    cfg.MODEL.WEIGHTS = "/scratch/mlainer/data/hail/detectron2/pretrained_models/model_final_f10217.pkl"  # initialize from model zoo
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = batch_size
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # 1 class (hail)
    cfg.MODEL.MASK_ON = True
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = CocoTrainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()

def run(run_dir, hparams, base_lr, gamma, batch_size):
    """
    Run the hyperparameter tuning experiment and log results to TensorBoard.

    Args:
    - run_dir (str): Directory to save the experiment results.
    - hparams (dict): Hyperparameters for the experiment.
    - base_lr (float): Learning rate.
    - gamma (float): Learning rate decay factor.
    - batch_size (int): Batch size per image.
    """
    with tf.summary.create_file_writer(run_dir).as_default():
        hp.hparams(hparams)  # record the values used in this trial
        train_hail_model(run_dir, base_lr, gamma, batch_size)
                         
        #tf.summary.scalar(AP_BBOX, AP_bbox, step=1)
        #tf.summary.scalar(AP50_BBOX, AP50_bbox, step=1)
        #tf.summary.scalar(AP75_BBOX, AP75_bbox, step=1)
        #tf.summary.scalar(AP_SEGM, AP_segm, step=1)
        #tf.summary.scalar(AP50_SEGM, AP50_segm, step=1)
        #tf.summary.scalar(AP75_SEGM, AP75_segm, step=1)

class CocoTrainer(DefaultTrainer):
  """
  Custom trainer class for COCO evaluation during training.
  """
  @classmethod
  def build_evaluator(cls, cfg, dataset_name, output_folder=cfg.OUTPUT_DIR):
    
    return COCOEvaluator(dataset_name, cfg, False, output_folder)

  def build_hooks(self):
        
    hooks = super().build_hooks()
    hooks.insert(-1,LossEvalHook(
        cfg.TEST.EVAL_PERIOD,
        self.model,
        build_detection_test_loader(
            self.cfg,
            self.cfg.DATASETS.TEST[0],
            DatasetMapper(self.cfg,True)
        )
    ))
    return hooks

class LossEvalHook(HookBase):
    """
    Custom hook for evaluating loss during training.
    """
    def __init__(self, eval_period, model, data_loader):
        self._model = model
        self._period = eval_period
        self._data_loader = data_loader
    
    def _do_loss_eval(self):
        # Copying inference_on_dataset from evaluator.py
        total = len(self._data_loader)
        num_warmup = min(5, total - 1)
            
        start_time = time.perf_counter()
        total_compute_time = 0
        losses = []
        for idx, inputs in enumerate(self._data_loader):            
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_compute_time = 0
            start_compute_time = time.perf_counter()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_compute_time += time.perf_counter() - start_compute_time
            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            seconds_per_img = total_compute_time / iters_after_start
            if idx >= num_warmup * 2 or seconds_per_img > 5:
                total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
                eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    "Loss on Validation  done {}/{}. {:.4f} s / img. ETA={}".format(
                        idx + 1, total, seconds_per_img, str(eta)
                    ),
                    n=5,
                )
            loss_batch = self._get_loss(inputs)
            losses.append(loss_batch)
        mean_loss = np.mean(losses)
        self.trainer.storage.put_scalar('validation_loss', mean_loss)
        comm.synchronize()

        return losses
            
    def _get_loss(self, data):
        # How loss is calculated on train_loop 
        metrics_dict = self._model(data)
        metrics_dict = {
            k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
            for k, v in metrics_dict.items()
        }
        total_losses_reduced = sum(loss for loss in metrics_dict.values())
        return total_losses_reduced
             
    def after_step(self):
        next_iter = self.trainer.iter + 1
        is_final = next_iter == self.trainer.max_iter
        if is_final or (self._period > 0 and next_iter % self._period == 0):
            self._do_loss_eval()
        self.trainer.storage.put_scalars(timetest=12)
        

In [12]:
register = 0
if register == 1:
    register_datasets()

HP_BASE_LR, HP_GAMMA, HP_BATCH_SIZE_PER_IMAGE = setup_hyperparameters()
#configure_tensorboard()

session_num = 0
for base_lr in HP_BASE_LR.domain.values:
    for gamma in HP_GAMMA.domain.values:
        for batch_size_per_image in HP_BATCH_SIZE_PER_IMAGE.domain.values:
            hparams = {
                HP_BASE_LR: base_lr,
                HP_GAMMA: gamma,
                HP_BATCH_SIZE_PER_IMAGE: batch_size_per_image,
            }
            run_name = "run-%d" % session_num
            print('--- Starting trial: %s' % run_name)
            print({h.name: hparams[h] for h in hparams})
            run('/scratch/mlainer/hail_master/crowdsource/detectron2/output/logs/hparam_tuning/' + run_name, hparams, base_lr, gamma, batch_size_per_image)
            session_num += 1

--- Starting trial: run-0
{'base_lr': 0.001, 'gamma': 0.5, 'batch_size_per_image': 256}
[32m[05/08 13:50:36 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=F

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[05/08 13:51:01 d2.engine.train_loop]: [0mStarting training from iteration 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[05/08 13:52:16 d2.utils.events]: [0m eta: 0:41:50  iter: 19  total_loss: 1.954  loss_cls: 0.6989  loss_box_reg: 0.5378  loss_mask: 0.673  loss_rpn_cls: 0.01535  loss_rpn_loc: 0.01028    time: 0.8421  last_time: 0.6376  data_time: 0.7207  last_data_time: 0.4572   lr: 6.427e-05  max_mem: 1811M
[32m[05/08 13:52:32 d2.utils.events]: [0m eta: 0:42:27  iter: 39  total_loss: 1.366  loss_cls: 0.3799  loss_box_reg: 0.5368  loss_mask: 0.3997  loss_rpn_cls: 0.008305  loss_rpn_loc: 0.00675    time: 0.8273  last_time: 0.9717  data_time: 0.6439  last_data_time: 0.8033   lr: 0.00013087  max_mem: 1811M
[32m[05/08 13:52:48 d2.utils.events]: [0m eta: 0:42:10  iter: 59  total_loss: 1  loss_cls: 0.2268  loss_box_reg: 0.5497  loss_mask: 0.214  loss_rpn_cls: 0.004443  loss_rpn_loc: 0.005939    time: 0.8218  last_time: 1.0443  data_time: 0.6417  last_data_time: 0.8974   lr: 0.00019747  max_mem: 1811M
[32m[05/08 13:53:05 d2.utils.events]: [0m eta: 0:40:30  iter: 79  total_loss: 0.7719  loss_cls: 