From de645987457e06f4d5b258706ef9e2d3e22a877f Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 1 Feb 2023 12:42:22 -0500 Subject: [PATCH 1/5] Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict --- src/sparseml/pytorch/torchvision/train.py | 35 ++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 7c77ee3e08c..86ee09eb20a 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -21,15 +21,17 @@ import sys import time import warnings +from collections import defaultdict from functools import update_wrapper from types import SimpleNamespace -from typing import Callable, Optional +from typing import Any, Callable, Dict, Optional import torch import torch.utils.data import torchvision from packaging import version from torch import nn +from torch.nn import Module from torch.utils.data.dataloader import DataLoader, default_collate from torchvision.transforms.functional import InterpolationMode @@ -533,6 +535,10 @@ def collate_fn(batch): # load params if checkpoint is not None: if "optimizer" in checkpoint and not args.test_only: + checkpoint["optimizer"] = _update_checkpoint_optimizer( + checkpoint_optim=checkpoint["optimizer"], + model=model, + ) optimizer.load_state_dict(checkpoint["optimizer"]) if model_ema and "model_ema" in checkpoint: model_ema.load_state_dict(checkpoint["model_ema"]) @@ -724,6 +730,33 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i _LOGGER.info(f"Training time {total_time_str}") +def _update_checkpoint_optimizer(checkpoint_optim: Dict[Any, Any], model: Module): + # delete params from state dict where the size does not match + # model param size, (this is required for example: in cases when we + # transfer learn onto a dataset where number of output classes is + # different from upstream dataset), these state dict items will be + # re-initialized by the optimizer based on model param shape when + # `optimizer.step()` is called, re-initialization only happens when + # the param are present in optimizer state_dict under "param_groups" + # key + + # torch uses `defaultdict(dict)` as default + checkpoint_state_dict = checkpoint_optim.get("state", defaultdict(dict)) + + for idx, (name, param) in enumerate(model.named_parameters()): + if ( + idx in checkpoint_state_dict + and "momentum_buffer" in checkpoint_state_dict[idx] + and param.size() != checkpoint_state_dict[idx]["momentum_buffer"].size() + ): + del checkpoint_state_dict[idx] + + return { + "state": checkpoint_state_dict, + "param_groups": checkpoint_optim.get("param_groups"), + } + + def _create_model( arch_key: Optional[str] = None, local_rank=None, From 4eb372a9d1a688096b0110e7aee41bf8785283a1 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 2 Feb 2023 10:48:13 -0500 Subject: [PATCH 2/5] Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set --- src/sparseml/pytorch/torchvision/train.py | 39 ++++------------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 86ee09eb20a..bec9e1cfffc 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -535,11 +535,13 @@ def collate_fn(batch): # load params if checkpoint is not None: if "optimizer" in checkpoint and not args.test_only: - checkpoint["optimizer"] = _update_checkpoint_optimizer( - checkpoint_optim=checkpoint["optimizer"], - model=model, - ) - optimizer.load_state_dict(checkpoint["optimizer"]) + if args.resume: + optimizer.load_state_dict(checkpoint["optimizer"]) + else: + warnings.warn( + "Optim state dict not loaded cause `--resume` was set to " + f"{args.resume}" + ) if model_ema and "model_ema" in checkpoint: model_ema.load_state_dict(checkpoint["model_ema"]) if scaler and "scaler" in checkpoint: @@ -730,33 +732,6 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i _LOGGER.info(f"Training time {total_time_str}") -def _update_checkpoint_optimizer(checkpoint_optim: Dict[Any, Any], model: Module): - # delete params from state dict where the size does not match - # model param size, (this is required for example: in cases when we - # transfer learn onto a dataset where number of output classes is - # different from upstream dataset), these state dict items will be - # re-initialized by the optimizer based on model param shape when - # `optimizer.step()` is called, re-initialization only happens when - # the param are present in optimizer state_dict under "param_groups" - # key - - # torch uses `defaultdict(dict)` as default - checkpoint_state_dict = checkpoint_optim.get("state", defaultdict(dict)) - - for idx, (name, param) in enumerate(model.named_parameters()): - if ( - idx in checkpoint_state_dict - and "momentum_buffer" in checkpoint_state_dict[idx] - and param.size() != checkpoint_state_dict[idx]["momentum_buffer"].size() - ): - del checkpoint_state_dict[idx] - - return { - "state": checkpoint_state_dict, - "param_groups": checkpoint_optim.get("param_groups"), - } - - def _create_model( arch_key: Optional[str] = None, local_rank=None, From 31e317285efd8f17bb559d28d035bc07e36c04de Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 2 Feb 2023 10:49:42 -0500 Subject: [PATCH 3/5] Remove: un-needed imports --- src/sparseml/pytorch/torchvision/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index bec9e1cfffc..e896a0baab1 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -21,17 +21,15 @@ import sys import time import warnings -from collections import defaultdict from functools import update_wrapper from types import SimpleNamespace -from typing import Any, Callable, Dict, Optional +from typing import Callable, Optional import torch import torch.utils.data import torchvision from packaging import version from torch import nn -from torch.nn import Module from torch.utils.data.dataloader import DataLoader, default_collate from torchvision.transforms.functional import InterpolationMode From 91f020f0f4f31e31267c4aa279f1cf428cbc77f5 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 2 Feb 2023 11:01:18 -0500 Subject: [PATCH 4/5] Address review comments --- src/sparseml/pytorch/torchvision/train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index e896a0baab1..b9b1571002c 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -52,7 +52,6 @@ from sparseml.pytorch.utils.model import load_model from sparsezoo import Model - _LOGGER = logging.getLogger(__name__) @@ -537,8 +536,9 @@ def collate_fn(batch): optimizer.load_state_dict(checkpoint["optimizer"]) else: warnings.warn( - "Optim state dict not loaded cause `--resume` was set to " - f"{args.resume}" + "Optimizer state dict not loaded from checkpoint. Unless run is " + "resumed with the --resume arg, the optimizer will start from a " + "fresh state" ) if model_ema and "model_ema" in checkpoint: model_ema.load_state_dict(checkpoint["model_ema"]) @@ -1040,7 +1040,7 @@ def new_func(*args, **kwargs): is_flag=True, default=False, help="Cache the datasets for quicker initialization. " - "It also serializes the transforms", + "It also serializes the transforms", ) @click.option("--sync-bn", is_flag=True, default=False, help="Use sync batch norm") @click.option("--test-only", is_flag=True, default=False, help="Only test the model") @@ -1146,15 +1146,15 @@ def new_func(*args, **kwargs): default=1, type=int, help="Save the best validation result after the given " - "epoch completes until the end of training", + "epoch completes until the end of training", ) @click.option( "--distill-teacher", default=None, type=str, help="Teacher model for distillation (a trained image classification model)" - " can be set to 'self' for self-distillation and 'disable' to switch-off" - " distillation, additionally can also take in a SparseZoo stub", + " can be set to 'self' for self-distillation and 'disable' to switch-off" + " distillation, additionally can also take in a SparseZoo stub", ) @click.option( "--pretrained-teacher-dataset", From 960db507007eb9854eb81f8828f2353e4497742a Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 2 Feb 2023 11:02:03 -0500 Subject: [PATCH 5/5] Style --- src/sparseml/pytorch/torchvision/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index b9b1571002c..d80ad837a30 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -52,6 +52,7 @@ from sparseml.pytorch.utils.model import load_model from sparsezoo import Model + _LOGGER = logging.getLogger(__name__) @@ -1040,7 +1041,7 @@ def new_func(*args, **kwargs): is_flag=True, default=False, help="Cache the datasets for quicker initialization. " - "It also serializes the transforms", + "It also serializes the transforms", ) @click.option("--sync-bn", is_flag=True, default=False, help="Use sync batch norm") @click.option("--test-only", is_flag=True, default=False, help="Only test the model") @@ -1146,15 +1147,15 @@ def new_func(*args, **kwargs): default=1, type=int, help="Save the best validation result after the given " - "epoch completes until the end of training", + "epoch completes until the end of training", ) @click.option( "--distill-teacher", default=None, type=str, help="Teacher model for distillation (a trained image classification model)" - " can be set to 'self' for self-distillation and 'disable' to switch-off" - " distillation, additionally can also take in a SparseZoo stub", + " can be set to 'self' for self-distillation and 'disable' to switch-off" + " distillation, additionally can also take in a SparseZoo stub", ) @click.option( "--pretrained-teacher-dataset",