From 1803e4814ba29d28386578c2e328cfb2ce21210e Mon Sep 17 00:00:00 2001 From: spacemanidol Date: Thu, 29 Jul 2021 14:57:59 -0500 Subject: [PATCH 1/2] Setting up Text Classification for SparseML --- examples/pytorch/question-answering/run_qa.py | 3 +- .../pytorch/text-classification/run_glue.py | 58 +++++++++++++++- .../text-classification/sparseml_utils.py | 67 +++++++++++++++++++ 3 files changed, 123 insertions(+), 5 deletions(-) create mode 100644 examples/pytorch/text-classification/sparseml_utils.py diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 2fb4ce3523a0..2fad54fffe12 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -69,7 +69,7 @@ class ModelArguments: default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} ) distill_hardness: Optional[float] = field( - default=1.0, metadata={"help": "Proportion of loss coming from teacher model."} + default=0.5, metadata={"help": "Proportion of loss coming from teacher model."} ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} @@ -217,7 +217,6 @@ def __post_init__(self): extension = self.test_file.split(".")[-1] assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 1b08def9c62f..356d0fa0b057 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -27,6 +27,7 @@ from datasets import load_dataset, load_metric import transformers +from sparseml_utils import GLUEModuleExporter, SparseMLGLUETrainer from transformers import ( AutoConfig, AutoModelForSequenceClassification, @@ -35,11 +36,11 @@ EvalPrediction, HfArgumentParser, PretrainedConfig, - Trainer, TrainingArguments, default_data_collator, set_seed, ) +from transformers.sparse import export_model, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -72,6 +73,19 @@ class DataTrainingArguments: the command line. """ + recipe: Optional[str] = field( + default=None, + metadata={ + "help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml " + "for more information" + }, + ) + onnx_export_path: Optional[str] = field( + default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) + num_exported_samples: Optional[int] = field( + default=20, metadata={"help": "Number of exported samples, default to 20"} + ) task_name: Optional[str] = field( default=None, metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, @@ -155,6 +169,9 @@ class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) + distill_teacher: Optional[str] = field( + default=None, metadata={"help": "Teacher model which needs to be a trained text classification model"} + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -305,6 +322,12 @@ def main(): # # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. + + # Load and preprocess the state dict if the model existed (in this case we continue to train or + # evaluate the model). The preprocessing step is to restore names of parameters changed by + # QAT process + state_dict = preprocess_state_dict(model_args.model_name_or_path) + config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, @@ -327,8 +350,19 @@ def main(): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + state_dict=state_dict, ) + teacher_model = None + if model_args.distill_teacher is not None: + teacher_model = AutoModelForSequenceClassification.from_pretrained( + model_args.distill_teacher, + from_tf=bool(".ckpt" in model_args.distill_teacher), + cache_dir=model_args.cache_dir, + ) + teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) + params = sum([np.prod(p.size()) for p in teacher_model_parameters]) + logger.info("Teacher Model has %s parameters", params) # Preprocessing the datasets if data_args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[data_args.task_name] @@ -445,17 +479,29 @@ def compute_metrics(p: EvalPrediction): else: data_collator = None + # Load possible existing recipe and new one passed in through command argument + existing_recipe = load_recipe(model_args.model_name_or_path) + new_recipe = data_args.recipe + # Initialize our Trainer - trainer = Trainer( + trainer = SparseMLGLUETrainer( + model_args.model_name_or_path, + [existing_recipe, new_recipe], + teacher=teacher_model, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, - compute_metrics=compute_metrics, tokenizer=tokenizer, data_collator=data_collator, + compute_metrics=compute_metrics, ) + # Apply recipes to the model. This is necessary given that + # sparsification methods such as QAT modified the model graph with their own learnable + # parameters. They are also restored/loaded to the model. + trainer.apply_recipes() + # Training if training_args.do_train: checkpoint = None @@ -536,6 +582,12 @@ def compute_metrics(p: EvalPrediction): trainer.push_to_hub(**kwargs) + if data_args.onnx_export_path: + logger.info("*** Export to ONNX ***") + eval_dataloader = trainer.get_eval_dataloader(eval_dataset) + exporter = GLUEModuleExporter(model, output_dir=data_args.onnx_export_path) + export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py new file mode 100644 index 000000000000..3ff9b9d78c0b --- /dev/null +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -0,0 +1,67 @@ +from typing import Any + +import numpy +import torch + +from sparseml.pytorch.utils import ModuleExporter, device_of + +from transformers.modeling_outputs import SequenceClassifierOutput +from transformers.sparse import SparseMLTrainer + + +class SparseMLGLUETrainer(SparseMLTrainer): + """ + GLUE trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + student_outputs = model(**inputs) + loss = student_outputs["loss"] + + target_device = device_of(inputs) + self.teacher.to(target_device) + with torch.no_grad(): + teacher_outputs = self.teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + steps_in_epoch = -1 # Unused + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + steps_in_epoch, + global_step=self.state.global_step, + student_outputs=student_outputs, + teacher_outputs=teacher_outputs, + ) + return (loss, student_outputs) if return_outputs else loss + + +class GLUEModuleExporter(ModuleExporter): + """ + Module exporter class for Sequence Classification + """ + + @classmethod + def get_output_names(self, out: Any): + if not isinstance(out, SequenceClassifierOutput): + raise ValueError("Expected SequenceClassifierOutput, got {type(out)}") + expected = ["logits"] + if numpy.any([name for name in expected if name not in out]): + raise ValueError("Expected output names not found in model output") + return expected From 9927cf8112103ad6ab9defbf3c7c70ee35f5528a Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 6 Oct 2021 16:23:28 -0400 Subject: [PATCH 2/2] Move teacher's logic out of compute_loss for GLUE --- .../text-classification/sparseml_utils.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py index 3ff9b9d78c0b..8d882dd6016c 100644 --- a/examples/pytorch/text-classification/sparseml_utils.py +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -1,10 +1,8 @@ from typing import Any import numpy -import torch - -from sparseml.pytorch.utils import ModuleExporter, device_of +from sparseml.pytorch.utils import ModuleExporter from transformers.modeling_outputs import SequenceClassifierOutput from transformers.sparse import SparseMLTrainer @@ -30,14 +28,9 @@ def compute_loss(self, model, inputs, return_outputs=False): student_outputs = model(**inputs) loss = student_outputs["loss"] - target_device = device_of(inputs) - self.teacher.to(target_device) - with torch.no_grad(): - teacher_outputs = self.teacher( - input_ids=inputs["input_ids"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"], - ) + teacher_input_keys = ["input_ids", "token_type_ids", "attention_mask"] + teacher_inputs = {k: inputs[k] for k in teacher_input_keys} + steps_in_epoch = -1 # Unused loss = self.manager.loss_update( loss, @@ -47,7 +40,7 @@ def compute_loss(self, model, inputs, return_outputs=False): steps_in_epoch, global_step=self.state.global_step, student_outputs=student_outputs, - teacher_outputs=teacher_outputs, + teacher_inputs=teacher_inputs, ) return (loss, student_outputs) if return_outputs else loss @@ -60,7 +53,7 @@ class GLUEModuleExporter(ModuleExporter): @classmethod def get_output_names(self, out: Any): if not isinstance(out, SequenceClassifierOutput): - raise ValueError("Expected SequenceClassifierOutput, got {type(out)}") + raise ValueError(f"Expected SequenceClassifierOutput, got {type(out)}") expected = ["logits"] if numpy.any([name for name in expected if name not in out]): raise ValueError("Expected output names not found in model output")