Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with a validation dataset #4368

Closed
hakespear opened this issue Jun 28, 2022 · 4 comments
Closed

Training with a validation dataset #4368

hakespear opened this issue Jun 28, 2022 · 4 comments
Labels
documentation Problems about existing documentation or comments

Comments

@hakespear
Copy link

馃摎 Documentation Issue

Hi everyone,

I'm struggling to understand how detectron2's Default Trainer is supposed to handle the validation set. Since I just want to do basic testing on a custom dataset, I mostly looked for a way to insert a validation set in train_net.py rather than studying Hooks or plain_train_net.py. That way I might see when the model starts overfitting thanks to the validation losses, so I can stop training accordingly.
So I found this training script example : https://github.com/facebookresearch/detectron2/blob/main/tools/train_net.py
and eventually came up with this script to add a validation set to the training set :

#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
"""
A main training script.

This scripts reads a given config file and runs the training or evaluation.
It is an entry point that is made to train standard models in detectron2.

In order to let one script support training of many models,
this script contains logic that are specific to these built-in models and therefore
may not be suitable for your own project.
For example, your research project perhaps only needs a single "evaluator".

Therefore, we recommend you to use detectron2 as an library and take
this file as an example of how to use the library.
You may want to write your own script with your datasets and other customizations.
"""

import logging
import os
from collections import OrderedDict
import torch

import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, Metadata
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.evaluation import (
    CityscapesInstanceEvaluator,
    CityscapesSemSegEvaluator,
    COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    verify_results,
)
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.data import DatasetCatalog
import random
from detectron2.utils.visualizer import Visualizer
import cv2


def build_evaluator(cfg, dataset_name, output_folder=None):
    """
    Create evaluator(s) for a given dataset.
    This uses the special metadata "evaluator_type" associated with each builtin dataset.
    For your own dataset, you can simply create an evaluator manually in your
    script and do not have to worry about the hacky if-else logic here.
    """
    if output_folder is None:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
    evaluator_list = []
    evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
    if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
        evaluator_list.append(
            SemSegEvaluator(
                dataset_name,
                distributed=True,
                output_dir=output_folder,
            )
        )
    if evaluator_type in ["coco", "coco_panoptic_seg"]:
        evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
    if evaluator_type == "coco_panoptic_seg":
        evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
    if evaluator_type == "cityscapes_instance":
        assert (
            torch.cuda.device_count() >= comm.get_rank()
        ), "CityscapesEvaluator currently do not work with multiple machines."
        return CityscapesInstanceEvaluator(dataset_name)
    if evaluator_type == "cityscapes_sem_seg":
        assert (
            torch.cuda.device_count() >= comm.get_rank()
        ), "CityscapesEvaluator currently do not work with multiple machines."
        return CityscapesSemSegEvaluator(dataset_name)
    elif evaluator_type == "pascal_voc":
        return PascalVOCDetectionEvaluator(dataset_name)
    elif evaluator_type == "lvis":
        return LVISEvaluator(dataset_name, output_dir=output_folder)
    if len(evaluator_list) == 0:
        raise NotImplementedError(
            "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
        )
    elif len(evaluator_list) == 1:
        return evaluator_list[0]
    return DatasetEvaluators(evaluator_list)


