From cdb6dba63e13db4372454b212a5cdde5702fb897 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Tue, 5 Oct 2021 00:47:42 -0400 Subject: [PATCH 1/2] Integrate SparseML with Masked LM training --- examples/pytorch/language-modeling/run_mlm.py | 85 +++++++++++++++++-- .../language-modeling/sparseml_utils.py | 61 +++++++++++++ src/transformers/sparse.py | 1 + 3 files changed, 142 insertions(+), 5 deletions(-) create mode 100644 examples/pytorch/language-modeling/sparseml_utils.py diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 60d315ef5fca..886573fab5a5 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -28,9 +28,11 @@ from dataclasses import dataclass, field from typing import Optional -from datasets import load_dataset +import numpy +from datasets import concatenate_datasets, load_dataset import transformers +from sparseml_utils import MaskedLanguageModelingModuleExporter, SparseMLMaskedLanguageModelingTrainer from transformers import ( CONFIG_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, @@ -39,10 +41,10 @@ AutoTokenizer, DataCollatorForLanguageModeling, HfArgumentParser, - Trainer, TrainingArguments, set_seed, ) +from transformers.sparse import export_model, load_recipe from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -72,6 +74,9 @@ class ModelArguments: default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) + distill_teacher: Optional[str] = field( + default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -105,12 +110,34 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + recipe: Optional[str] = field( + default=None, + metadata={ + "help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml " + "for more information" + }, + ) + onnx_export_path: Optional[str] = field( + default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) + num_exported_samples: Optional[int] = field( + default=20, metadata={"help": "Number of exported samples, default to 20"} + ) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) + + # An extra second dataset + dataset_name_2: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name_2: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, @@ -266,6 +293,30 @@ def main(): # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. + # Load extra dataset if specified, and concatenate with the original one + if data_args.dataset_name_2 is not None: + # Downloading and loading a dataset from the hub. + datasets_2 = load_dataset( + data_args.dataset_name_2, data_args.dataset_config_name_2, cache_dir=model_args.cache_dir + ) + if "validation" not in datasets_2.keys(): + datasets_2["validation"] = load_dataset( + data_args.dataset_name_2, + data_args.dataset_config_name_2, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets_2["train"] = load_dataset( + data_args.dataset_name_2, + data_args.dataset_config_name_2, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + # Concatenate two datasets + if datasets is not None: + for split in ["validation", "train"]: + datasets[split] = concatenate_datasets([datasets[split], datasets_2[split]]) + # Load pretrained model and tokenizer # # Distributed training: @@ -299,7 +350,6 @@ def main(): "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) - if model_args.model_name_or_path: model = AutoModelForMaskedLM.from_pretrained( model_args.model_name_or_path, @@ -315,6 +365,17 @@ def main(): model.resize_token_embeddings(len(tokenizer)) + teacher_model = None + if model_args.distill_teacher is not None: + teacher_model = AutoModelForMaskedLM.from_pretrained( + model_args.distill_teacher, + from_tf=bool(".ckpt" in model_args.distill_teacher), + cache_dir=model_args.cache_dir, + ) + teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) + params = sum([numpy.prod(p.size()) for p in teacher_model_parameters]) + logger.info("Teacher Model has %s parameters", params) + # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: @@ -400,7 +461,6 @@ def group_texts(examples): # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, @@ -431,14 +491,23 @@ def group_texts(examples): pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, ) + # Load possible existing recipe and new one passed in through command argument + existing_recipe = load_recipe(model_args.model_name_or_path) + new_recipe = data_args.recipe + + compute_metrics = None # Initialize our Trainer - trainer = Trainer( + trainer = SparseMLMaskedLanguageModelingTrainer( + model_args.model_name_or_path, + [existing_recipe, new_recipe], + teacher=teacher_model, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, + compute_metrics=compute_metrics, ) # Training @@ -490,6 +559,12 @@ def group_texts(examples): trainer.push_to_hub(**kwargs) + if data_args.onnx_export_path: + logger.info("*** Export to ONNX ***") + eval_dataloader = trainer.get_eval_dataloader(eval_dataset) + exporter = MaskedLanguageModelingModuleExporter(model, output_dir=data_args.onnx_export_path) + export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/language-modeling/sparseml_utils.py b/examples/pytorch/language-modeling/sparseml_utils.py new file mode 100644 index 000000000000..b1365265eac9 --- /dev/null +++ b/examples/pytorch/language-modeling/sparseml_utils.py @@ -0,0 +1,61 @@ +from typing import Any + +import numpy +import torch + +from sparseml.pytorch.utils import ModuleExporter, device_of +from transformers.sparse import SparseMLTrainer + + +class SparseMLMaskedLanguageModelingTrainer(SparseMLTrainer): + """ + Masked language model trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + student_outputs = model(**inputs) + loss = student_outputs["loss"] + + target_device = device_of(inputs) + self.teacher.to(target_device) + with torch.no_grad(): + teacher_outputs = self.teacher(**inputs) + + steps_in_epoch = -1 # Unused + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + steps_in_epoch, + global_step=self.state.global_step, + student_outputs=student_outputs, + teacher_outputs=teacher_outputs, + ) + return (loss, student_outputs) if return_outputs else loss + + +class MaskedLanguageModelingModuleExporter(ModuleExporter): + """ + Module exporter class for Masked Language Modeling + """ + + @classmethod + def get_output_names(self, out: Any): + # if not isinstance(out, QuestionAnsweringModelOutput): + # raise ValueError("Expected QuestionAnsweringModelOutput, got {type(out)}") + expected = ["logits"] + if numpy.any([name for name in expected if name not in out]): + raise ValueError("Expected output names not found in model output") + return expected diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index a2a6c2497d87..3a29b4f63722 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -157,6 +157,7 @@ def export_model(exporter, dataloader, output_dir, num_exported_samples): os.makedirs(sample_outputs, exist_ok=True) for _, sample_batch in enumerate(dataloader): + sample_batch.pop("labels", None) if sess is None: forward_args_spec = inspect.getfullargspec(exporter._module.__class__.forward) one_sample_input = collections.OrderedDict( From f847c46cf0559a396faff75678d625eed66842fd Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 6 Oct 2021 16:16:30 -0400 Subject: [PATCH 2/2] Move teacher's logits out of compute_loss --- examples/pytorch/language-modeling/run_mlm.py | 5 +++++ examples/pytorch/language-modeling/sparseml_utils.py | 10 ++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 886573fab5a5..f75935c03521 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -510,6 +510,11 @@ def group_texts(examples): compute_metrics=compute_metrics, ) + # Apply recipes to the model. This is necessary given that + # sparsification methods such as QAT modified the model graph with their own learnable + # parameters. They are also restored/loaded to the model. + trainer.apply_recipes() + # Training if training_args.do_train: checkpoint = None diff --git a/examples/pytorch/language-modeling/sparseml_utils.py b/examples/pytorch/language-modeling/sparseml_utils.py index b1365265eac9..e8964657f562 100644 --- a/examples/pytorch/language-modeling/sparseml_utils.py +++ b/examples/pytorch/language-modeling/sparseml_utils.py @@ -1,9 +1,8 @@ from typing import Any import numpy -import torch -from sparseml.pytorch.utils import ModuleExporter, device_of +from sparseml.pytorch.utils import ModuleExporter from transformers.sparse import SparseMLTrainer @@ -27,11 +26,6 @@ def compute_loss(self, model, inputs, return_outputs=False): student_outputs = model(**inputs) loss = student_outputs["loss"] - target_device = device_of(inputs) - self.teacher.to(target_device) - with torch.no_grad(): - teacher_outputs = self.teacher(**inputs) - steps_in_epoch = -1 # Unused loss = self.manager.loss_update( loss, @@ -41,7 +35,7 @@ def compute_loss(self, model, inputs, return_outputs=False): steps_in_epoch, global_step=self.state.global_step, student_outputs=student_outputs, - teacher_outputs=teacher_outputs, + teacher_inputs=inputs, ) return (loss, student_outputs) if return_outputs else loss