Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 85 additions & 5 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
import numpy
from datasets import concatenate_datasets, load_dataset

import transformers
from sparseml_utils import MaskedLanguageModelingModuleExporter, SparseMLMaskedLanguageModelingTrainer
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
Expand All @@ -39,10 +41,10 @@
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.sparse import export_model, load_recipe
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version

Expand Down Expand Up @@ -72,6 +74,9 @@ class ModelArguments:
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained QA model"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -105,12 +110,34 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""

recipe: Optional[str] = field(
default=None,
metadata={
"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
"for more information"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
num_exported_samples: Optional[int] = field(
default=20, metadata={"help": "Number of exported samples, default to 20"}
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)

# An extra second dataset
dataset_name_2: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name_2: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)

train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
Expand Down Expand Up @@ -266,6 +293,30 @@ def main():
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

# Load extra dataset if specified, and concatenate with the original one
if data_args.dataset_name_2 is not None:
# Downloading and loading a dataset from the hub.
datasets_2 = load_dataset(
data_args.dataset_name_2, data_args.dataset_config_name_2, cache_dir=model_args.cache_dir
)
if "validation" not in datasets_2.keys():
datasets_2["validation"] = load_dataset(
data_args.dataset_name_2,
data_args.dataset_config_name_2,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets_2["train"] = load_dataset(
data_args.dataset_name_2,
data_args.dataset_config_name_2,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
# Concatenate two datasets
if datasets is not None:
for split in ["validation", "train"]:
datasets[split] = concatenate_datasets([datasets[split], datasets_2[split]])

# Load pretrained model and tokenizer
#
# Distributed training:
Expand Down Expand Up @@ -299,7 +350,6 @@ def main():
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

if model_args.model_name_or_path:
model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
Expand All @@ -315,6 +365,17 @@ def main():

model.resize_token_embeddings(len(tokenizer))

teacher_model = None
if model_args.distill_teacher is not None:
teacher_model = AutoModelForMaskedLM.from_pretrained(
model_args.distill_teacher,
from_tf=bool(".ckpt" in model_args.distill_teacher),
cache_dir=model_args.cache_dir,
)
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
params = sum([numpy.prod(p.size()) for p in teacher_model_parameters])
logger.info("Teacher Model has %s parameters", params)

# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
Expand Down Expand Up @@ -400,7 +461,6 @@ def group_texts(examples):
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
Expand Down Expand Up @@ -431,16 +491,30 @@ def group_texts(examples):
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
)

# Load possible existing recipe and new one passed in through command argument
existing_recipe = load_recipe(model_args.model_name_or_path)
new_recipe = data_args.recipe

compute_metrics = None
# Initialize our Trainer
trainer = Trainer(
trainer = SparseMLMaskedLanguageModelingTrainer(
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)

# Apply recipes to the model. This is necessary given that
# sparsification methods such as QAT modified the model graph with their own learnable
# parameters. They are also restored/loaded to the model.
trainer.apply_recipes()

# Training
if training_args.do_train:
checkpoint = None
Expand Down Expand Up @@ -490,6 +564,12 @@ def group_texts(examples):

trainer.push_to_hub(**kwargs)

if data_args.onnx_export_path:
logger.info("*** Export to ONNX ***")
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
exporter = MaskedLanguageModelingModuleExporter(model, output_dir=data_args.onnx_export_path)
export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down
55 changes: 55 additions & 0 deletions examples/pytorch/language-modeling/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any

import numpy

from sparseml.pytorch.utils import ModuleExporter
from transformers.sparse import SparseMLTrainer


class SparseMLMaskedLanguageModelingTrainer(SparseMLTrainer):
"""
Masked language model trainer with SparseML integration

:param recipe: recipe for model sparsification
:param teacher: teacher model for distillation
:param distill_hardness: ratio of loss by teacher targets (between 0 and 1)
:param distill_temperature: temperature for distillation
:param args, kwargs: arguments passed into parent class
"""

def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if not self.recipes or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)
student_outputs = model(**inputs)
loss = student_outputs["loss"]

steps_in_epoch = -1 # Unused
loss = self.manager.loss_update(
loss,
model,
self.optimizer,
self.state.epoch,
steps_in_epoch,
global_step=self.state.global_step,
student_outputs=student_outputs,
teacher_inputs=inputs,
)
return (loss, student_outputs) if return_outputs else loss


class MaskedLanguageModelingModuleExporter(ModuleExporter):
"""
Module exporter class for Masked Language Modeling
"""

@classmethod
def get_output_names(self, out: Any):
# if not isinstance(out, QuestionAnsweringModelOutput):
# raise ValueError("Expected QuestionAnsweringModelOutput, got {type(out)}")
expected = ["logits"]
if numpy.any([name for name in expected if name not in out]):
raise ValueError("Expected output names not found in model output")
return expected
1 change: 1 addition & 0 deletions src/transformers/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def export_model(exporter, dataloader, output_dir, num_exported_samples):
os.makedirs(sample_outputs, exist_ok=True)

for _, sample_batch in enumerate(dataloader):
sample_batch.pop("labels", None)
if sess is None:
forward_args_spec = inspect.getfullargspec(exporter._module.__class__.forward)
one_sample_input = collections.OrderedDict(
Expand Down