From aa88a89eb8453519aa0055962e6bea83c0a5c049 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 10 Feb 2021 17:29:43 -0500 Subject: [PATCH 1/8] rwightman/pytorch-image-models integration --- examples/timm-sparseml/README.md | 78 +++ examples/timm-sparseml/main.py | 843 +++++++++++++++++++++++++++++++ 2 files changed, 921 insertions(+) create mode 100644 examples/timm-sparseml/README.md create mode 100755 examples/timm-sparseml/main.py diff --git a/examples/timm-sparseml/README.md b/examples/timm-sparseml/README.md new file mode 100644 index 00000000000..5a5ffc511a7 --- /dev/null +++ b/examples/timm-sparseml/README.md @@ -0,0 +1,78 @@ + + +# SparseML-rwightman/pytorch-image-models integration +This directory provides a training script for the popular +[rwightman/pytorch-image-models](https://github.com/rwightman/pytorch-image-models) +repository also known as [timm](https://pypi.org/project/timm/). + +Using this integration, you will be able to apply SparseML optimizations +to the powerful training flows of the pytorch-image-models repository. + +Some of the tasks you can perform using this integration include, but are not limited to: +* model pruning +* quantization-aware-training +* sparse quantization-aware-training + +## Installation +Both requirements can be installed via `pip install ` or can be cloned +and installed from their respective source repositories. + +```bash +pip install git+https://github.com/rwightman/pytorch-image-models.git +pip install spraseml[torchvision] +``` + + +## Script +`examples/timm-sparseml/main.py` modifies +[`train.py`](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) +from pytorch-image-models to include a `sparseml-recipe-path` argument +to run SparseML optimizations with. Running the script will +follow the normal pytorch-image-models training flow with the given +SparseML optimizations enabled. + +Some considerations: + +* `--sparseml-recipe-path` is a required parameter +* `--epochs` will now be overridden by the epochs set in the SparseML recipe +* All learning rate parameters and schedulers from the original script will be overritten by learning rate modifiers in the SparseML recipe +* Modifiers will log their outputs to the console as well as to a tensorboard file +* After training is complete, the final model will be exported to ONNX using SparseML + +You can learn how to build or download a recipe using the +[SparseML](https://github.com/neuralmagic/sparseml) +or [SparseZoo](https://github.com/neuralmagic/sparsezoo) +documentation, or export one with [Sparsify](https://github.com/neuralmagic/sparsify). + +Documentation on the original script can be found +[here](https://rwightman.github.io/pytorch-image-models/scripts/). +The latest commit hash that `main.py` is included in the docstring. + + +#### Example Command +```bash +python examples/timm-sparseml/main.py \ + /PATH/TO/DATASET/imagenet/ \ + --sparseml-recipe-path /PATH/TO/RECIPE/recipe.yaml \ + --dataset imagenet \ + --batch-size 64 \ + --remode pixel --reprob 0.6 --smoothing 0.1 \ + --output models/optimized \ + --model resnet50 \ + --initial-checkpoint PATH/TO/CHECKPOINT/model.pth \ + --workers 8 \ +``` diff --git a/examples/timm-sparseml/main.py b/examples/timm-sparseml/main.py new file mode 100755 index 00000000000..6d2a44de9fc --- /dev/null +++ b/examples/timm-sparseml/main.py @@ -0,0 +1,843 @@ +#!/usr/bin/env python + +# neuralmagic: no copyright +# flake8: noqa +# fmt: off +# isort: skip_file + +""" +Integration between https://github.com/rwightman/pytorch-image-models and SparseML + +This script is adapted from https://github.com/rwightman/pytorch-image-models/blob/master/train.py +to apply a SparseML recipe from the required `--sparseml-recipe-path` argument. +Integration lines are preceded by commend blocks. Run with `--help` for help printout, +more information can be found in the readme file. + +Latest pytorch-image-models commit this script is based on: aaa715b + +Original doc-string: + +ImageNet Training Script + +This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet +training results with some of the latest networks and training techniques. It favours canonical PyTorch +and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed +and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. + +This script was started from an early version of the PyTorch ImageNet example +(https://github.com/pytorch/examples/tree/master/imagenet) + +NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples +(https://github.com/NVIDIA/apex/tree/master/examples/imagenet) + +Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) +""" +import argparse +import time +import yaml +import os +import logging +from collections import OrderedDict +from contextlib import suppress +from datetime import datetime + +import torch +import torch.nn as nn +import torchvision.utils +from torch.nn.parallel import DistributedDataParallel as NativeDDP + +from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model +from timm.utils import * +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy +from timm.optim import create_optimizer +from timm.scheduler import create_scheduler +from timm.utils import ApexScaler, NativeScaler + +from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer +from sparseml.pytorch.utils import ModuleExporter, PythonLogger, TensorBoardLogger + +try: + from apex import amp + from apex.parallel import DistributedDataParallel as ApexDDP + from apex.parallel import convert_syncbn_model + has_apex = True +except ImportError: + has_apex = False + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +torch.backends.cudnn.benchmark = True +_logger = logging.getLogger('train') + +# The first arg parser parses out only the --config argument, this argument is used to +# load a yaml file containing key-values that override the defaults for the main parser below +config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) +parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', + help='YAML config file specifying default arguments') + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') + +# SparseML Recipe parameter +parser.add_argument('--sparseml-recipe-path', required=True, type=str, + help='path to a SparseML YAML or markdown recipe file') + +# Dataset / Model parameters +parser.add_argument('data_dir', metavar='DIR', + help='path to dataset') +parser.add_argument('--dataset', '-d', metavar='NAME', default='', + help='dataset type (default: ImageFolder/ImageTar if empty)') +parser.add_argument('--train-split', metavar='NAME', default='train', + help='dataset train split (default: train)') +parser.add_argument('--val-split', metavar='NAME', default='validation', + help='dataset validation split (default: validation)') +parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', + help='Name of model to train (default: "countception"') +parser.add_argument('--pretrained', action='store_true', default=False, + help='Start with pretrained version of specified network (if avail)') +parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', + help='Initialize model from this checkpoint (default: none)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='Resume full model and optimizer state from checkpoint (default: none)') +parser.add_argument('--no-resume-opt', action='store_true', default=False, + help='prevent resume of optimizer state when resuming model') +parser.add_argument('--num-classes', type=int, default=None, metavar='N', + help='number of label classes (Model default if None)') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') +parser.add_argument('--img-size', type=int, default=None, metavar='N', + help='Image patch size (default: None => model default)') +parser.add_argument('--input-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop percent (for validation only)') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') +parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', + help='ratio of validation batch size to training batch size (default: 1)') + +# Optimizer parameters +parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: None, use opt default)') +parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='Optimizer momentum (default: 0.9)') +parser.add_argument('--weight-decay', type=float, default=0.0001, + help='weight decay (default: 0.0001)') +parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + + + +# Learning rate schedule parameters +#################################################################################### +# SparseML Integration, hide lr args, they will be overridden by SparseML +#################################################################################### +# parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', +# help='LR scheduler (default: "step"') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help=argparse.SUPPRESS) # hide from help text +# parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', +# help='learning rate noise on/off epoch percentages') +# parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', +# help='learning rate noise limit percent (default: 0.67)') +# parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', +# help='learning rate noise std-dev (default: 1.0)') +# parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', +# help='learning rate cycle len multiplier (default: 1.0)') +# parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', +# help='learning rate cycle limit') +# parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', +# help='warmup learning rate (default: 0.0001)') +# parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', +# help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') +# # parser.add_argument('--epochs', type=int, default=200, metavar='N', +# # help='number of epochs to train (default: 2)') +# parser.add_argument('--start-epoch', default=None, type=int, metavar='N', +# help='manual epoch number (useful on restarts)') +# parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', +# help='epoch interval to decay LR') +# parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', +# help='epochs to warmup LR, if scheduler supports') +# parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', +# help='epochs to cooldown LR at min_lr, after cyclic schedule ends') +# parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', +# help='patience epochs for Plateau LR scheduler (default: 10') +# parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', +# help='LR decay rate (default: 0.1)') + +#################################################################################### +# End SparseML integration hide LR args +#################################################################################### + +# Augmentation & regularization parameters +parser.add_argument('--no-aug', action='store_true', default=False, + help='Disable all training augmentation, override other train aug args') +parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', + help='Random resize scale (default: 0.08 1.0)') +parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', + help='Random resize aspect ratio (default: 0.75 1.33)') +parser.add_argument('--hflip', type=float, default=0.5, + help='Horizontal flip training aug probability') +parser.add_argument('--vflip', type=float, default=0., + help='Vertical flip training aug probability') +parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') +parser.add_argument('--aa', type=str, default=None, metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". (default: None)'), +parser.add_argument('--aug-splits', type=int, default=0, + help='Number of augmentation splits (default: 0, valid: 0 or >=2)') +parser.add_argument('--jsd', action='store_true', default=False, + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') +parser.add_argument('--reprob', type=float, default=0., metavar='PCT', + help='Random erase prob (default: 0.)') +parser.add_argument('--remode', type=str, default='const', + help='Random erase mode (default: "const")') +parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') +parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') +parser.add_argument('--mixup', type=float, default=0.0, + help='mixup alpha, mixup enabled if > 0. (default: 0.)') +parser.add_argument('--cutmix', type=float, default=0.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') +parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') +parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') +parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') +parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') +parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', + help='Turn off mixup after this epoch, disabled if 0 (default: 0)') +parser.add_argument('--smoothing', type=float, default=0.1, + help='Label smoothing (default: 0.1)') +parser.add_argument('--train-interpolation', type=str, default='random', + help='Training interpolation (random, bilinear, bicubic default: "random")') +parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') +parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', + help='Drop connect rate, DEPRECATED, use drop-path (default: None)') +parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', + help='Drop path rate (default: None)') +parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', + help='Drop block rate (default: None)') + +# Batch norm parameters (only works with gen_efficientnet based models currently) +parser.add_argument('--bn-tf', action='store_true', default=False, + help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') +parser.add_argument('--bn-momentum', type=float, default=None, + help='BatchNorm momentum override (if not None)') +parser.add_argument('--bn-eps', type=float, default=None, + help='BatchNorm epsilon override (if not None)') +parser.add_argument('--sync-bn', action='store_true', + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') +parser.add_argument('--dist-bn', type=str, default='', + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') +parser.add_argument('--split-bn', action='store_true', + help='Enable separate BN layers per augmentation split.') + +# Model Exponential Moving Average +parser.add_argument('--model-ema', action='store_true', default=False, + help='Enable tracking moving average of model weights') +parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, + help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') +parser.add_argument('--model-ema-decay', type=float, default=0.9998, + help='decay factor for model weights moving average (default: 0.9998)') + +# Misc +parser.add_argument('--seed', type=int, default=42, metavar='S', + help='random seed (default: 42)') +parser.add_argument('--log-interval', type=int, default=50, metavar='N', + help='how many batches to wait before logging training status') +parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', + help='how many batches to wait before writing recovery checkpoint') +parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', + help='number of checkpoints to keep (default: 10)') +parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', + help='how many training processes to use (default: 1)') +parser.add_argument('--save-images', action='store_true', default=False, + help='save images of input bathes every log interval for debugging') +parser.add_argument('--amp', action='store_true', default=False, + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--pin-mem', action='store_true', default=False, + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') +parser.add_argument('--no-prefetcher', action='store_true', default=False, + help='disable fast prefetcher') +parser.add_argument('--output', default='', type=str, metavar='PATH', + help='path to output folder (default: none, current dir)') +parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', + help='Best metric (default: "top1"') +parser.add_argument('--tta', type=int, default=0, metavar='N', + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') +parser.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, + help='use the multi-epochs-loader to save time at the beginning of every epoch') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') + + +def _parse_args(): + # Do we have a config file to parse? + args_config, remaining = config_parser.parse_known_args() + if args_config.config: + with open(args_config.config, 'r') as f: + cfg = yaml.safe_load(f) + parser.set_defaults(**cfg) + + # The main arg parser parses the rest of the args, the usual + # defaults will have been overridden if config file specified. + args = parser.parse_args(remaining) + + # Cache the args as a text string to save them in the output dir later + args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) + return args, args_text + + +def main(): + setup_default_logging() + args, args_text = _parse_args() + + args.prefetcher = not args.no_prefetcher + args.distributed = False + if 'WORLD_SIZE' in os.environ: + args.distributed = int(os.environ['WORLD_SIZE']) > 1 + args.device = 'cuda:0' + args.world_size = 1 + args.rank = 0 # global rank + if args.distributed: + args.device = 'cuda:%d' % args.local_rank + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' + % (args.rank, args.world_size)) + else: + _logger.info('Training with a single process on 1 GPUs.') + assert args.rank >= 0 + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + if args.amp: + # for backwards compat, `--amp` arg tries apex before native amp + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + _logger.warning("Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") + + torch.manual_seed(args.seed + args.rank) + + model = create_model( + args.model, + pretrained=args.pretrained, + num_classes=args.num_classes, + drop_rate=args.drop, + drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path + drop_path_rate=args.drop_path, + drop_block_rate=args.drop_block, + global_pool=args.gp, + bn_tf=args.bn_tf, + bn_momentum=args.bn_momentum, + bn_eps=args.bn_eps, + scriptable=args.torchscript, + checkpoint_path=args.initial_checkpoint) + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly + + if args.local_rank == 0: + _logger.info('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + + data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + + # setup augmentation batch splits for contrastive loss or split bn + num_aug_splits = 0 + if args.aug_splits > 0: + assert args.aug_splits > 1, 'A split of 1 makes no sense' + num_aug_splits = args.aug_splits + + # enable split bn (separate bn stats per batch-portion) + if args.split_bn: + assert num_aug_splits > 1 or args.resplit + model = convert_splitbn_model(model, max(num_aug_splits, 2)) + + # move model to GPU, enable channels last layout if set + model.cuda() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + # setup synchronized BatchNorm for distributed training + if args.distributed and args.sync_bn: + assert not args.split_bn + if has_apex and use_amp != 'native': + # Apex SyncBN preferred unless native amp is activated + model = convert_syncbn_model(model) + else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if args.local_rank == 0: + _logger.info( + 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' + 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + + if args.torchscript: + assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' + assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' + model = torch.jit.script(model) + + optimizer = create_optimizer(args, model) + + # setup automatic mixed-precision (AMP) loss scaling and op casting + amp_autocast = suppress # do nothing + loss_scaler = None + if use_amp == 'apex': + model, optimizer = amp.initialize(model, optimizer, opt_level='O1') + loss_scaler = ApexScaler() + if args.local_rank == 0: + _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') + elif use_amp == 'native': + amp_autocast = torch.cuda.amp.autocast + loss_scaler = NativeScaler() + if args.local_rank == 0: + _logger.info('Using native Torch AMP. Training in mixed precision.') + else: + if args.local_rank == 0: + _logger.info('AMP not enabled. Training in float32.') + + # optionally resume from a checkpoint + resume_epoch = None + if args.resume: + resume_epoch = resume_checkpoint( + model, args.resume, + optimizer=None if args.no_resume_opt else optimizer, + loss_scaler=None if args.no_resume_opt else loss_scaler, + log_info=args.local_rank == 0) + + # setup exponential moving average of model weights, SWA could be used here too + model_ema = None + if args.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + model_ema = ModelEmaV2( + model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + if args.resume: + load_checkpoint(model_ema.module, args.resume, use_ema=True) + + # setup distributed training + if args.distributed: + if has_apex and use_amp != 'native': + # Apex DDP preferred unless native amp is activated + if args.local_rank == 0: + _logger.info("Using NVIDIA APEX DistributedDataParallel.") + model = ApexDDP(model, delay_allreduce=True) + else: + if args.local_rank == 0: + _logger.info("Using native Torch DistributedDataParallel.") + model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 + # NOTE: EMA model does not need to be wrapped by DDP + + # setup learning rate schedule and starting epoch + #################################################################################### + # SparseML integration suppress lr_scheduler + # let SparseML recipe handle LR schedule + # set epoch range later using recipe + #################################################################################### + lr_scheduler = None + # lr_scheduler, num_epochs = create_scheduler(args, optimizer) + # start_epoch = 0 + # if args.start_epoch is not None: + # # a specified start_epoch will always override the resume epoch + # start_epoch = args.start_epoch + # elif resume_epoch is not None: + # start_epoch = resume_epoch + # if lr_scheduler is not None and start_epoch > 0: + # lr_scheduler.step(start_epoch) + #################################################################################### + # End SparseML integration suppress lr_scheduler + #################################################################################### + + # create the train and eval datasets + dataset_train = create_dataset( + args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) + dataset_eval = create_dataset( + args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) + + # setup mixup / cutmix + collate_fn = None + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_args = dict( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.num_classes) + if args.prefetcher: + assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) + collate_fn = FastCollateMixup(**mixup_args) + else: + mixup_fn = Mixup(**mixup_args) + + # wrap dataset in AugMix helper + if num_aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + + # create data loaders w/ augmentation pipeiine + train_interpolation = args.train_interpolation + if args.no_aug or not train_interpolation: + train_interpolation = data_config['interpolation'] + loader_train = create_loader( + dataset_train, + input_size=data_config['input_size'], + batch_size=args.batch_size, + is_training=True, + use_prefetcher=args.prefetcher, + no_aug=args.no_aug, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + re_split=args.resplit, + scale=args.scale, + ratio=args.ratio, + hflip=args.hflip, + vflip=args.vflip, + color_jitter=args.color_jitter, + auto_augment=args.aa, + num_aug_splits=num_aug_splits, + interpolation=train_interpolation, + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + distributed=args.distributed, + collate_fn=collate_fn, + pin_memory=args.pin_mem, + use_multi_epochs_loader=args.use_multi_epochs_loader + ) + + loader_eval = create_loader( + dataset_eval, + input_size=data_config['input_size'], + batch_size=args.validation_batch_size_multiplier * args.batch_size, + is_training=False, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + distributed=args.distributed, + crop_pct=data_config['crop_pct'], + pin_memory=args.pin_mem, + ) + + # setup loss function + if args.jsd: + assert num_aug_splits > 1 # JSD only valid with aug splits set + train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() + elif mixup_active: + # smoothing is handled with mixup target transform + train_loss_fn = SoftTargetCrossEntropy().cuda() + elif args.smoothing: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() + else: + train_loss_fn = nn.CrossEntropyLoss().cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + + # setup checkpoint saver and eval metric tracking + eval_metric = args.eval_metric + best_metric = None + best_epoch = None + saver = None + output_dir = '' + if args.local_rank == 0: + output_base = args.output if args.output else './output' + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + args.model, + str(data_config['input_size'][-1]) + ]) + output_dir = get_outdir(output_base, 'train', exp_name) + decreasing = True if eval_metric == 'loss' else False + saver = CheckpointSaver( + model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, + checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) + with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: + f.write(args_text) + + #################################################################################### + # Start SparseML Integration + #################################################################################### + sparseml_loggers = ( + [PythonLogger(), TensorBoardLogger(log_path=output_dir)] + if output_dir + else None + ) + manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe_path) + optimizer = ScheduledOptimizer( + optimizer, + model, + manager, + steps_per_epoch=len(loader_train), + loggers=sparseml_loggers + ) + start_epoch = manager.min_epochs # override min_epochs + num_epochs = manager.max_epochs or num_epochs # override num_epochs + #################################################################################### + # End SparseML Integration + #################################################################################### + + if args.local_rank == 0: + _logger.info('Scheduled epochs: {}'.format(num_epochs)) + + try: + for epoch in range(start_epoch, num_epochs): + if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): + loader_train.sampler.set_epoch(epoch) + + train_metrics = train_one_epoch( + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): + if args.local_rank == 0: + _logger.info("Distributing BatchNorm running means and vars") + distribute_bn(model, args.world_size, args.dist_bn == 'reduce') + + eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) + + if model_ema is not None and not args.model_ema_force_cpu: + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): + distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') + ema_eval_metrics = validate( + model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + eval_metrics = ema_eval_metrics + + if lr_scheduler is not None: + # step LR for next epoch + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + + update_summary( + epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), + write_header=best_metric is None) + + if saver is not None: + # save proper checkpoint with eval metric + save_metric = eval_metrics[eval_metric] + best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + + ################################################################################# + # Start SparseML ONNX Export + ################################################################################# + if output_dir: + _logger.info( + "training complete, exporting ONNX to {}/model.onnx".format(output_dir) + ) + exporter = ModuleExporter(model, output_dir) + exporter.export_onnx(torch.randn((1, *data_config["input_size"]))) + ################################################################################# + # End SparseML ONNX Export + ################################################################################# + + except KeyboardInterrupt: + pass + if best_metric is not None: + _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + + +def train_one_epoch( + epoch, model, loader, optimizer, loss_fn, args, + lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, + loss_scaler=None, model_ema=None, mixup_fn=None): + + if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: + if args.prefetcher and loader.mixup_enabled: + loader.mixup_enabled = False + elif mixup_fn is not None: + mixup_fn.mixup_enabled = False + + second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + losses_m = AverageMeter() + + model.train() + + end = time.time() + last_idx = len(loader) - 1 + num_updates = epoch * len(loader) + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + data_time_m.update(time.time() - end) + if not args.prefetcher: + input, target = input.cuda(), target.cuda() + if mixup_fn is not None: + input, target = mixup_fn(input, target) + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + with amp_autocast(): + output = model(input) + loss = loss_fn(output, target) + + if not args.distributed: + losses_m.update(loss.item(), input.size(0)) + + optimizer.zero_grad() + if loss_scaler is not None: + loss_scaler( + loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) + else: + loss.backward(create_graph=second_order) + if args.clip_grad is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) + optimizer.step() + + if model_ema is not None: + model_ema.update(model) + + torch.cuda.synchronize() + num_updates += 1 + batch_time_m.update(time.time() - end) + if last_batch or batch_idx % args.log_interval == 0: + lrl = [param_group['lr'] for param_group in optimizer.param_groups] + lr = sum(lrl) / len(lrl) + + if args.distributed: + reduced_loss = reduce_tensor(loss.data, args.world_size) + losses_m.update(reduced_loss.item(), input.size(0)) + + if args.local_rank == 0: + _logger.info( + 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' + 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' + 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'LR: {lr:.3e} ' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx, len(loader), + 100. * batch_idx / last_idx, + loss=losses_m, + batch_time=batch_time_m, + rate=input.size(0) * args.world_size / batch_time_m.val, + rate_avg=input.size(0) * args.world_size / batch_time_m.avg, + lr=lr, + data_time=data_time_m)) + + if args.save_images and output_dir: + torchvision.utils.save_image( + input, + os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), + padding=0, + normalize=True) + + if saver is not None and args.recovery_interval and ( + last_batch or (batch_idx + 1) % args.recovery_interval == 0): + saver.save_recovery(epoch, batch_idx=batch_idx) + + if lr_scheduler is not None: + lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + + end = time.time() + # end for + + if hasattr(optimizer, 'sync_lookahead'): + optimizer.sync_lookahead() + + return OrderedDict([('loss', losses_m.avg)]) + + +def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): + batch_time_m = AverageMeter() + losses_m = AverageMeter() + top1_m = AverageMeter() + top5_m = AverageMeter() + + model.eval() + + end = time.time() + last_idx = len(loader) - 1 + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + if not args.prefetcher: + input = input.cuda() + target = target.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + with amp_autocast(): + output = model(input) + if isinstance(output, (tuple, list)): + output = output[0] + + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] + + loss = loss_fn(output, target) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + if args.distributed: + reduced_loss = reduce_tensor(loss.data, args.world_size) + acc1 = reduce_tensor(acc1, args.world_size) + acc5 = reduce_tensor(acc5, args.world_size) + else: + reduced_loss = loss.data + + torch.cuda.synchronize() + + losses_m.update(reduced_loss.item(), input.size(0)) + top1_m.update(acc1.item(), output.size(0)) + top5_m.update(acc5.item(), output.size(0)) + + batch_time_m.update(time.time() - end) + end = time.time() + if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): + log_name = 'Test' + log_suffix + _logger.info( + '{0}: [{1:>4d}/{2}] ' + 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + log_name, batch_idx, last_idx, batch_time=batch_time_m, + loss=losses_m, top1=top1_m, top5=top5_m)) + + metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + + return metrics + + +if __name__ == '__main__': + main() From 587514c36b3231c4a5e5f8b1f547f5a6eb081c24 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sat, 13 Feb 2021 17:31:27 -0500 Subject: [PATCH 2/8] sparsezoo recipes and weights integration --- examples/timm-sparseml/README.md | 59 +++++++++++++++++++++++--- examples/timm-sparseml/main.py | 72 ++++++++++++++++++++++++++++---- 2 files changed, 118 insertions(+), 13 deletions(-) diff --git a/examples/timm-sparseml/README.md b/examples/timm-sparseml/README.md index 5a5ffc511a7..67270abf25b 100644 --- a/examples/timm-sparseml/README.md +++ b/examples/timm-sparseml/README.md @@ -15,7 +15,7 @@ limitations under the License. --> # SparseML-rwightman/pytorch-image-models integration -This directory provides a training script for the popular +This directory provides a SparseML integrated training script for the popular [rwightman/pytorch-image-models](https://github.com/rwightman/pytorch-image-models) repository also known as [timm](https://pypi.org/project/timm/). @@ -26,14 +26,15 @@ Some of the tasks you can perform using this integration include, but are not li * model pruning * quantization-aware-training * sparse quantization-aware-training +* sparse transfer learning ## Installation -Both requirements can be installed via `pip install ` or can be cloned +Both requirements can be installed via `pip` or can be cloned and installed from their respective source repositories. ```bash pip install git+https://github.com/rwightman/pytorch-image-models.git -pip install spraseml[torchvision] +pip install sparseml[torchvision] ``` @@ -41,7 +42,23 @@ pip install spraseml[torchvision] `examples/timm-sparseml/main.py` modifies [`train.py`](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) from pytorch-image-models to include a `sparseml-recipe-path` argument -to run SparseML optimizations with. Running the script will +to run SparseML optimizations with. This can be a file path to a local +SparseML recipe or a SparseZoo model stub prefixed by `zoo:` such as +`zoo:cv-classification/resnet_v1-50/pytorch-rwightman/imagenet-augmented/pruned_quant-aggressive`. + +Additionally, for sparse transfer learning, the flag `--sparse-transfer-learn` +was added. Running the script with this flag will add modifiers to the given +recipe that will keep the base sparsity constant during training, allowing +the model to learn the new dataset while keeping the same optimized structure. +If a SparseZoo recipe path is provided with sparse transfer learning enabled, +then the the model's specific "transfer" recipe will be loaded instead. + +To load the base weights for a SparseZoo recipe as the initial checkpoint, set +`--initial-checkpoint` to `zoo`. To use the weights of a SparseZoo model as the +initial checkpoint, pass that model's SparseZoo stub prefixed by `zoo:` to the +`--initial-checkpoint` argument. + +Running the script will follow the normal pytorch-image-models training flow with the given SparseML optimizations enabled. @@ -60,19 +77,49 @@ documentation, or export one with [Sparsify](https://github.com/neuralmagic/spar Documentation on the original script can be found [here](https://rwightman.github.io/pytorch-image-models/scripts/). -The latest commit hash that `main.py` is included in the docstring. +The latest commit hash that `main.py` is based on is included in the docstring. #### Example Command +Training from a local recipe and checkpoint ```bash python examples/timm-sparseml/main.py \ /PATH/TO/DATASET/imagenet/ \ --sparseml-recipe-path /PATH/TO/RECIPE/recipe.yaml \ + --initial-checkpoint PATH/TO/CHECKPOINT/model.pth \ + --dataset imagenet \ + --batch-size 64 \ + --remode pixel --reprob 0.6 --smoothing 0.1 \ + --output models/optimized \ + --model resnet50 \ + --workers 8 \ +``` + +Training from a local recipe and SparseZoo checkpoint +```bash +python examples/timm-sparseml/main.py \ + /PATH/TO/DATASET/imagenet/ \ + --sparseml-recipe-path /PATH/TO/RECIPE/recipe.yaml \ + --initial-checkpoint zoo:model/stub/path \ --dataset imagenet \ --batch-size 64 \ --remode pixel --reprob 0.6 --smoothing 0.1 \ --output models/optimized \ --model resnet50 \ - --initial-checkpoint PATH/TO/CHECKPOINT/model.pth \ --workers 8 \ ``` + +Training from a SparseZoo recipe and checkpoint with sparse transfer learning enabled +```bash +python examples/timm-sparseml/main.py \ + /PATH/TO/DATASET/imagenet/ \ + --sparseml-recipe-path zoo:model/stub/path \ + --initial-checkpoint zoo \ + --sparse-transfer-learn \ + --dataset imagenet \ + --batch-size 64 \ + --remode pixel --reprob 0.6 --smoothing 0.1 \ + --output models/optimized \ + --model resnet50 \ + --workers 8 \ +``` \ No newline at end of file diff --git a/examples/timm-sparseml/main.py b/examples/timm-sparseml/main.py index 6d2a44de9fc..fbf40aa52d8 100755 --- a/examples/timm-sparseml/main.py +++ b/examples/timm-sparseml/main.py @@ -10,7 +10,7 @@ This script is adapted from https://github.com/rwightman/pytorch-image-models/blob/master/train.py to apply a SparseML recipe from the required `--sparseml-recipe-path` argument. -Integration lines are preceded by commend blocks. Run with `--help` for help printout, +Integration lines are preceded by comment blocks. Run with `--help` for help printout, more information can be found in the readme file. Latest pytorch-image-models commit this script is based on: aaa715b @@ -56,6 +56,7 @@ from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer from sparseml.pytorch.utils import ModuleExporter, PythonLogger, TensorBoardLogger +from sparsezoo import Zoo try: from apex import amp @@ -84,9 +85,29 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -# SparseML Recipe parameter -parser.add_argument('--sparseml-recipe-path', required=True, type=str, - help='path to a SparseML YAML or markdown recipe file') + +#################################################################################### +# Start SparseML arguments +#################################################################################### +parser.add_argument( + "--sparseml-recipe-path", + required=True, + type=str, + help="path to a SparseML recipe file or a SparseZoo model stub for a recipe to load. " + "SparseZoo stubs should be preceded by 'zoo:'. i.e. '/path/to/local/recipe.yaml', " + "'zoo:zoo/model/stub'" +) +parser.add_argument( + "--sparse-transfer-learn", + action="store_true", + help="Enable sparse transfer learning modifiers to enforce the sparsity " + "if the recipe comes from a local file, modifiers will be added to the manager " + "to hold already sparse layers at the same sparsity level. If the recipe comes " + "from SparseZoo, the 'transfer' recipe for the model will be loaded instead", +) +#################################################################################### +# End SparseML arguments +#################################################################################### # Dataset / Model parameters parser.add_argument('data_dir', metavar='DIR', @@ -102,7 +123,10 @@ parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', - help='Initialize model from this checkpoint (default: none)') + help='Initialize model from this checkpoint (default: none). ' + 'can pass in "zoo" if using a SparseZoo recipe to load that recipes ' + 'base weights, or pass in a SparseZoo model stub, prefixed with "zoo:" to ' + 'load weights directly from SparseZoo') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--no-resume-opt', action='store_true', default=False, @@ -357,6 +381,33 @@ def main(): torch.manual_seed(args.seed + args.rank) + #################################################################################### + # Start - SparseML optional load weights from SparseZoo + #################################################################################### + if args.initial_checkpoint == "zoo": + # Load checkpoint from base weights associated with given SparseZoo recipe + if args.sparseml_recipe_path.startswith("zoo:"): + recipe_type = "transfer" if args.sparse_transfer_learn else "original" + args.initial_checkpoint = Zoo.download_recipe_base_framework_files( + args.sparseml_recipe_path, + recipe_type=recipe_type, + extensions=[".pth.tar", ".pth"] + )[0] + else: + raise ValueError( + "Attempting to load weights from SparseZoo recipe, but not given a " + "SparseZoo recipe stub. When initial-checkpoint is set to 'zoo'. " + "sparseml-recipe-path must start with 'zoo:' and be a SparseZoo model " + f"stub. sparseml-recipe-path was set to {args.sparseml_recipe_path}" + ) + elif args.initial_checkpoint.startswith("zoo:"): + # Load weights from a SparseZoo model stub + zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint) + args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"]) + #################################################################################### + # End - SparseML optional load weights from SparseZoo + #################################################################################### + model = create_model( args.model, pretrained=args.pretrained, @@ -599,7 +650,14 @@ def main(): if output_dir else None ) - manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe_path) + # determine recipe type to be used if loading from SparseZoo + if args.sparseml_recipe_path.startswith("zoo:"): + zoo_recipe_type = "transfer" if args.sparse_transfer_learn else "original" + else: + zoo_recipe_type = None + manager = ScheduledModifierManager.from_yaml( + args.sparseml_recipe_path, zoo_recipe_type=zoo_recipe_type + ) optimizer = ScheduledOptimizer( optimizer, model, @@ -658,7 +716,7 @@ def main(): ################################################################################# if output_dir: _logger.info( - "training complete, exporting ONNX to {}/model.onnx".format(output_dir) + f"training complete, exporting ONNX to {output_dir}/model.onnx" ) exporter = ModuleExporter(model, output_dir) exporter.export_onnx(torch.randn((1, *data_config["input_size"]))) From 908b84160d3a9a5e176a6a01948377067f72e2ff Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 19 Feb 2021 14:59:58 -0500 Subject: [PATCH 3/8] splitbn qat fusing, rebase to main, address comments --- .../timm}/README.md | 19 ++- .../timm}/main.py | 142 ++++++++---------- .../pytorch/optim/modifier_quantization.py | 11 +- .../pytorch/optim/quantization/helpers.py | 39 ++++- 4 files changed, 108 insertions(+), 103 deletions(-) rename {examples/timm-sparseml => integrations/timm}/README.md (88%) rename {examples/timm-sparseml => integrations/timm}/main.py (88%) diff --git a/examples/timm-sparseml/README.md b/integrations/timm/README.md similarity index 88% rename from examples/timm-sparseml/README.md rename to integrations/timm/README.md index 67270abf25b..2c20d888be4 100644 --- a/examples/timm-sparseml/README.md +++ b/integrations/timm/README.md @@ -39,9 +39,9 @@ pip install sparseml[torchvision] ## Script -`examples/timm-sparseml/main.py` modifies +`integrations/timm/main.py` modifies [`train.py`](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) -from pytorch-image-models to include a `sparseml-recipe-path` argument +from pytorch-image-models to include a `sparseml-recipe` argument to run SparseML optimizations with. This can be a file path to a local SparseML recipe or a SparseZoo model stub prefixed by `zoo:` such as `zoo:cv-classification/resnet_v1-50/pytorch-rwightman/imagenet-augmented/pruned_quant-aggressive`. @@ -64,9 +64,8 @@ SparseML optimizations enabled. Some considerations: -* `--sparseml-recipe-path` is a required parameter +* `--sparseml-recipe` is a required parameter * `--epochs` will now be overridden by the epochs set in the SparseML recipe -* All learning rate parameters and schedulers from the original script will be overritten by learning rate modifiers in the SparseML recipe * Modifiers will log their outputs to the console as well as to a tensorboard file * After training is complete, the final model will be exported to ONNX using SparseML @@ -83,9 +82,9 @@ The latest commit hash that `main.py` is based on is included in the docstring. #### Example Command Training from a local recipe and checkpoint ```bash -python examples/timm-sparseml/main.py \ +python integrations/timm/main.py \ /PATH/TO/DATASET/imagenet/ \ - --sparseml-recipe-path /PATH/TO/RECIPE/recipe.yaml \ + --sparseml-recipe /PATH/TO/RECIPE/recipe.yaml \ --initial-checkpoint PATH/TO/CHECKPOINT/model.pth \ --dataset imagenet \ --batch-size 64 \ @@ -97,9 +96,9 @@ python examples/timm-sparseml/main.py \ Training from a local recipe and SparseZoo checkpoint ```bash -python examples/timm-sparseml/main.py \ +python integrations/timm/main.py \ /PATH/TO/DATASET/imagenet/ \ - --sparseml-recipe-path /PATH/TO/RECIPE/recipe.yaml \ + --sparseml-recipe /PATH/TO/RECIPE/recipe.yaml \ --initial-checkpoint zoo:model/stub/path \ --dataset imagenet \ --batch-size 64 \ @@ -111,9 +110,9 @@ python examples/timm-sparseml/main.py \ Training from a SparseZoo recipe and checkpoint with sparse transfer learning enabled ```bash -python examples/timm-sparseml/main.py \ +python integrations/timm/main.py \ /PATH/TO/DATASET/imagenet/ \ - --sparseml-recipe-path zoo:model/stub/path \ + --sparseml-recipe zoo:model/stub/path \ --initial-checkpoint zoo \ --sparse-transfer-learn \ --dataset imagenet \ diff --git a/examples/timm-sparseml/main.py b/integrations/timm/main.py similarity index 88% rename from examples/timm-sparseml/main.py rename to integrations/timm/main.py index fbf40aa52d8..49d91d02c07 100755 --- a/examples/timm-sparseml/main.py +++ b/integrations/timm/main.py @@ -9,7 +9,7 @@ Integration between https://github.com/rwightman/pytorch-image-models and SparseML This script is adapted from https://github.com/rwightman/pytorch-image-models/blob/master/train.py -to apply a SparseML recipe from the required `--sparseml-recipe-path` argument. +to apply a SparseML recipe from the required `--sparseml-recipe` argument. Integration lines are preceded by comment blocks. Run with `--help` for help printout, more information can be found in the readme file. @@ -54,9 +54,10 @@ from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler -from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer +from sparseml.pytorch.optim import ScheduledModifierManager from sparseml.pytorch.utils import ModuleExporter, PythonLogger, TensorBoardLogger from sparsezoo import Zoo +import warnings try: from apex import amp @@ -90,21 +91,13 @@ # Start SparseML arguments #################################################################################### parser.add_argument( - "--sparseml-recipe-path", + "--sparseml-recipe", required=True, type=str, help="path to a SparseML recipe file or a SparseZoo model stub for a recipe to load. " "SparseZoo stubs should be preceded by 'zoo:'. i.e. '/path/to/local/recipe.yaml', " "'zoo:zoo/model/stub'" ) -parser.add_argument( - "--sparse-transfer-learn", - action="store_true", - help="Enable sparse transfer learning modifiers to enforce the sparsity " - "if the recipe comes from a local file, modifiers will be added to the manager " - "to hold already sparse layers at the same sparsity level. If the recipe comes " - "from SparseZoo, the 'transfer' recipe for the model will be loaded instead", -) #################################################################################### # End SparseML arguments #################################################################################### @@ -169,41 +162,38 @@ # Learning rate schedule parameters -#################################################################################### -# SparseML Integration, hide lr args, they will be overridden by SparseML -#################################################################################### -# parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', -# help='LR scheduler (default: "step"') +parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "step"') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help=argparse.SUPPRESS) # hide from help text -# parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', -# help='learning rate noise on/off epoch percentages') -# parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', -# help='learning rate noise limit percent (default: 0.67)') -# parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', -# help='learning rate noise std-dev (default: 1.0)') -# parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', -# help='learning rate cycle len multiplier (default: 1.0)') -# parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', -# help='learning rate cycle limit') -# parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', -# help='warmup learning rate (default: 0.0001)') -# parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', -# help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') -# # parser.add_argument('--epochs', type=int, default=200, metavar='N', -# # help='number of epochs to train (default: 2)') -# parser.add_argument('--start-epoch', default=None, type=int, metavar='N', -# help='manual epoch number (useful on restarts)') -# parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', -# help='epoch interval to decay LR') -# parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', -# help='epochs to warmup LR, if scheduler supports') -# parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', -# help='epochs to cooldown LR at min_lr, after cyclic schedule ends') -# parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', -# help='patience epochs for Plateau LR scheduler (default: 10') -# parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', -# help='LR decay rate (default: 0.1)') + help='learning rate') +parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') +parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') +parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') +parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', + help='learning rate cycle len multiplier (default: 1.0)') +parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', + help='learning rate cycle limit') +parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', + help='warmup learning rate (default: 0.0001)') +parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') +parser.add_argument('--epochs', type=int, default=200, metavar='N', + help='number of epochs to train (default: 2)') +parser.add_argument('--start-epoch', default=None, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') +parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', + help='epochs to warmup LR, if scheduler supports') +parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') +parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') +parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') #################################################################################### # End SparseML integration hide LR args @@ -386,19 +376,17 @@ def main(): #################################################################################### if args.initial_checkpoint == "zoo": # Load checkpoint from base weights associated with given SparseZoo recipe - if args.sparseml_recipe_path.startswith("zoo:"): - recipe_type = "transfer" if args.sparse_transfer_learn else "original" + if args.sparseml_recipe.startswith("zoo:"): args.initial_checkpoint = Zoo.download_recipe_base_framework_files( - args.sparseml_recipe_path, - recipe_type=recipe_type, + args.sparseml_recipe, extensions=[".pth.tar", ".pth"] )[0] else: raise ValueError( "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When initial-checkpoint is set to 'zoo'. " - "sparseml-recipe-path must start with 'zoo:' and be a SparseZoo model " - f"stub. sparseml-recipe-path was set to {args.sparseml_recipe_path}" + "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " + f"stub. sparseml-recipe was set to {args.sparseml_recipe}" ) elif args.initial_checkpoint.startswith("zoo:"): # Load weights from a SparseZoo model stub @@ -517,24 +505,15 @@ def main(): # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch - #################################################################################### - # SparseML integration suppress lr_scheduler - # let SparseML recipe handle LR schedule - # set epoch range later using recipe - #################################################################################### - lr_scheduler = None - # lr_scheduler, num_epochs = create_scheduler(args, optimizer) - # start_epoch = 0 - # if args.start_epoch is not None: - # # a specified start_epoch will always override the resume epoch - # start_epoch = args.start_epoch - # elif resume_epoch is not None: - # start_epoch = resume_epoch - # if lr_scheduler is not None and start_epoch > 0: - # lr_scheduler.step(start_epoch) - #################################################################################### - # End SparseML integration suppress lr_scheduler - #################################################################################### + lr_scheduler, num_epochs = create_scheduler(args, optimizer) + start_epoch = 0 + if args.start_epoch is not None: + # a specified start_epoch will always override the resume epoch + start_epoch = args.start_epoch + elif resume_epoch is not None: + start_epoch = resume_epoch + if lr_scheduler is not None and start_epoch > 0: + lr_scheduler.step(start_epoch) # create the train and eval datasets dataset_train = create_dataset( @@ -650,23 +629,22 @@ def main(): if output_dir else None ) - # determine recipe type to be used if loading from SparseZoo - if args.sparseml_recipe_path.startswith("zoo:"): - zoo_recipe_type = "transfer" if args.sparse_transfer_learn else "original" - else: - zoo_recipe_type = None - manager = ScheduledModifierManager.from_yaml( - args.sparseml_recipe_path, zoo_recipe_type=zoo_recipe_type - ) - optimizer = ScheduledOptimizer( - optimizer, + manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe) + manager.initialize( model, - manager, + optimizer, steps_per_epoch=len(loader_train), loggers=sparseml_loggers ) - start_epoch = manager.min_epochs # override min_epochs - num_epochs = manager.max_epochs or num_epochs # override num_epochs + # override lr scheduler if recipe makes any LR updates + if any("LearningRate" in str(modifier) for modifier in manager.modifiers): + _logger.info("Disabling timm LR scheduler, managing LR using SparseML recipe") + lr_scheduler = None + if manager.max_epochs: + _logger.info( + f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe" + ) + num_epochs = manager.max_epochs or num_epochs #################################################################################### # End SparseML Integration #################################################################################### diff --git a/src/sparseml/pytorch/optim/modifier_quantization.py b/src/sparseml/pytorch/optim/modifier_quantization.py index 93867496ace..e1c980e9dd0 100644 --- a/src/sparseml/pytorch/optim/modifier_quantization.py +++ b/src/sparseml/pytorch/optim/modifier_quantization.py @@ -19,7 +19,7 @@ """ -from typing import List, Union +from typing import Any, Dict, List, Union from torch.nn import Module from torch.optim.optimizer import Optimizer @@ -75,6 +75,8 @@ class QuantizationModifier(ScheduledModifier): None to not stop tracking batch norm stats during QAT. Default is None :param end_epoch: Disabled, setting to anything other than -1 will raise an exception. For compatibility with YAML serialization only. + :param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed + to the model fusing function """ def __init__( @@ -85,6 +87,7 @@ def __init__( disable_quantization_observer_epoch: Union[float, None] = None, freeze_bn_stats_epoch: Union[float, None] = None, end_epoch: float = -1, + model_fuse_fn_kwargs: Dict[str, Any] = None, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -103,6 +106,7 @@ def __init__( self._start_epoch = start_epoch self._submodules = submodules self._model_fuse_fn_name = model_fuse_fn_name + self._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {} self._disable_quantization_observer_epoch = disable_quantization_observer_epoch self._freeze_bn_stats_epoch = freeze_bn_stats_epoch @@ -254,9 +258,10 @@ def update( self._model_fuse_fn_name ) ) - module_fuse_fn() + module_fuse_fn(**self._model_fuse_fn_kwargs) elif self._model_fuse_fn_name is None: # default auto fn - fuse_module_conv_bn_relus(module, inplace=True) + self._model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs) # prepare each module / submodule for quantization qconfig = get_qat_qconfig() for quant_module in self._modules_to_quantize: diff --git a/src/sparseml/pytorch/optim/quantization/helpers.py b/src/sparseml/pytorch/optim/quantization/helpers.py index 9172637b6bf..d438fe662d6 100644 --- a/src/sparseml/pytorch/optim/quantization/helpers.py +++ b/src/sparseml/pytorch/optim/quantization/helpers.py @@ -103,7 +103,11 @@ def get_qat_qconfig() -> torch_quantization.QConfig: ) -def fuse_module_conv_bn_relus(module: Module, inplace: bool = True) -> Module: +def fuse_module_conv_bn_relus( + module: Module, + inplace: bool = True, + override_bn_subclasses: bool = True, +) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the given module. To be fused, these layers must appear sequentially in @@ -115,6 +119,10 @@ def fuse_module_conv_bn_relus(module: Module, inplace: bool = True) -> Module: :param module: the module to fuse :param inplace: set True to perform fusions in-place. default is True + :param override_bn_subclasses: if true, modules that are subclasses of + BatchNorm2d will be modified to be BatchNorm2d but with the forward + pass and state variables copied from the subclass. This is so these + BN modules can pass PyTorch type checking when fusing. Default is True :return: the fused module """ if torch_quantization is None: @@ -139,7 +147,15 @@ def fuse_module_conv_bn_relus(module: Module, inplace: bool = True) -> Module: and submodule_name == current_block_submodule_name ): if isinstance(layer, ReLU_nm): - _replace_nm_relu(module, name, layer) + _set_submodule(module, name, ReLU(inplace=layer.inplace)) + if ( + override_bn_subclasses + and isinstance(layer, BatchNorm2d) + and not type(layer) is BatchNorm2d + ): + # swap BN subclass with overwritten BN class that will pass torch + # type checking + _set_submodule(module, name, _wrap_bn_sub_class(layer)) current_block.append(name) else: if current_block: @@ -155,10 +171,17 @@ def fuse_module_conv_bn_relus(module: Module, inplace: bool = True) -> Module: return module -def _replace_nm_relu(root_module, relu_path, nm_relu): +def _set_submodule(root_module, sub_module_path, sub_module): current_module = root_module - relu_path = relu_path.split(".") - for sub_module in relu_path[:-1]: - current_module = getattr(current_module, sub_module) - new_relu = ReLU(inplace=nm_relu.inplace) - setattr(current_module, relu_path[-1], new_relu) + sub_module_path = sub_module_path.split(".") + for child_module in sub_module_path[:-1]: + current_module = getattr(current_module, child_module) + setattr(current_module, sub_module_path[-1], sub_module) + + +def _wrap_bn_sub_class(bn_subclass): + batch_norm = BatchNorm2d(bn_subclass.num_features) + batch_norm.__dict__ = bn_subclass.__dict__ + batch_norm.forward = bn_subclass.forward + del bn_subclass + return batch_norm From 29aecc3d17b6dcd21b55c54987d1f1952b9d2a6c Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 19 Feb 2021 16:11:35 -0500 Subject: [PATCH 4/8] option to override bn subclass during fusing without overriding the parent forward pass --- .../pytorch/optim/quantization/helpers.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/sparseml/pytorch/optim/quantization/helpers.py b/src/sparseml/pytorch/optim/quantization/helpers.py index d438fe662d6..0c8a0529a92 100644 --- a/src/sparseml/pytorch/optim/quantization/helpers.py +++ b/src/sparseml/pytorch/optim/quantization/helpers.py @@ -17,6 +17,7 @@ """ from copy import deepcopy +from typing import Union import torch from torch.nn import BatchNorm2d, Conv2d, Module, ReLU @@ -106,7 +107,7 @@ def get_qat_qconfig() -> torch_quantization.QConfig: def fuse_module_conv_bn_relus( module: Module, inplace: bool = True, - override_bn_subclasses: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, ) -> Module: """ Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the @@ -119,10 +120,12 @@ def fuse_module_conv_bn_relus( :param module: the module to fuse :param inplace: set True to perform fusions in-place. default is True - :param override_bn_subclasses: if true, modules that are subclasses of + :param override_bn_subclasses_forward: if True, modules that are subclasses of BatchNorm2d will be modified to be BatchNorm2d but with the forward pass and state variables copied from the subclass. This is so these - BN modules can pass PyTorch type checking when fusing. Default is True + BN modules can pass PyTorch type checking when fusing. Can set to + "override-only" and only parameters will be overwritten, not the + forward pass. Default is True :return: the fused module """ if torch_quantization is None: @@ -148,14 +151,21 @@ def fuse_module_conv_bn_relus( ): if isinstance(layer, ReLU_nm): _set_submodule(module, name, ReLU(inplace=layer.inplace)) - if ( - override_bn_subclasses - and isinstance(layer, BatchNorm2d) - and not type(layer) is BatchNorm2d - ): + if isinstance(layer, BatchNorm2d) and not type(layer) is BatchNorm2d: + if not override_bn_subclasses_forward: + raise RuntimeError( + "Detected a Conv-BN block that uses a subclass of BatchNorm2d. " + "This will cause a type error when fusing with PyTorch, " + "set override_bn_subclasses_forward to True or 'override-only " + "to modify this BN subclass to be a BatchNorm2d object" + ) # swap BN subclass with overwritten BN class that will pass torch # type checking - _set_submodule(module, name, _wrap_bn_sub_class(layer)) + overwritten_bn = _wrap_bn_sub_class( + layer, + override_forward=override_bn_subclasses_forward != "override-only", + ) + _set_submodule(module, name, overwritten_bn), current_block.append(name) else: if current_block: @@ -179,9 +189,10 @@ def _set_submodule(root_module, sub_module_path, sub_module): setattr(current_module, sub_module_path[-1], sub_module) -def _wrap_bn_sub_class(bn_subclass): +def _wrap_bn_sub_class(bn_subclass, override_forward=True): batch_norm = BatchNorm2d(bn_subclass.num_features) batch_norm.__dict__ = bn_subclass.__dict__ - batch_norm.forward = bn_subclass.forward + if override_forward: + batch_norm.forward = bn_subclass.forward del bn_subclass return batch_norm From 14ca60a04cbc0a85bfaa131ef0542f5cb7fbdbc7 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 19 Feb 2021 16:26:37 -0500 Subject: [PATCH 5/8] addressing comments --- integrations/timm/README.md | 15 +++++++-------- integrations/timm/main.py | 8 +++----- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/integrations/timm/README.md b/integrations/timm/README.md index 2c20d888be4..b2f762a0dba 100644 --- a/integrations/timm/README.md +++ b/integrations/timm/README.md @@ -46,12 +46,12 @@ to run SparseML optimizations with. This can be a file path to a local SparseML recipe or a SparseZoo model stub prefixed by `zoo:` such as `zoo:cv-classification/resnet_v1-50/pytorch-rwightman/imagenet-augmented/pruned_quant-aggressive`. -Additionally, for sparse transfer learning, the flag `--sparse-transfer-learn` -was added. Running the script with this flag will add modifiers to the given -recipe that will keep the base sparsity constant during training, allowing -the model to learn the new dataset while keeping the same optimized structure. -If a SparseZoo recipe path is provided with sparse transfer learning enabled, -then the the model's specific "transfer" recipe will be loaded instead. +Additionally, to run sparse transfer learning with a SparseZoo model that has +a transfer learning recipe, add `?recipe_type=transfer` as part of the model stub. +i.e. `zoo:cv-classification/resnet_v1-50/pytorch-rwightman/imagenet-augmented/pruned_quant-aggressive?recipe_type=transfer`. +This will run a recipe that holds the optimized sparsity structure the same while allowing +non-zero weights to be updated during training, so pre-learned optimizations can be applied +to different datasets. To load the base weights for a SparseZoo recipe as the initial checkpoint, set `--initial-checkpoint` to `zoo`. To use the weights of a SparseZoo model as the @@ -112,9 +112,8 @@ Training from a SparseZoo recipe and checkpoint with sparse transfer learning en ```bash python integrations/timm/main.py \ /PATH/TO/DATASET/imagenet/ \ - --sparseml-recipe zoo:model/stub/path \ + --sparseml-recipe zoo:model/stub/path?recipe_type=transfer \ --initial-checkpoint zoo \ - --sparse-transfer-learn \ --dataset imagenet \ --batch-size 64 \ --remode pixel --reprob 0.6 --smoothing 0.1 \ diff --git a/integrations/timm/main.py b/integrations/timm/main.py index 49d91d02c07..59832006be5 100755 --- a/integrations/timm/main.py +++ b/integrations/timm/main.py @@ -13,7 +13,9 @@ Integration lines are preceded by comment blocks. Run with `--help` for help printout, more information can be found in the readme file. -Latest pytorch-image-models commit this script is based on: aaa715b +Latest pytorch-image-models commit this script is based on: +https://github.com/rwightman/pytorch-image-models/tree/aaa715b1e94a8d10a2c0ff0f4abef7ddc97b2576 +(commit hash: aaa715b) Original doc-string: @@ -195,10 +197,6 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') -#################################################################################### -# End SparseML integration hide LR args -#################################################################################### - # Augmentation & regularization parameters parser.add_argument('--no-aug', action='store_true', default=False, help='Disable all training augmentation, override other train aug args') From 60ab8844f11dcf0e4b3979940d8a714232eaea70 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sun, 21 Feb 2021 02:04:44 -0500 Subject: [PATCH 6/8] rebasing on main --- integrations/timm/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integrations/timm/main.py b/integrations/timm/main.py index 59832006be5..669a6ce8536 100755 --- a/integrations/timm/main.py +++ b/integrations/timm/main.py @@ -56,7 +56,7 @@ from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler -from sparseml.pytorch.optim import ScheduledModifierManager +from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer from sparseml.pytorch.utils import ModuleExporter, PythonLogger, TensorBoardLogger from sparsezoo import Zoo import warnings @@ -628,9 +628,10 @@ def main(): else None ) manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe) - manager.initialize( - model, + optimizer = ScheduledOptimizer( optimizer, + model, + manager, steps_per_epoch=len(loader_train), loggers=sparseml_loggers ) From 75169ffe05747db9a27d11e712e5af328e67b6c2 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 23 Feb 2021 09:36:14 -0500 Subject: [PATCH 7/8] renaming script to train.py --- integrations/timm/README.md | 10 +++++----- integrations/timm/{main.py => train.py} | 0 2 files changed, 5 insertions(+), 5 deletions(-) rename integrations/timm/{main.py => train.py} (100%) diff --git a/integrations/timm/README.md b/integrations/timm/README.md index b2f762a0dba..22b2a808bb0 100644 --- a/integrations/timm/README.md +++ b/integrations/timm/README.md @@ -39,7 +39,7 @@ pip install sparseml[torchvision] ## Script -`integrations/timm/main.py` modifies +`integrations/timm/train.py` modifies [`train.py`](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) from pytorch-image-models to include a `sparseml-recipe` argument to run SparseML optimizations with. This can be a file path to a local @@ -76,13 +76,13 @@ documentation, or export one with [Sparsify](https://github.com/neuralmagic/spar Documentation on the original script can be found [here](https://rwightman.github.io/pytorch-image-models/scripts/). -The latest commit hash that `main.py` is based on is included in the docstring. +The latest commit hash that `train.py` is based on is included in the docstring. #### Example Command Training from a local recipe and checkpoint ```bash -python integrations/timm/main.py \ +python integrations/timm/train.py \ /PATH/TO/DATASET/imagenet/ \ --sparseml-recipe /PATH/TO/RECIPE/recipe.yaml \ --initial-checkpoint PATH/TO/CHECKPOINT/model.pth \ @@ -96,7 +96,7 @@ python integrations/timm/main.py \ Training from a local recipe and SparseZoo checkpoint ```bash -python integrations/timm/main.py \ +python integrations/timm/train.py \ /PATH/TO/DATASET/imagenet/ \ --sparseml-recipe /PATH/TO/RECIPE/recipe.yaml \ --initial-checkpoint zoo:model/stub/path \ @@ -110,7 +110,7 @@ python integrations/timm/main.py \ Training from a SparseZoo recipe and checkpoint with sparse transfer learning enabled ```bash -python integrations/timm/main.py \ +python integrations/timm/train.py \ /PATH/TO/DATASET/imagenet/ \ --sparseml-recipe zoo:model/stub/path?recipe_type=transfer \ --initial-checkpoint zoo \ diff --git a/integrations/timm/main.py b/integrations/timm/train.py similarity index 100% rename from integrations/timm/main.py rename to integrations/timm/train.py From 2ca2ffb39240d21680654ece802abf4fb9a5be48 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Tue, 23 Feb 2021 12:51:43 -0500 Subject: [PATCH 8/8] recipe_type=transfer -> transfer_learn --- integrations/timm/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/timm/README.md b/integrations/timm/README.md index 22b2a808bb0..a1e96289e10 100644 --- a/integrations/timm/README.md +++ b/integrations/timm/README.md @@ -47,8 +47,8 @@ SparseML recipe or a SparseZoo model stub prefixed by `zoo:` such as `zoo:cv-classification/resnet_v1-50/pytorch-rwightman/imagenet-augmented/pruned_quant-aggressive`. Additionally, to run sparse transfer learning with a SparseZoo model that has -a transfer learning recipe, add `?recipe_type=transfer` as part of the model stub. -i.e. `zoo:cv-classification/resnet_v1-50/pytorch-rwightman/imagenet-augmented/pruned_quant-aggressive?recipe_type=transfer`. +a transfer learning recipe, add `?recipe_type=transfer_learn` as part of the model stub. +i.e. `zoo:cv-classification/resnet_v1-50/pytorch-rwightman/imagenet-augmented/pruned_quant-aggressive?recipe_type=transfer_learn`. This will run a recipe that holds the optimized sparsity structure the same while allowing non-zero weights to be updated during training, so pre-learned optimizations can be applied to different datasets. @@ -112,7 +112,7 @@ Training from a SparseZoo recipe and checkpoint with sparse transfer learning en ```bash python integrations/timm/train.py \ /PATH/TO/DATASET/imagenet/ \ - --sparseml-recipe zoo:model/stub/path?recipe_type=transfer \ + --sparseml-recipe zoo:model/stub/path?recipe_type=transfer_learn \ --initial-checkpoint zoo \ --dataset imagenet \ --batch-size 64 \ @@ -120,4 +120,4 @@ python integrations/timm/train.py \ --output models/optimized \ --model resnet50 \ --workers 8 \ -``` \ No newline at end of file +```