Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions examples/pytorch/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@
from datasets import ClassLabel, load_dataset, load_metric

import transformers
from sparseml_utils import TokenClassificationModuleExporter
from transformers import (
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
HfArgumentParser,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.sparse import SparseMLTrainer, export_model, load_recipe, preprocess_state_dict
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version

Expand All @@ -59,6 +60,9 @@ class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained QA model"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -88,6 +92,19 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""

recipe: Optional[str] = field(
default=None,
metadata={
"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
"for more information"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
num_exported_samples: Optional[int] = field(
default=20, metadata={"help": "Number of exported samples, default to 20"}
)
task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
Expand Down Expand Up @@ -278,6 +295,12 @@ def get_label_list(labels):
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

# Load and preprocess the state dict if the model existed (in this case we continue to train or
# evaluate the model). The preprocessing step is to restore names of parameters changed by
# QAT process.
state_dict = preprocess_state_dict(model_args.model_name_or_path)

config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
Expand All @@ -300,8 +323,20 @@ def get_label_list(labels):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
state_dict=state_dict,
)

teacher_model = None
if model_args.distill_teacher is not None:
teacher_model = AutoModelForTokenClassification.from_pretrained(
model_args.distill_teacher,
from_tf=bool(".ckpt" in model_args.distill_teacher),
cache_dir=model_args.cache_dir,
)
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
params = sum([np.prod(p.size()) for p in teacher_model_parameters])
logger.info("Teacher Model has %s parameters", params)

# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
Expand Down Expand Up @@ -424,8 +459,15 @@ def compute_metrics(p):
"accuracy": results["overall_accuracy"],
}

# Load possible existing recipe and new one passed in through command argument
existing_recipe = load_recipe(model_args.model_name_or_path)
new_recipe = data_args.recipe

# Initialize our Trainer
trainer = Trainer(
trainer = SparseMLTrainer(
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
Expand All @@ -435,6 +477,11 @@ def compute_metrics(p):
compute_metrics=compute_metrics,
)

# Apply recipes to the model. This is necessary given that
# sparsification methods such as QAT modified the model graph with their own learnable
# parameters. They are also restored/loaded to the model.
trainer.apply_recipes()

# Training
if training_args.do_train:
checkpoint = None
Expand Down Expand Up @@ -502,6 +549,12 @@ def compute_metrics(p):

trainer.push_to_hub(**kwargs)

if data_args.onnx_export_path:
logger.info("*** Export to ONNX ***")
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
exporter = TokenClassificationModuleExporter(model, output_dir=data_args.onnx_export_path)
export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down
21 changes: 21 additions & 0 deletions examples/pytorch/token-classification/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any

import numpy

from sparseml.pytorch.utils import ModuleExporter
from transformers.modeling_outputs import TokenClassifierOutput


class TokenClassificationModuleExporter(ModuleExporter):
"""
Module exporter class for Question Answering
"""

@classmethod
def get_output_names(self, out: Any):
if not isinstance(out, TokenClassifierOutput):
raise ValueError(f"Expected TokenClassifierOutput, got {type(out)}")
expected = ["logits"]
if numpy.any([name for name in expected if name not in out]):
raise ValueError("Expected output names not found in model output")
return expected
22 changes: 22 additions & 0 deletions src/transformers/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,28 @@ def qat_active(self, epoch: int):

return qat_start < epoch + 1

def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if not self.recipes or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)
student_outputs = model(**inputs)
loss = student_outputs["loss"]

steps_in_epoch = -1 # Unused
loss = self.manager.loss_update(
loss,
model,
self.optimizer,
self.state.epoch,
steps_in_epoch,
global_step=self.state.global_step,
student_outputs=student_outputs,
teacher_inputs=inputs,
)
return (loss, student_outputs) if return_outputs else loss

def save_model(self, output_dir: Optional[str] = None):
"""
Save model during or after training. Modifiers that change the model architecture will also be saved.
Expand Down