## Train Mask R-CNN model for detecting boulders
# (update 2023 04 13)

In [1]:
# Display full window
%%HTML
<style>
    div#notebook-container    { width: 95%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>

In [1]:
# Load Detectron2 and related modules
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import cv2
import random
import os
import matplotlib.pyplot as plt
%matplotlib inline

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.data.catalog import DatasetCatalog

In [3]:
# Load training images annotated by coco format (Note: datasets have alreadly been augmented)
from detectron2.data.datasets import register_coco_instances
BASE_PATH = "Drirectory path of your datasets"

#DatasetCatalog._REGISTERED.clear()
register_coco_instances("gravel_train", {}, os.path.join(BASE_PATH, "train_annotations.json"),      os.path.join(BASE_PATH, "train"))
register_coco_instances("gravel_valid", {}, os.path.join(BASE_PATH, "validation_annotations.json"), os.path.join(BASE_PATH, "valid"))
register_coco_instances("gravel_test",  {}, os.path.join(BASE_PATH, "testing_annotations.json"),    os.path.join(BASE_PATH, "test"))

In [None]:
# Get metadata of datasets
gravel_metadata = MetadataCatalog.get("gravel_train")
dataset_dicts = DatasetCatalog.get("gravel_train")

import random
from detectron2.utils.visualizer import Visualizer

import matplotlib.pyplot as plt
%matplotlib inline

# Show 10 images and annotaions within the traning dataset
for d in random.sample(dataset_dicts, 10):
    print(d["file_name"])
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], 
                            metadata=gravel_metadata, 
                            scale=1.0
                           )
    out = visualizer.draw_dataset_dict(d)
    plt.figure(figsize=(8,8))
    plt.imshow(out.get_image())

In [5]:
# Define trainer

#torch.cuda.empty_cache()
import detectron2.data.transforms as T
from detectron2.data import DatasetMapper, build_detection_train_loader

class NonAugmentationsTrainer(DefaultTrainer):
  @classmethod
  def build_evaluator(cls, cfg, dataset_name, output_folder=None):

    if output_folder is None:
        os.makedirs(os.path.join(BASE_PATH,"coco_eval"), exist_ok=True)
        output_folder = os.path.join(BASE_PATH,"coco_eval")

    return COCOEvaluator(dataset_name, cfg, False, output_folder)

In [None]:
# Train
from detectron2.engine import DefaultTrainer
model_path = os.path.join(BASE_PATH,'model')

# Set up config file
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("gravel_train",)
cfg.DATASETS.TEST  = ("gravel_valid",)
cfg.INPUT.MAX_SIZE_TRAIN=2000
cfg.INPUT.MAX_SIZE_TEST =2000
cfg.INPUT.MIN_SIZE_TRAIN=2000
cfg.INPUT.MIN_SIZE_TEST =2000

cfg.DATALOADER.NUM_WORKERS = 4
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False

cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025 
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # From model zoo
#cfg.MODEL.WEIGHTS = os.path.join(model_path, "model_0199999.pth") # From local

NUM_IMAGES = len(dataset_dicts)
NUM_IM_EPOCH  = round(NUM_IMAGES/cfg.SOLVER.IMS_PER_BATCH)
cfg.SOLVER.MAX_ITER = NUM_IMAGES * 15 
cfg.SOLVER.CHECKPOINT_PERIOD = 5000
cfg.SOLVER.STEPS = (NUM_IMAGES * 2, NUM_IMAGES * 4, NUM_IMAGES * 6 , NUM_IMAGES * 8)

cfg.SOLVER.GAMMA = 0.1     
cfg.TEST.DETECTIONS_PER_IMAGE = 400
cfg.OUTPUT_DIR = model_path
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
      
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256 
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  

cfg.TEST.EVAL_PERIOD = NUM_IM_EPOCH * 1

print('Number of images: ' + str(NUM_IMAGES)) 
print('Max iteration:' + str(cfg.SOLVER.MAX_ITER))

# Apply training
trainer = NonAugmentationsTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

# Note:
# small:  [0**2 32**2]
# medium: [32**2 96**2]
# large:  [96**2 100000**2]

In [10]:
# Show the training curves in tensorboard:
%load_ext tensorboard
%tensorboard --logdir /home/kei/ドキュメント/data/Prj_Gravel/Prj_gravel20230405/model

# Show training results

In [7]:
# Load training weight
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the trained model weight
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set a custom testing threshold
predictor = DefaultPredictor(cfg)

In [8]:
# Register testing image dataset
register_coco_instances("gravel_test3", {}, os.path.join(BASE_PATH, "test/_annotations.coco.json"), os.path.join(BASE_PATH, "test"))

In [None]:
# Evaluation
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
evaluator = COCOEvaluator("gravel_test", cfg, False, output_dir="path to output directory")

val_loader = build_detection_test_loader(cfg, "gravel_test")
inference_on_dataset(predictor.model, val_loader, evaluator)

In [None]:
# Apply the model to some images
from detectron2.utils.visualizer import ColorMode
dataset_dicts = DatasetCatalog.get("gravel_test")
for d in random.sample(dataset_dicts, 20):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = Visualizer(im[:, :, ::-1],
                   metadata=gravel_metadata, 
                   scale=1, 
                   instance_mode=ColorMode.SEGMENTATION
                  )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.figure(figsize=(20,20))
    plt.imshow(out.get_image())

In [None]:
# PR plot
plt.figure(figsize=(15, 8))

for i, _cls in enumerate(df_dict.keys()):
    df = df_dict[_cls]
    
    ax = plt.subplot(2, 3, i + 1)
    ax.plot(df['rec'], df['pre'], '-o')
    ax.fill_between(df['rec'], df['pre'], facecolor='b', alpha=0.3)
    ax.set_xlim([0, 1.05])
    ax.set_ylim([0.5, 1.03])
    ax.grid(True)
    ax.set_title(CAT_NAME_JP[_cls])
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    
plt.tight_layout()
plt.show()