In [1]:
# This notebook have built from this tutorial: https://github.com/bnsreenu/python_for_microscopists/blob/master/330_Detectron2_Instance_3D_EM_Platelet.ipynb
import os

# Setup detectron2 logger
from detectron2.utils.logger import setup_logger
setup_logger()

<Logger detectron2 (DEBUG)>

## Load Dataset

In [2]:
from src.segmentation.framework_handlers.detectron2_handler import register_and_split_dataset

segments_dataset_name = "etaylor/stigmas_dataset"
release = "v0.2"

train_metadata, train_dataset_dicts, test_metadata, test_dataset_dicts = register_and_split_dataset(
    dataset_name=segments_dataset_name,
    release_version=release,
    train_ratio=0.8,  # 80% for training, 20% for testing
)

print(f"Train dataset: {len(train_dataset_dicts)} samples")
print(f"Test dataset: {len(test_dataset_dicts)} samples")


Initializing dataset...
Preloading all samples. This may take a while...


100%|[38;2;255;153;0m██████████[0m| 115/115 [00:00<00:00, 1574.26it/s]


Initialized dataset with 115 images.
Exporting dataset. This may take a while...


100%|[38;2;255;153;0m██████████[0m| 115/115 [00:52<00:00,  2.20it/s]

Exported to ./export_coco-instance_etaylor_stigmas_dataset_v0.2.json. Images in segments/etaylor_stigmas_dataset/v0.2
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.

[32m[01/03 17:56:48 d2.data.datasets.coco]: [0mLoaded 115 images in COCO format from segments/etaylor_stigmas_dataset/annotations/export_coco-instance_etaylor_stigmas_dataset_v0.2.json
Train dataset: 92 samples
Test dataset: 23 samples





In [8]:
import gc
gc.collect()
torch.cuda.empty_cache()


In [9]:
import os
from datetime import datetime
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
from detectron2 import model_zoo
from detectron2.checkpoint import DetectionCheckpointer
import torch
from detectron2.engine import AMPTrainer
import gc


# Define paths and parameters
detectron2_models_path = "/home/etaylor/code_projects/thesis/checkpoints/stigmas_segmentation/detectron2/fine_tuned"
num_epochs = 50
batch_size = 2
learning_rate = 0.00025
num_classes = 1

# Dictionary of models to train
models = {
    "mask_rcnn_R_50_FPN_3x": "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
    "mask_rcnn_R_101_FPN_3x": "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
    "mask_rcnn_X_101_32x8d_FPN_3x": "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
    "mask_rcnn_R_50_C4_3x": "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml",
    "mask_rcnn_R_50_DC5_3x": "COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml"
}

# Function to train a model
def train_model(model_name, config_path):
    # Create configuration
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file(config_path))
    cfg.DATASETS.TRAIN = (f"{segments_dataset_name}_train",)
    cfg.DATASETS.TEST = (f"{segments_dataset_name}_test",)
    cfg.INPUT.MASK_FORMAT = "bitmask"
    cfg.DATALOADER.NUM_WORKERS = 4
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_path)  # Initialize from model zoo
    cfg.SOLVER.IMS_PER_BATCH = batch_size
    cfg.SOLVER.BASE_LR = learning_rate
    cfg.SOLVER.MAX_ITER = 1000
    cfg.SOLVER.STEPS = []  # Do not decay learning rate
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512  # Default is 512
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes  # Set number of classes
    cfg.SOLVER.AMP.ENABLED = True

    # Set output directory
    model_saving_path = os.path.join(detectron2_models_path, model_name)
    cfg.OUTPUT_DIR = model_saving_path
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

    # Train the model
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()

    # Save checkpoint after training
    checkpointer = DetectionCheckpointer(trainer.model, save_dir=cfg.OUTPUT_DIR)
    checkpointer.save("model_final")

# Train each model
for model_name, config_path in models.items():
    print(f"Training {model_name}...")
    train_model(model_name, config_path)
    gc.collect()
    torch.cuda.empty_cache()
    print(f"Finished training {model_name}. Model saved in {os.path.join(detectron2_models_path, model_name)}")


Training mask_rcnn_R_50_FPN_3x...
[32m[01/03 20:12:07 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[01/03 20:12:07 d2.engine.train_loop]: [0mStarting training from iteration 0


  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])


[32m[01/03 20:12:26 d2.utils.events]: [0m eta: 0:08:29  iter: 19  total_loss: 2.488  loss_cls: 0.6364  loss_box_reg: 0.2086  loss_mask: 0.6899  loss_rpn_cls: 0.8348  loss_rpn_loc: 0.08507    time: 0.8381  last_time: 1.3471  data_time: 0.6107  last_data_time: 0.9915   lr: 4.9953e-06  max_mem: 9964M
[32m[01/03 20:12:41 d2.utils.events]: [0m eta: 0:09:15  iter: 39  total_loss: 2.151  loss_cls: 0.5511  loss_box_reg: 0.2729  loss_mask: 0.6846  loss_rpn_cls: 0.5684  loss_rpn_loc: 0.06239    time: 0.7950  last_time: 0.6646  data_time: 0.4098  last_data_time: 0.3433   lr: 9.9902e-06  max_mem: 9964M
[32m[01/03 20:12:55 d2.utils.events]: [0m eta: 0:08:16  iter: 59  total_loss: 1.816  loss_cls: 0.4438  loss_box_reg: 0.2936  loss_mask: 0.6736  loss_rpn_cls: 0.2916  loss_rpn_loc: 0.05652    time: 0.7610  last_time: 1.1993  data_time: 0.3559  last_data_time: 0.8587   lr: 1.4985e-05  max_mem: 9964M
[32m[01/03 20:13:10 d2.utils.events]: [0m eta: 0:09:00  iter: 79  total_loss: 1.756  loss_cls: 

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[01/03 20:24:05 d2.engine.train_loop]: [0mStarting training from iteration 0


  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])


[32m[01/03 20:24:22 d2.utils.events]: [0m eta: 0:10:16  iter: 19  total_loss: 2.58  loss_cls: 0.582  loss_box_reg: 0.2593  loss_mask: 0.6913  loss_rpn_cls: 0.942  loss_rpn_loc: 0.08704    time: 0.7807  last_time: 1.1767  data_time: 0.3960  last_data_time: 0.6739   lr: 4.9953e-06  max_mem: 9964M
[32m[01/03 20:24:36 d2.utils.events]: [0m eta: 0:09:13  iter: 39  total_loss: 2.045  loss_cls: 0.5044  loss_box_reg: 0.2072  loss_mask: 0.6863  loss_rpn_cls: 0.5813  loss_rpn_loc: 0.05334    time: 0.7206  last_time: 1.8116  data_time: 0.1993  last_data_time: 1.3618   lr: 9.9902e-06  max_mem: 9964M
[32m[01/03 20:24:52 d2.utils.events]: [0m eta: 0:09:01  iter: 59  total_loss: 1.982  loss_cls: 0.4258  loss_box_reg: 0.3035  loss_mask: 0.677  loss_rpn_cls: 0.5371  loss_rpn_loc: 0.07244    time: 0.7529  last_time: 0.4370  data_time: 0.3759  last_data_time: 0.0029   lr: 1.4985e-05  max_mem: 9964M
[32m[01/03 20:25:06 d2.utils.events]: [0m eta: 0:08:45  iter: 79  total_loss: 1.649  loss_cls: 0.36

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[01/03 20:36:15 d2.engine.train_loop]: [0mStarting training from iteration 0


  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])


[32m[01/03 20:36:43 d2.utils.events]: [0m eta: 0:22:12  iter: 19  total_loss: 3.098  loss_cls: 0.7635  loss_box_reg: 0.173  loss_mask: 0.6899  loss_rpn_cls: 1.36  loss_rpn_loc: 0.1042    time: 1.3558  last_time: 1.3052  data_time: 0.0920  last_data_time: 0.0042   lr: 4.9953e-06  max_mem: 9964M
[32m[01/03 20:37:10 d2.utils.events]: [0m eta: 0:21:45  iter: 39  total_loss: 2.167  loss_cls: 0.6211  loss_box_reg: 0.1766  loss_mask: 0.6845  loss_rpn_cls: 0.5902  loss_rpn_loc: 0.05966    time: 1.3477  last_time: 1.4934  data_time: 0.0052  last_data_time: 0.0046   lr: 9.9902e-06  max_mem: 9964M
[32m[01/03 20:37:38 d2.utils.events]: [0m eta: 0:21:21  iter: 59  total_loss: 1.758  loss_cls: 0.4639  loss_box_reg: 0.3179  loss_mask: 0.6737  loss_rpn_cls: 0.2908  loss_rpn_loc: 0.06431    time: 1.3544  last_time: 1.3812  data_time: 0.0082  last_data_time: 0.0048   lr: 1.4985e-05  max_mem: 9964M
[32m[01/03 20:38:04 d2.utils.events]: [0m eta: 0:20:52  iter: 79  total_loss: 1.604  loss_cls: 0.38

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 2048) in the checkpoint but (2, 2048) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 2048) in the checkpoint but (4, 2048) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[01/03 20:59:27 d2.engine.train_loop]: [0mStarting training from iteration 0


  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])


[32m[01/03 20:59:48 d2.utils.events]: [0m eta: 0:14:58  iter: 19  total_loss: 2.528  loss_cls: 0.731  loss_box_reg: 0.4247  loss_mask: 0.6913  loss_rpn_cls: 0.6863  loss_rpn_loc: 0.09599    time: 0.9382  last_time: 0.9118  data_time: 0.1506  last_data_time: 0.0045   lr: 4.9953e-06  max_mem: 9964M
[32m[01/03 21:00:07 d2.utils.events]: [0m eta: 0:14:44  iter: 39  total_loss: 2.482  loss_cls: 0.6229  loss_box_reg: 0.3478  loss_mask: 0.6845  loss_rpn_cls: 0.7469  loss_rpn_loc: 0.09691    time: 0.9429  last_time: 0.9296  data_time: 0.0365  last_data_time: 0.0040   lr: 9.9902e-06  max_mem: 9964M
[32m[01/03 21:00:26 d2.utils.events]: [0m eta: 0:14:30  iter: 59  total_loss: 2.254  loss_cls: 0.4737  loss_box_reg: 0.3626  loss_mask: 0.6722  loss_rpn_cls: 0.5972  loss_rpn_loc: 0.09126    time: 0.9473  last_time: 0.8912  data_time: 0.0382  last_data_time: 0.0048   lr: 1.4985e-05  max_mem: 9964M
[32m[01/03 21:00:46 d2.utils.events]: [0m eta: 0:14:19  iter: 79  total_loss: 2.222  loss_cls: 0

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[01/03 21:15:27 d2.engine.train_loop]: [0mStarting training from iteration 0


  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])


[32m[01/03 21:15:47 d2.utils.events]: [0m eta: 0:14:24  iter: 19  total_loss: 3.21  loss_cls: 0.5947  loss_box_reg: 0.2539  loss_mask: 0.6919  loss_rpn_cls: 1.58  loss_rpn_loc: 0.1678    time: 0.9393  last_time: 0.7666  data_time: 0.2165  last_data_time: 0.0149   lr: 4.9953e-06  max_mem: 9964M
[32m[01/03 21:16:04 d2.utils.events]: [0m eta: 0:13:32  iter: 39  total_loss: 2.682  loss_cls: 0.5324  loss_box_reg: 0.2477  loss_mask: 0.6855  loss_rpn_cls: 1.093  loss_rpn_loc: 0.1007    time: 0.8879  last_time: 0.8666  data_time: 0.0602  last_data_time: 0.0994   lr: 9.9902e-06  max_mem: 9964M
[32m[01/03 21:16:21 d2.utils.events]: [0m eta: 0:13:25  iter: 59  total_loss: 2.264  loss_cls: 0.459  loss_box_reg: 0.4099  loss_mask: 0.6698  loss_rpn_cls: 0.5898  loss_rpn_loc: 0.09267    time: 0.8807  last_time: 0.8479  data_time: 0.0290  last_data_time: 0.0653   lr: 1.4985e-05  max_mem: 9964M
[32m[01/03 21:16:39 d2.utils.events]: [0m eta: 0:13:18  iter: 79  total_loss: 2.22  loss_cls: 0.4223  