diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index f0f69f9e39b3..22a32e09025c 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -29,6 +29,7 @@ from datasets import ClassLabel, load_dataset, load_metric import transformers +from sparseml_utils import TokenClassificationModuleExporter from transformers import ( AutoConfig, AutoModelForTokenClassification, @@ -36,10 +37,10 @@ DataCollatorForTokenClassification, HfArgumentParser, PreTrainedTokenizerFast, - Trainer, TrainingArguments, set_seed, ) +from transformers.sparse import SparseMLTrainer, export_model, load_recipe, preprocess_state_dict from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -59,6 +60,9 @@ class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) + distill_teacher: Optional[str] = field( + default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -88,6 +92,19 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + recipe: Optional[str] = field( + default=None, + metadata={ + "help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml " + "for more information" + }, + ) + onnx_export_path: Optional[str] = field( + default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} + ) + num_exported_samples: Optional[int] = field( + default=20, metadata={"help": "Number of exported samples, default to 20"} + ) task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} @@ -278,6 +295,12 @@ def get_label_list(labels): # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. + + # Load and preprocess the state dict if the model existed (in this case we continue to train or + # evaluate the model). The preprocessing step is to restore names of parameters changed by + # QAT process. + state_dict = preprocess_state_dict(model_args.model_name_or_path) + config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, @@ -300,8 +323,20 @@ def get_label_list(labels): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + state_dict=state_dict, ) + teacher_model = None + if model_args.distill_teacher is not None: + teacher_model = AutoModelForTokenClassification.from_pretrained( + model_args.distill_teacher, + from_tf=bool(".ckpt" in model_args.distill_teacher), + cache_dir=model_args.cache_dir, + ) + teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters()) + params = sum([np.prod(p.size()) for p in teacher_model_parameters]) + logger.info("Teacher Model has %s parameters", params) + # Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( @@ -424,8 +459,15 @@ def compute_metrics(p): "accuracy": results["overall_accuracy"], } + # Load possible existing recipe and new one passed in through command argument + existing_recipe = load_recipe(model_args.model_name_or_path) + new_recipe = data_args.recipe + # Initialize our Trainer - trainer = Trainer( + trainer = SparseMLTrainer( + model_args.model_name_or_path, + [existing_recipe, new_recipe], + teacher=teacher_model, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, @@ -435,6 +477,11 @@ def compute_metrics(p): compute_metrics=compute_metrics, ) + # Apply recipes to the model. This is necessary given that + # sparsification methods such as QAT modified the model graph with their own learnable + # parameters. They are also restored/loaded to the model. + trainer.apply_recipes() + # Training if training_args.do_train: checkpoint = None @@ -502,6 +549,12 @@ def compute_metrics(p): trainer.push_to_hub(**kwargs) + if data_args.onnx_export_path: + logger.info("*** Export to ONNX ***") + eval_dataloader = trainer.get_eval_dataloader(eval_dataset) + exporter = TokenClassificationModuleExporter(model, output_dir=data_args.onnx_export_path) + export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/pytorch/token-classification/sparseml_utils.py b/examples/pytorch/token-classification/sparseml_utils.py new file mode 100644 index 000000000000..200c68d680a2 --- /dev/null +++ b/examples/pytorch/token-classification/sparseml_utils.py @@ -0,0 +1,21 @@ +from typing import Any + +import numpy + +from sparseml.pytorch.utils import ModuleExporter +from transformers.modeling_outputs import TokenClassifierOutput + + +class TokenClassificationModuleExporter(ModuleExporter): + """ + Module exporter class for Question Answering + """ + + @classmethod + def get_output_names(self, out: Any): + if not isinstance(out, TokenClassifierOutput): + raise ValueError(f"Expected TokenClassifierOutput, got {type(out)}") + expected = ["logits"] + if numpy.any([name for name in expected if name not in out]): + raise ValueError("Expected output names not found in model output") + return expected diff --git a/setup.py b/setup.py index f1d2cd01b86c..4d4eb1a5def0 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,7 @@ "black==21.4b0", "cookiecutter==1.7.2", "dataclasses", - "datasets<1.13.0", + "datasets", "deepspeed>=0.3.16", "docutils==0.16.0", "fairscale>0.3", @@ -99,7 +99,7 @@ "flake8>=3.8.3", "flax>=0.3.2", "fugashi>=1.0", - "huggingface-hub==0.0.8", + "huggingface-hub", "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 4326a589d65f..566c2a3976dc 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -14,6 +14,7 @@ import dataclasses import json +import os import re import sys from argparse import ArgumentParser, ArgumentTypeError @@ -21,6 +22,14 @@ from pathlib import Path from typing import Any, Iterable, List, NewType, Optional, Tuple, Union +from .utils.logging import get_logger + +from sparsezoo import Zoo +from sparsezoo.requests.base import ZOO_STUB_PREFIX + + +logger = get_logger(__name__) + DataClass = NewType("DataClass", Any) DataClassType = NewType("DataClassType", Any) @@ -190,12 +199,17 @@ def parse_args_into_dataclasses( # additional namespace. outputs.append(namespace) if return_remaining_strings: - return (*outputs, remaining_args) + return tuple( + *[_download_dataclass_zoo_stub_files(output) for output in outputs], + remaining_args, + ) else: if remaining_args: raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: """ @@ -209,7 +223,9 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: inputs = {k: v for k, v in data.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: """ @@ -222,4 +238,38 @@ def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: inputs = {k: v for k, v in args.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) + + +def _download_dataclass_zoo_stub_files(data_class: DataClass): + for name, val in data_class.__dict__.items(): + if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"): + continue + + logger.info(f"Downloading framework files for SparseZoo stub: {val}") + + zoo_model = Zoo.load_model_from_stub(val) + framework_file_paths = zoo_model.download_framework_files() + assert framework_file_paths, ( + "Unable to download any framework files for SparseZoo stub {val}" + ) + framework_file_names = [os.path.basename(path) for path in framework_file_paths] + if "pytorch_model.bin" not in framework_file_names or ( + "config.json" not in framework_file_names + ): + raise RuntimeError( + "Unable to find 'pytorch_model.bin' and 'config.json' in framework " + f"files downloaded from {val}. Found {framework_file_names}. Check " + "if the given stub is for a transformers repo model" + ) + framework_dir_path = Path(framework_file_paths[0]).parent.absolute() + + logger.info( + f"Overwriting argument {name} to downloaded {framework_dir_path}" + ) + + data_class.__dict__[name] = str(framework_dir_path) + + return data_class diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index 575aa515731b..386bc93de764 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -125,6 +125,28 @@ def qat_active(self, epoch: int): return qat_start < epoch + 1 + def compute_loss(self, model, inputs, return_outputs=False): + """ + Computing loss using teacher/student distillation + """ + if not self.recipes or self.teacher is None: + return super().compute_loss(model, inputs, return_outputs=return_outputs) + student_outputs = model(**inputs) + loss = student_outputs["loss"] + + steps_in_epoch = -1 # Unused + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + steps_in_epoch, + global_step=self.state.global_step, + student_outputs=student_outputs, + teacher_inputs=inputs, + ) + return (loss, student_outputs) if return_outputs else loss + def save_model(self, output_dir: Optional[str] = None): """ Save model during or after training. Modifiers that change the model architecture will also be saved.