diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 940e2ddc21ff..c356ffec7dc9 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -28,7 +28,12 @@ from datasets import load_dataset, load_metric import transformers -from sparseml_utils import SparseMLQATrainer, export_model +from sparseml_utils import ( + SparseMLQATrainer, + export_model, + preprocess_state_dict, + load_recipe +) from transformers import ( AutoConfig, AutoModelForQuestionAnswering, @@ -311,6 +316,12 @@ def main(): revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) + + # Load and preprocess the state dict if the model existed (in this case we continue to train or + # evaluate the model). The preprocessing step is to restore names of parameters changed by + # QAT process. + state_dict = preprocess_state_dict(model_args.model_name_or_path) + model = AutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -318,6 +329,7 @@ def main(): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + state_dict=state_dict ) teacher_model = None @@ -573,9 +585,14 @@ def post_processing_function(examples, features, predictions, stage="eval"): def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) + # Load possible existing recipe and new one passed in through command argument + existing_recipe = load_recipe(model_args.model_name_or_path) + new_recipe = data_args.recipe + # Initialize our Trainer trainer = SparseMLQATrainer( - data_args.recipe, + model_args.model_name_or_path, + [existing_recipe, new_recipe], teacher=teacher_model, distill_hardness=model_args.distill_hardness, distill_temperature=model_args.distill_temperature, @@ -590,6 +607,11 @@ def compute_metrics(p: EvalPrediction): compute_metrics=compute_metrics, ) + # Apply recipes to the model. This is necessary given that + # sparsification methods such as QAT modified the model graph with their own learnable + # parameters. They are also restored/loaded to the model. + trainer.apply_recipes() + # Training if training_args.do_train: checkpoint = None diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index cc505174b263..a2fe04010004 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -1,8 +1,8 @@ -import inspect import collections +import inspect import math import os -from typing import Any +from typing import Any, Optional import numpy import torch @@ -13,6 +13,7 @@ from sparseml.pytorch.optim.optimizer import ScheduledOptimizer from sparseml.pytorch.utils import ModuleExporter, logger from trainer_qa import QuestionAnsweringTrainer +from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME from transformers.modeling_outputs import QuestionAnsweringModelOutput from transformers.models.bert.modeling_bert import BertForQuestionAnswering @@ -28,36 +29,74 @@ class SparseMLQATrainer(QuestionAnsweringTrainer): :param args, kwargs: arguments passed into parent class """ - def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs): + def __init__( + self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs + ): super().__init__(*args, **kwargs) - self.recipe = recipe + self.model_name_or_path = str(model_name_or_path) + self.recipes = [recipe for recipe in recipes if recipe] self.teacher = teacher self.distill_hardness = distill_hardness self.distill_temperature = distill_temperature self.criterion = torch.nn.CrossEntropyLoss() - self.manager = None + manager = None + modifiers = [] + for recipe in self.recipes: + manager = ScheduledModifierManager.from_yaml(recipe, modifiers) + modifiers = manager.modifiers + self.manager = manager + self.loggers = None - if self.recipe is not None: + if self.recipes is not None: loggers = [] if "wandb" in self.args.report_to: loggers.append(logger.WANDBLogger()) self.loggers = loggers + def apply_recipes(self, epoch=0.0): + """ + Apply recipes and sparsification related parameters to the model + """ + if self.manager is not None: + org_state_dict = self.model.state_dict() + self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) + new_state_dict = self.model.state_dict() + new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] + + if os.path.isdir(self.model_name_or_path): + if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): + archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(archive_file, map_location="cpu") + new_params_to_init = [p for p in new_params if p in state_dict.keys()] + if new_params_to_init: + # If we're here, the assumption is that all the new parameters introduced + # by the recipes are available to be restore from the checkpoint---this is + # case of evaluating pruned or pruned quantized models + # Otherwise, we're in use cases such as quantizing a block pruned model in which + # new parameters need to be initialized and trained during the QAT process + _, missing_keys, unexpected_keys, _ = BertForQuestionAnswering._load_state_dict_into_model( + self.model, state_dict, self.model_name_or_path, _fast_init=False + ) + if missing_keys or unexpected_keys: + raise RuntimeError( + "Unexpected or missing keys detected when applying recipes to models\n" + f"Missing keys: {missing_keys}\n" + f"Unexpected keys: {unexpected_keys}\n" + ) + def create_optimizer(self): """ Create optimizer customized using SparseML """ super().create_optimizer() - if self.recipe is None: + if not self.recipes: return steps_per_epoch = math.ceil( len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) ) - self.manager = ScheduledModifierManager.from_yaml(self.recipe) self.args.num_train_epochs = float(self.manager.max_epochs) if hasattr(self, "scaler"): - self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers) self.scaler = self.manager.modify( self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler ) @@ -70,7 +109,7 @@ def compute_loss(self, model, inputs, return_outputs=False): """ Computing loss using teacher/student distillation """ - if self.recipe is None or self.teacher is None: + if not self.recipes or self.teacher is None: return super().compute_loss(model, inputs, return_outputs=return_outputs) outputs = model(**inputs) @@ -114,11 +153,25 @@ def compute_loss(self, model, inputs, return_outputs=False): loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) return (loss, outputs) if return_outputs else loss + def save_model(self, output_dir: Optional[str] = None): + """ + Save model during or after training. The sparsification recipe will also be saved. + """ + super().save_model(output_dir=output_dir) + if self.manager is not None: + self._save_recipe(output_dir=output_dir) + + def _save_recipe(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + output_recipe_file = os.path.join(output_dir, RECIPE_NAME) + self.manager.save(output_recipe_file) + class QuestionAnsweringModuleExporter(ModuleExporter): """ Module exporter class for Question Answering """ + @classmethod def get_output_names(self, out: Any): if not isinstance(out, QuestionAnsweringModelOutput): @@ -173,3 +226,44 @@ def export_model(model, dataloader, output_dir, num_exported_samples): num_samples += 1 if num_samples >= num_exported_samples: return + + +def preprocess_state_dict(pretrained_model_name_or_path): + """ + Restore original parameter names that were changed by QAT process + """ + state_dict = None + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): + recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) + manager = ScheduledModifierManager.from_yaml(recipe) + modifiers = [m.__class__.__name__ for m in manager.modifiers] + is_qat_recipe = "QuantizationModifier" in modifiers + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(archive_file, map_location="cpu") + removed_keys = ( + [key for key in state_dict if (key.endswith(".module.weight") or key.endswith(".module.bias"))] + if is_qat_recipe + else [] + ) + for key in removed_keys: + new_key = key.replace(".module", "") + state_dict[new_key] = state_dict[key] + state_dict.pop(key) + return state_dict + + +def load_recipe(pretrained_model_name_or_path): + """ + Load recipe from the model directory + """ + recipe = None + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): + recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) + return recipe diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index dc1af32f3b36..75e0f819694c 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -220,6 +220,7 @@ CONFIG_NAME = "config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" MODEL_CARD_NAME = "modelcard.json" +RECIPE_NAME = "recipe.yaml" SENTENCEPIECE_UNDERLINE = "▁" SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility