diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 2fb4ce3523a0..38124dd8ac84 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -46,6 +46,8 @@ from transformers.utils import check_min_version from utils_qa import postprocess_qa_predictions +import wandb +wandb.init(project="sparse-transfer-downstream-qa-daniel") # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.7.0.dev0") @@ -217,7 +219,6 @@ def __post_init__(self): extension = self.test_file.split(".")[-1] assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index 0f84c562e6af..6ea5e1593191 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -20,32 +20,44 @@ class SparseMLQATrainer(SparseMLTrainer, QuestionAnsweringTrainer): :param distill_temperature: temperature for distillation :param args, kwargs: arguments passed into parent class """ - + def compute_loss(self, model, inputs, return_outputs=False): """ Computing loss using teacher/student distillation """ - if not self.recipes or self.teacher is None: + if not self.recipes and self.teachers is None: return super().compute_loss(model, inputs, return_outputs=return_outputs) outputs = model(**inputs) - if self.teacher is None: + if self.teachers is None: loss = outputs["loss"] else: - input_device = inputs["input_ids"].device - self.teacher = self.teacher.to(input_device) start_logits_student = outputs["start_logits"] end_logits_student = outputs["end_logits"] start_logits_label = inputs["start_positions"] end_logits_label = inputs["end_positions"] - with torch.no_grad(): - teacher_output = self.teacher( - input_ids=inputs["input_ids"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"], - ) - start_logits_teacher = teacher_output["start_logits"] - end_logits_teacher = teacher_output["end_logits"] + if self.multi_gpu: + input_ids = torch.split(inputs['input_ids'], int(inputs['input_ids'].shape[0]/self.num_gpus)) + start_logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda') + end_logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda') + for i in range(self.num_gpus): + with torch.no_grad(): + input_device = self.teachers[i].device + teacher_output = self.teachers[i](input_ids[i].to(input_device)) + start_logits_teacher = torch.cat((start_logits_teacher, teacher_output["start_logits"].to('cuda')), dim=0) + end_logits_teacher = torch.cat((end_logits_teacher, teacher_output["end_logits"].to('cuda')), dim=0) + else: # CPU or single GPU + input_device = inputs["input_ids"].device + self.teachers = self.teachers.to(input_device) + with torch.no_grad(): + teacher_output = self.teachers( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + start_logits_teacher = teacher_output["start_logits"] + end_logits_teacher = teacher_output["end_logits"] + loss_start = ( F.kl_div( input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 1b08def9c62f..81d37c7a164a 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -23,6 +23,7 @@ from dataclasses import dataclass, field from typing import Optional +import wandb import numpy as np from datasets import load_dataset, load_metric @@ -40,13 +41,13 @@ default_data_collator, set_seed, ) + +from sparseml_utils import GLUEModuleExporter +from transformers.sparse import export_model, SparseMLTrainer, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.7.0.dev0") - task_to_keys = { "cola": ("sentence", None), "mnli": ("premise", "hypothesis"), @@ -71,7 +72,17 @@ class DataTrainingArguments: into argparse arguments to be able to specify them on the command line. """ - + recipe: Optional[str] = field( + default=None, + metadata={"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml " + "for more information"}, + ) + onnx_export_path: Optional[str] = field( + default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) + num_exported_samples: Optional[int] = field( + default=20, metadata={"help": "Number of exported samples, default to 20"} + ) task_name: Optional[str] = field( default=None, metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, @@ -155,6 +166,15 @@ class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) + distill_teacher: Optional[str] = field( + default=None, metadata={"help": "Teacher model which needs to be a trained text classification model"} + ) + distill_temperature: Optional[float] = field( + default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} + ) + distill_hardness: Optional[float] = field( + default=1.0, metadata={"help": "Proportion of loss coming from teacher model."} + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -305,6 +325,13 @@ def main(): # # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. + + # Load and preprocess the state dict if the model existed (in this case we continue to train or + # evaluate the model). The preprocessing step is to restore names of parameters changed by + # QAT process + state_dict = preprocess_state_dict(model_args.model_name_or_path) + + config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, @@ -327,8 +354,19 @@ def main(): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + state_dict=state_dict, ) + teacher_model = None + if model_args.distill_teacher is not None: + teacher_model = AutoModelForSequenceClassification.from_pretrained( + model_args.distill_teacher, + from_tf=bool(".ckpt" in model_args.distill_teacher), + cache_dir=model_args.cache_dir, + ) + teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) + params = sum([np.prod(p.size()) for p in teacher_model_parameters]) + logger.info("Teacher Model has %s parameters", params) # Preprocessing the datasets if data_args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[data_args.task_name] @@ -445,17 +483,31 @@ def compute_metrics(p: EvalPrediction): else: data_collator = None + # Load possible existing recipe and new one passed in through command argument + existing_recipe = load_recipe(model_args.model_name_or_path) + new_recipe = data_args.recipe + # Initialize our Trainer - trainer = Trainer( + trainer = SparseMLTrainer( + model_args.model_name_or_path, + [existing_recipe, new_recipe], + teacher=teacher_model, + distill_hardness=model_args.distill_hardness, + distill_temperature=model_args.distill_temperature, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, - compute_metrics=compute_metrics, tokenizer=tokenizer, data_collator=data_collator, + compute_metrics=compute_metrics, ) + # Apply recipes to the model. This is necessary given that + # sparsification methods such as QAT modified the model graph with their own learnable + # parameters. They are also restored/loaded to the model. + trainer.apply_recipes() + # Training if training_args.do_train: checkpoint = None @@ -536,6 +588,11 @@ def compute_metrics(p: EvalPrediction): trainer.push_to_hub(**kwargs) + if data_args.onnx_export_path: + logger.info("*** Export to ONNX ***") + eval_dataloader = trainer.get_eval_dataloader(eval_dataset) + export_model(model, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/text-classification/sparseml_utils.py b/examples/pytorch/text-classification/sparseml_utils.py new file mode 100644 index 000000000000..60ff80952743 --- /dev/null +++ b/examples/pytorch/text-classification/sparseml_utils.py @@ -0,0 +1,23 @@ +from typing import Any + +import numpy +import torch +import torch.nn.functional as F + +from sparseml.pytorch.utils import ModuleExporter + +from transformers.modeling_outputs import SequenceClassifierOutput + +class GLUEModuleExporter(ModuleExporter): + """ + Module exporter class for Sequence Classification + """ + + @classmethod + def get_output_names(self, out: Any): + if not isinstance(out, SequenceClassifierOutput): + raise ValueError("Expected SequenceClassifierOutput, got {type(out)}") + expected = ["logits"] + if numpy.any([name for name in expected if name not in out]): + raise ValueError("Expected output names not found in model output") + return expected diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index a2a6c2497d87..8bcc3d0be80c 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -6,6 +6,7 @@ import numpy import torch +import torch.nn.functional as F import onnxruntime from sparseml.pytorch.optim.manager import ScheduledModifierManager @@ -32,7 +33,14 @@ def __init__( super().__init__(*args, **kwargs) self.model_name_or_path = str(model_name_or_path) self.recipes = [recipe for recipe in recipes if recipe] - self.teacher = teacher + self.teachers = teacher + self.multi_gpu = False + if torch.cuda.device_count() and teacher != None: + self.multi_gpu = True + self.num_gpus = torch.cuda.device_count() + self.teachers = [teacher for i in range(self.num_gpus)] + for i in range(self.num_gpus): + self.teachers[i] = self.teachers[i].to(i) self.distill_hardness = distill_hardness self.distill_temperature = distill_temperature self.criterion = torch.nn.CrossEntropyLoss() @@ -82,6 +90,54 @@ def apply_recipes(self, epoch=0.0): f"Unexpected keys: {unexpected_keys}\n" ) + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes and self.teachers is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + outputs = model(**inputs) + if self.teachers is None: + loss = outputs["loss"] + else: + logits_student = outputs["logits"] + if self.multi_gpu: + input_ids = torch.split(inputs['input_ids'], int(inputs['input_ids'].shape[0]/self.num_gpus)) + token_type_ids = torch.split(inputs['token_type_ids'], int(inputs['token_type_ids'].shape[0]/self.num_gpus)) + attention_mask = torch.split(inputs['attention_mask'], int(inputs['attention_mask'].shape[0]/self.num_gpus)) + logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda') + for i in range(self.num_gpus): + with torch.no_grad(): + input_device = self.teachers[i].device + teacher_output = self.teachers[i]( + input_ids=input_ids[i].to(input_device), + token_type_ids=token_type_ids[i].to(input_device), + attention_mask=attention_mask[i].to(input_device) + ) + logits_teacher = torch.cat((logits_teacher, teacher_output["logits"].to('cuda')), dim=0) + else: # CPU or single GPU + input_device = inputs["input_ids"].device + self.teachers = self.teachers.to(input_device) + with torch.no_grad(): + teacher_output = self.teachers( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + logits_teacher = teacher_output["start_logits"] + + teacher_loss = (F.kl_div( + input=F.log_softmax(logits_student / self.distill_temperature, dim=-1), + target=F.softmax(logits_teacher / self.distill_temperature, dim=-1), + reduction="batchmean", + ) + * (self.distill_temperature ** 2) + ) + + loss = ((1 - self.distill_hardness) * outputs["loss"]) + (self.distill_hardness * teacher_loss) + return (loss, outputs) if return_outputs else loss + def create_optimizer(self): """ Create optimizer customized using SparseML @@ -102,6 +158,7 @@ def create_optimizer(self): self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers ) + def create_scheduler(self, num_training_steps: int): """ Override LR scheduler if the SparseML manager has LR modifiers, otherwise diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 8e02a1ee0ce5..0e8779816de2 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -59,6 +59,7 @@ def set_seed(seed: int): if is_torch_available(): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True # ^^ safe to call this function even if cuda is not available if is_tf_available(): tf.random.set_seed(seed)