# Training

In [1]:
import os

if "models" in os.getcwd():
    os.chdir("..")

if "notebooks" in os.getcwd():
    os.chdir("..")
os.getcwd()

'/home/jordi/Documents/GitHub/zebra_fish'

In [2]:
from src.dataset import register_default_datasets
from src.LrFinder import LRFinder
from src.hooks.LossEvalHook import LossEvalHook
from src.hooks.PredictionVisualHook import PredictionVisualHook
from src.hooks.ConfusionHook import ConfusionHook

#detectron
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.logger import setup_logger
from detectron2.data import DatasetCatalog, build_detection_test_loader, build_detection_train_loader, DatasetMapper, detection_utils as utils, transforms as T
from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from detectron2.engine import DefaultTrainer
from detectron2.evaluation import COCOEvaluator
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import CallbackHook
from detectron2.engine.hooks import HookBase
from detectron2.utils.events import EventStorage
import detectron2.utils.comm as comm
from detectron2.utils.logger import log_every_n_seconds
from detectron2.utils.visualizer import Visualizer
import matplotlib.pyplot as plt
from pathlib import Path

import time
import datetime
import torch
import numpy as np
import gc

register_default_datasets()

In [3]:
class CustomTrainer(DefaultTrainer):

    def __init__(self, cfg):
        super().__init__(cfg)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
            os.makedirs(output_folder, exist_ok=True)
        return COCOEvaluator(dataset_name, cfg, True, output_folder)

    def build_hooks(self):
        hooks = super().build_hooks()
        hooks.insert(-1, LossEvalHook.create(self.cfg))
        hooks.insert(-1, PredictionVisualHook.create(self.cfg)),
        hooks.insert(-1, ConfusionHook.create(self.cfg, threshold=0.75))
        
        return hooks
    
    @classmethod
    def build_train_loader(cls, cfg):
        augmentations = [
            T.ResizeShortestEdge(short_edge_length=cfg.INPUT.MIN_SIZE_TRAIN, max_size=cfg.INPUT.MAX_SIZE_TRAIN, sample_style=cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING),
            T.RandomRotation([0, 365], expand=False, center=[[0.45, 0.45], [0.55, 0.55]], sample_style="range"),
            T.RandomFlip(prob=0.5, horizontal=True),
        ]
        mapper = DatasetMapper(
            is_train=True,
            augmentations=augmentations,
            image_format=cfg.INPUT.FORMAT,
            use_instance_mask=cfg.MODEL.MASK_ON,
            instance_mask_format=cfg.INPUT.MASK_FORMAT,
            use_keypoint=cfg.MODEL.KEYPOINT_ON,
            recompute_boxes=True,
        )

        return build_detection_train_loader(cfg, mapper=mapper)
    
    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        return torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=cfg.SOLVER.BASE_LR,
            total_steps=cfg.SOLVER.MAX_ITER,
        )
    
    def find_lr(self):
        finder = LRFinder()
        
        with EventStorage(0) as self.storage:
            res = finder.find(
                self.model,
                self.optimizer,
                self.build_train_loader(self.cfg),
            )
        
        return res

In [10]:
name = "train"
meta_dataset = MetadataCatalog.get(name)

# Default config
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("train",)
cfg.DATASETS.TEST = ("val",)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(meta_dataset.thing_classes)

cfg.SOLVER.CHECKPOINT_PERIOD = 500
cfg.SOLVER.MAX_ITER = 100
cfg.TEST.EVAL_PERIOD = 20
cfg.OUTPUT_DIR = "./test2"
cfg.MODEL.BACKBONE.FREEZE_AT = 2

# Hyper-params
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.0026738416158399486


assert cfg.TEST.EVAL_PERIOD % 20 == 0, "EVAL_PERIOD must be a multiple of 20"

In [11]:
resume = False
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
assert resume or len(os.listdir(cfg.OUTPUT_DIR))  == 0, "Output dir is not empty!"

trainer = CustomTrainer(cfg)
trainer.resume_or_load(resume=resume)


with open(Path(cfg.OUTPUT_DIR) / "config.yaml", "w") as f:
    f.write(cfg.dump())

[32m[08/05 11:30:29 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
 

In [None]:
trainer.train()