# Setup Environment


Import required packages:

In [None]:
# Setup detectron2
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import os, json, random, shutil

import torch

from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer, DefaultPredictor, HookBase, launch
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog, build_detection_train_loader
from detectron2.data.datasets import register_coco_instances
import detectron2.utils.comm as comm

Global variables and settings:


In [None]:
path_train = "/thecube/students/jravagli/datasets/train"
path_train_images = os.path.join(path_train, "image")
path_train_json = os.path.join(path_train, "train.json")
path_val_images = os.path.join(path_train, "image")
path_val_json = os.path.join(path_train, "valid.json")
path_output_dir = "/thecube/students/jravagli/outputs/detectron"

# Model settings
lr = 0.02
batch_size = 8
n_train_images = 163173 # Number of training images
# We make a number of iterations so as to make the model see the whole training set *epochs* times
epochs = 12
iterations = epochs * n_train_images // batch_size
n_classes = 13 # Number of classes of the training set
# LR is reduced by a gamma factor after 8 and 11 epochs
scheduler_steps = (8*n_train_images // batch_size, 11*n_train_images // batch_size,)
weight_decay = 1e-5

resume_training = False

Clear output directory:

In [None]:
if not resume_training:
    shutil.rmtree(path_output_dir)

# Training


Define a hook to monitor the validation loss during training ([GitHub issue](https://github.com/facebookresearch/detectron2/issues/810#issuecomment-596194293)):

In [None]:
class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
        self._loader = iter(build_detection_train_loader(self.cfg))
        
    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                                 **loss_dict_reduced)

Register the dataset:

In [None]:
register_coco_instances("deepfashion_train", {}, path_train_json, path_train_images)
register_coco_instances("deepfashion_val", {}, path_val_json, path_val_images)

Define model configuration:

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("deepfashion_train",)
cfg.DATASETS.VAL = ("deepfashion_val",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
if resume_training:
    cfg.MODEL.WEIGHTS = None
else:
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo

cfg.SOLVER.IMS_PER_BATCH = batch_size
cfg.SOLVER.MAX_ITER = iterations    # Number of batch updates
cfg.SOLVER.BASE_LR = lr
cfg.SOLVER.MOMENTUM = 0.9
cfg.SOLVER.GAMMA = 0.1
# The iteration number to decrease learning rate by GAMMA
cfg.SOLVER.STEPS = scheduler_steps
cfg.SOLVER.WEIGHT_DECAY = weight_decay

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128  # RoI batch size
cfg.MODEL.ROI_HEADS.NUM_CLASSES = n_classes
cfg.OUTPUT_DIR = path_output_dir

Create the trainer and attach the validation hook:

In [None]:
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)

val_loss = ValidationLoss(cfg)  
trainer.register_hooks([val_loss])
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]

Train the model:

In [None]:
trainer.resume_or_load(resume=resume_training)
trainer.train()