diff --git a/requirements.txt b/requirements.txt index fd187eb56cfe..9fa43f2e1d7f 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ pandas # extras -------------------------------------- thop # FLOPS computation pycocotools>=2.0 # COCO mAP +sparseml~=0.2 diff --git a/train.py b/train.py index acfc9ef5527b..221b8a8baa07 100644 --- a/train.py +++ b/train.py @@ -21,6 +21,11 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from sparseml.pytorch.nn import replace_activations +from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer +from sparseml.pytorch.utils import PythonLogger, TensorBoardLogger, ModuleExporter +from sparseml.pytorch.utils.quantization import skip_onnx_input_quantize + import test # import test.py to get mAP after each epoch from models.experimental import attempt_load from models.yolo import Model @@ -59,6 +64,7 @@ def train(hyp, opt, device, tb_writer=None): # Configure plots = not opt.evolve # create plots cuda = device.type != 'cpu' + half_precision = cuda and not opt.disable_amp init_seeds(2 + rank) with open(opt.data) as f: data_dict = yaml.safe_load(f) # data dict @@ -87,7 +93,7 @@ def train(hyp, opt, device, tb_writer=None): ckpt = torch.load(weights, map_location=device) # load checkpoint model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys - state_dict = ckpt['model'].float().state_dict() # to FP32 + state_dict = ckpt['model'].float().state_dict() if isinstance(ckpt['model'], nn.Module) else ckpt['model'] state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect model.load_state_dict(state_dict, strict=False) # load logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report @@ -141,7 +147,7 @@ def train(hyp, opt, device, tb_writer=None): # plot_lr_scheduler(optimizer, scheduler, epochs) # EMA - ema = ModelEMA(model) if rank in [-1, 0] else None + ema = ModelEMA(model, enabled=not opt.disable_ema) if rank in [-1, 0] else None # Resume start_epoch, best_fitness = 0, 0.0 @@ -153,8 +159,7 @@ def train(hyp, opt, device, tb_writer=None): # EMA if ema and ckpt.get('ema'): - ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) - ema.updates = ckpt['updates'] + ema.load_state_dict(ckpt) # Results if ckpt.get('training_results') is not None: @@ -214,7 +219,8 @@ def train(hyp, opt, device, tb_writer=None): # Anchors if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) - model.half().float() # pre-reduce anchor precision + if half_precision: + model.half().float() # pre-reduce anchor precision # DDP mode if cuda and rank != -1: @@ -233,14 +239,50 @@ def train(hyp, opt, device, tb_writer=None): model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.names = names + # SparseML Integration + if opt.use_leaky_relu: # use LeakyReLU activations + model = replace_activations(model, 'lrelu', inplace=True) + + qat = False + if opt.sparseml_recipe: + manager = ScheduledModifierManager.from_yaml(opt.sparseml_recipe) + optimizer = ScheduledOptimizer( + optimizer, + model if not is_parallel(model) else model.module, + manager, + steps_per_epoch=len(dataloader), + loggers=[PythonLogger(), TensorBoardLogger(writer=tb_writer)] + ) + # override lr scheduler if recipe makes any LR updates + if manager.learning_rate_modifiers: + logger.info('Disabling LR scheduler, managing LR using SparseML recipe') + scheduler = None + # override num epochs if recipe explicitly modifies epoch range + if manager.epoch_modifiers and manager.max_epochs: + epochs = manager.max_epochs or epochs # override num_epochs + logger.info(f'Overriding number of epochs from SparseML manager to {manager.max_epochs}') + # mark that QAT will be applied, pickled QAT exports currently not supported + if manager.quantization_modifiers: + logger.info('Disabling pickling for model exports, QAT scheduled to run') + if not opt.use_leaky_relu: + logger.warning( + 'QAT detected in sparsification recipe, but --use-leaky-relu not set ' + 'quantized model may not run well with default activations' + ) + qat = True + # make sure that sparsity structure is held during EMA updates + if ema and manager.pruning_modifiers: + ema.pruning_manager = manager + # Start training t0 = time.time() nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) - scheduler.last_epoch = start_epoch - 1 # do not move - scaler = amp.GradScaler(enabled=cuda) + if scheduler: + scheduler.last_epoch = start_epoch - 1 # do not move + scaler = amp.GradScaler(enabled=half_precision) compute_loss = ComputeLoss(model) # init loss class logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n' f'Using {dataloader.num_workers} dataloader workers\n' @@ -286,7 +328,8 @@ def train(hyp, opt, device, tb_writer=None): accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 - x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) + if scheduler: + x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) if 'momentum' in x: x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']]) @@ -299,7 +342,7 @@ def train(hyp, opt, device, tb_writer=None): imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) # Forward - with amp.autocast(enabled=cuda): + with amp.autocast(enabled=half_precision): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if rank != -1: @@ -342,7 +385,8 @@ def train(hyp, opt, device, tb_writer=None): # Scheduler lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard - scheduler.step() + if scheduler: + scheduler.step() # DDP process 0 or single-GPU if rank in [-1, 0]: @@ -354,7 +398,7 @@ def train(hyp, opt, device, tb_writer=None): results, maps, times = test.test(data_dict, batch_size=batch_size * 2, imgsz=imgsz_test, - model=ema.ema, + model=ema.ema if ema.enabled else model, single_cls=opt.single_cls, dataloader=testloader, save_dir=save_dir, @@ -362,7 +406,8 @@ def train(hyp, opt, device, tb_writer=None): plots=plots and final_epoch, wandb_logger=wandb_logger, compute_loss=compute_loss, - is_coco=is_coco) + is_coco=is_coco, + half_precision=half_precision) # Write with open(results_file, 'a') as f: @@ -389,14 +434,16 @@ def train(hyp, opt, device, tb_writer=None): # Save model if (not opt.nosave) or (final_epoch and not opt.evolve): # if save + ckpt_model = deepcopy(model.module if is_parallel(model) else model) + if qat: + ckpt_model = model.state_dict() # pickled QAT exports not currently supported ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': results_file.read_text(), - 'model': deepcopy(model.module if is_parallel(model) else model).half(), - 'ema': deepcopy(ema.ema).half(), - 'updates': ema.updates, + 'model': ckpt_model.half() if half_precision else ckpt_model, 'optimizer': optimizer.state_dict(), 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None} + ckpt.update(ema.state_dict(half_precision=half_precision)) # add EMA model and updates if enabled # Save last, best and delete torch.save(ckpt, last) @@ -422,23 +469,39 @@ def train(hyp, opt, device, tb_writer=None): logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) if opt.data.endswith('coco.yaml') and nc == 80: # if COCO for m in (last, best) if best.exists() else (last): # speed, mAP tests + test_model = attempt_load(m, device) if not qat else model results, _, _ = test.test(opt.data, batch_size=batch_size * 2, imgsz=imgsz_test, conf_thres=0.001, iou_thres=0.7, - model=attempt_load(m, device).half(), + model=test_model.half() if half_precision else test_model, single_cls=opt.single_cls, dataloader=testloader, save_dir=save_dir, save_json=True, plots=False, - is_coco=is_coco) + is_coco=is_coco, + half_precision=half_precision) + + # ONNX export + if opt.export_onnx: + try: + onnx_path = f'{save_dir}/model.onnx' + logger.info(f'training complete, exporting ONNX to {onnx_path}') + export_model = model.module if is_parallel_model(model) else model + export_model.model[-1].export = True # do not export grid post-procesing + exporter = ModuleExporter(export_model, save_dir) + exporter.export_onnx(torch.randn(1, 3, imgsz, imgsz), convert_qat=True) + if qat: + skip_onnx_input_quantize(onnx_path, onnx_path) + except Exception as e: + logger.warning(f'exception occured during ONNX export, model not exported to ONNX. error message {e}') # Strip optimizers final = best if best.exists() else last # final model for f in last, best: - if f.exists(): + if f.exists() and not qat: # qat state dict incompatible strip_optimizer(f) # strip optimizers if opt.bucket: os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload @@ -489,6 +552,11 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') + parser.add_argument('--sparseml-recipe', type=str, default=None, help='Path to a SparseML sparsification recipe, see for more information') + parser.add_argument('--use-leaky-relu', action='store_true', help='Override default SiLU activation with LeakyReLU') + parser.add_argument('--export-onnx', action='store_true', help='export final model to ONNX') + parser.add_argument('--disable-amp', action='store_true', help='Disable FP16 half precision (enabled by default)') + parser.add_argument('--disable-ema', action='store_true', help='Disable EMA model updates (enabled by default)') opt = parser.parse_args() # Set DDP variables diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 9991e5ec87d8..d5d7a50fb312 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -276,17 +276,38 @@ class ModelEMA: GPU assignment and distributed training wrappers. """ - def __init__(self, model, decay=0.9999, updates=0): + def __init__(self, model, decay=0.9999, updates=0, enabled=True): # Create EMA self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA # if next(model.parameters()).device.type != 'cpu': # self.ema.half() # FP16 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) + self.enabled = enabled + self.pruning_manager = None # type: sparseml.pytorch.optim.ScheduledModifierManager for p in self.ema.parameters(): p.requires_grad_(False) + def state_dict(self, half_precision=True): + if not self.enabled: + return {} + + ema = deepcopy(self.ema) + return { + 'ema': ema.half() if half_precision else ema, + 'updates': self.updates, + } + + def load_state_dict(self, state_dict): + if 'ema' in state_dict: + self.ema.load_state_dict(state_dict['ema'].float().state_dict()) + if 'updates' in state_dict: + self.updates = state_dict['updates'] + def update(self, model): + if not self.enabled: + return + # Update EMA parameters with torch.no_grad(): self.updates += 1 @@ -299,5 +320,17 @@ def update(self, model): v += (1. - d) * msd[k].detach() def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): + if not self.enabled: + return + + # store pre-ema sparsity masks + if self.pruning_manager is not None: + pruning_dict = self.pruning_manager.state_dict() + # Update EMA attributes copy_attr(self.ema, model, include, exclude) + + # restore sparsity structure post-ema + if self.pruning_manager is not None: + self.pruning_manager.load_state_dict(pruning_dict) + del pruning_dict