class Trainer(DefaultTrainer):
    """
    We use the "DefaultTrainer" which contains pre-defined default logic for
    standard training workflow. They may not work for you, especially if you
    are working on a new research project. In that case you can write your
    own training loop. You can use "tools/plain_train_net.py" as an example.
    """

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        return build_evaluator(cfg, dataset_name, output_folder)

    @classmethod
    def test_with_TTA(cls, cfg, model):
        logger = logging.getLogger("detectron2.trainer")
        # In the end of training, run an evaluation with TTA
        # Only support some R-CNN models.
        logger.info("Running inference with test-time augmentation ...")
        model = GeneralizedRCNNWithTTA(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg


def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()


if __name__ == "__main__":
    os.system('export DETECTRON2_DATASETS=/mnt/RAM_disk/detectron2/datasets/hakespear')
    register_coco_instances("coco_macha_train", {},
                            "/mnt/RAM_disk/detectron2/datasets/hakespear/Train.json",
                            "/mnt/RAM_disk/detectron2/datasets/hakespear/Train/")
    register_coco_instances("coco_macha_val", {},
                            "/mnt/RAM_disk/detectron2/datasets/hakespear/Val.json",
                            "/mnt/RAM_disk/detecrton2/datasets/hakespear/Val/")
    coco_macha_metadata_train = MetadataCatalog.get("coco_macha_train").set(thing_classes=["x"])
    coco_macha_metadata_val = MetadataCatalog.get("coco_macha_val").set(thing_classes=["x"])
    args = default_argument_parser().parse_args()

    cfg = get_cfg()
    cfg.merge_from_file(
        #"../configs/COCO-InstanceSegmentation/mask_rcnn_regnety.yaml"
        "../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml" # choiced
        #"../configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"
        #"../configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"
        #"../configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_400ep_LSJ.py"
    )

    cfg.MODEL.WEIGHTS = "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"  # initialize from model zoo

    cfg.DATASETS.TRAIN = ("coco_macha_train",)
    cfg.DATASETS.TEST = ("coco_macha_val",)  # no metrics implemented for this dataset

    cfg.DATALOADER.NUM_WORKERS = 1

    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128

    cfg.SOLVER.MAX_ITER = 500
    cfg.SOLVER.BASE_LR = 0.0005
    cfg.SOLVER.CHECKPOINT_PERIOD = 100
    cfg.SOLVER.IMS_PER_BATCH = 2

    cfg.TEST.EVAL_PERIOD = 100

    cfg.OUTPUT_DIR = '/mnt/RAM_disk/detectron2/output/hakespear/'


    print('cfg.OUTPUT_DIR=', cfg.OUTPUT_DIR)
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()

    # Saving the model
    model_save_name = 'coco_macha_train.pth'
    path = F"/mnt/RAM_disk/detectron2/output/hakespear/{model_save_name}"
    torch.save(trainer.model.state_dict(), path)

By looking to the configs documentation : https://detectron2.readthedocs.io/en/latest/modules/config.html I found that a cfg.TEST.EVAL_PERIOD needs to be set in order to call the Evaluator during training. Since it's a .TEST config I understood the validation set has to be put in cfg.DATASETS.TEST rather than cfg.DATASETS.TRAIN. However that gave me the expected error :

[06/28 11:28:24 d2.data.datasets.coco]: Loaded 25 images in COCO format from /mnt/RAM_disk/detectron2/datasets/hakespear/Val.json
[06/28 11:28:24 d2.data.dataset_mapper]: [DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(800, 800), max_size=1333, sample_style='choice')]
[06/28 11:28:24 d2.data.common]: Serializing 25 elements to byte tensors and concatenating them all ...
[06/28 11:28:24 d2.data.common]: Serialized dataset takes 0.01 MiB
WARNING [06/28 11:28:24 d2.engine.defaults]: No evaluator found. Use `DefaultTrainer.test(evaluators=)`, or implement its `build_evaluator` method.

This is where I misunderstand the docs. I looked for other basic training loops and tutorial examples but could neither understand how the validation set is taken into account. Is there anything obvious that I missed ? Or did I correctly understand the docs but something was a bit inaccurate ?

Thanks for letting me know. Also sorry if I misslabeled this post

@hakespear hakespear added the documentation Problems about existing documentation or comments label Jun 28, 2022
@crsegerie
Copy link

Hi, I'm also stuck with this issue.
In the intro colab, the validation is only performed at the end of the training. But I would like to perform it regularly during the training. I do not understand the syntax:

Here is what I've tried to add the evaluation during the training, but currently not working:

evaluator = COCOEvaluator("dataset_val", output_dir=eval_dir)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.test(evaluators=evaluator, cfg=cfg, model= ? ) # <- I do not know what I could put here
res = trainer.train()

Thank you

@Robotatron
Copy link

Would like to know that too. Example needed :)

@aymanaboghonim
Copy link

here is my code and it doing evaluation after predefined no of iterations
class MyTrainer(DefaultTrainer): @classmethod def build_evaluator(cls, cfg, dataset_name, output_folder=None): if output_folder is None: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") return COCOEvaluator(dataset_name, cfg, True, output_folder)

then instantiate an object from my trainer class and enjoy .

cfg.TEST.EVAL_PERIOD = 5000
change it to evaluate after your preferred no of iterations.

@wiktorowski211
Copy link

wiktorowski211 commented Jan 11, 2023

Hi @hakespear, all you have to do is to define your own Trainer subclass where you define list of evaluators that will be used during the training:

class MyTrainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        coco_evaluator = COCOEvaluator(dataset_name, output_dir=output_folder)
        
        evaluator_list = [coco_evaluator]
        
        return DatasetEvaluators(evaluator_list)

Set the evaluation interval and treshold:

cfg.TEST.EVAL_PERIOD =  1000
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.50

Train with:

trainer = MyTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

Note MyTrainer instead of DefaultTrainer.
Test set will be evaluated every 1000 iterations within this setting.

Post training test with:

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.50

trainer = MyTrainer(cfg) 
trainer.test(cfg, trainer.model)

@facebookresearch facebookresearch locked and limited conversation to collaborators Feb 11, 2023
@ppwwyyxx ppwwyyxx converted this issue into discussion #4788 Feb 11, 2023

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
documentation Problems about existing documentation or comments
Projects
None yet
Development

No branches or pull requests

5 participants