From af5a56d9255816dd8031f56833ab4a767bf9d6d3 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Fri, 15 Oct 2021 12:43:52 -0400 Subject: [PATCH 1/2] SparseML integ for NER --- .../pytorch/token-classification/run_ner.py | 57 ++++++++++++++++++- .../token-classification/sparseml_utils.py | 56 ++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 examples/pytorch/token-classification/sparseml_utils.py diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index f0f69f9e39b3..4203022b3120 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -29,6 +29,7 @@ from datasets import ClassLabel, load_dataset, load_metric import transformers +from sparseml_utils import SparseMLTokenClassificationTrainer, TokenClassificationModuleExporter from transformers import ( AutoConfig, AutoModelForTokenClassification, @@ -36,10 +37,10 @@ DataCollatorForTokenClassification, HfArgumentParser, PreTrainedTokenizerFast, - Trainer, TrainingArguments, set_seed, ) +from transformers.sparse import export_model, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -59,6 +60,9 @@ class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) + distill_teacher: Optional[str] = field( + default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -88,6 +92,19 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + recipe: Optional[str] = field( + default=None, + metadata={ + "help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml " + "for more information" + }, + ) + onnx_export_path: Optional[str] = field( + default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) + num_exported_samples: Optional[int] = field( + default=20, metadata={"help": "Number of exported samples, default to 20"} + ) task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} @@ -278,6 +295,12 @@ def get_label_list(labels): # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. + + # Load and preprocess the state dict if the model existed (in this case we continue to train or + # evaluate the model). The preprocessing step is to restore names of parameters changed by + # QAT process. + state_dict = preprocess_state_dict(model_args.model_name_or_path) + config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, @@ -300,8 +323,20 @@ def get_label_list(labels): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + state_dict=state_dict, ) + teacher_model = None + if model_args.distill_teacher is not None: + teacher_model = AutoModelForTokenClassification.from_pretrained( + model_args.distill_teacher, + from_tf=bool(".ckpt" in model_args.distill_teacher), + cache_dir=model_args.cache_dir, + ) + teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) + params = sum([np.prod(p.size()) for p in teacher_model_parameters]) + logger.info("Teacher Model has %s parameters", params) + # Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( @@ -424,8 +459,15 @@ def compute_metrics(p): "accuracy": results["overall_accuracy"], } + # Load possible existing recipe and new one passed in through command argument + existing_recipe = load_recipe(model_args.model_name_or_path) + new_recipe = data_args.recipe + # Initialize our Trainer - trainer = Trainer( + trainer = SparseMLTokenClassificationTrainer( + model_args.model_name_or_path, + [existing_recipe, new_recipe], + teacher=teacher_model, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, @@ -435,6 +477,11 @@ def compute_metrics(p): compute_metrics=compute_metrics, ) + # Apply recipes to the model. This is necessary given that + # sparsification methods such as QAT modified the model graph with their own learnable + # parameters. They are also restored/loaded to the model. + trainer.apply_recipes() + # Training if training_args.do_train: checkpoint = None @@ -502,6 +549,12 @@ def compute_metrics(p): trainer.push_to_hub(**kwargs) + if data_args.onnx_export_path: + logger.info("*** Export to ONNX ***") + eval_dataloader = trainer.get_eval_dataloader(eval_dataset) + exporter = TokenClassificationModuleExporter(model, output_dir=data_args.onnx_export_path) + export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/token-classification/sparseml_utils.py b/examples/pytorch/token-classification/sparseml_utils.py new file mode 100644 index 000000000000..4210a88743f5 --- /dev/null +++ b/examples/pytorch/token-classification/sparseml_utils.py @@ -0,0 +1,56 @@ +from typing import Any + +import numpy + +from sparseml.pytorch.utils import ModuleExporter +from transformers.modeling_outputs import TokenClassifierOutput +from transformers.sparse import SparseMLTrainer + + +class SparseMLTokenClassificationTrainer(SparseMLTrainer): + """ + Token Classification trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + student_outputs = model(**inputs) + loss = student_outputs["loss"] + + steps_in_epoch = -1 # Unused + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + steps_in_epoch, + global_step=self.state.global_step, + student_outputs=student_outputs, + teacher_inputs=inputs, + ) + return (loss, student_outputs) if return_outputs else loss + + +class TokenClassificationModuleExporter(ModuleExporter): + """ + Module exporter class for Question Answering + """ + + @classmethod + def get_output_names(self, out: Any): + if not isinstance(out, TokenClassifierOutput): + raise ValueError(f"Expected TokenClassifierOutput, got {type(out)}") + expected = ["logits"] + if numpy.any([name for name in expected if name not in out]): + raise ValueError("Expected output names not found in model output") + return expected From 9202624b5dea2edaf79ba262f2368183ce5edf13 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 20 Oct 2021 09:41:27 -0400 Subject: [PATCH 2/2] Reuse SparseMLTrainer for token classification --- .../pytorch/token-classification/run_ner.py | 6 ++-- .../token-classification/sparseml_utils.py | 35 ------------------- src/transformers/sparse.py | 22 ++++++++++++ 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index 4203022b3120..22a32e09025c 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -29,7 +29,7 @@ from datasets import ClassLabel, load_dataset, load_metric import transformers -from sparseml_utils import SparseMLTokenClassificationTrainer, TokenClassificationModuleExporter +from sparseml_utils import TokenClassificationModuleExporter from transformers import ( AutoConfig, AutoModelForTokenClassification, @@ -40,7 +40,7 @@ TrainingArguments, set_seed, ) -from transformers.sparse import export_model, load_recipe, preprocess_state_dict +from transformers.sparse import SparseMLTrainer, export_model, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -464,7 +464,7 @@ def compute_metrics(p): new_recipe = data_args.recipe # Initialize our Trainer - trainer = SparseMLTokenClassificationTrainer( + trainer = SparseMLTrainer( model_args.model_name_or_path, [existing_recipe, new_recipe], teacher=teacher_model, diff --git a/examples/pytorch/token-classification/sparseml_utils.py b/examples/pytorch/token-classification/sparseml_utils.py index 4210a88743f5..200c68d680a2 100644 --- a/examples/pytorch/token-classification/sparseml_utils.py +++ b/examples/pytorch/token-classification/sparseml_utils.py @@ -4,41 +4,6 @@ from sparseml.pytorch.utils import ModuleExporter from transformers.modeling_outputs import TokenClassifierOutput -from transformers.sparse import SparseMLTrainer - - -class SparseMLTokenClassificationTrainer(SparseMLTrainer): - """ - Token Classification trainer with SparseML integration - - :param recipe: recipe for model sparsification - :param teacher: teacher model for distillation - :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) - :param distill_temperature: temperature for distillation - :param args, kwargs: arguments passed into parent class - """ - - def compute_loss(self, model, inputs, return_outputs=False): - """ - Computing loss using teacher/student distillation - """ - if not self.recipes or self.teacher is None: - return super().compute_loss(model, inputs, return_outputs=return_outputs) - student_outputs = model(**inputs) - loss = student_outputs["loss"] - - steps_in_epoch = -1 # Unused - loss = self.manager.loss_update( - loss, - model, - self.optimizer, - self.state.epoch, - steps_in_epoch, - global_step=self.state.global_step, - student_outputs=student_outputs, - teacher_inputs=inputs, - ) - return (loss, student_outputs) if return_outputs else loss class TokenClassificationModuleExporter(ModuleExporter): diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index 575aa515731b..386bc93de764 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -125,6 +125,28 @@ def qat_active(self, epoch: int): return qat_start < epoch + 1 + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + student_outputs = model(**inputs) + loss = student_outputs["loss"] + + steps_in_epoch = -1 # Unused + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + steps_in_epoch, + global_step=self.state.global_step, + student_outputs=student_outputs, + teacher_inputs=inputs, + ) + return (loss, student_outputs) if return_outputs else loss + def save_model(self, output_dir: Optional[str] = None): """ Save model during or after training. Modifiers that change the model architecture will also be saved.