From 712b59f4cb0147754c4c8cfaa9d85833f3072c47 Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Thu, 29 Jul 2021 14:57:59 -0500 Subject: [PATCH 01/18] Setting up Text Classification for SparseML --- examples/pytorch/question-answering/run_qa.py | 29 +- .../pytorch/text-classification/run_glue.py | 72 ++++- .../text-classification/sparseml_utils.py | 247 ++++++++++++++++++ 3 files changed, 341 insertions(+), 7 deletions(-) create mode 100644 examples/pytorch/text-classification/sparseml_utils.py diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index c356ffec7dc9..a1494651ffbc 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -72,7 +72,7 @@ class ModelArguments: default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} ) distill_hardness: Optional[float] = field( - default=1.0, metadata={"help": "Proportion of loss coming from teacher model."} + default=0.5, metadata={"help": "Proportion of loss coming from teacher model."} ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} @@ -218,6 +218,33 @@ def __post_init__(self): extension = self.test_file.split(".")[-1] assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." +def convert_example_to_features(example, tokenizer, max_seq_length, sentence1_key, sentence2_key): + tokens = [] + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for t in tokenizer.tokenize(example[sentence1_key])[:int(max_seq_length/2)]: + tokens.append(t) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) + if sentence1_key != None: + for t in tokenizer.tokenize(example[sentence2_key])[:int(max_seq_length/2)]: + tokens.append(t) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + return ( + torch.from_numpy(np.array([np.array(input_ids, dtype=np.int64)])), + torch.from_numpy(np.array([np.array(input_mask, dtype=np.int64)])), + torch.from_numpy(np.array([np.array(segment_ids, dtype=np.int64)])), + ) def main(): # See all possible arguments in src/transformers/training_args.py diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 1b08def9c62f..92b8a05138fa 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -40,13 +40,16 @@ default_data_collator, set_seed, ) +from sparseml_utils import ( + SparseMLGLUETrainer, + export_model, + preprocess_state_dict, + load_recipe +) from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.7.0.dev0") - task_to_keys = { "cola": ("sentence", None), "mnli": ("premise", "hypothesis"), @@ -71,7 +74,17 @@ class DataTrainingArguments: into argparse arguments to be able to specify them on the command line. """ - + recipe: Optional[str] = field( + default=None, + metadata={"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml " + "for more information"}, + ) + onnx_export_path: Optional[str] = field( + default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) + num_exported_samples: Optional[int] = field( + default=20, metadata={"help": "Number of exported samples, default to 20"} + ) task_name: Optional[str] = field( default=None, metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, @@ -155,6 +168,15 @@ class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) + distill_teacher: Optional[str] = field( + default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} + ) + distill_temperature: Optional[float] = field( + default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} + ) + distill_hardness: Optional[float] = field( + default=1.0, metadata={"help": "Proportion of loss coming from teacher model."} + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -305,6 +327,13 @@ def main(): # # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. + + # Load and preprocess the state dict if the model existed (in this case we continue to train or + # evaluate the model). The preprocessing step is to restore names of parameters changed by + # QAT process + state_dict = preprocess_state_dict(model_args.model_name_or_path) + + config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, @@ -327,8 +356,19 @@ def main(): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + state_dict=state_dict, ) + teacher_model = None + if model_args.distill_teacher is not None: + teacher_model = AutoModelForSequenceClassification.from_pretrained( + model_args.distill_teacher, + from_tf=bool(".ckpt" in model_args.distill_teacher), + cache_dir=model_args.cache_dir, + ) + teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) + params = sum([numpy.prod(p.size()) for p in teacher_model_parameters]) + logger.info("Teacher Model has %s parameters", params) # Preprocessing the datasets if data_args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[data_args.task_name] @@ -445,17 +485,32 @@ def compute_metrics(p: EvalPrediction): else: data_collator = None + # Load possible existing recipe and new one passed in through command argument + existing_recipe = load_recipe(model_args.model_name_or_path) + new_recipe = data_args.recipe + # Initialize our Trainer - trainer = Trainer( + trainer = SparseMLGLUETrainer( + model_args.model_name_or_path, + [existing_recipe, new_recipe], + teacher=teacher_model, + distill_hardness=model_args.distill_hardness, + distill_temperature=model_args.distill_temperature, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, - compute_metrics=compute_metrics, tokenizer=tokenizer, data_collator=data_collator, + post_process_function=post_processing_function, + compute_metrics=compute_metrics, ) + # Apply recipes to the model. This is necessary given that + # sparsification methods such as QAT modified the model graph with their own learnable + # parameters. They are also restored/loaded to the model. + trainer.apply_recipes() + # Training if training_args.do_train: checkpoint = None @@ -536,6 +591,11 @@ def compute_metrics(p: EvalPrediction): trainer.push_to_hub(**kwargs) + if data_args.onnx_export_path: + logger.info("*** Export to ONNX ***") + eval_dataloader = trainer.get_eval_dataloader(eval_dataset) + export_model(model, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py new file mode 100644 index 000000000000..5ce56f4c1dfd --- /dev/null +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -0,0 +1,247 @@ +import collections +import inspect +import math +import os +from typing import Any, Optional + +import numpy +import torch +import torch.nn.functional as F + +import onnxruntime +from sparseml.pytorch.optim.manager import ScheduledModifierManager +from sparseml.pytorch.optim.optimizer import ScheduledOptimizer +from sparseml.pytorch.utils import ModuleExporter, logger +from transformers import Trainer +from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME +from transformers.modeling_outputs import SequenceClassifierOutput +from transformers.models.bert.modeling_bert import BertForSequenceClassification + + +class SparseMLGLUErainer(Trainer): + """ + GLUE trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def __init__( + self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.model_name_or_path = str(model_name_or_path) + self.recipes = [recipe for recipe in recipes if recipe] + self.teacher = teacher + self.distill_hardness = distill_hardness + self.distill_temperature = distill_temperature + self.criterion = torch.nn.CrossEntropyLoss() + + manager = None + modifiers = [] + for recipe in self.recipes: + manager = ScheduledModifierManager.from_yaml(recipe, modifiers) + modifiers = manager.modifiers + self.manager = manager + + self.loggers = None + if self.recipes is not None: + loggers = [] + if "wandb" in self.args.report_to: + loggers.append(logger.WANDBLogger()) + self.loggers = loggers + + def apply_recipes(self, epoch=0.0): + """ + Apply recipes and sparsification related parameters to the model + """ + if self.manager is not None: + org_state_dict = self.model.state_dict() + self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) + new_state_dict = self.model.state_dict() + new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] + + if os.path.isdir(self.model_name_or_path): + if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): + archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(archive_file, map_location="cpu") + new_params_to_init = [p for p in new_params if p in state_dict.keys()] + if new_params_to_init: + # If we're here, the assumption is that all the new parameters introduced + # by the recipes are available to be restore from the checkpoint---this is + # case of evaluating pruned or pruned quantized models + # Otherwise, we're in use cases such as quantizing a block pruned model in which + # new parameters need to be initialized and trained during the QAT process + _, missing_keys, unexpected_keys, _ = BertForSequenceClassification._load_state_dict_into_model( + self.model, state_dict, self.model_name_or_path, _fast_init=False + ) + if missing_keys or unexpected_keys: + raise RuntimeError( + "Unexpected or missing keys detected when applying recipes to models\n" + f"Missing keys: {missing_keys}\n" + f"Unexpected keys: {unexpected_keys}\n" + ) + + def create_optimizer(self): + """ + Create optimizer customized using SparseML + """ + super().create_optimizer() + if not self.recipes: + return + steps_per_epoch = math.ceil( + len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) + ) + self.args.num_train_epochs = float(self.manager.max_epochs) + if hasattr(self, "scaler"): + self.scaler = self.manager.modify( + self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler + ) + else: + self.optimizer = ScheduledOptimizer( + self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers + ) + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + outputs = model(**inputs) + if self.teacher is None: + loss = outputs["loss"] + else: + input_device = inputs["input_ids"].device + label_loss = outputs["loss"] + self.teacher = self.teacher.to(input_device) + logits_student = outputs["logits"] + with torch.no_grad(): + teacher_output = self.teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + logits_teacher = teacher_outputs["logits"] + loss_distill = F.kl_div( input=logits_student, target=logits_teacher, reduction="batchmean",) * (self.temperature ** 2) + loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) + return (loss, outputs) if return_outputs else loss + + def save_model(self, output_dir: Optional[str] = None): + """ + Save model during or after training. The sparsification recipe will also be saved. + """ + super().save_model(output_dir=output_dir) + if self.manager is not None: + self._save_recipe(output_dir=output_dir) + + def _save_recipe(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + output_recipe_file = os.path.join(output_dir, RECIPE_NAME) + self.manager.save(output_recipe_file) + + +class GLUEModuleExporter(ModuleExporter): + """ + Module exporter class for Sequence Classification + """ + + @classmethod + def get_output_names(self, out: Any): + if not isinstance(out, SequenceClassifierOutput): + raise ValueError("Expected SequenceClassifierOutput, got {type(out)}") + expected = ["logits"] + if numpy.any([name for name in expected if name not in out]): + raise ValueError("Expected output names not found in model output") + return expected + + +def export_model(model, dataloader, output_dir, num_exported_samples): + """ + Export a trained model to ONNX + :param model: trained model + :param dataloader: dataloader to get sample batch + :param output_dir: output directory for ONNX model + """ + exporter = GLUEModuleExporter(model, output_dir=output_dir) + + sess = None + num_samples = 0 + + sample_inputs = os.path.join(output_dir, "sample-inputs") + sample_outputs = os.path.join(output_dir, "sample-outputs") + os.makedirs(sample_inputs, exist_ok=True) + os.makedirs(sample_outputs, exist_ok=True) + + forward_args_spec = inspect.getfullargspec(BertForSequenceClassification.forward) + for _, sample_batch in enumerate(dataloader): + if sess is None: + one_sample_input = collections.OrderedDict( + [(f, sample_batch[f][0].reshape(1, -1)) for f in forward_args_spec.args if f in sample_batch] + ) + + try: + exporter.export_onnx(sample_batch=one_sample_input, convert_qat=True) + onnx_file = os.path.join(output_dir, "model.onnx") + except Exception: + raise RuntimeError("Error exporting ONNX models and/or inputs/outputs") + + sess = onnxruntime.InferenceSession(onnx_file) + + input_names = list(sample_batch.keys()) + output_names = [o.name for o in sess.get_outputs()] + for input_vals in zip(*sample_batch.values()): + input_feed = {k: v.numpy() for k, v in zip(input_names, input_vals)} + output_vals = sess.run(output_names, {k: input_feed[k].reshape(1, -1) for k in input_feed}) + output_dict = {name: numpy.squeeze(val) for name, val in zip(output_names, output_vals)} + file_idx = f"{num_samples}".zfill(4) + numpy.savez(f"{sample_inputs}/inp-{file_idx}.npz", **input_feed) + numpy.savez(f"{sample_outputs}/out-{file_idx}.npz", **output_dict) + num_samples += 1 + if num_samples >= num_exported_samples: + return + + +def preprocess_state_dict(pretrained_model_name_or_path): + """ + Restore original parameter names that were changed by QAT process + """ + state_dict = None + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): + recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) + manager = ScheduledModifierManager.from_yaml(recipe) + modifiers = [m.__class__.__name__ for m in manager.modifiers] + is_qat_recipe = "QuantizationModifier" in modifiers + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(archive_file, map_location="cpu") + removed_keys = ( + [key for key in state_dict if (key.endswith(".module.weight") or key.endswith(".module.bias"))] + if is_qat_recipe + else [] + ) + for key in removed_keys: + new_key = key.replace(".module", "") + state_dict[new_key] = state_dict[key] + state_dict.pop(key) + return state_dict + + +def load_recipe(pretrained_model_name_or_path): + """ + Load recipe from the model directory + """ + recipe = None + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): + recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) + return recipe From 64579a359896af8cfa3cc44c7517c70101e25cdd Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Thu, 29 Jul 2021 15:30:37 -0500 Subject: [PATCH 02/18] removing unneeded function --- examples/pytorch/question-answering/run_qa.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index a1494651ffbc..f0633f3778df 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -218,34 +218,6 @@ def __post_init__(self): extension = self.test_file.split(".")[-1] assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." -def convert_example_to_features(example, tokenizer, max_seq_length, sentence1_key, sentence2_key): - tokens = [] - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for t in tokenizer.tokenize(example[sentence1_key])[:int(max_seq_length/2)]: - tokens.append(t) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) - if sentence1_key != None: - for t in tokenizer.tokenize(example[sentence2_key])[:int(max_seq_length/2)]: - tokens.append(t) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(1) - input_ids = tokenizer.convert_tokens_to_ids(tokens) - input_mask = [1] * len(input_ids) - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) - return ( - torch.from_numpy(np.array([np.array(input_ids, dtype=np.int64)])), - torch.from_numpy(np.array([np.array(input_mask, dtype=np.int64)])), - torch.from_numpy(np.array([np.array(segment_ids, dtype=np.int64)])), - ) - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. From 3f1dd25ff52d92915fa485b517d3d325f73fc570 Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Mon, 2 Aug 2021 13:07:11 -0500 Subject: [PATCH 03/18] Update sparseml_utils.py --- examples/pytorch/text-classification/sparseml_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py index 5ce56f4c1dfd..2e81e07e463e 100644 --- a/examples/pytorch/text-classification/sparseml_utils.py +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -18,7 +18,7 @@ from transformers.models.bert.modeling_bert import BertForSequenceClassification -class SparseMLGLUErainer(Trainer): +class SparseMLGLUETrainer(Trainer): """ GLUE trainer with SparseML integration From 0fbb74fbda5495e0921701339fa1953533c4d1bc Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Mon, 2 Aug 2021 13:09:27 -0500 Subject: [PATCH 04/18] Update run_glue.py --- examples/pytorch/text-classification/run_glue.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 92b8a05138fa..503942bdf0ed 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -502,7 +502,6 @@ def compute_metrics(p: EvalPrediction): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, - post_process_function=post_processing_function, compute_metrics=compute_metrics, ) From 0987ce2f15943f845a6df51f749da48078139360 Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Mon, 2 Aug 2021 15:19:06 -0500 Subject: [PATCH 05/18] Update run_glue.py adding wand and changing numpy --- examples/pytorch/text-classification/run_glue.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 503942bdf0ed..1c74bcb8e8ac 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -23,6 +23,7 @@ from dataclasses import dataclass, field from typing import Optional +import wandb import numpy as np from datasets import load_dataset, load_metric @@ -367,7 +368,7 @@ def main(): cache_dir=model_args.cache_dir, ) teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) - params = sum([numpy.prod(p.size()) for p in teacher_model_parameters]) + params = sum([np.prod(p.size()) for p in teacher_model_parameters]) logger.info("Teacher Model has %s parameters", params) # Preprocessing the datasets if data_args.task_name is not None: From 6af8e1dd4c5c16ba72bb5a7e6ead03582d0012a2 Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Mon, 2 Aug 2021 17:24:10 -0500 Subject: [PATCH 06/18] Update sparseml_utils.py --- examples/pytorch/text-classification/sparseml_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py index 2e81e07e463e..e35c7d6c1eb4 100644 --- a/examples/pytorch/text-classification/sparseml_utils.py +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -121,7 +121,7 @@ def compute_loss(self, model, inputs, return_outputs=False): self.teacher = self.teacher.to(input_device) logits_student = outputs["logits"] with torch.no_grad(): - teacher_output = self.teacher( + teacher_outputs = self.teacher( input_ids=inputs["input_ids"], token_type_ids=inputs["token_type_ids"], attention_mask=inputs["attention_mask"], From d841e40f264f4b88d80c67be56bac90beb6d3cc8 Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Mon, 2 Aug 2021 17:25:26 -0500 Subject: [PATCH 07/18] Update run_glue.py fixed mention of QA --- examples/pytorch/text-classification/run_glue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 1c74bcb8e8ac..adecfb8617f1 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -170,7 +170,7 @@ class ModelArguments: metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) distill_teacher: Optional[str] = field( - default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} + default=None, metadata={"help": "Teacher model which needs to be a trained text classification model"} ) distill_temperature: Optional[float] = field( default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} From d90900948ee45ef8290c59c3c341ddb2945246bb Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Fri, 6 Aug 2021 14:52:14 -0500 Subject: [PATCH 08/18] Update sparseml_utils.py updating to match refactored code --- .../text-classification/sparseml_utils.py | 196 +----------------- 1 file changed, 6 insertions(+), 190 deletions(-) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py index e35c7d6c1eb4..17b128a96527 100644 --- a/examples/pytorch/text-classification/sparseml_utils.py +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -1,24 +1,16 @@ -import collections -import inspect -import math -import os -from typing import Any, Optional +from typing import Any import numpy import torch import torch.nn.functional as F -import onnxruntime -from sparseml.pytorch.optim.manager import ScheduledModifierManager -from sparseml.pytorch.optim.optimizer import ScheduledOptimizer -from sparseml.pytorch.utils import ModuleExporter, logger -from transformers import Trainer -from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME +from sparseml.pytorch.utils import ModuleExporter + from transformers.modeling_outputs import SequenceClassifierOutput -from transformers.models.bert.modeling_bert import BertForSequenceClassification +from transformers.sparse import SparseMLTrainer -class SparseMLGLUETrainer(Trainer): +class SparseMLGLUETrainer(SparseMLTrainer): """ GLUE trainer with SparseML integration @@ -29,82 +21,6 @@ class SparseMLGLUETrainer(Trainer): :param args, kwargs: arguments passed into parent class """ - def __init__( - self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs - ): - super().__init__(*args, **kwargs) - self.model_name_or_path = str(model_name_or_path) - self.recipes = [recipe for recipe in recipes if recipe] - self.teacher = teacher - self.distill_hardness = distill_hardness - self.distill_temperature = distill_temperature - self.criterion = torch.nn.CrossEntropyLoss() - - manager = None - modifiers = [] - for recipe in self.recipes: - manager = ScheduledModifierManager.from_yaml(recipe, modifiers) - modifiers = manager.modifiers - self.manager = manager - - self.loggers = None - if self.recipes is not None: - loggers = [] - if "wandb" in self.args.report_to: - loggers.append(logger.WANDBLogger()) - self.loggers = loggers - - def apply_recipes(self, epoch=0.0): - """ - Apply recipes and sparsification related parameters to the model - """ - if self.manager is not None: - org_state_dict = self.model.state_dict() - self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) - new_state_dict = self.model.state_dict() - new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] - - if os.path.isdir(self.model_name_or_path): - if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): - archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) - state_dict = torch.load(archive_file, map_location="cpu") - new_params_to_init = [p for p in new_params if p in state_dict.keys()] - if new_params_to_init: - # If we're here, the assumption is that all the new parameters introduced - # by the recipes are available to be restore from the checkpoint---this is - # case of evaluating pruned or pruned quantized models - # Otherwise, we're in use cases such as quantizing a block pruned model in which - # new parameters need to be initialized and trained during the QAT process - _, missing_keys, unexpected_keys, _ = BertForSequenceClassification._load_state_dict_into_model( - self.model, state_dict, self.model_name_or_path, _fast_init=False - ) - if missing_keys or unexpected_keys: - raise RuntimeError( - "Unexpected or missing keys detected when applying recipes to models\n" - f"Missing keys: {missing_keys}\n" - f"Unexpected keys: {unexpected_keys}\n" - ) - - def create_optimizer(self): - """ - Create optimizer customized using SparseML - """ - super().create_optimizer() - if not self.recipes: - return - steps_per_epoch = math.ceil( - len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) - ) - self.args.num_train_epochs = float(self.manager.max_epochs) - if hasattr(self, "scaler"): - self.scaler = self.manager.modify( - self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler - ) - else: - self.optimizer = ScheduledOptimizer( - self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers - ) - def compute_loss(self, model, inputs, return_outputs=False): """ Computing loss using teacher/student distillation @@ -127,23 +43,10 @@ def compute_loss(self, model, inputs, return_outputs=False): attention_mask=inputs["attention_mask"], ) logits_teacher = teacher_outputs["logits"] - loss_distill = F.kl_div( input=logits_student, target=logits_teacher, reduction="batchmean",) * (self.temperature ** 2) + teacher_loss = F.kl_div( input=logits_student, target=logits_teacher, reduction="batchmean",) * (self.distill_temperature ** 2) loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) return (loss, outputs) if return_outputs else loss - def save_model(self, output_dir: Optional[str] = None): - """ - Save model during or after training. The sparsification recipe will also be saved. - """ - super().save_model(output_dir=output_dir) - if self.manager is not None: - self._save_recipe(output_dir=output_dir) - - def _save_recipe(self, output_dir: Optional[str] = None): - output_dir = output_dir if output_dir is not None else self.args.output_dir - output_recipe_file = os.path.join(output_dir, RECIPE_NAME) - self.manager.save(output_recipe_file) - class GLUEModuleExporter(ModuleExporter): """ @@ -158,90 +61,3 @@ def get_output_names(self, out: Any): if numpy.any([name for name in expected if name not in out]): raise ValueError("Expected output names not found in model output") return expected - - -def export_model(model, dataloader, output_dir, num_exported_samples): - """ - Export a trained model to ONNX - :param model: trained model - :param dataloader: dataloader to get sample batch - :param output_dir: output directory for ONNX model - """ - exporter = GLUEModuleExporter(model, output_dir=output_dir) - - sess = None - num_samples = 0 - - sample_inputs = os.path.join(output_dir, "sample-inputs") - sample_outputs = os.path.join(output_dir, "sample-outputs") - os.makedirs(sample_inputs, exist_ok=True) - os.makedirs(sample_outputs, exist_ok=True) - - forward_args_spec = inspect.getfullargspec(BertForSequenceClassification.forward) - for _, sample_batch in enumerate(dataloader): - if sess is None: - one_sample_input = collections.OrderedDict( - [(f, sample_batch[f][0].reshape(1, -1)) for f in forward_args_spec.args if f in sample_batch] - ) - - try: - exporter.export_onnx(sample_batch=one_sample_input, convert_qat=True) - onnx_file = os.path.join(output_dir, "model.onnx") - except Exception: - raise RuntimeError("Error exporting ONNX models and/or inputs/outputs") - - sess = onnxruntime.InferenceSession(onnx_file) - - input_names = list(sample_batch.keys()) - output_names = [o.name for o in sess.get_outputs()] - for input_vals in zip(*sample_batch.values()): - input_feed = {k: v.numpy() for k, v in zip(input_names, input_vals)} - output_vals = sess.run(output_names, {k: input_feed[k].reshape(1, -1) for k in input_feed}) - output_dict = {name: numpy.squeeze(val) for name, val in zip(output_names, output_vals)} - file_idx = f"{num_samples}".zfill(4) - numpy.savez(f"{sample_inputs}/inp-{file_idx}.npz", **input_feed) - numpy.savez(f"{sample_outputs}/out-{file_idx}.npz", **output_dict) - num_samples += 1 - if num_samples >= num_exported_samples: - return - - -def preprocess_state_dict(pretrained_model_name_or_path): - """ - Restore original parameter names that were changed by QAT process - """ - state_dict = None - if pretrained_model_name_or_path is not None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): - recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) - manager = ScheduledModifierManager.from_yaml(recipe) - modifiers = [m.__class__.__name__ for m in manager.modifiers] - is_qat_recipe = "QuantizationModifier" in modifiers - if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - state_dict = torch.load(archive_file, map_location="cpu") - removed_keys = ( - [key for key in state_dict if (key.endswith(".module.weight") or key.endswith(".module.bias"))] - if is_qat_recipe - else [] - ) - for key in removed_keys: - new_key = key.replace(".module", "") - state_dict[new_key] = state_dict[key] - state_dict.pop(key) - return state_dict - - -def load_recipe(pretrained_model_name_or_path): - """ - Load recipe from the model directory - """ - recipe = None - if pretrained_model_name_or_path is not None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): - recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) - return recipe From 1d78eb4d987e52b1c399d1c58733ad5a9f62806c Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Fri, 6 Aug 2021 14:52:59 -0500 Subject: [PATCH 09/18] Update run_glue.py updating to match refactored sparseml trainer --- examples/pytorch/text-classification/run_glue.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index adecfb8617f1..4ebd39326a2f 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -41,12 +41,9 @@ default_data_collator, set_seed, ) -from sparseml_utils import ( - SparseMLGLUETrainer, - export_model, - preprocess_state_dict, - load_recipe -) + +from sparseml_utils import GLUEModuleExporter, SparseMLGLUETrainer +from transformers.sparse import export_model, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version From fd4348a1a812f5cf00b4461b57793134bfab30dc Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Tue, 14 Sep 2021 16:03:03 -0500 Subject: [PATCH 10/18] adding code for multi gpu distillation --- .../question-answering/sparseml_utils.py | 45 ++- src/transformers/sparse.py | 258 ++++++++++++++++++ src/transformers/trainer_utils.py | 1 + 3 files changed, 291 insertions(+), 13 deletions(-) create mode 100644 src/transformers/sparse.py diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index eb32511d85f8..b03ca39cf86a 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -35,7 +35,14 @@ def __init__( super().__init__(*args, **kwargs) self.model_name_or_path = str(model_name_or_path) self.recipes = [recipe for recipe in recipes if recipe] - self.teacher = teacher + self.teachers = teacher + self.multi_gpu = False + if torch.cuda.device_count(): + self.multi_gpu = True + self.num_gpus = torch.cuda.device_count() + self.teachers = [teacher for i in range(self.num_gpus)] + for i in range(self.num_gpus): + self.teachers[i] = self.teachers[i].to(i) self.distill_hardness = distill_hardness self.distill_temperature = distill_temperature self.criterion = torch.nn.CrossEntropyLoss() @@ -109,27 +116,39 @@ def compute_loss(self, model, inputs, return_outputs=False): """ Computing loss using teacher/student distillation """ - if not self.recipes or self.teacher is None: + if not self.recipes and self.teachers is None: return super().compute_loss(model, inputs, return_outputs=return_outputs) outputs = model(**inputs) - if self.teacher is None: + if self.teachers is None: loss = outputs["loss"] else: - input_device = inputs["input_ids"].device - self.teacher = self.teacher.to(input_device) start_logits_student = outputs["start_logits"] end_logits_student = outputs["end_logits"] start_logits_label = inputs["start_positions"] end_logits_label = inputs["end_positions"] - with torch.no_grad(): - teacher_output = self.teacher( - input_ids=inputs["input_ids"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"], - ) - start_logits_teacher = teacher_output["start_logits"] - end_logits_teacher = teacher_output["end_logits"] + if self.multi_gpu: + input_ids = torch.split(inputs['input_ids'], int(inputs['input_ids'].shape[0]/self.num_gpus)) + start_logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda') + end_logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda') + for i in range(self.num_gpus): + with torch.no_grad(): + input_device = self.teachers[i].device + teacher_output = self.teachers[i](input_ids[i].to(input_device)) + start_logits_teacher = torch.cat((start_logits_teacher, teacher_output["start_logits"].to('cuda')), dim=0) + end_logits_teacher = torch.cat((end_logits_teacher, teacher_output["end_logits"].to('cuda')), dim=0) + else: # CPU or single GPU + input_device = inputs["input_ids"].device + self.teachers = self.teachers.to(input_device) + with torch.no_grad(): + teacher_output = self.teachers( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + start_logits_teacher = teacher_output["start_logits"] + end_logits_teacher = teacher_output["end_logits"] + loss_start = ( F.kl_div( input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py new file mode 100644 index 000000000000..9abfb148be58 --- /dev/null +++ b/src/transformers/sparse.py @@ -0,0 +1,258 @@ +import collections +import inspect +import math +import os +from typing import Optional + +import numpy +import torch +import torch.nn.functional as F + +import onnxruntime +from sparseml.pytorch.optim.manager import ScheduledModifierManager +from sparseml.pytorch.optim.optimizer import ScheduledOptimizer +from sparseml.pytorch.utils import logger +from transformers import Trainer +from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME + + +class SparseMLTrainer(Trainer): + """ + Trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def __init__( + self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.model_name_or_path = str(model_name_or_path) + self.recipes = [recipe for recipe in recipes if recipe] + self.teachers = teacher + self.multi_gpu = False + if torch.cuda.device_count(): + self.multi_gpu = True + self.num_gpus = torch.cuda.device_count() + self.teachers = [teacher for i in range(self.num_gpus)] + for i in range(self.num_gpus): + self.teachers[i] = self.teachers[i].to(i) + self.distill_hardness = distill_hardness + self.distill_temperature = distill_temperature + self.criterion = torch.nn.CrossEntropyLoss() + + manager = None + modifiers = [] + for recipe in self.recipes: + manager = ScheduledModifierManager.from_yaml(recipe, modifiers) + modifiers = manager.modifiers + self.manager = manager + + self.loggers = None + if self.recipes is not None: + loggers = [] + if "wandb" in self.args.report_to: + loggers.append(logger.WANDBLogger()) + self.loggers = loggers + + def apply_recipes(self, epoch=0.0): + """ + Apply recipes and sparsification related parameters to the model + """ + if self.manager is not None: + org_state_dict = self.model.state_dict() + self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) + new_state_dict = self.model.state_dict() + new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] + + if os.path.isdir(self.model_name_or_path): + if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): + archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(archive_file, map_location="cpu") + new_params_to_init = [p for p in new_params if p in state_dict.keys()] + if new_params_to_init: + # If we're here, the assumption is that all the new parameters introduced + # by the recipes are available to be restore from the checkpoint---this is + # case of evaluating pruned or pruned quantized models + # Otherwise, we're in use cases such as quantizing a block pruned model in which + # new parameters need to be initialized and trained during the QAT process + _, missing_keys, unexpected_keys, _ = self.model._load_state_dict_into_model( + self.model, state_dict, self.model_name_or_path, _fast_init=False + ) + if missing_keys or unexpected_keys: + raise RuntimeError( + "Unexpected or missing keys detected when applying recipes to models\n" + f"Missing keys: {missing_keys}\n" + f"Unexpected keys: {unexpected_keys}\n" + ) + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes and self.teachers is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + outputs = model(**inputs) + if self.teachers is None: + loss = outputs["loss"] + else: + logits_student = outputs["logits"] + if self.multi_gpu: + input_ids = torch.split(inputs['input_ids'], int(inputs['input_ids'].shape[0]/self.num_gpus)) + token_type_ids = torch.split(inputs['token_type_ids'], int(inputs['token_type_ids'].shape[0]/self.num_gpus)) + attention_mask = torch.split(inputs['attention_mask'], int(inputs['attention_mask'].shape[0]/self.num_gpus)) + logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda') + for i in range(self.num_gpus): + with torch.no_grad(): + input_device = self.teachers[i].device + teacher_output = self.teachers[i]( + input_ids=input_ids[i].to(input_device), + token_type_ids=token_type_ids[i].to(input_device), + attention_mask=attention_mask[i].to(input_device) + ) + logits_teacher = torch.cat((logits_teacher, teacher_output["logits"].to('cuda')), dim=0) + else: # CPU or single GPU + input_device = inputs["input_ids"].device + self.teachers = self.teachers.to(input_device) + with torch.no_grad(): + teacher_output = self.teachers( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + logits_teacher = teacher_output["start_logits"] + + teacher_loss = (F.kl_div( + input=F.log_softmax(logits_student / self.distill_temperature, dim=-1), + target=F.softmax(logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + + loss = ((1 - self.distill_hardness) * outputs["loss"]) + (self.distill_hardness * teacher_loss) + return (loss, outputs) if return_outputs else loss + + def create_optimizer(self): + """ + Create optimizer customized using SparseML + """ + super().create_optimizer() + if not self.recipes: + return + steps_per_epoch = math.ceil( + len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) + ) + self.args.num_train_epochs = float(self.manager.max_epochs) + if hasattr(self, "scaler"): + self.scaler = self.manager.modify( + self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler + ) + else: + self.optimizer = ScheduledOptimizer( + self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers + ) + + def save_model(self, output_dir: Optional[str] = None): + """ + Save model during or after training. The sparsification recipe will also be saved. + """ + super().save_model(output_dir=output_dir) + if self.manager is not None: + self._save_recipe(output_dir=output_dir) + + def _save_recipe(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + output_recipe_file = os.path.join(output_dir, RECIPE_NAME) + self.manager.save(output_recipe_file) + + +def export_model(exporter, dataloader, output_dir, num_exported_samples): + """ + Export a trained model to ONNX + :param exporter: a model exporter created from a trained model + :param dataloader: dataloader to get sample batch + :param output_dir: output directory for ONNX model + """ + + sess = None + num_samples = 0 + + sample_inputs = os.path.join(output_dir, "sample-inputs") + sample_outputs = os.path.join(output_dir, "sample-outputs") + os.makedirs(sample_inputs, exist_ok=True) + os.makedirs(sample_outputs, exist_ok=True) + + for _, sample_batch in enumerate(dataloader): + if sess is None: + forward_args_spec = inspect.getfullargspec(exporter._module.__class__.forward) + one_sample_input = collections.OrderedDict( + [(f, sample_batch[f][0].reshape(1, -1)) for f in forward_args_spec.args if f in sample_batch] + ) + + try: + exporter.export_onnx(sample_batch=one_sample_input, convert_qat=True) + onnx_file = os.path.join(output_dir, "model.onnx") + except Exception: + raise RuntimeError("Error exporting ONNX models and/or inputs/outputs") + + sess = onnxruntime.InferenceSession(onnx_file) + + input_names = list(sample_batch.keys()) + output_names = [o.name for o in sess.get_outputs()] + for input_vals in zip(*sample_batch.values()): + input_feed = {k: v.numpy() for k, v in zip(input_names, input_vals)} + output_vals = sess.run(output_names, {k: input_feed[k].reshape(1, -1) for k in input_feed}) + output_dict = {name: numpy.squeeze(val) for name, val in zip(output_names, output_vals)} + file_idx = f"{num_samples}".zfill(4) + numpy.savez(f"{sample_inputs}/inp-{file_idx}.npz", **input_feed) + numpy.savez(f"{sample_outputs}/out-{file_idx}.npz", **output_dict) + num_samples += 1 + if num_samples >= num_exported_samples: + return + + +def preprocess_state_dict(pretrained_model_name_or_path): + """ + Restore original parameter names that were changed by QAT process + """ + state_dict = None + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): + recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) + manager = ScheduledModifierManager.from_yaml(recipe) + modifiers = [m.__class__.__name__ for m in manager.modifiers] + is_qat_recipe = "QuantizationModifier" in modifiers + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(archive_file, map_location="cpu") + removed_keys = ( + [key for key in state_dict if (key.endswith(".module.weight") or key.endswith(".module.bias"))] + if is_qat_recipe + else [] + ) + for key in removed_keys: + new_key = key.replace(".module", "") + state_dict[new_key] = state_dict[key] + state_dict.pop(key) + return state_dict + + +def load_recipe(pretrained_model_name_or_path): + """ + Load recipe from the model directory + """ + recipe = None + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)): + recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME) + return recipe \ No newline at end of file diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 8e02a1ee0ce5..0e8779816de2 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -59,6 +59,7 @@ def set_seed(seed: int): if is_torch_available(): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True # ^^ safe to call this function even if cuda is not available if is_tf_available(): tf.random.set_seed(seed) From e0ab13aa1afab3ffa662561e735b142c5a59f175 Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Tue, 14 Sep 2021 16:07:09 -0500 Subject: [PATCH 11/18] Update run_qa.py --- examples/pytorch/question-answering/run_qa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 2fad54fffe12..178cb9bbf280 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -69,7 +69,7 @@ class ModelArguments: default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} ) distill_hardness: Optional[float] = field( - default=0.5, metadata={"help": "Proportion of loss coming from teacher model."} + default=1.0, metadata={"help": "Proportion of loss coming from teacher model."} ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} From ac9e6536b564cb211c1d7dac4605ed9b91e0998b Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Tue, 14 Sep 2021 16:08:32 -0500 Subject: [PATCH 12/18] Update sparseml_utils.py --- .../text-classification/sparseml_utils.py | 40 ------------------- 1 file changed, 40 deletions(-) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py index 17b128a96527..60ff80952743 100644 --- a/examples/pytorch/text-classification/sparseml_utils.py +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -7,46 +7,6 @@ from sparseml.pytorch.utils import ModuleExporter from transformers.modeling_outputs import SequenceClassifierOutput -from transformers.sparse import SparseMLTrainer - - -class SparseMLGLUETrainer(SparseMLTrainer): - """ - GLUE trainer with SparseML integration - - :param recipe: recipe for model sparsification - :param teacher: teacher model for distillation - :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) - :param distill_temperature: temperature for distillation - :param args, kwargs: arguments passed into parent class - """ - - def compute_loss(self, model, inputs, return_outputs=False): - """ - Computing loss using teacher/student distillation - """ - if not self.recipes or self.teacher is None: - return super().compute_loss(model, inputs, return_outputs=return_outputs) - - outputs = model(**inputs) - if self.teacher is None: - loss = outputs["loss"] - else: - input_device = inputs["input_ids"].device - label_loss = outputs["loss"] - self.teacher = self.teacher.to(input_device) - logits_student = outputs["logits"] - with torch.no_grad(): - teacher_outputs = self.teacher( - input_ids=inputs["input_ids"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"], - ) - logits_teacher = teacher_outputs["logits"] - teacher_loss = F.kl_div( input=logits_student, target=logits_teacher, reduction="batchmean",) * (self.distill_temperature ** 2) - loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) - return (loss, outputs) if return_outputs else loss - class GLUEModuleExporter(ModuleExporter): """ From ca45eec5fc280582f6622f48a6437d456f29b3ac Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Tue, 14 Sep 2021 16:09:22 -0500 Subject: [PATCH 13/18] Update run_glue.py --- examples/pytorch/text-classification/run_glue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 4ebd39326a2f..eb99090c1fde 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -488,7 +488,7 @@ def compute_metrics(p: EvalPrediction): new_recipe = data_args.recipe # Initialize our Trainer - trainer = SparseMLGLUETrainer( + trainer = SparseMLTrainer( model_args.model_name_or_path, [existing_recipe, new_recipe], teacher=teacher_model, From 6f798792243fbfe569f8605709b275ffd0198dfd Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Tue, 14 Sep 2021 16:10:00 -0500 Subject: [PATCH 14/18] Update run_glue.py --- examples/pytorch/text-classification/run_glue.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index eb99090c1fde..204f454115d9 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -42,8 +42,8 @@ set_seed, ) -from sparseml_utils import GLUEModuleExporter, SparseMLGLUETrainer -from transformers.sparse import export_model, load_recipe, preprocess_state_dict +from sparseml_utils import GLUEModuleExpo +from transformers.sparse import export_model, SparseMLTrainer, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version From 3e91e9aa07595a4ef2b984feb1826eb9c91a2cef Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Wed, 15 Sep 2021 11:30:04 -0500 Subject: [PATCH 15/18] minor updates --- examples/pytorch/question-answering/run_qa.py | 2 ++ examples/pytorch/question-answering/sparseml_utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index f0633f3778df..95fcd3f59ae0 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -50,6 +50,8 @@ from transformers.utils import check_min_version from utils_qa import postprocess_qa_predictions +import wandb +wandb.init(project="sparse-transfer-downstream-qa-daniel") # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.7.0.dev0") diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index b03ca39cf86a..55357fb61b91 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -37,7 +37,7 @@ def __init__( self.recipes = [recipe for recipe in recipes if recipe] self.teachers = teacher self.multi_gpu = False - if torch.cuda.device_count(): + if torch.cuda.device_count() and teacher != None: self.multi_gpu = True self.num_gpus = torch.cuda.device_count() self.teachers = [teacher for i in range(self.num_gpus)] From 519edb82d217726567e5cd637e714b1e15750de1 Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Wed, 15 Sep 2021 11:49:24 -0500 Subject: [PATCH 16/18] fixing merge --- .../question-answering/sparseml_utils.py | 87 ------------------- 1 file changed, 87 deletions(-) diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index 71d9eee5b14e..6ea5e1593191 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -20,94 +20,7 @@ class SparseMLQATrainer(SparseMLTrainer, QuestionAnsweringTrainer): :param distill_temperature: temperature for distillation :param args, kwargs: arguments passed into parent class """ -<<<<<<< HEAD - - def __init__( - self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs - ): - super().__init__(*args, **kwargs) - self.model_name_or_path = str(model_name_or_path) - self.recipes = [recipe for recipe in recipes if recipe] - self.teachers = teacher - self.multi_gpu = False - if torch.cuda.device_count() and teacher != None: - self.multi_gpu = True - self.num_gpus = torch.cuda.device_count() - self.teachers = [teacher for i in range(self.num_gpus)] - for i in range(self.num_gpus): - self.teachers[i] = self.teachers[i].to(i) - self.distill_hardness = distill_hardness - self.distill_temperature = distill_temperature - self.criterion = torch.nn.CrossEntropyLoss() - - manager = None - modifiers = [] - for recipe in self.recipes: - manager = ScheduledModifierManager.from_yaml(recipe, modifiers) - modifiers = manager.modifiers - self.manager = manager - - self.loggers = None - if self.recipes is not None: - loggers = [] - if "wandb" in self.args.report_to: - loggers.append(logger.WANDBLogger()) - self.loggers = loggers - - def apply_recipes(self, epoch=0.0): - """ - Apply recipes and sparsification related parameters to the model - """ - if self.manager is not None: - org_state_dict = self.model.state_dict() - self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) - new_state_dict = self.model.state_dict() - new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] - - if os.path.isdir(self.model_name_or_path): - if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): - archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) - state_dict = torch.load(archive_file, map_location="cpu") - new_params_to_init = [p for p in new_params if p in state_dict.keys()] - if new_params_to_init: - # If we're here, the assumption is that all the new parameters introduced - # by the recipes are available to be restore from the checkpoint---this is - # case of evaluating pruned or pruned quantized models - # Otherwise, we're in use cases such as quantizing a block pruned model in which - # new parameters need to be initialized and trained during the QAT process - _, missing_keys, unexpected_keys, _ = BertForQuestionAnswering._load_state_dict_into_model( - self.model, state_dict, self.model_name_or_path, _fast_init=False - ) - if missing_keys or unexpected_keys: - raise RuntimeError( - "Unexpected or missing keys detected when applying recipes to models\n" - f"Missing keys: {missing_keys}\n" - f"Unexpected keys: {unexpected_keys}\n" - ) - - def create_optimizer(self): - """ - Create optimizer customized using SparseML - """ - super().create_optimizer() - if not self.recipes: - return - steps_per_epoch = math.ceil( - len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) - ) - self.args.num_train_epochs = float(self.manager.max_epochs) - if hasattr(self, "scaler"): - self.scaler = self.manager.modify( - self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler - ) - else: - self.optimizer = ScheduledOptimizer( - self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers - ) - -======= ->>>>>>> 6f798792243fbfe569f8605709b275ffd0198dfd def compute_loss(self, model, inputs, return_outputs=False): """ Computing loss using teacher/student distillation From 001aba71317510cdfe02158df78457a15c0d0160 Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Wed, 15 Sep 2021 11:52:50 -0500 Subject: [PATCH 17/18] fix error that will try to go multi gpu without distill teacher --- src/transformers/sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index a642e30a0165..8bcc3d0be80c 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -35,7 +35,7 @@ def __init__( self.recipes = [recipe for recipe in recipes if recipe] self.teachers = teacher self.multi_gpu = False - if torch.cuda.device_count(): + if torch.cuda.device_count() and teacher != None: self.multi_gpu = True self.num_gpus = torch.cuda.device_count() self.teachers = [teacher for i in range(self.num_gpus)] From 547300b692eaed4e4a7dfe5662bd8f015ecb8c2e Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Thu, 16 Sep 2021 15:23:12 -0500 Subject: [PATCH 18/18] mnli and qqp fix --- examples/pytorch/text-classification/run_glue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 204f454115d9..81d37c7a164a 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -42,7 +42,7 @@ set_seed, ) -from sparseml_utils import GLUEModuleExpo +from sparseml_utils import GLUEModuleExporter from transformers.sparse import export_model, SparseMLTrainer, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version