-
Notifications
You must be signed in to change notification settings - Fork 3
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,8 @@ | ||
| import inspect | ||
| import collections | ||
| import inspect | ||
| import math | ||
| import os | ||
| from typing import Any | ||
| from typing import Any, Optional | ||
|
|
||
| import numpy | ||
| import torch | ||
|
|
@@ -13,6 +13,7 @@ | |
| from sparseml.pytorch.optim.optimizer import ScheduledOptimizer | ||
| from sparseml.pytorch.utils import ModuleExporter, logger | ||
| from trainer_qa import QuestionAnsweringTrainer | ||
| from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME | ||
| from transformers.modeling_outputs import QuestionAnsweringModelOutput | ||
| from transformers.models.bert.modeling_bert import BertForQuestionAnswering | ||
|
|
||
|
|
@@ -28,36 +29,74 @@ class SparseMLQATrainer(QuestionAnsweringTrainer): | |
| :param args, kwargs: arguments passed into parent class | ||
| """ | ||
|
|
||
| def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs): | ||
| def __init__( | ||
| self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs | ||
| ): | ||
| super().__init__(*args, **kwargs) | ||
| self.recipe = recipe | ||
| self.model_name_or_path = str(model_name_or_path) | ||
| self.recipes = [recipe for recipe in recipes if recipe] | ||
| self.teacher = teacher | ||
| self.distill_hardness = distill_hardness | ||
| self.distill_temperature = distill_temperature | ||
| self.criterion = torch.nn.CrossEntropyLoss() | ||
|
|
||
| self.manager = None | ||
| manager = None | ||
| modifiers = [] | ||
| for recipe in self.recipes: | ||
| manager = ScheduledModifierManager.from_yaml(recipe, modifiers) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a good reason you need to pass in an empty list for modifiers just to then set it to manage.modifiers?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The manager.modifiers hold a list of modifiers accumulated from recipes that were "consumed" so far. We need to add the list back when constructing the manager with the next recipe. |
||
| modifiers = manager.modifiers | ||
|
||
| self.manager = manager | ||
|
|
||
| self.loggers = None | ||
| if self.recipe is not None: | ||
| if self.recipes is not None: | ||
| loggers = [] | ||
| if "wandb" in self.args.report_to: | ||
| loggers.append(logger.WANDBLogger()) | ||
| self.loggers = loggers | ||
|
|
||
| def apply_recipes(self, epoch=0.0): | ||
| """ | ||
| Apply recipes and sparsification related parameters to the model | ||
| """ | ||
| if self.manager is not None: | ||
| org_state_dict = self.model.state_dict() | ||
| self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) | ||
| new_state_dict = self.model.state_dict() | ||
| new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] | ||
|
|
||
| if os.path.isdir(self.model_name_or_path): | ||
| if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): | ||
| archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) | ||
| state_dict = torch.load(archive_file, map_location="cpu") | ||
| new_params_to_init = [p for p in new_params if p in state_dict.keys()] | ||
| if new_params_to_init: | ||
| # If we're here, the assumption is that all the new parameters introduced | ||
| # by the recipes are available to be restore from the checkpoint---this is | ||
| # case of evaluating pruned or pruned quantized models | ||
| # Otherwise, we're in use cases such as quantizing a block pruned model in which | ||
| # new parameters need to be initialized and trained during the QAT process | ||
| _, missing_keys, unexpected_keys, _ = BertForQuestionAnswering._load_state_dict_into_model( | ||
| self.model, state_dict, self.model_name_or_path, _fast_init=False | ||
| ) | ||
| if missing_keys or unexpected_keys: | ||
| raise RuntimeError( | ||
| "Unexpected or missing keys detected when applying recipes to models\n" | ||
| f"Missing keys: {missing_keys}\n" | ||
| f"Unexpected keys: {unexpected_keys}\n" | ||
| ) | ||
|
|
||
| def create_optimizer(self): | ||
| """ | ||
| Create optimizer customized using SparseML | ||
| """ | ||
| super().create_optimizer() | ||
| if self.recipe is None: | ||
| if not self.recipes: | ||
| return | ||
| steps_per_epoch = math.ceil( | ||
| len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) | ||
| ) | ||
| self.manager = ScheduledModifierManager.from_yaml(self.recipe) | ||
| self.args.num_train_epochs = float(self.manager.max_epochs) | ||
| if hasattr(self, "scaler"): | ||
| self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers) | ||
| self.scaler = self.manager.modify( | ||
| self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler | ||
| ) | ||
|
|
@@ -70,7 +109,7 @@ def compute_loss(self, model, inputs, return_outputs=False): | |
| """ | ||
| Computing loss using teacher/student distillation | ||
| """ | ||
| if self.recipe is None or self.teacher is None: | ||
| if not self.recipes or self.teacher is None: | ||
| return super().compute_loss(model, inputs, return_outputs=return_outputs) | ||
|
|
||
| outputs = model(**inputs) | ||
|
|
@@ -114,11 +153,25 @@ def compute_loss(self, model, inputs, return_outputs=False): | |
| loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) | ||
| return (loss, outputs) if return_outputs else loss | ||
|
|
||
| def save_model(self, output_dir: Optional[str] = None): | ||
| """ | ||
| Save model during or after training. The sparsification recipe will also be saved. | ||
| """ | ||
| super().save_model(output_dir=output_dir) | ||
| if self.manager is not None: | ||
| self._save_recipe(output_dir=output_dir) | ||
|
|
||
| def _save_recipe(self, output_dir: Optional[str] = None): | ||
| output_dir = output_dir if output_dir is not None else self.args.output_dir | ||
| output_recipe_file = os.path.join(output_dir, RECIPE_NAME) | ||
| self.manager.save(output_recipe_file) | ||
|
|
||
|
|
||
| class QuestionAnsweringModuleExporter(ModuleExporter): | ||
| """ | ||
| Module exporter class for Question Answering | ||
| """ | ||
|
|
||
| @classmethod | ||
| def get_output_names(self, out: Any): | ||
| if not isinstance(out, QuestionAnsweringModelOutput): | ||
|
|
@@ -173,3 +226,44 @@ def export_model(model, dataloader, output_dir, num_exported_samples): | |
| num_samples += 1 | ||
| if num_samples >= num_exported_samples: | ||
| return | ||
|
|
||
|
|
||
| def preprocess_state_dict(pretrained_model_name_or_path): | ||
| """ | ||
| Restore original parameter names that were changed by QAT process | ||
| """ | ||
| state_dict = None | ||
| if pretrained_model_name_or_path is not None: | ||
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||
| if os.path.isdir(pretrained_model_name_or_path): | ||
| if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): | ||
| recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) | ||
| manager = ScheduledModifierManager.from_yaml(recipe) | ||
| modifiers = [m.__class__.__name__ for m in manager.modifiers] | ||
| is_qat_recipe = "QuantizationModifier" in modifiers | ||
| if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): | ||
| archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | ||
| state_dict = torch.load(archive_file, map_location="cpu") | ||
| removed_keys = ( | ||
| [key for key in state_dict if (key.endswith(".module.weight") or key.endswith(".module.bias"))] | ||
| if is_qat_recipe | ||
| else [] | ||
| ) | ||
| for key in removed_keys: | ||
| new_key = key.replace(".module", "") | ||
| state_dict[new_key] = state_dict[key] | ||
| state_dict.pop(key) | ||
| return state_dict | ||
|
|
||
|
|
||
| def load_recipe(pretrained_model_name_or_path): | ||
| """ | ||
| Load recipe from the model directory | ||
| """ | ||
| recipe = None | ||
| if pretrained_model_name_or_path is not None: | ||
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||
| if os.path.isdir(pretrained_model_name_or_path): | ||
| if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): | ||
| recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) | ||
| return recipe | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there possible overlap between the recipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
possible, but it won't break with the current use cases (eval and train dense, pruned and quantized BERT). IMO, This should better be addressed at the manager level code.