Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class DataTrainingArguments:
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
pad_to_max_length: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -423,7 +427,12 @@ def preprocess_function(examples):
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result

datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
datasets = datasets.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_train:
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
Expand Down
57 changes: 55 additions & 2 deletions examples/pytorch/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@
from datasets import ClassLabel, load_dataset, load_metric

import transformers
from sparseml_utils import TokenClassificationModuleExporter
from transformers import (
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
HfArgumentParser,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.sparse import SparseMLTrainer, export_model, load_recipe, preprocess_state_dict
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version

Expand All @@ -59,6 +60,9 @@ class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained QA model"}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should say NER model

)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -88,6 +92,19 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""

recipe: Optional[str] = field(
default=None,
metadata={
"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
"for more information"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
num_exported_samples: Optional[int] = field(
default=20, metadata={"help": "Number of exported samples, default to 20"}
)
task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
Expand Down Expand Up @@ -278,6 +295,12 @@ def get_label_list(labels):
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

# Load and preprocess the state dict if the model existed (in this case we continue to train or
# evaluate the model). The preprocessing step is to restore names of parameters changed by
# QAT process.
state_dict = preprocess_state_dict(model_args.model_name_or_path)

config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
Expand All @@ -300,8 +323,20 @@ def get_label_list(labels):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
state_dict=state_dict,
)

teacher_model = None
if model_args.distill_teacher is not None:
teacher_model = AutoModelForTokenClassification.from_pretrained(
model_args.distill_teacher,
from_tf=bool(".ckpt" in model_args.distill_teacher),
cache_dir=model_args.cache_dir,
)
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
params = sum([np.prod(p.size()) for p in teacher_model_parameters])
logger.info("Teacher Model has %s parameters", params)

# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
Expand Down Expand Up @@ -424,8 +459,15 @@ def compute_metrics(p):
"accuracy": results["overall_accuracy"],
}

# Load possible existing recipe and new one passed in through command argument
existing_recipe = load_recipe(model_args.model_name_or_path)
new_recipe = data_args.recipe

# Initialize our Trainer
trainer = Trainer(
trainer = SparseMLTrainer(
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
Expand All @@ -435,6 +477,11 @@ def compute_metrics(p):
compute_metrics=compute_metrics,
)

# Apply recipes to the model. This is necessary given that
# sparsification methods such as QAT modified the model graph with their own learnable
# parameters. They are also restored/loaded to the model.
trainer.apply_recipes()

# Training
if training_args.do_train:
checkpoint = None
Expand Down Expand Up @@ -502,6 +549,12 @@ def compute_metrics(p):

trainer.push_to_hub(**kwargs)

if data_args.onnx_export_path:
logger.info("*** Export to ONNX ***")
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
exporter = TokenClassificationModuleExporter(model, output_dir=data_args.onnx_export_path)
export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down
21 changes: 21 additions & 0 deletions examples/pytorch/token-classification/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any

import numpy

from sparseml.pytorch.utils import ModuleExporter
from transformers.modeling_outputs import TokenClassifierOutput


class TokenClassificationModuleExporter(ModuleExporter):
"""
Module exporter class for Question Answering

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question answering when it should be token classification or named entity resolution

"""

@classmethod
def get_output_names(self, out: Any):
if not isinstance(out, TokenClassifierOutput):
raise ValueError(f"Expected TokenClassifierOutput, got {type(out)}")
expected = ["logits"]
if numpy.any([name for name in expected if name not in out]):
raise ValueError("Expected output names not found in model output")
return expected
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"black==21.4b0",
"cookiecutter==1.7.2",
"dataclasses",
"datasets<1.13.0",
"datasets",
"deepspeed>=0.3.16",
"docutils==0.16.0",
"fairscale>0.3",
Expand All @@ -99,7 +99,7 @@
"flake8>=3.8.3",
"flax>=0.3.2",
"fugashi>=1.0",
"huggingface-hub==0.0.8",
"huggingface-hub",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
Expand Down
58 changes: 54 additions & 4 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@

import dataclasses
import json
import os
import re
import sys
from argparse import ArgumentParser, ArgumentTypeError
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union

from .utils.logging import get_logger

from sparsezoo import Zoo
from sparsezoo.requests.base import ZOO_STUB_PREFIX


logger = get_logger(__name__)


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -190,12 +199,17 @@ def parse_args_into_dataclasses(
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
return tuple(
*[_download_dataclass_zoo_stub_files(output) for output in outputs],
remaining_args,
)
else:
if remaining_args:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")

return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Expand All @@ -209,7 +223,9 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Expand All @@ -222,4 +238,38 @@ def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
inputs = {k: v for k, v in args.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)


def _download_dataclass_zoo_stub_files(data_class: DataClass):
for name, val in data_class.__dict__.items():
if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"):
continue

logger.info(f"Downloading framework files for SparseZoo stub: {val}")

zoo_model = Zoo.load_model_from_stub(val)
framework_file_paths = zoo_model.download_framework_files()
assert framework_file_paths, (
"Unable to download any framework files for SparseZoo stub {val}"
)
framework_file_names = [os.path.basename(path) for path in framework_file_paths]
if "pytorch_model.bin" not in framework_file_names or (
"config.json" not in framework_file_names
):
raise RuntimeError(
"Unable to find 'pytorch_model.bin' and 'config.json' in framework "
f"files downloaded from {val}. Found {framework_file_names}. Check "
"if the given stub is for a transformers repo model"
)
framework_dir_path = Path(framework_file_paths[0]).parent.absolute()

logger.info(
f"Overwriting argument {name} to downloaded {framework_dir_path}"
)

data_class.__dict__[name] = str(framework_dir_path)

return data_class
22 changes: 22 additions & 0 deletions src/transformers/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,28 @@ def qat_active(self, epoch: int):

return qat_start < epoch + 1

def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if not self.recipes or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)
student_outputs = model(**inputs)
loss = student_outputs["loss"]

steps_in_epoch = -1 # Unused
loss = self.manager.loss_update(
loss,
model,
self.optimizer,
self.state.epoch,
steps_in_epoch,
global_step=self.state.global_step,
student_outputs=student_outputs,
teacher_inputs=inputs,
)
return (loss, student_outputs) if return_outputs else loss

def save_model(self, output_dir: Optional[str] = None):
"""
Save model during or after training. Modifiers that change the model architecture will also be saved.
Expand Down