diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 27155208be5f..0cdef4787b4c 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -24,10 +24,11 @@ from dataclasses import dataclass, field from typing import Optional +import numpy from datasets import load_dataset, load_metric import transformers -from trainer_qa import QuestionAnsweringTrainer +from sparseml_utils import SparseMLQATrainer, export_model from transformers import ( AutoConfig, AutoModelForQuestionAnswering, @@ -56,10 +57,18 @@ class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. """ - model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) + distill_teacher: Optional[str] = field( + default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} + ) + distill_temperature: Optional[float] = field( + default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} + ) + distill_hardness: Optional[float] = field( + default=1.0, metadata={"help": "Proportion of loss coming from teacher model."} + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -89,6 +98,14 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + recipe: Optional[str] = field( + default=None, + metadata={"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml " + "for more information"}, + ) + onnx_export_path: Optional[str] = field( + default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) @@ -300,6 +317,18 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + teacher_model = None + if model_args.distill_teacher is not None: + teacher_model = AutoModelForQuestionAnswering.from_pretrained( + model_args.distill_teacher, + from_tf=bool(".ckpt" in model_args.distill_teacher), + config=config, + cache_dir=model_args.cache_dir, + ) + teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) + params = sum([numpy.prod(p.size()) for p in teacher_model_parameters]) + logger.info("Teacher Model has %s parameters", params) + # Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( @@ -543,7 +572,11 @@ def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) # Initialize our Trainer - trainer = QuestionAnsweringTrainer( + trainer = SparseMLQATrainer( + data_args.recipe, + teacher=teacher_model, + distill_hardness=model_args.distill_hardness, + distill_temperature=model_args.distill_temperature, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, @@ -612,6 +645,11 @@ def compute_metrics(p: EvalPrediction): trainer.push_to_hub(**kwargs) + if data_args.onnx_export_path: + logger.info("*** Export to ONNX ***") + eval_dataloader = trainer.get_eval_dataloader(eval_dataset) + export_model(model, eval_dataloader, data_args.onnx_export_path) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py new file mode 100644 index 000000000000..ca30f7c61954 --- /dev/null +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -0,0 +1,121 @@ +import math + +import torch +import torch.nn.functional as F + +from sparseml.pytorch.optim.manager import ScheduledModifierManager +from sparseml.pytorch.optim.optimizer import ScheduledOptimizer +from sparseml.pytorch.utils import ModuleExporter, logger +from trainer_qa import QuestionAnsweringTrainer + + +class SparseMLQATrainer(QuestionAnsweringTrainer): + """ + Question Answering trainer with SparseML integration + + :param recipe: recipe for model sparsification + :param teacher: teacher model for distillation + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) + :param distill_temperature: temperature for distillation + :param args, kwargs: arguments passed into parent class + """ + + def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.recipe = recipe + self.teacher = teacher + self.distill_hardness = distill_hardness + self.distill_temperature = distill_temperature + self.criterion = torch.nn.CrossEntropyLoss() + + self.manager = None + self.loggers = None + if self.recipe is not None: + loggers = [] + if "wandb" in self.args.report_to: + loggers.append(logger.WANDBLogger()) + self.loggers = loggers + + def create_optimizer(self): + """ + Create optimizer customized using SparseML + """ + super().create_optimizer() + if self.recipe is None: + return + steps_per_epoch = math.ceil( + len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) + ) + self.manager = ScheduledModifierManager.from_yaml(self.recipe) + self.args.num_train_epochs = float(self.manager.max_epochs) + if hasattr(self, "scaler"): + self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers) + self.scaler = self.manager.modify( + self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler + ) + else: + self.optimizer = ScheduledOptimizer( + self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers + ) + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if self.recipe is None or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + outputs = model(**inputs) + if self.teacher is None: + loss = outputs["loss"] + else: + input_device = inputs["input_ids"].device + self.teacher = self.teacher.to(input_device) + start_logits_student = outputs["start_logits"] + end_logits_student = outputs["end_logits"] + start_logits_label = inputs["start_positions"] + end_logits_label = inputs["end_positions"] + with torch.no_grad(): + teacher_output = self.teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + start_logits_teacher = teacher_output["start_logits"] + end_logits_teacher = teacher_output["end_logits"] + loss_start = ( + F.kl_div( + input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), + target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + loss_end = ( + F.kl_div( + input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1), + target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + teacher_loss = (loss_start + loss_end) / 2.0 + loss_start = self.criterion(start_logits_student, start_logits_label) + loss_end = self.criterion(end_logits_student, end_logits_label) + label_loss = (loss_start + loss_end) / 2.0 + loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) + return (loss, outputs) if return_outputs else loss + + +def export_model(model, dataloader, output_dir): + """ + Export a trained model to ONNX + :param model: trained model + :param dataloader: dataloader to get sample batch + :param output_dir: output directory for ONNX model + """ + exporter = ModuleExporter(model, output_dir=output_dir) + for _, sample_batch in enumerate(dataloader): + sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"]) + exporter.export_onnx(sample_batch=sample_input, convert_qat=True) + break