From 3fdee7574caa33018029ec64d8dc2de7afcb8b5e Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 31 Jan 2023 15:14:06 -0500 Subject: [PATCH 1/2] Bugfix: use label smoothing only when torch version is >= 1.10 (#1352) * Bugfix: use label smoothing only when torch version is >= 1.10 * Apply suggestions from code review --- src/sparseml/pytorch/torchvision/train.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 5627cedcb18..a4dce9f2797 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -28,6 +28,7 @@ import torch import torch.utils.data import torchvision +from packaging import version from torch import nn from torch.utils.data.dataloader import DataLoader, default_collate from torchvision.transforms.functional import InterpolationMode @@ -380,7 +381,15 @@ def collate_fn(batch): if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + if version.parse(torch.__version__) >= version.parse("1.10"): + criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + elif args.label_smoothing > 0: + raise ValueError( + f"`label_smoothing` not supported for {torch.__version__}, " + f"try upgrading to at-least 1.10" + ) + else: + criterion = nn.CrossEntropyLoss() custom_keys_weight_decay = [] if args.bias_weight_decay is not None: From c7b0aee4bd76c083959f828155ddfb62184a2bac Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 2 Feb 2023 11:15:54 -0500 Subject: [PATCH 2/2] [BugFix][Torchvision] update optimizer state dict before transfer learning (#1358) * Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict * Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set * Remove: un-needed imports * Address review comments * Style --- src/sparseml/pytorch/torchvision/train.py | 199 ++++++++++++++++++---- 1 file changed, 165 insertions(+), 34 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index a4dce9f2797..51b2e9e6117 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -23,7 +23,7 @@ import warnings from functools import update_wrapper from types import SimpleNamespace -from typing import Callable +from typing import Callable, Optional import torch import torch.utils.data @@ -34,6 +34,7 @@ from torchvision.transforms.functional import InterpolationMode import click +from sparseml.optim.helpers import load_recipe_yaml_str from sparseml.pytorch.models.registry import ModelRegistry from sparseml.pytorch.optim import ScheduledModifierManager from sparseml.pytorch.torchvision import presets, transforms, utils @@ -65,6 +66,7 @@ def train_one_epoch( epoch: int, args, log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None], + manager=None, model_ema=None, scaler=None, ) -> utils.MetricLogger: @@ -93,13 +95,24 @@ def train_one_epoch( start_time = time.time() image, target = image.to(device), target.to(device) with torch.cuda.amp.autocast(enabled=scaler is not None): - output = model(image) + outputs = output = model(image) if isinstance(output, tuple): # NOTE: sparseml models return two things (logits & probs) output = output[0] loss = criterion(output, target) if steps_accumulated % accum_steps == 0: + if manager is not None: + loss = manager.loss_update( + loss=loss, + module=model, + optimizer=optimizer, + epoch=epoch, + steps_per_epoch=len(data_loader) / accum_steps, + student_outputs=outputs, + student_inputs=image, + ) + # first: do training to consume gradients if scaler is not None: scaler.scale(loss).backward() @@ -127,11 +140,17 @@ def train_one_epoch( # Reset ema buffer to keep copying weights during warmup period model_ema.n_averaged.fill_(0) - acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + acc1, num_correct_1, acc5, num_correct_5 = utils.accuracy( + output, target, topk=(1, 5) + ) batch_size = image.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) - metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) - metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["acc1"].update( + acc1.item(), n=batch_size, total=num_correct_1 + ) + metric_logger.meters["acc5"].update( + acc5.item(), n=batch_size, total=num_correct_5 + ) metric_logger.meters["imgs_per_sec"].update( batch_size / (time.time() - start_time) ) @@ -169,13 +188,19 @@ def evaluate( output = output[0] loss = criterion(output, target) - acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + acc1, num_correct_1, acc5, num_correct_5 = utils.accuracy( + output, target, topk=(1, 5) + ) # FIXME need to take into account that the datasets # could have been padded in distributed setup batch_size = image.shape[0] metric_logger.update(loss=loss.item()) - metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) - metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["acc1"].update( + acc1.item(), n=batch_size, total=num_correct_1 + ) + metric_logger.meters["acc5"].update( + acc5.item(), n=batch_size, total=num_correct_5 + ) num_processed_samples += batch_size # gather the stats from all processes @@ -356,27 +381,32 @@ def collate_fn(batch): ) _LOGGER.info("Creating model") - if args.arch_key in ModelRegistry.available_keys(): - with torch_distributed_zero_first(args.rank if args.distributed else None): - model = ModelRegistry.create( - key=args.arch_key, - pretrained=args.pretrained, - pretrained_path=args.checkpoint_path, - pretrained_dataset=args.pretrained_dataset, - num_classes=num_classes, - ) - elif args.arch_key in torchvision.models.__dict__: - # fall back to torchvision - model = torchvision.models.__dict__[args.arch_key]( - pretrained=args.pretrained, num_classes=num_classes + local_rank = args.rank if args.distributed else None + model, arch_key = _create_model( + arch_key=args.arch_key, + local_rank=local_rank, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint_path, + pretrained_dataset=args.pretrained_dataset, + device=device, + num_classes=num_classes, + ) + + if args.distill_teacher not in ["self", "disable", None]: + _LOGGER.info("Instantiating teacher") + distill_teacher, _ = _create_model( + arch_key=args.teacher_arch_key, + local_rank=local_rank, + pretrained=True, # teacher is always pretrained + pretrained_dataset=args.pretrained_teacher_dataset, + checkpoint_path=args.distill_teacher, + device=device, + num_classes=num_classes, ) if args.checkpoint_path is not None: load_model(args.checkpoint_path, model, strict=True) else: - raise ValueError( - f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models" - ) - model.to(device) + distill_teacher = args.distill_teacher if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -476,7 +506,7 @@ def collate_fn(batch): ) checkpoint_manager = ( ScheduledModifierManager.from_yaml(checkpoint["recipe"]) - if "recipe" in checkpoint + if "recipe" in checkpoint and checkpoint["recipe"] is not None else None ) elif args.resume: @@ -504,8 +534,15 @@ def collate_fn(batch): # load params if checkpoint is not None: - if "optimizer" in checkpoint: - optimizer.load_state_dict(checkpoint["optimizer"]) + if "optimizer" in checkpoint and not args.test_only: + if args.resume: + optimizer.load_state_dict(checkpoint["optimizer"]) + else: + warnings.warn( + "Optimizer state dict not loaded from checkpoint. Unless run is " + "resumed with the --resume arg, the optimizer will start from a " + "fresh state" + ) if model_ema and "model_ema" in checkpoint: model_ema.load_state_dict(checkpoint["model_ema"]) if scaler and "scaler" in checkpoint: @@ -541,13 +578,26 @@ def collate_fn(batch): TensorBoardLogger(log_path=args.output_dir), ] try: - loggers.append(WANDBLogger()) + config = vars(args) + if manager is not None: + config["manager"] = str(manager) + loggers.append(WANDBLogger(init_kwargs=dict(config=config))) except ImportError: warnings.warn("Unable to import wandb for logging") logger = LoggerManager(loggers) else: logger = LoggerManager(log_python=False) + if args.recipe is not None: + base_path = os.path.join(args.output_dir, "original_recipe.yaml") + with open(base_path, "w") as fp: + fp.write(load_recipe_yaml_str(args.recipe)) + logger.save(base_path) + + full_path = os.path.join(args.output_dir, "final_recipe.yaml") + manager.save(full_path) + logger.save(full_path) + steps_per_epoch = len(data_loader) / args.gradient_accum_steps def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: int): @@ -558,10 +608,23 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i ) if manager is not None: - manager.initialize(model, epoch=args.start_epoch, loggers=logger) - optimizer = manager.modify( - model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch + manager.initialize( + model, + epoch=args.start_epoch, + loggers=logger, + distillation_teacher=distill_teacher, ) + step_wrapper = manager.modify( + model, + optimizer, + steps_per_epoch=steps_per_epoch, + epoch=args.start_epoch, + wrap_optim=scaler, + ) + if scaler is None: + optimizer = step_wrapper + else: + scaler = step_wrapper lr_scheduler = _get_lr_scheduler( args, optimizer, checkpoint=checkpoint, manager=manager @@ -582,7 +645,8 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i if args.distributed: train_sampler.set_epoch(epoch) if manager is not None and manager.qat_active(epoch=epoch): - scaler = None + if scaler is not None: + scaler._enabled = False model_ema = None train_metrics = train_one_epoch( @@ -595,6 +659,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i epoch, args, log_metrics, + manager=manager, model_ema=model_ema, scaler=scaler, ) @@ -625,6 +690,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i "state_dict": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "args": args, + "arch_key": arch_key, } if lr_scheduler: checkpoint["lr_scheduler"] = lr_scheduler.state_dict() @@ -644,7 +710,8 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i ) else: checkpoint["epoch"] = -1 if epoch == max_epochs - 1 else epoch - checkpoint["recipe"] = str(manager) + if str(manager) is not None: + checkpoint["recipe"] = str(manager) file_names = ["checkpoint.pth"] if is_new_best: @@ -666,6 +733,42 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i _LOGGER.info(f"Training time {total_time_str}") +def _create_model( + arch_key: Optional[str] = None, + local_rank=None, + pretrained: Optional[bool] = False, + checkpoint_path: Optional[str] = None, + pretrained_dataset: Optional[str] = None, + device=None, + num_classes=None, +): + if not arch_key or arch_key in ModelRegistry.available_keys(): + with torch_distributed_zero_first(local_rank): + model = ModelRegistry.create( + key=arch_key, + pretrained=pretrained, + pretrained_path=checkpoint_path, + pretrained_dataset=pretrained_dataset, + num_classes=num_classes, + ) + + if isinstance(model, tuple): + model, arch_key = model + elif arch_key in torchvision.models.__dict__: + # fall back to torchvision + model = torchvision.models.__dict__[arch_key]( + pretrained=pretrained, num_classes=num_classes + ) + if checkpoint_path is not None: + load_model(checkpoint_path, model, strict=True) + else: + raise ValueError( + f"Unable to find {arch_key} in ModelRegistry or in torchvision.models" + ) + model.to(device) + return model, arch_key + + def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None): lr_scheduler = None @@ -1048,6 +1151,34 @@ def new_func(*args, **kwargs): help="Save the best validation result after the given " "epoch completes until the end of training", ) +@click.option( + "--distill-teacher", + default=None, + type=str, + help="Teacher model for distillation (a trained image classification model)" + " can be set to 'self' for self-distillation and 'disable' to switch-off" + " distillation, additionally can also take in a SparseZoo stub", +) +@click.option( + "--pretrained-teacher-dataset", + default=None, + type=str, + help=( + "The dataset to load pretrained weights for the teacher" + "Load the default dataset for the architecture if set to None. " + "examples:`imagenet`, `cifar10`, etc..." + ), +) +@click.option( + "--teacher-arch-key", + default=None, + type=str, + help=( + "The architecture key for teacher image classification model; " + "example: `resnet50`, `mobilenet`. " + "Note: Will be read from the checkpoint if not specified" + ), +) @click.pass_context def cli(ctx, **kwargs): """