-
Notifications
You must be signed in to change notification settings - Fork 7.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: so it's easier for users to customize the training process Reviewed By: rbgirshick Differential Revision: D18064883 fbshipit-source-id: 6fabed8c7e231f06da2c45111ab509b20906557f
- Loading branch information
1 parent
883b068
commit 1dd6d42
Showing
1 changed file
with
230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
""" | ||
Detectron2 training script with a plain training loop. | ||
This scripts reads a given config file and runs the training or evaluation. | ||
It is an entry point that is able to train standard models in detectron2. | ||
In order to let one script support training of many models, | ||
this script contains logic that are specific to these built-in models and therefore | ||
may not be suitable for your own project. | ||
For example, your research project perhaps only needs a single "evaluator". | ||
Therefore, we recommend you to use detectron2 as an library and take | ||
this file as an example of how to use the library. | ||
You may want to write your own script with your datasets and other customizations. | ||
Compared to "train_net.py", this script supports fewer features, and also | ||
includes fewer abstraction. | ||
""" | ||
|
||
import logging | ||
import os | ||
from collections import OrderedDict | ||
import torch | ||
from torch.nn.parallel import DistributedDataParallel | ||
|
||
import detectron2.utils.comm as comm | ||
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer | ||
from detectron2.config import get_cfg | ||
from detectron2.data import ( | ||
MetadataCatalog, | ||
build_detection_test_loader, | ||
build_detection_train_loader, | ||
) | ||
from detectron2.engine import default_argument_parser, default_setup, launch | ||
from detectron2.evaluation import ( | ||
CityscapesEvaluator, | ||
COCOEvaluator, | ||
COCOPanopticEvaluator, | ||
DatasetEvaluators, | ||
LVISEvaluator, | ||
PascalVOCDetectionEvaluator, | ||
SemSegEvaluator, | ||
inference_on_dataset, | ||
print_csv_format, | ||
) | ||
from detectron2.modeling import build_model | ||
from detectron2.solver import build_lr_scheduler, build_optimizer | ||
from detectron2.utils.events import ( | ||
CommonMetricPrinter, | ||
EventStorage, | ||
JSONWriter, | ||
TensorboardXWriter, | ||
) | ||
|
||
logger = logging.getLogger("detectron2") | ||
|
||
|
||
def get_evaluator(cfg, dataset_name, output_folder=None): | ||
""" | ||
Create evaluator(s) for a given dataset. | ||
This uses the special metadata "evaluator_type" associated with each builtin dataset. | ||
For your own dataset, you can simply create an evaluator manually in your | ||
script and do not have to worry about the hacky if-else logic here. | ||
""" | ||
if output_folder is None: | ||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") | ||
evaluator_list = [] | ||
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type | ||
if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: | ||
evaluator_list.append( | ||
SemSegEvaluator( | ||
dataset_name, | ||
distributed=True, | ||
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, | ||
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, | ||
output_dir=output_folder, | ||
) | ||
) | ||
if evaluator_type in ["coco", "coco_panoptic_seg"]: | ||
evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) | ||
if evaluator_type == "coco_panoptic_seg": | ||
evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) | ||
if evaluator_type == "cityscapes": | ||
assert ( | ||
torch.cuda.device_count() >= comm.get_rank() | ||
), "CityscapesEvaluator currently do not work with multiple machines." | ||
return CityscapesEvaluator(dataset_name) | ||
if evaluator_type == "pascal_voc": | ||
return PascalVOCDetectionEvaluator(dataset_name) | ||
if evaluator_type == "lvis": | ||
return LVISEvaluator(dataset_name, cfg, True, output_folder) | ||
if len(evaluator_list) == 0: | ||
raise NotImplementedError( | ||
"no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type) | ||
) | ||
if len(evaluator_list) == 1: | ||
return evaluator_list[0] | ||
return DatasetEvaluators(evaluator_list) | ||
|
||
|
||
def do_test(cfg, model): | ||
results = OrderedDict() | ||
for idx, dataset_name in enumerate(cfg.DATASETS.TEST): | ||
data_loader = build_detection_test_loader(cfg, dataset_name) | ||
evaluator = get_evaluator( | ||
cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) | ||
) | ||
results_i = inference_on_dataset(model, data_loader, evaluator) | ||
results[dataset_name] = results_i | ||
if comm.is_main_process(): | ||
logger.info("Evaluation results for {} in csv format:".format(dataset_name)) | ||
print_csv_format(results_i) | ||
if len(results) == 1: | ||
results = list(results.values())[0] | ||
return results | ||
|
||
|
||
def do_train(cfg, model, resume=False): | ||
model.train() | ||
optimizer = build_optimizer(cfg, model) | ||
scheduler = build_lr_scheduler(cfg, optimizer) | ||
|
||
checkpointer = DetectionCheckpointer( | ||
model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler | ||
) | ||
start_iter = ( | ||
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 | ||
) | ||
max_iter = cfg.SOLVER.MAX_ITER | ||
|
||
periodic_checkpointer = PeriodicCheckpointer( | ||
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter | ||
) | ||
|
||
writers = ( | ||
[ | ||
CommonMetricPrinter(max_iter), | ||
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), | ||
TensorboardXWriter(cfg.OUTPUT_DIR), | ||
] | ||
if comm.is_main_process() | ||
else [] | ||
) | ||
|
||
# compared to "train_net.py", we do not support accurate timing and | ||
# precise BN here, because they are not trivial to implement | ||
data_loader = build_detection_train_loader(cfg) | ||
logger.info("Starting training from iteration {}".format(start_iter)) | ||
with EventStorage(start_iter) as storage: | ||
for data, iteration in zip(data_loader, range(start_iter, max_iter)): | ||
iteration = iteration + 1 | ||
storage.step() | ||
|
||
loss_dict = model(data) | ||
losses = sum(loss for loss in loss_dict.values()) | ||
assert torch.isfinite(losses).all(), loss_dict | ||
|
||
loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()} | ||
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | ||
if comm.is_main_process(): | ||
storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) | ||
|
||
optimizer.zero_grad() | ||
losses.backward() | ||
optimizer.step() | ||
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) | ||
scheduler.step() | ||
|
||
if ( | ||
cfg.TEST.EVAL_PERIOD > 0 | ||
and iteration % cfg.TEST.EVAL_PERIOD == 0 | ||
and iteration != max_iter | ||
): | ||
do_test(cfg, model) | ||
# Compared to "train_net.py", the test results are not dumped to EventStorage | ||
comm.synchronize() | ||
|
||
if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter): | ||
for writer in writers: | ||
writer.write() | ||
periodic_checkpointer.step(iteration) | ||
|
||
|
||
def setup(args): | ||
""" | ||
Create configs and perform basic setups. | ||
""" | ||
cfg = get_cfg() | ||
cfg.merge_from_file(args.config_file) | ||
cfg.merge_from_list(args.opts) | ||
cfg.freeze() | ||
default_setup( | ||
cfg, args | ||
) # if you don't like any of the default setup, write your own setup code | ||
return cfg | ||
|
||
|
||
def main(args): | ||
cfg = setup(args) | ||
|
||
model = build_model(cfg) | ||
logger.info("Model:\n{}".format(model)) | ||
if args.eval_only: | ||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( | ||
cfg.MODEL.WEIGHTS, resume=args.resume | ||
) | ||
return do_test(cfg, model) | ||
|
||
distributed = comm.get_world_size() > 1 | ||
if distributed: | ||
model = DistributedDataParallel( | ||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False | ||
) | ||
|
||
do_train(cfg, model) | ||
return do_test(cfg, model) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = default_argument_parser().parse_args() | ||
print("Command Line Args:", args) | ||
launch( | ||
main, | ||
args.num_gpus, | ||
num_machines=args.num_machines, | ||
machine_rank=args.machine_rank, | ||
dist_url=args.dist_url, | ||
args=(args,), | ||
) |