-
Notifications
You must be signed in to change notification settings - Fork 17
/
train_net.py
executable file
·454 lines (381 loc) · 18 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
"""
hierarchialdet Training Script. Adapted from DiffusionDet.
"""
import os
import itertools
import weakref
from typing import Any, Dict, List, Set
import logging
from collections import OrderedDict
import random
import torch
from fvcore.nn.precise_bn import get_bn_modules
import detectron2.utils.comm as comm
from detectron2.utils.logger import setup_logger
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, create_ddp_model, \
AMPTrainer, SimpleTrainer, hooks
from detectron2.evaluation import LVISEvaluator, verify_results, print_csv_format
from hierarchialdet.util.coco_3class_eval import COCOEvaluator
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.modeling import build_model
from evaluator import DatasetEvaluator, inference_on_dataset
from hierarchialdet import DiffusionDetDatasetMapper, add_diffusiondet_config, DiffusionDetWithTTA
from hierarchialdet.util.model_ema import add_model_ema_configs, may_build_model_ema, may_get_ema_checkpointer, EMAHook, \
apply_model_ema_and_restore, EMADetectionCheckpointer
class Trainer(DefaultTrainer):
""" Extension of the Trainer class adapted to DiffusionDet. """
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
super(DefaultTrainer, self).__init__() # call grandfather's `__init__` while avoid father's `__init()`
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
########## EMA ############
kwargs = {
'trainer': weakref.proxy(self),
}
kwargs.update(may_get_ema_checkpointer(cfg, model))
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
**kwargs,
# trainer=weakref.proxy(self),
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
# setup EMA
may_build_model_ema(cfg, model)
return model
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each builtin dataset.
For your own dataset, you can simply create an evaluator manually in your
script and do not have to worry about the hacky if-else logic here.
"""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
if 'lvis' in dataset_name:
return LVISEvaluator(dataset_name, cfg, True, output_folder)
else:
#return LVISEvaluator(dataset_name, cfg, True, output_folder)
return COCOEvaluator(dataset_name, cfg, True, output_folder)
@classmethod
def build_train_loader(cls, cfg):
mapper = DiffusionDetDatasetMapper(cfg, is_train=True)
return build_detection_train_loader(cfg, mapper=mapper)
@classmethod
def build_optimizer(cls, cfg, model):
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for key, value in model.named_parameters(recurse=True):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "backbone" in key:
lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
@classmethod
def test(cls, cfg, model, k=0, evaluators=None):
"""
Evaluate the given model. The given model is expected to already contain
weights to evaluate.
Args:
cfg (CfgNode):
model (nn.Module):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
``cfg.DATASETS.TEST``.
Returns:
dict: a dict of result metrics
"""
logger = logging.getLogger(__name__)
if isinstance(evaluators, DatasetEvaluator):
evaluators = [evaluators]
if evaluators is not None:
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
len(cfg.DATASETS.TEST), len(evaluators)
)
results_1 = OrderedDict()
#results_2 = OrderedDict()
#results_3 = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, dataset_name)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results_1[dataset_name] = {}
continue
results_i_1 = inference_on_dataset(model, data_loader, k, evaluator)
results_1[dataset_name] = results_i_1
if comm.is_main_process():
assert isinstance(
results_i_1[0], dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i_1
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name+"for class"+str(k+1)))
print_csv_format(results_i_1)
"""
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, dataset_name)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results_2[dataset_name] = {}
continue
results_i_2 = inference_on_dataset(model, data_loader, 1, evaluator)
results_2[dataset_name] = results_i_2
if comm.is_main_process():
assert isinstance(
results_i_2, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i_2
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name+"for class"+str(2)))
print_csv_format(results_i_2)
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, dataset_name)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results_3[dataset_name] = {}
continue
results_i_3 = inference_on_dataset(model, data_loader, 2, evaluator)
results_3[dataset_name] = results_i_3
if comm.is_main_process():
assert isinstance(
results_i_3, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i_3
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name+"for class"+str(3)))
print_csv_format(results_i_3)
"""
if len(results_1) == 1:
results_1 = list(results_1.values())[0]
"""
if len(results_2) == 1:
results_2 = list(results_2.values())[0]
if len(results_3) == 1:
results_3 = list(results_3.values())[0]
"""
#results=[results_1, results_2, results_3]
return results_1
@classmethod
def ema_test(cls, cfg, model, evaluators=None):
# model with ema weights
logger = logging.getLogger("detectron2.trainer")
if cfg.MODEL_EMA.ENABLED:
logger.info("Run evaluation with EMA.")
with apply_model_ema_and_restore(model):
results = cls.test(cfg, model, evaluators=evaluators,k=2 )
else:
results= cls.test(cfg, model, evaluators=evaluators,k=2)
return results
@classmethod
def test_with_TTA(cls, cfg, model):
logger = logging.getLogger("detectron2.trainer")
logger.info("Running inference with test-time augmentation ...")
model = DiffusionDetWithTTA(cfg, model)
evaluators = [
cls.build_evaluator(
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
)
for name in cfg.DATASETS.TEST
]
if cfg.MODEL_EMA.ENABLED:
cls.ema_test(cfg, model, evaluators,i)
else:
res = cls.test(cfg, model, evaluators,i)
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
return res
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
EMAHook(self.cfg, self.model) if cfg.MODEL_EMA.ENABLED else None, # EMA hook
hooks.LRScheduler(),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results(k):
self._last_eval_results = self.test(self.cfg, self.model, k)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
#ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results, 0))
#if comm.is_main_process():
# ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results, 0))
#if comm.is_main_process():
# ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
#ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results, 2))
if comm.is_main_process():
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_diffusiondet_config(cfg)
add_model_ema_configs(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
model = Trainer.build_model(cfg)
kwargs = may_get_ema_checkpointer(cfg, model)
if cfg.MODEL_EMA.ENABLED:
EMADetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,
resume=args.resume)
else:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,
resume=args.resume)
res = Trainer.ema_test(cfg, model)
if cfg.TEST.AUG.ENABLED:
res.update(Trainer.test_with_TTA(cfg, model))
if comm.is_main_process():
verify_results(cfg, res)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
from detectron2.data.datasets import register_coco_instances
register_coco_instances('custom_train_class', {}, "../sorted/challenge/train_merged_disease_coco3class_onlyd_fixed.json", "../sorted/challenge/for_coco_disease_train")
#register_coco_instances('custom_train_class', {}, "datasets/custom_quadrant_enumeration_triple/train.json", "../sorted/tooth_number_w_images")
#register_coco_instances('custom_train_class2', {}, "datasets/custom_quadrant/train.json", "../sorted/quadrant_w_xrays")
#register_coco_instances('custom_train_class3', {}, "datasets/custom_quadrant/train.json", "../sorted/quadrant_w_xrays")
#register_coco_instances('custom_validation_class1', {}, "datasets/custom_quadrant/validation.json", "../sorted/quadrant_w_xrays")
#register_coco_instances('custom_validation_class2', {}, "datasets/custom_quadrant/validation.json", "../sorted/quadrant_w_xrays")
#register_coco_instances('custom_validation_class', {}, "datasets/custom_quadrant_enumeration_triple/validation.json", "../sorted/tooth_number_w_images")
register_coco_instances('custom_validation_class', {}, "../sorted/challenge/test_merged_disease_coco3class.json", "../sorted/challenge/for_coco_disease_test")
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)