Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from datasets import load_dataset, load_metric

import transformers
from sparseml_utils import SparseMLQATrainer, export_model
from sparseml_utils import (
SparseMLQATrainer,
export_model,
preprocess_state_dict,
load_recipe
)
from transformers import (
AutoConfig,
AutoModelForQuestionAnswering,
Expand Down Expand Up @@ -311,13 +316,20 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)

# Load and preprocess the state dict if the model existed (in this case we continue to train or
# evaluate the model). The preprocessing step is to restore names of parameters changed by
# QAT process.
state_dict = preprocess_state_dict(model_args.model_name_or_path)

model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
state_dict=state_dict
)

teacher_model = None
Expand Down Expand Up @@ -573,9 +585,14 @@ def post_processing_function(examples, features, predictions, stage="eval"):
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)

# Load possible existing recipe and new one passed in through command argument
existing_recipe = load_recipe(model_args.model_name_or_path)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there possible overlap between the recipes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible, but it won't break with the current use cases (eval and train dense, pruned and quantized BERT). IMO, This should better be addressed at the manager level code.

new_recipe = data_args.recipe

# Initialize our Trainer
trainer = SparseMLQATrainer(
data_args.recipe,
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
distill_hardness=model_args.distill_hardness,
distill_temperature=model_args.distill_temperature,
Expand All @@ -590,6 +607,11 @@ def compute_metrics(p: EvalPrediction):
compute_metrics=compute_metrics,
)

# Apply recipes to the model. This is necessary given that
# sparsification methods such as QAT modified the model graph with their own learnable
# parameters. They are also restored/loaded to the model.
trainer.apply_recipes()

# Training
if training_args.do_train:
checkpoint = None
Expand Down
114 changes: 104 additions & 10 deletions examples/pytorch/question-answering/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import inspect
import collections
import inspect
import math
import os
from typing import Any
from typing import Any, Optional

import numpy
import torch
Expand All @@ -13,6 +13,7 @@
from sparseml.pytorch.optim.optimizer import ScheduledOptimizer
from sparseml.pytorch.utils import ModuleExporter, logger
from trainer_qa import QuestionAnsweringTrainer
from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from transformers.models.bert.modeling_bert import BertForQuestionAnswering

Expand All @@ -28,36 +29,74 @@ class SparseMLQATrainer(QuestionAnsweringTrainer):
:param args, kwargs: arguments passed into parent class
"""

def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs):
def __init__(
self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.recipe = recipe
self.model_name_or_path = str(model_name_or_path)
self.recipes = [recipe for recipe in recipes if recipe]
self.teacher = teacher
self.distill_hardness = distill_hardness
self.distill_temperature = distill_temperature
self.criterion = torch.nn.CrossEntropyLoss()

self.manager = None
manager = None
modifiers = []
for recipe in self.recipes:
manager = ScheduledModifierManager.from_yaml(recipe, modifiers)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a good reason you need to pass in an empty list for modifiers just to then set it to manage.modifiers?
Could it not be manager = ScheduledModifierManager.from_yaml(recipe, None) and also remove line 44?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manager.modifiers hold a list of modifiers accumulated from recipes that were "consumed" so far. We need to add the list back when constructing the manager with the next recipe.

modifiers = manager.modifiers
Comment on lines +45 to +47
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

self.manager = manager

self.loggers = None
if self.recipe is not None:
if self.recipes is not None:
loggers = []
if "wandb" in self.args.report_to:
loggers.append(logger.WANDBLogger())
self.loggers = loggers

def apply_recipes(self, epoch=0.0):
"""
Apply recipes and sparsification related parameters to the model
"""
if self.manager is not None:
org_state_dict = self.model.state_dict()
self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers)
new_state_dict = self.model.state_dict()
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]

if os.path.isdir(self.model_name_or_path):
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
state_dict = torch.load(archive_file, map_location="cpu")
new_params_to_init = [p for p in new_params if p in state_dict.keys()]
if new_params_to_init:
# If we're here, the assumption is that all the new parameters introduced
# by the recipes are available to be restore from the checkpoint---this is
# case of evaluating pruned or pruned quantized models
# Otherwise, we're in use cases such as quantizing a block pruned model in which
# new parameters need to be initialized and trained during the QAT process
_, missing_keys, unexpected_keys, _ = BertForQuestionAnswering._load_state_dict_into_model(
self.model, state_dict, self.model_name_or_path, _fast_init=False
)
if missing_keys or unexpected_keys:
raise RuntimeError(
"Unexpected or missing keys detected when applying recipes to models\n"
f"Missing keys: {missing_keys}\n"
f"Unexpected keys: {unexpected_keys}\n"
)

def create_optimizer(self):
"""
Create optimizer customized using SparseML
"""
super().create_optimizer()
if self.recipe is None:
if not self.recipes:
return
steps_per_epoch = math.ceil(
len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu)
)
self.manager = ScheduledModifierManager.from_yaml(self.recipe)
self.args.num_train_epochs = float(self.manager.max_epochs)
if hasattr(self, "scaler"):
self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers)
self.scaler = self.manager.modify(
self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler
)
Expand All @@ -70,7 +109,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if self.recipe is None or self.teacher is None:
if not self.recipes or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)

outputs = model(**inputs)
Expand Down Expand Up @@ -114,11 +153,25 @@ def compute_loss(self, model, inputs, return_outputs=False):
loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
return (loss, outputs) if return_outputs else loss

def save_model(self, output_dir: Optional[str] = None):
"""
Save model during or after training. The sparsification recipe will also be saved.
"""
super().save_model(output_dir=output_dir)
if self.manager is not None:
self._save_recipe(output_dir=output_dir)

def _save_recipe(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
output_recipe_file = os.path.join(output_dir, RECIPE_NAME)
self.manager.save(output_recipe_file)


class QuestionAnsweringModuleExporter(ModuleExporter):
"""
Module exporter class for Question Answering
"""

@classmethod
def get_output_names(self, out: Any):
if not isinstance(out, QuestionAnsweringModelOutput):
Expand Down Expand Up @@ -173,3 +226,44 @@ def export_model(model, dataloader, output_dir, num_exported_samples):
num_samples += 1
if num_samples >= num_exported_samples:
return


def preprocess_state_dict(pretrained_model_name_or_path):
"""
Restore original parameter names that were changed by QAT process
"""
state_dict = None
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)):
recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME)
manager = ScheduledModifierManager.from_yaml(recipe)
modifiers = [m.__class__.__name__ for m in manager.modifiers]
is_qat_recipe = "QuantizationModifier" in modifiers
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
state_dict = torch.load(archive_file, map_location="cpu")
removed_keys = (
[key for key in state_dict if (key.endswith(".module.weight") or key.endswith(".module.bias"))]
if is_qat_recipe
else []
)
for key in removed_keys:
new_key = key.replace(".module", "")
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
return state_dict


def load_recipe(pretrained_model_name_or_path):
"""
Load recipe from the model directory
"""
recipe = None
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)):
recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME)
return recipe
1 change: 1 addition & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
MODEL_CARD_NAME = "modelcard.json"
RECIPE_NAME = "recipe.yaml"

SENTENCEPIECE_UNDERLINE = "▁"
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
Expand Down