From b9529983d3c9a8e04bc81aa6b670853e618ba6fa Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 06:25:57 -0700 Subject: [PATCH 01/13] Fix model load bug and add logging to catch potential future issues --- .../transformers/language_modeling.py | 2 +- .../transformers/question_answering.py | 2 +- .../transformers/sparsification/trainer.py | 7 + .../transformers/text_classification.py | 2 +- .../transformers/token_classification.py | 2 +- src/sparseml/transformers/utils/model.py | 210 ++++++++++++------ 6 files changed, 158 insertions(+), 67 deletions(-) diff --git a/src/sparseml/transformers/language_modeling.py b/src/sparseml/transformers/language_modeling.py index 613dd72bd61..12513bbadef 100644 --- a/src/sparseml/transformers/language_modeling.py +++ b/src/sparseml/transformers/language_modeling.py @@ -454,7 +454,7 @@ def main(): "use_auth_token": True if model_args.use_auth_token else None, }, teacher_name_or_path=model_args.distill_teacher, - teacher_kwars={ + teacher_kwargs={ "cache_dir": model_args.cache_dir, "use_auth_token": True if model_args.use_auth_token else None, }, diff --git a/src/sparseml/transformers/question_answering.py b/src/sparseml/transformers/question_answering.py index aee7c536832..45c6ab14fb7 100644 --- a/src/sparseml/transformers/question_answering.py +++ b/src/sparseml/transformers/question_answering.py @@ -437,7 +437,7 @@ def main(): "use_auth_token": True if model_args.use_auth_token else None, }, teacher_name_or_path=model_args.distill_teacher, - teacher_kwars={ + teacher_kwargs={ "cache_dir": model_args.cache_dir, "use_auth_token": True if model_args.use_auth_token else None, }, diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 4867afeca79..05cdb08221d 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -34,6 +34,7 @@ from sparseml.pytorch.optim.manager import ScheduledModifierManager from sparseml.pytorch.utils import WANDBLogger +from sparseml.transformers.utils import SparseAutoModel from sparseml.transformers.utils.helpers import RECIPE_REGEX, RECIPE_TEMPLATE @@ -415,6 +416,12 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]): _LOGGER.info( f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}" ) + SparseAutoModel.log_model_load( + self.model, + self.model_state_path, + model_type="student" if self.teacher else "model", + delayed_load=False, + ) class TrainerInterface(RecipeManagerTrainerInterface): diff --git a/src/sparseml/transformers/text_classification.py b/src/sparseml/transformers/text_classification.py index eecda792e8a..edafb5e2f45 100644 --- a/src/sparseml/transformers/text_classification.py +++ b/src/sparseml/transformers/text_classification.py @@ -444,7 +444,7 @@ def main(): "use_auth_token": True if model_args.use_auth_token else None, }, teacher_name_or_path=model_args.distill_teacher, - teacher_kwars={ + teacher_kwargs={ "cache_dir": model_args.cache_dir, "use_auth_token": True if model_args.use_auth_token else None, }, diff --git a/src/sparseml/transformers/token_classification.py b/src/sparseml/transformers/token_classification.py index 646963ce8ef..d4e6b64d16e 100644 --- a/src/sparseml/transformers/token_classification.py +++ b/src/sparseml/transformers/token_classification.py @@ -398,7 +398,7 @@ def get_label_list(labels): "use_auth_token": True if model_args.use_auth_token else None, }, teacher_name_or_path=model_args.distill_teacher, - teacher_kwars={ + teacher_kwargs={ "cache_dir": model_args.cache_dir, "use_auth_token": True if model_args.use_auth_token else None, }, diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py index d224ab99cc5..187aacff847 100644 --- a/src/sparseml/transformers/utils/model.py +++ b/src/sparseml/transformers/utils/model.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os from typing import Any, Dict, Optional, Tuple, Union -import numpy import torch from torch.nn import Module from transformers import ( @@ -27,6 +27,8 @@ ) from transformers.file_utils import WEIGHTS_NAME +from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity + __all__ = ["SparseAutoModel"] @@ -42,39 +44,46 @@ class SparseAutoModel: @staticmethod def masked_language_modeling_from_pretrained( model_name_or_path: str, + model_type: str, config: Any, **kwargs, ) -> Module: """ :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] :param config: the config for the model describing pipeline :param kwargs: keyword arguments to pass through to the AutoModel call :return: the created model for masked language modeling """ + delayed = False if not model_name_or_path: _LOGGER.info("Training new model from scratch") - return AutoModelForMaskedLM.from_config(config) - - SparseAutoModel._check_tf(model_name_or_path) - if not kwargs: - kwargs = {} - kwargs["from_tf"] = False - if "state_dict" not in kwargs: - kwargs["state_dict"] = SparseAutoModel._loadable_state_dict( - model_name_or_path + model = AutoModelForMaskedLM.from_config(config) + else: + SparseAutoModel._check_tf(model_name_or_path) + if not kwargs: + kwargs = {} + kwargs["from_tf"] = False + if "state_dict" not in kwargs: + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( + model_name_or_path + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_name_or_path, + **kwargs, ) - return AutoModelForSequenceClassification.from_pretrained( - model_name_or_path, - **kwargs, - ) + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model @staticmethod def masked_language_modeling_from_pretrained_distil( model_name_or_path: str, teacher_name_or_path: Optional[str], model_kwargs: Dict[str, Any], - teacher_kwars: Dict[str, Any], + teacher_kwargs: Dict[str, Any], ) -> Tuple[Module, Optional[Union[Module, str]]]: """ :param model_name_or_path: the name of or path to the model to load @@ -82,32 +91,39 @@ def masked_language_modeling_from_pretrained_distil( None or one of ['self', 'disable'] will not create a teacher and instead return the value passed in :param model_kwargs: the keyword args to pass into the AutoModel for model - :param teacher_kwars: the keyword args to pass into the AutoModel for teacher + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher :return: a tuple containing the model and distillation teacher (optional) for masked language modeling """ model = SparseAutoModel.masked_language_modeling_from_pretrained( - model_name_or_path, model_kwargs["config"], **model_kwargs + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + config=model_kwargs["config"], + **model_kwargs, ) teacher = ( SparseAutoModel.masked_language_modeling_from_pretrained( - teacher_name_or_path, None, **teacher_kwars + teacher_name_or_path, + model_type="teacher", + config=None, + **teacher_kwargs, ) if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] else teacher_name_or_path ) - if isinstance(teacher, Module): - SparseAutoModel._log_distillation_teacher_load(teacher) return model, teacher @staticmethod def question_answering_from_pretrained( model_name_or_path: str, + model_type: str, **kwargs, ) -> Module: """ :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] :param kwargs: keyword arguments to pass through to the AutoModel call :return: the created model for question answering """ @@ -115,22 +131,25 @@ def question_answering_from_pretrained( if not kwargs: kwargs = {} kwargs["from_tf"] = False + delayed = False if "state_dict" not in kwargs: - kwargs["state_dict"] = SparseAutoModel._loadable_state_dict( + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( model_name_or_path ) - - return AutoModelForQuestionAnswering.from_pretrained( + model = AutoModelForQuestionAnswering.from_pretrained( model_name_or_path, **kwargs, ) + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model @staticmethod def question_answering_from_pretrained_distil( model_name_or_path: str, teacher_name_or_path: Optional[str], model_kwargs: Dict[str, Any], - teacher_kwars: Dict[str, Any], + teacher_kwargs: Dict[str, Any], ) -> Tuple[Module, Optional[Union[Module, str]]]: """ :param model_name_or_path: the name of or path to the model to load @@ -138,32 +157,35 @@ def question_answering_from_pretrained_distil( None or one of ['self', 'disable'] will not create a teacher and instead return the value passed in :param model_kwargs: the keyword args to pass into the AutoModel for model - :param teacher_kwars: the keyword args to pass into the AutoModel for teacher + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher :return: a tuple containing the model and distillation teacher (optional) for question answering """ model = SparseAutoModel.question_answering_from_pretrained( - model_name_or_path, **model_kwargs + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + **model_kwargs, ) teacher = ( SparseAutoModel.question_answering_from_pretrained( - teacher_name_or_path, **teacher_kwars + teacher_name_or_path, model_type="teacher", **teacher_kwargs ) if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] else teacher_name_or_path ) - if isinstance(teacher, Module): - SparseAutoModel._log_distillation_teacher_load(teacher) return model, teacher @staticmethod def text_classification_from_pretrained( model_name_or_path: str, + model_type: str, **kwargs, ) -> Module: """ :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] :param kwargs: keyword arguments to pass through to the AutoModel call :return: the created model for text classification """ @@ -171,22 +193,25 @@ def text_classification_from_pretrained( if not kwargs: kwargs = {} kwargs["from_tf"] = False + delayed = False if "state_dict" not in kwargs: - kwargs["state_dict"] = SparseAutoModel._loadable_state_dict( + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( model_name_or_path ) - - return AutoModelForSequenceClassification.from_pretrained( + model = AutoModelForSequenceClassification.from_pretrained( model_name_or_path, **kwargs, ) + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model @staticmethod def text_classification_from_pretrained_distil( model_name_or_path: str, teacher_name_or_path: Optional[str], model_kwargs: Dict[str, Any], - teacher_kwars: Dict[str, Any], + teacher_kwargs: Dict[str, Any], ) -> Tuple[Module, Optional[Module]]: """ :param model_name_or_path: the name of or path to the model to load @@ -194,32 +219,35 @@ def text_classification_from_pretrained_distil( None or one of ['self', 'disable'] will not create a teacher and instead return the value passed in :param model_kwargs: the keyword args to pass into the AutoModel for model - :param teacher_kwars: the keyword args to pass into the AutoModel for teacher + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher :return: a tuple containing the model and distillation teacher (optional) for sequence/text classification """ model = SparseAutoModel.text_classification_from_pretrained( - model_name_or_path, **model_kwargs + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + **model_kwargs, ) teacher = ( SparseAutoModel.text_classification_from_pretrained( - teacher_name_or_path, **teacher_kwars + teacher_name_or_path, model_type="teacher", **teacher_kwargs ) if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] else teacher_name_or_path ) - if isinstance(teacher, Module): - SparseAutoModel._log_distillation_teacher_load(teacher) return model, teacher @staticmethod def token_classification_from_pretrained( model_name_or_path: str, + model_type: str, **kwargs, ) -> Module: """ :param model_name_or_path: the name of or path to the model to load + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] :param kwargs: keyword arguments to pass through to the AutoModel call :return: the created model for token classification """ @@ -227,22 +255,25 @@ def token_classification_from_pretrained( if not kwargs: kwargs = {} kwargs["from_tf"] = False + delayed = False if "state_dict" not in kwargs: kwargs["state_dict"] = SparseAutoModel._loadable_state_dict( model_name_or_path ) - - return AutoModelForTokenClassification.from_pretrained( + model = AutoModelForTokenClassification.from_pretrained( model_name_or_path, **kwargs, ) + SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed) + + return model @staticmethod def token_classification_from_pretrained_distil( model_name_or_path: str, teacher_name_or_path: Optional[str], model_kwargs: Dict[str, Any], - teacher_kwars: Dict[str, Any], + teacher_kwargs: Dict[str, Any], ) -> Tuple[Module, Optional[Module]]: """ :param model_name_or_path: the name of or path to the model to load @@ -250,31 +281,91 @@ def token_classification_from_pretrained_distil( None or one of ['self', 'disable'] will not create a teacher and instead return the value passed in :param model_kwargs: the keyword args to pass into the AutoModel for model - :param teacher_kwars: the keyword args to pass into the AutoModel for teacher + :param teacher_kwargs: the keyword args to pass into the AutoModel for teacher :return: a tuple containing the model and distillation teacher (optional) for token classification """ model = SparseAutoModel.token_classification_from_pretrained( - model_name_or_path, **model_kwargs + model_name_or_path, + model_type="student" if teacher_name_or_path else "model", + **model_kwargs, ) teacher = ( SparseAutoModel.token_classification_from_pretrained( - teacher_name_or_path, **teacher_kwars + teacher_name_or_path, model_type="teacher", **teacher_kwargs ) if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] else teacher_name_or_path ) - if isinstance(teacher, Module): - SparseAutoModel._log_distillation_teacher_load(teacher) return model, teacher @staticmethod - def _loadable_state_dict(model_name_or_path: str) -> Optional[Dict[str, Any]]: + def log_model_load( + model: Module, model_name_or_path: str, model_type: str, delayed_load: bool + ): + """ + Log the state of a loaded model including sparsity and + prunable params information. + + :param model: the loaded model + :param model_name_or_path: the original name of or path to the model that loaded + :param model_type: specify the type of model loaded for logging; + ex one of [model, student, teacher] + :param delayed_load: True if this model load was delayed until after + recipe instantiation due to QAT or other architectural state changes + """ + if delayed_load: + _LOGGER.info( + f"Delayed load of model {model_name_or_path} detected. " + f"Will print out model information once SparseML recipes have loaded" + ) + return + + model_params = list( + filter(lambda param: param.requires_grad, model.parameters()) + ) + total_params = sum(torch.numel(param) for param in model_params) + params_info = { + f"{name}.weight": { + "sparsity": tensor_sparsity(layer.weight).item(), + "numel": torch.numel(layer.weight), + } + for (name, layer) in get_prunable_layers(model) + } + prunable_sparse_params = sum( + round(param["numel"] * param["sparsity"]) for param in params_info.values() + ) + prunable_total_params = sum( + round(param["numel"]) for param in params_info.values() + ) + avg_prunable_sparsity = float(prunable_sparse_params) / prunable_total_params + + _LOGGER.info( + f"Loaded {model_type} from {model_name_or_path} " + f"with {total_params} total params. " + f"Of those there are {prunable_total_params} prunable params " + f"which have {avg_prunable_sparsity} avg sparsity." + ) + _LOGGER.info( + f"{'sparse' if avg_prunable_sparsity > 0.05 else 'dense'} model detected, " + f"prunable params info: {json.dumps(params_info)}" + ) + + @staticmethod + def _loadable_state_dict( + model_name_or_path: str, + ) -> Tuple[Optional[Dict[str, Any]], bool]: + """ + :param model_name_or_path: name of or path to model + :return: (loaded state dict, True if overriding state dict for delayed load) + delayed load happens when a QAT graph is detected since a recipe + must be applied first + """ if not model_name_or_path or not os.path.isfile( os.path.join(model_name_or_path, WEIGHTS_NAME) ): - return None + return None, False state_dict = torch.load( os.path.join(model_name_or_path, WEIGHTS_NAME), map_location="cpu" @@ -286,13 +377,15 @@ def _loadable_state_dict(model_name_or_path: str) -> Optional[Dict[str, Any]]: ] ) - if is_qat_state: - _LOGGER.warning( - "QAT state detected, ignore any loading errors, weights will reload " - f"after SparseML recipes have been applied {model_name_or_path}" - ) + if not is_qat_state: + return None, False - return {} + _LOGGER.warning( + "QAT state detected, ignore any loading errors, weights will reload " + f"after SparseML recipes have been applied {model_name_or_path}" + ) + + return {}, True @staticmethod def _check_tf(model_name_or_path: str): @@ -303,12 +396,3 @@ def _check_tf(model_name_or_path: str): "Detected a TensorFlow model from model_name_or_path: " f"{model_name_or_path}" ) - - @staticmethod - def _log_distillation_teacher_load(teacher: Module): - if teacher is None or _LOGGER is None: - return - - teacher_params = filter(lambda p: p.requires_grad, teacher.parameters()) - params = sum(numpy.prod(p.size()) for p in teacher_params) - _LOGGER.info("Loaded distillation teacher with %s parameters", params) From 9effb9fa5faef64af9867a7282bfce8c89ad70b7 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 07:20:55 -0700 Subject: [PATCH 02/13] initial migration to generalize module sparsification information --- src/sparseml/pytorch/utils/helpers.py | 67 +++++++++++-- src/sparseml/pytorch/utils/sparsification.py | 98 +++++++++++++++++++ .../transformers/sparsification/trainer.py | 5 + 3 files changed, 164 insertions(+), 6 deletions(-) create mode 100644 src/sparseml/pytorch/utils/sparsification.py diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 7a93832e973..a5da3f8b99b 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -27,15 +27,19 @@ import torch from torch import Tensor from torch.nn import Linear, Module, Parameter -from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader try: + from torch.nn.qat import QATConv2d, QATConv3d, QATLinear from torch.quantization import QuantWrapper except Exception: QuantWrapper = None + QATLinear = None + QATConv2d = None + QATConv3d = None from sparseml.utils import create_dirs, save_numpy @@ -64,6 +68,7 @@ "get_conv_layers", "get_linear_layers", "get_prunable_layers", + "get_quantizable_layers", "get_named_layers_and_params_by_regex", "any_str_or_regex_matches_param_name", "NamedLayerParam", @@ -751,13 +756,63 @@ def get_prunable_layers(module: Module) -> List[Tuple[str, Module]]: :return: a list containing the names and modules of the prunable layers (Linear, ConvNd) """ - layers = [] + return [ + (name, mod) + for (name, mod) in module.named_modules() + if ( + isinstance(mod, Linear) + or isinstance(mod, _ConvNd) + or (QATLinear and isinstance(mod, QATLinear)) + or (QATConv2d and isinstance(mod, QATConv2d)) + or (QATConv3d and isinstance(mod, QATConv3d)) + ) + ] + + +def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]: + """ + :param module: the module to get the quantizable layers from + :return: a list containing the names and modules of the quantizable layers + (Linear, Conv2d, Conv3d) + """ + if QATLinear is None: + raise ImportError( + "PyTorch version is not setup for Quantization. " + "Please install a QAT compatible version of PyTorch" + ) + + return [ + (name, mod) + for (name, mod) in module.named_modules() + if ( + isinstance(mod, Linear) + or isinstance(mod, Conv2d) + or isinstance(mod, Conv3d) + ) + ] - for name, mod in module.named_modules(): - if isinstance(mod, Linear) or isinstance(mod, _ConvNd): - layers.append((name, mod)) - return layers +def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]: + """ + :param module: the module to get the quantized layers from + :return: a list containing the names and modules of the quantized layers + (Linear, Conv2d, Conv3d) + """ + if QATLinear is None: + raise ImportError( + "PyTorch version is not setup for Quantization. " + "Please install a QAT compatible version of PyTorch" + ) + + return [ + (name, mod) + for (name, mod) in module.named_modules() + if ( + (QATLinear and isinstance(mod, QATLinear)) + or (QATConv2d and isinstance(mod, QATConv2d)) + or (QATConv3d and isinstance(mod, QATConv3d)) + ) + ] def get_layer_param(param: str, layer: str, module: Module) -> Parameter: diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py new file mode 100644 index 00000000000..58642517fe0 --- /dev/null +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Dict + +import torch +from torch.nn import Module + +from sparseml.pytorch.utils.helpers import ( + get_prunable_layers, + get_quantizable_layers, + get_quantized_layers, + tensor_sparsity, +) + + +class ModuleSparsificationInfo: + def __init__(self, module: Module): + self.module = module + self.trainable_params = list( + filter(lambda param: param.requires_grad, self.module.parameters()) + ) + + def __str__(self): + return json.dumps({}) + + @property + def params_total(self) -> int: + return sum(torch.numel(param) for param in self.trainable_params) + + @property + def params_sparse(self) -> int: + return sum(tensor_sparsity(param) for param in self.trainable_params) + + @property + def params_sparse_percent(self) -> float: + return self.params_sparse / float(self.params_total) * 100 + + @property + def params_prunable_total(self) -> int: + return sum( + torch.numel(layer.weight) + for (name, layer) in get_prunable_layers(self.module) + ) + + @property + def params_pruanble_sparse(self) -> int: + return sum( + tensor_sparsity(layer.weight) + for (name, layer) in get_prunable_layers(self.module) + ) + + @property + def params_prunable_sparse_percent(self) -> float: + return self.params_pruanble_sparse / float(self.params_prunable_total) * 100 + + @property + def params_quantizable(self) -> int: + return sum( + torch.numel(layer.weight) + + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias else 0) + for (name, layer) in get_quantizable_layers(self.module) + ) + + @property + def params_quantized(self) -> int: + return sum( + torch.numel(layer.weight) + + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias else 0) + for (name, layer) in get_quantized_layers(self.module) + ) + + @property + def params_quantized_percent(self) -> float: + return self.params_quantized / self.params_quantizable + + @property + def params_info(self) -> Dict[str, Dict]: + return { + f"{name}.weight": { + "numel": torch.numel(layer.weight), + "sparsity": tensor_sparsity(layer.weight).item(), + "quantized": hasattr(layer, "weight_fake_quant"), + } + for (name, layer) in get_prunable_layers(self.module) + } diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 05cdb08221d..1b1bad4b74d 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -591,6 +591,7 @@ def __init__(self, trainer: RecipeManagerTrainerInterface, *args, **kwargs): super().__init__(*args, **kwargs) self.trainer = trainer self.on_begin_called = False + self.quant_start_epoch = math.inf def check_disable(self, epoch: float, force: bool = False): if ( @@ -612,6 +613,7 @@ def disable_amp(self, epoch: float): if hasattr(self.trainer, "scaler"): self.trainer.scaler._enabled = False + self.quant_start_epoch = epoch _LOGGER.info(f"entering QAT phase at epoch {epoch}, disabling FP16 training") def on_epoch_begin( @@ -627,3 +629,6 @@ def on_epoch_begin( super().on_epoch_begin(args, state, control, **kwargs) self.on_begin_called = True self.check_disable(state.epoch) + + if state.epoch > self.quant_start_epoch: + _LOGGER.info(self.trainer.model) From f3b800fd664959b4d8886e7f48be3bfcd97fc96b Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 2 Feb 2022 09:43:53 -0500 Subject: [PATCH 03/13] propagate ModuleSparsificationInfo --- src/sparseml/pytorch/utils/__init__.py | 1 + src/sparseml/pytorch/utils/sparsification.py | 7 ++++ src/sparseml/transformers/utils/model.py | 37 +++++++------------- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/sparseml/pytorch/utils/__init__.py b/src/sparseml/pytorch/utils/__init__.py index a9d6985a204..e9a8eb656c7 100644 --- a/src/sparseml/pytorch/utils/__init__.py +++ b/src/sparseml/pytorch/utils/__init__.py @@ -27,6 +27,7 @@ from .mfac_helpers import * from .model import * from .module import * +from .sparsification import * from .ssd_helpers import * from .yolo_helpers import * diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index 58642517fe0..a41bd21870e 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Helper functions for retrieving information related to model sparsification +""" + import json from typing import Dict @@ -26,6 +30,9 @@ ) +__all__ = ["ModuleSparsificationInfo"] + + class ModuleSparsificationInfo: def __init__(self, module: Module): self.module = module diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py index 187aacff847..69f6f65c8d7 100644 --- a/src/sparseml/transformers/utils/model.py +++ b/src/sparseml/transformers/utils/model.py @@ -27,7 +27,7 @@ ) from transformers.file_utils import WEIGHTS_NAME -from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity +from sparseml.pytorch.utils import ModuleSparsificationInfo __all__ = ["SparseAutoModel"] @@ -322,34 +322,23 @@ def log_model_load( ) return - model_params = list( - filter(lambda param: param.requires_grad, model.parameters()) - ) - total_params = sum(torch.numel(param) for param in model_params) - params_info = { - f"{name}.weight": { - "sparsity": tensor_sparsity(layer.weight).item(), - "numel": torch.numel(layer.weight), - } - for (name, layer) in get_prunable_layers(model) - } - prunable_sparse_params = sum( - round(param["numel"] * param["sparsity"]) for param in params_info.values() - ) - prunable_total_params = sum( - round(param["numel"]) for param in params_info.values() - ) - avg_prunable_sparsity = float(prunable_sparse_params) / prunable_total_params + sparsification_info = ModuleSparsificationInfo(model) _LOGGER.info( f"Loaded {model_type} from {model_name_or_path} " - f"with {total_params} total params. " - f"Of those there are {prunable_total_params} prunable params " - f"which have {avg_prunable_sparsity} avg sparsity." + f"with {sparsification_info.params_total} total params. " + f"Of those there are {sparsification_info.params_prunable_total} prunable " + f"params which have {sparsification_info.params_prunable_sparse_percent} " + "avg sparsity." + ) + model_type = ( + "sparse" + if sparsification_info.params_prunable_sparse_percent > 0.05 + else "dense" ) _LOGGER.info( - f"{'sparse' if avg_prunable_sparsity > 0.05 else 'dense'} model detected, " - f"prunable params info: {json.dumps(params_info)}" + f"{model_type} model detected, " + f"prunable params info: {json.dumps(sparsification_info.params_info)}" ) @staticmethod From 060beb2c761dbd98fbb51c40100e6b649335a316 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 2 Feb 2022 09:52:57 -0500 Subject: [PATCH 04/13] report type of input tensors in export.py --- src/sparseml/transformers/export.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index 40f197aa44b..eb176a7a26f 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -177,8 +177,11 @@ def export_transformer_to_onnx( "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value ).data # Dict[Tensor] inputs_shapes = { - key: f"{type(val)}({val.shape if hasattr(val, 'shape') else 'unknown'})" - for key, val in inputs + key: ( + f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: " + f"{list(val.shape) if hasattr(val, 'shape') else 'unknown'}" + ) + for key, val in inputs.items() } _LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}") From 87d36f3413aa341ecbe8dc12ed72f7ccc1777df1 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 08:58:38 -0700 Subject: [PATCH 05/13] minor bug fixes --- src/sparseml/pytorch/utils/sparsification.py | 2 +- src/sparseml/transformers/utils/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index a41bd21870e..357a8329524 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -91,7 +91,7 @@ def params_quantized(self) -> int: @property def params_quantized_percent(self) -> float: - return self.params_quantized / self.params_quantizable + return self.params_quantized / float(self.params_quantizable) * 100 @property def params_info(self) -> Dict[str, Dict]: diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py index 69f6f65c8d7..09c8732f1c2 100644 --- a/src/sparseml/transformers/utils/model.py +++ b/src/sparseml/transformers/utils/model.py @@ -333,7 +333,7 @@ def log_model_load( ) model_type = ( "sparse" - if sparsification_info.params_prunable_sparse_percent > 0.05 + if sparsification_info.params_prunable_sparse_percent > 5 else "dense" ) _LOGGER.info( From 32c7343b28be26cf75f67faff6937f4e3e64d8ec Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 2 Feb 2022 11:07:04 -0500 Subject: [PATCH 06/13] ModuleSparsificationInfo docs --- src/sparseml/pytorch/utils/sparsification.py | 40 +++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index 357a8329524..0265556631b 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -34,6 +34,14 @@ class ModuleSparsificationInfo: + """ + Helper class for providing information related to torch Module parameters + and the amount of sparsification applied. Includes information for pruning + and quantization + + :param module: torch Module to analyze + """ + def __init__(self, module: Module): self.module = module self.trainable_params = list( @@ -41,22 +49,34 @@ def __init__(self, module: Module): ) def __str__(self): - return json.dumps({}) + return json.dumps(self.params_info) @property def params_total(self) -> int: + """ + :return: total number of trainable parameters in the model + """ return sum(torch.numel(param) for param in self.trainable_params) @property def params_sparse(self) -> int: + """ + :return: total number of sparse (0) trainable parameters in the model + """ return sum(tensor_sparsity(param) for param in self.trainable_params) @property def params_sparse_percent(self) -> float: + """ + :return: percent of sparsified parameters in the entire model + """ return self.params_sparse / float(self.params_total) * 100 @property def params_prunable_total(self) -> int: + """ + :return: total number of parameters across prunable layers + """ return sum( torch.numel(layer.weight) for (name, layer) in get_prunable_layers(self.module) @@ -64,6 +84,9 @@ def params_prunable_total(self) -> int: @property def params_pruanble_sparse(self) -> int: + """ + :return: total number of sparse (0) parameters across prunable lauyers + """ return sum( tensor_sparsity(layer.weight) for (name, layer) in get_prunable_layers(self.module) @@ -71,10 +94,16 @@ def params_pruanble_sparse(self) -> int: @property def params_prunable_sparse_percent(self) -> float: + """ + :return: percent of prunable parameters that have been pruned + """ return self.params_pruanble_sparse / float(self.params_prunable_total) * 100 @property def params_quantizable(self) -> int: + """ + :return: number of parameters that are included in quantizable layers + """ return sum( torch.numel(layer.weight) + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias else 0) @@ -83,6 +112,9 @@ def params_quantizable(self) -> int: @property def params_quantized(self) -> int: + """ + :return: number of parameters across quantized layers + """ return sum( torch.numel(layer.weight) + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias else 0) @@ -91,10 +123,16 @@ def params_quantized(self) -> int: @property def params_quantized_percent(self) -> float: + """ + :return: percentage of parameters that have been quantized + """ return self.params_quantized / float(self.params_quantizable) * 100 @property def params_info(self) -> Dict[str, Dict]: + """ + :return: dict of parameter name to its sparsification information + """ return { f"{name}.weight": { "numel": torch.numel(layer.weight), From c6d040f039b0c15e9008318634f2f1c1f2428d94 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 2 Feb 2022 11:45:22 -0500 Subject: [PATCH 07/13] export onnx bugfix --- src/sparseml/transformers/export.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index eb176a7a26f..2bc63ef6873 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -80,12 +80,14 @@ def _load_task_model(task: str, model_path: str, config: Any) -> Module: return SparseAutoModel.masked_language_modeling_from_pretrained( model_name_or_path=model_path, config=config, + model_type="model", ) if task == "question-answering" or task == "qa": return SparseAutoModel.question_answering_from_pretrained( model_name_or_path=model_path, config=config, + model_type="model", ) if ( @@ -97,12 +99,14 @@ def _load_task_model(task: str, model_path: str, config: Any) -> Module: return SparseAutoModel.text_classification_from_pretrained( model_name_or_path=model_path, config=config, + model_type="model", ) if task == "token-classification" or task == "ner": return SparseAutoModel.token_classification_from_pretrained( model_name_or_path=model_path, config=config, + model_type="model", ) raise ValueError(f"unrecognized task given of {task}") From e2907a4f0af9061b513025bc90b3d845fbce1676 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 11:03:23 -0700 Subject: [PATCH 08/13] bug fixes --- src/sparseml/pytorch/utils/helpers.py | 6 +++-- src/sparseml/pytorch/utils/sparsification.py | 27 ++++++++++++++----- .../transformers/sparsification/trainer.py | 25 ++++++++++++++++- src/sparseml/transformers/utils/model.py | 10 +++---- 4 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index a5da3f8b99b..cba713c2792 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -33,9 +33,11 @@ try: - from torch.nn.qat import QATConv2d, QATConv3d, QATLinear + quant_err = None + from torch.nn.qat import Conv2d as QATConv2d, Conv3d as QATConv3d, Linear as QATLinear from torch.quantization import QuantWrapper -except Exception: +except Exception as _err: + quant_err = _err QuantWrapper = None QATLinear = None QATConv2d = None diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index 0265556631b..5ef9dfd3b6c 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -49,7 +49,22 @@ def __init__(self, module: Module): ) def __str__(self): - return json.dumps(self.params_info) + return json.dumps( + { + "params_summary": { + "total": self.params_total, + "sparse": self.params_sparse, + "sparsity_percent": self.params_sparse_percent, + "prunable": self.params_prunable_total, + "prunable_sparse": self.params_prunable_sparse, + "prunable_sparsity_percent": self.params_prunable_sparse_percent, + "quantizable": self.params_quantizable, + "quantized": self.params_quantized, + "quantized_percent": self.params_quantized_percent, + }, + "params_info": self.params_info, + } + ) @property def params_total(self) -> int: @@ -63,7 +78,7 @@ def params_sparse(self) -> int: """ :return: total number of sparse (0) trainable parameters in the model """ - return sum(tensor_sparsity(param) for param in self.trainable_params) + return sum(round(tensor_sparsity(param).item() * torch.numel(param)) for param in self.trainable_params) @property def params_sparse_percent(self) -> float: @@ -83,12 +98,12 @@ def params_prunable_total(self) -> int: ) @property - def params_pruanble_sparse(self) -> int: + def params_prunable_sparse(self) -> int: """ :return: total number of sparse (0) parameters across prunable lauyers """ return sum( - tensor_sparsity(layer.weight) + round(tensor_sparsity(layer.weight).item() * torch.numel(layer.weight)) for (name, layer) in get_prunable_layers(self.module) ) @@ -97,7 +112,7 @@ def params_prunable_sparse_percent(self) -> float: """ :return: percent of prunable parameters that have been pruned """ - return self.params_pruanble_sparse / float(self.params_prunable_total) * 100 + return self.params_prunable_sparse / float(self.params_prunable_total) * 100 @property def params_quantizable(self) -> int: @@ -106,7 +121,7 @@ def params_quantizable(self) -> int: """ return sum( torch.numel(layer.weight) - + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias else 0) + + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias is not None else 0) for (name, layer) in get_quantizable_layers(self.module) ) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 1b1bad4b74d..c0272823a9a 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -33,7 +33,7 @@ from transformers.trainer_utils import get_last_checkpoint from sparseml.pytorch.optim.manager import ScheduledModifierManager -from sparseml.pytorch.utils import WANDBLogger +from sparseml.pytorch.utils import ModuleSparsificationInfo, WANDBLogger from sparseml.transformers.utils import SparseAutoModel from sparseml.transformers.utils.helpers import RECIPE_REGEX, RECIPE_TEMPLATE @@ -320,6 +320,29 @@ def save_model(self, output_dir: Optional[str] = None): self.manager.save(recipe_path) _LOGGER.info(f"Saved SparseML recipe with model state to {recipe_path}") + def log_model_sparsification(self): + """ + Log the current model sparsification info including pruned and quantized states + """ + sparsification_info = ModuleSparsificationInfo(self.model) + + _LOGGER.info( + f"Sparsification info for {self.model_state_path}: " + f"{sparsification_info.params_total} total params. " + f"Of those there are {sparsification_info.params_prunable_total} prunable " + f"params which have {sparsification_info.params_prunable_sparse_percent} " + "avg sparsity." + ) + model_type = ( + "sparse" + if sparsification_info.params_prunable_sparse_percent > 5 + else "dense" + ) + _LOGGER.info( + f"{model_type} model detected, " + f"all sparsification info: {sparsification_info}" + ) + def _check_super_defined(self, func: str): if not hasattr(super(), func): raise NotImplementedError( diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py index 09c8732f1c2..28aa8429c58 100644 --- a/src/sparseml/transformers/utils/model.py +++ b/src/sparseml/transformers/utils/model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import os from typing import Any, Dict, Optional, Tuple, Union @@ -45,20 +44,19 @@ class SparseAutoModel: def masked_language_modeling_from_pretrained( model_name_or_path: str, model_type: str, - config: Any, **kwargs, ) -> Module: """ :param model_name_or_path: the name of or path to the model to load :param model_type: specify the type of model loaded for logging; ex one of [model, student, teacher] - :param config: the config for the model describing pipeline :param kwargs: keyword arguments to pass through to the AutoModel call :return: the created model for masked language modeling """ delayed = False if not model_name_or_path: _LOGGER.info("Training new model from scratch") + config = kwargs["config"] model = AutoModelForMaskedLM.from_config(config) else: SparseAutoModel._check_tf(model_name_or_path) @@ -98,14 +96,12 @@ def masked_language_modeling_from_pretrained_distil( model = SparseAutoModel.masked_language_modeling_from_pretrained( model_name_or_path, model_type="student" if teacher_name_or_path else "model", - config=model_kwargs["config"], **model_kwargs, ) teacher = ( SparseAutoModel.masked_language_modeling_from_pretrained( teacher_name_or_path, model_type="teacher", - config=None, **teacher_kwargs, ) if teacher_name_or_path and teacher_name_or_path not in ["self", "disable"] @@ -257,7 +253,7 @@ def token_classification_from_pretrained( kwargs["from_tf"] = False delayed = False if "state_dict" not in kwargs: - kwargs["state_dict"] = SparseAutoModel._loadable_state_dict( + kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict( model_name_or_path ) model = AutoModelForTokenClassification.from_pretrained( @@ -338,7 +334,7 @@ def log_model_load( ) _LOGGER.info( f"{model_type} model detected, " - f"prunable params info: {json.dumps(sparsification_info.params_info)}" + f"all sparsification info: {sparsification_info}" ) @staticmethod From d60ce2cc6095d30f058f4401bb16b71fd83e4e6a Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 11:04:36 -0700 Subject: [PATCH 09/13] make style --- src/sparseml/pytorch/utils/helpers.py | 4 +++- src/sparseml/pytorch/utils/sparsification.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index cba713c2792..ef052e4f4ca 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -34,7 +34,9 @@ try: quant_err = None - from torch.nn.qat import Conv2d as QATConv2d, Conv3d as QATConv3d, Linear as QATLinear + from torch.nn.qat import Conv2d as QATConv2d + from torch.nn.qat import Conv3d as QATConv3d + from torch.nn.qat import Linear as QATLinear from torch.quantization import QuantWrapper except Exception as _err: quant_err = _err diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index 5ef9dfd3b6c..42b4b8f9db6 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -78,7 +78,10 @@ def params_sparse(self) -> int: """ :return: total number of sparse (0) trainable parameters in the model """ - return sum(round(tensor_sparsity(param).item() * torch.numel(param)) for param in self.trainable_params) + return sum( + round(tensor_sparsity(param).item() * torch.numel(param)) + for param in self.trainable_params + ) @property def params_sparse_percent(self) -> float: @@ -121,7 +124,11 @@ def params_quantizable(self) -> int: """ return sum( torch.numel(layer.weight) - + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias is not None else 0) + + ( + torch.numel(layer.bias) + if hasattr(layer, "bias") and layer.bias is not None + else 0 + ) for (name, layer) in get_quantizable_layers(self.module) ) From e4dbfad7ae596054dde2076e300bb5fc0449a2fd Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 11:09:42 -0700 Subject: [PATCH 10/13] bug fix for quantization --- src/sparseml/pytorch/utils/sparsification.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index 42b4b8f9db6..d88a548ad12 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -139,7 +139,11 @@ def params_quantized(self) -> int: """ return sum( torch.numel(layer.weight) - + (torch.numel(layer.bias) if hasattr(layer, "bias") and layer.bias else 0) + + ( + torch.numel(layer.bias) + if hasattr(layer, "bias") and layer.bias is not None + else 0 + ) for (name, layer) in get_quantized_layers(self.module) ) From 7fa29a58e5301759ed2c777021bcc7f9fa96e1f0 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 2 Feb 2022 14:07:07 -0500 Subject: [PATCH 11/13] revert to use ScheduledOptimizer due to bug with torch LambdaLR --- .../transformers/sparsification/trainer.py | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index c0272823a9a..ff208eb1fff 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -32,7 +32,7 @@ from transformers.trainer_callback import TrainerState from transformers.trainer_utils import get_last_checkpoint -from sparseml.pytorch.optim.manager import ScheduledModifierManager +from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer from sparseml.pytorch.utils import ModuleSparsificationInfo, WANDBLogger from sparseml.transformers.utils import SparseAutoModel from sparseml.transformers.utils.helpers import RECIPE_REGEX, RECIPE_TEMPLATE @@ -205,20 +205,33 @@ def create_optimizer(self): self.manager_steps_per_epoch = math.ceil( len(self.train_dataset) / total_batch_size ) - wrap_optim_key = "scaler" if hasattr(self, "scaler") else "optimizer" - setattr( - self, - wrap_optim_key, - self.manager.modify( - module=self.model, - optimizer=self.optimizer, + + if hasattr(self, "scaler"): + wrap_optim_key = "scaler" + self.scaler = self.manager.modify( + self.model, + self.optimizer, steps_per_epoch=self.manager_steps_per_epoch, - wrap_optim=getattr(self, wrap_optim_key), allow_parallel_module=False, + wrap_optim=self.scaler, loggers=self.manager_loggers, distillation_teacher=self.teacher, - ), - ) + ) + else: + wrap_optim_key = "optimizer" + self.optimizer = ScheduledOptimizer( + self.optimizer, + self.model, + self.manager, + steps_per_epoch=self.manager_steps_per_epoch, + loggers=self.manager_loggers, + ) + if not self.manager.initialized: + self.manager.initialize( + self.model, + loggers=self.manager_loggers, + distillation_teacher=self.teacher, + ) self.manager_initialized = True _LOGGER.info( f"Modified the {wrap_optim_key} from the recipe for training with " From e1878cc976165cae295e975c1f4cf5b8151ee3c4 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 2 Feb 2022 16:05:27 -0500 Subject: [PATCH 12/13] remove language_modeling script --- setup.py | 1 - .../transformers/language_modeling.py | 680 ------------------ 2 files changed, 681 deletions(-) delete mode 100644 src/sparseml/transformers/language_modeling.py diff --git a/setup.py b/setup.py index e1e64d76ea7..f3d2ec580f4 100644 --- a/setup.py +++ b/setup.py @@ -127,7 +127,6 @@ def _setup_entry_points() -> Dict: "question_answering", "text_classification", "token_classification", - "language_modeling", ]: entry_points["console_scripts"].extend( [ diff --git a/src/sparseml/transformers/language_modeling.py b/src/sparseml/transformers/language_modeling.py deleted file mode 100644 index 12513bbadef..00000000000 --- a/src/sparseml/transformers/language_modeling.py +++ /dev/null @@ -1,680 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2020 The HuggingFace Team All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Adapted from https://github.com/huggingface/transformers -# neuralmagic: no copyright - -""" -Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) -on a text file or a dataset - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=masked-lm -""" - -# You can also adapt this script on your own masked language modeling task. -# Pointers for this are left as comments - -import logging -import math -import os -import sys -from dataclasses import dataclass, field -from typing import Optional - -import transformers -from datasets import concatenate_datasets, load_dataset -from transformers import ( - CONFIG_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, - AutoConfig, - AutoTokenizer, - DataCollatorForLanguageModeling, - HfArgumentParser, - TrainingArguments, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version - -from sparseml.transformers.sparsification import Trainer -from sparseml.transformers.utils import SparseAutoModel - - -# Will error if the minimal version of Transformers is not installed. -# Remove at your own risks -check_min_version("4.7.0.dev0") - -_LOGGER = logging.getLogger(__name__) -MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune, - or train from scratch - """ - - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": "The model checkpoint for weights initialization." - "Don't set if you want to train a model from scratch." - }, - ) - model_type: Optional[str] = field( - default=None, - metadata={ - "help": "If training from scratch, pass a model type from the list: " - + ", ".join(MODEL_TYPES) - }, - ) - distill_teacher: Optional[str] = field( - default=None, - metadata={"help": "Teacher model which needs to be a trained QA model"}, - ) - config_name: Optional[str] = field( - default=None, - metadata={ - "help": "Pretrained config name or path if not the same as model_name" - }, - ) - tokenizer_name: Optional[str] = field( - default=None, - metadata={ - "help": "Pretrained tokenizer name or path if not the same as model_name" - }, - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where to store the pretrained models from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizers. Default True"}, - ) - model_revision: str = field( - default="main", - metadata={ - "help": "The specific model version to use " - "(can be a branch name, tag name or commit id)" - }, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": "Will use token generated when running `transformers-cli login` " - "(necessary to use this script with private models)" - }, - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for - training and eval - """ - - recipe: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Path to a SparseML sparsification recipe, see " - "https://github.com/neuralmagic/sparseml for more information" - ), - }, - ) - recipe_args: Optional[str] = field( - default=None, - metadata={"help": "Recipe arguments to be overwritten"}, - ) - dataset_name: Optional[str] = field( - default=None, - metadata={"help": "The name of the dataset to use (via the datasets library)"}, - ) - dataset_config_name: Optional[str] = field( - default=None, - metadata={ - "help": ("The configuration name of the dataset to use"), - }, - ) - - # An extra second dataset - dataset_name_2: Optional[str] = field( - default=None, - metadata={"help": "The name of the dataset to use (via the datasets library)"}, - ) - dataset_config_name_2: Optional[str] = field( - default=None, - metadata={ - "help": ("The configuration name of the dataset to use"), - }, - ) - - train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a text file)."} - ) - validation_file: Optional[str] = field( - default=None, - metadata={ - "help": ( - "An optional input evaluation data file to evaluate the perplexity on" - "(a text file)." - ), - }, - ) - overwrite_cache: bool = field( - default=False, - metadata={"help": "Overwrite the cached training and evaluation sets"}, - ) - validation_split_percentage: Optional[int] = field( - default=5, - metadata={ - "help": ( - "The percentage of the train set used as validation set in case " - "there's no validation split" - ) - }, - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={ - "help": "The maximum total input sequence length after tokenization. " - "Sequences longer than this will be truncated." - }, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - mlm_probability: float = field( - default=0.15, - metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}, - ) - line_by_line: bool = field( - default=False, - metadata={ - "help": ( - "Whether distinct lines of text in the dataset are to be handled as " - "distinct sequences." - ), - }, - ) - pad_to_max_length: bool = field( - default=False, - metadata={ - "help": "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to " - "the maximum length in the batch." - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number " - "of training examples to this value if set." - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number " - "of evaluation examples to this value if set." - }, - ) - - def __post_init__(self): - if ( - self.dataset_name is None - and self.train_file is None - and self.validation_file is None - ): - raise ValueError( - "Need either a dataset name or a training/validation file." - ) - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in [ - "csv", - "json", - "txt", - ], "`train_file` should be a csv, a json or a txt file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension in [ - "csv", - "json", - "txt", - ], "`validation_file` should be a csv, a json or a txt file." - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - parser = HfArgumentParser( - (ModelArguments, DataTrainingArguments, TrainingArguments) - ) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file( - json_file=os.path.abspath(sys.argv[1]) - ) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # Detecting last checkpoint. - last_checkpoint = None - if ( - os.path.isdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and (len(os.listdir(training_args.output_dir)) > 0): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and " - "is not empty. Use --overwrite_output_dir to overcome." - ) - elif ( - last_checkpoint is not None and training_args.resume_from_checkpoint is None - ): - _LOGGER.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. " - "To avoid this behavior, change the `--output_dir` or add " - "`--overwrite_output_dir` to train from scratch." - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - _LOGGER.setLevel(logging.INFO if training_args.should_log else logging.WARN) - - # Log on each process the small summary: - _LOGGER.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}" - f", n_gpu: {training_args.n_gpu} distributed training: " - f"{bool(training_args.local_rank != -1)}, 16-bits training: " - f"{training_args.fp16}" - ) - # Set the verbosity to info of the Transformers _LOGGER (on main process only): - if training_args.should_log: - transformers.utils.logging.set_verbosity_info() - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - _LOGGER.info(f"Training/evaluation parameters {training_args}") - - # Set seed before initializing model. - set_seed(training_args.seed) - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and - # evaluation files (see below) or just provide the name of one of the public - # datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub - # - # For CSV/JSON files, this script will use the column called 'text' or the - # first column. You can easily tweak this behavior (see below) - # - # In distributed training, the load_dataset function guarantee that only one - # local process can concurrently download the dataset. - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - ) - if "validation" not in datasets.keys(): - datasets["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - ) - datasets["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - ) - else: - data_files = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - datasets = load_dataset( - extension, data_files=data_files, cache_dir=model_args.cache_dir - ) - # See more about loading any type of standard or custom dataset (from files, - # python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load extra dataset if specified, and concatenate with the original one - if data_args.dataset_name_2 is not None: - # Downloading and loading a dataset from the hub. - datasets_2 = load_dataset( - data_args.dataset_name_2, - data_args.dataset_config_name_2, - cache_dir=model_args.cache_dir, - ) - if "validation" not in datasets_2.keys(): - datasets_2["validation"] = load_dataset( - data_args.dataset_name_2, - data_args.dataset_config_name_2, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - ) - datasets_2["train"] = load_dataset( - data_args.dataset_name_2, - data_args.dataset_config_name_2, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - ) - # Concatenate two datasets - if datasets is not None: - for split in ["validation", "train"]: - datasets[split] = concatenate_datasets( - [datasets[split], datasets_2[split]] - ) - - # Load pretrained model and tokenizer - # - # Distributed training: - # The .from_pretrained methods guarantee that only one local process can - # concurrently download model & vocab. - config_kwargs = { - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.config_name: - config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) - elif model_args.model_name_or_path: - config = AutoConfig.from_pretrained( - model_args.model_name_or_path, **config_kwargs - ) - else: - config = CONFIG_MAPPING[model_args.model_type]() - _LOGGER.warning("You are instantiating a new config instance from scratch.") - - tokenizer_kwargs = { - "cache_dir": model_args.cache_dir, - "use_fast": model_args.use_fast_tokenizer, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name, **tokenizer_kwargs - ) - elif model_args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, **tokenizer_kwargs - ) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not " - "supported by this script. You can do it from another script, save " - "it, and load it from here, using --tokenizer_name." - ) - - model, teacher = SparseAutoModel.masked_language_modeling_from_pretrained_distil( - model_name_or_path=model_args.model_name_or_path, - model_kwargs={ - "config": config, - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - }, - teacher_name_or_path=model_args.distill_teacher, - teacher_kwargs={ - "cache_dir": model_args.cache_dir, - "use_auth_token": True if model_args.use_auth_token else None, - }, - ) - - model.resize_token_embeddings(len(tokenizer)) - - # Preprocessing the datasets. - # First we tokenize all the texts. - if training_args.do_train: - column_names = datasets["train"].column_names - else: - column_names = datasets["validation"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - if data_args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - if max_seq_length > 1024: - _LOGGER.warning( - "The tokenizer picked seems to have a very large `model_max_length`" - f"({tokenizer.model_max_length}). Picking 1024 instead. You can " - "change that default value by passing --max_seq_length xxx." - ) - max_seq_length = 1024 - else: - if data_args.max_seq_length > tokenizer.model_max_length: - _LOGGER.warning( - f"The max_seq_length passed ({data_args.max_seq_length}) " - "is larger than the maximum length for the model " - f"({tokenizer.model_max_length}). Using " - f"max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - - if data_args.line_by_line: - # When using line_by_line, we just tokenize each nonempty line. - padding = "max_length" if data_args.pad_to_max_length else False - - def tokenize_function(examples): - # Remove empty lines - examples["text"] = [ - line - for line in examples["text"] - if len(line) > 0 and not line.isspace() - ] - return tokenizer( - examples["text"], - padding=padding, - truncation=True, - max_length=max_seq_length, - # We use this option because DataCollatorForLanguageModeling (see - # below) is more efficient when it receives the `special_tokens_mask`. - return_special_tokens_mask=True, - ) - - tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=[text_column_name], - load_from_cache_file=not data_args.overwrite_cache, - ) - else: - # Otherwise, we tokenize every text, then concatenate them together before - # splitting them in smaller parts. We use `return_special_tokens_mask=True` - # because DataCollatorForLanguageModeling (see below) is more - # efficient when it receives the `special_tokens_mask`. - def tokenize_function(examples): - return tokenizer( - examples[text_column_name], return_special_tokens_mask=True - ) - - tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - ) - - # Main data processing function that will concatenate all texts from our - # dataset and generate chunks of max_seq_length. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported - # it instead of this drop, you can customize this part to your needs - total_length = (total_length // max_seq_length) * max_seq_length - # Split by chunks of max_len. - result = { - k: [ - t[i : i + max_seq_length] - for i in range(0, total_length, max_seq_length) - ] - for k, t in concatenated_examples.items() - } - return result - - # Note that with `batched=True`, this map processes 1,000 texts together, - # so group_texts throws away a remainder for each of those groups of 1,000 - # texts. You can adjust that batch_size here but a higher value - # might be slower to preprocess. - # - # To speed up this part, we use multiprocessing. See the documentation of - # the map method for more information: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - tokenized_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - ) - - if training_args.do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = tokenized_datasets["train"] - if data_args.max_train_samples is not None: - train_dataset = train_dataset.select(range(data_args.max_train_samples)) - - if training_args.do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = tokenized_datasets["validation"] - if data_args.max_eval_samples is not None: - eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) - - # Data collator - # This one will take care of randomly masking the tokens. - pad_to_multiple_of_8 = ( - data_args.line_by_line - and training_args.fp16 - and not data_args.pad_to_max_length - ) - data_collator = DataCollatorForLanguageModeling( - tokenizer=tokenizer, - mlm_probability=data_args.mlm_probability, - pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, - ) - - compute_metrics = None - - # Initialize our Trainer - trainer = Trainer( - model=model, - model_state_path=model_args.model_name_or_path, - recipe=data_args.recipe, - recipe_args=data_args.recipe_args, - teacher=teacher, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - tokenizer=tokenizer, - data_collator=data_collator, - compute_metrics=compute_metrics, - ) - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - trainer.save_model() # Saves the tokenizer too for easy upload - metrics = train_result.metrics - - max_train_samples = ( - data_args.max_train_samples - if data_args.max_train_samples is not None - else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - if training_args.do_eval: - _LOGGER.info("*** Evaluate ***") - - metrics = trainer.evaluate() - - max_eval_samples = ( - data_args.max_eval_samples - if data_args.max_eval_samples is not None - else len(eval_dataset) - ) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["perplexity"] = perplexity - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "fill-mask"} - if data_args.dataset_name is not None: - kwargs["dataset_tags"] = data_args.dataset_name - if data_args.dataset_config_name is not None: - kwargs["dataset_args"] = data_args.dataset_config_name - kwargs[ - "dataset" - ] = f"{data_args.dataset_name} {data_args.dataset_config_name}" - else: - kwargs["dataset"] = data_args.dataset_name - - trainer.push_to_hub(**kwargs) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() From 664c3cdd795d109b7b360dbc6ffe6741b09ad97c Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Feb 2022 14:21:48 -0700 Subject: [PATCH 13/13] add end model sparsification log --- src/sparseml/transformers/sparsification/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index ff208eb1fff..d912c9808c5 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -516,6 +516,7 @@ def train(self, *args, **kwargs): output = super().train(*args, **kwargs) if applied: self.finalize_manager() + self.log_model_sparsification() return output