Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 37a75cd

Browse files
committed
Load and save SparseML QAT recipes
1 parent b7df172 commit 37a75cd

File tree

3 files changed

+120
-11
lines changed

3 files changed

+120
-11
lines changed

examples/pytorch/question-answering/run_qa.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from datasets import load_dataset, load_metric
2929

3030
import transformers
31-
from sparseml_utils import SparseMLQATrainer, export_model
31+
from sparseml_utils import (
32+
SparseMLQATrainer,
33+
export_model,
34+
preprocess_state_dict,
35+
load_recipe
36+
)
3237
from transformers import (
3338
AutoConfig,
3439
AutoModelForQuestionAnswering,
@@ -311,13 +316,20 @@ def main():
311316
revision=model_args.model_revision,
312317
use_auth_token=True if model_args.use_auth_token else None,
313318
)
319+
320+
# Load and preprocess the state dict if the model existed (in this case we continue to train or
321+
# evaluate the model). The preprocessing step is to restore names of parameters changed by
322+
# QAT process.
323+
state_dict = preprocess_state_dict(model_args.model_name_or_path)
324+
314325
model = AutoModelForQuestionAnswering.from_pretrained(
315326
model_args.model_name_or_path,
316327
from_tf=bool(".ckpt" in model_args.model_name_or_path),
317328
config=config,
318329
cache_dir=model_args.cache_dir,
319330
revision=model_args.model_revision,
320331
use_auth_token=True if model_args.use_auth_token else None,
332+
state_dict=state_dict
321333
)
322334

323335
teacher_model = None
@@ -573,9 +585,14 @@ def post_processing_function(examples, features, predictions, stage="eval"):
573585
def compute_metrics(p: EvalPrediction):
574586
return metric.compute(predictions=p.predictions, references=p.label_ids)
575587

588+
# Load possible existing recipe and new one passed in through command argument
589+
existing_recipe = load_recipe(model_args.model_name_or_path)
590+
new_recipe = data_args.recipe
591+
576592
# Initialize our Trainer
577593
trainer = SparseMLQATrainer(
578-
data_args.recipe,
594+
model_args.model_name_or_path,
595+
[existing_recipe, new_recipe],
579596
teacher=teacher_model,
580597
distill_hardness=model_args.distill_hardness,
581598
distill_temperature=model_args.distill_temperature,
@@ -590,6 +607,11 @@ def compute_metrics(p: EvalPrediction):
590607
compute_metrics=compute_metrics,
591608
)
592609

610+
# Apply recipes to the model. This is necessary given that
611+
# sparsification methods such as QAT modified the model graph with their own learnable
612+
# parameters. They are also restored/loaded to the model.
613+
trainer.apply_recipes()
614+
593615
# Training
594616
if training_args.do_train:
595617
checkpoint = None

examples/pytorch/question-answering/sparseml_utils.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import collections
33
import math
44
import os
5-
from typing import Any
5+
from typing import Any, Optional
6+
import json
67

78
import numpy
89
import torch
@@ -13,6 +14,8 @@
1314
from sparseml.pytorch.optim.optimizer import ScheduledOptimizer
1415
from sparseml.pytorch.utils import ModuleExporter, logger
1516
from trainer_qa import QuestionAnsweringTrainer
17+
18+
from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME
1619
from transformers.modeling_outputs import QuestionAnsweringModelOutput
1720
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
1821

@@ -28,36 +31,63 @@ class SparseMLQATrainer(QuestionAnsweringTrainer):
2831
:param args, kwargs: arguments passed into parent class
2932
"""
3033

31-
def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs):
34+
def __init__(
35+
self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs
36+
):
3237
super().__init__(*args, **kwargs)
33-
self.recipe = recipe
38+
self.model_name_or_path = str(model_name_or_path)
39+
self.recipes = [recipe for recipe in recipes if recipe]
3440
self.teacher = teacher
3541
self.distill_hardness = distill_hardness
3642
self.distill_temperature = distill_temperature
3743
self.criterion = torch.nn.CrossEntropyLoss()
3844

39-
self.manager = None
45+
manager = None
46+
modifiers = []
47+
for recipe in self.recipes:
48+
manager = ScheduledModifierManager.from_yaml(recipe, modifiers)
49+
modifiers = manager.modifiers
50+
self.manager = manager
51+
4052
self.loggers = None
41-
if self.recipe is not None:
53+
if self.recipes is not None:
4254
loggers = []
4355
if "wandb" in self.args.report_to:
4456
loggers.append(logger.WANDBLogger())
4557
self.loggers = loggers
4658

59+
def apply_recipes(self, epoch=0.0):
60+
"""
61+
Apply recipes and sparsification related parameters to the model
62+
"""
63+
if self.manager is not None:
64+
self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers)
65+
if os.path.isdir(self.model_name_or_path):
66+
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
67+
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
68+
state_dict = torch.load(archive_file, map_location="cpu")
69+
_, missing_keys, unexpected_keys, _ = BertForQuestionAnswering._load_state_dict_into_model(
70+
self.model, state_dict, self.model_name_or_path, _fast_init=False
71+
)
72+
if missing_keys or unexpected_keys:
73+
raise RuntimeError(
74+
"Unexpected or missing keys detected when applying recipes to models\n"
75+
f"Missing keys: {missing_keys}\n"
76+
f"Unexpected keys: {unexpected_keys}\n"
77+
)
78+
4779
def create_optimizer(self):
4880
"""
4981
Create optimizer customized using SparseML
5082
"""
5183
super().create_optimizer()
52-
if self.recipe is None:
84+
if not self.recipes:
5385
return
5486
steps_per_epoch = math.ceil(
5587
len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu)
5688
)
57-
self.manager = ScheduledModifierManager.from_yaml(self.recipe)
5889
self.args.num_train_epochs = float(self.manager.max_epochs)
5990
if hasattr(self, "scaler"):
60-
self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers)
6191
self.scaler = self.manager.modify(
6292
self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler
6393
)
@@ -70,7 +100,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
70100
"""
71101
Computing loss using teacher/student distillation
72102
"""
73-
if self.recipe is None or self.teacher is None:
103+
if not self.recipes or self.teacher is None:
74104
return super().compute_loss(model, inputs, return_outputs=return_outputs)
75105

76106
outputs = model(**inputs)
@@ -114,6 +144,22 @@ def compute_loss(self, model, inputs, return_outputs=False):
114144
loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
115145
return (loss, outputs) if return_outputs else loss
116146

147+
def save_model(self, output_dir: Optional[str] = None):
148+
"""
149+
Save model during or after training. The sparsification recipe will also be saved.
150+
"""
151+
super().save_model(output_dir=output_dir)
152+
self._save_recipe(output_dir=output_dir)
153+
154+
def _save_recipe(self, output_dir: Optional[str] = None):
155+
if output_dir is None:
156+
output_dir = self.args.output_dir
157+
output_dir = output_dir if output_dir is not None else self.args.output_dir
158+
os.makedirs(output_dir, exist_ok=True)
159+
output_recipe_file = os.path.join(output_dir, RECIPE_NAME)
160+
with open(output_recipe_file, "w") as fp:
161+
json.dump({"recipe": str(self.manager) if self.manager is not None else None}, fp)
162+
117163

118164
class QuestionAnsweringModuleExporter(ModuleExporter):
119165
"""
@@ -173,3 +219,43 @@ def export_model(model, dataloader, output_dir, num_exported_samples):
173219
num_samples += 1
174220
if num_samples >= num_exported_samples:
175221
return
222+
223+
224+
def preprocess_state_dict(pretrained_model_name_or_path):
225+
"""
226+
Restore original parameter names that were changed by QAT process
227+
"""
228+
state_dict = None
229+
if pretrained_model_name_or_path is not None:
230+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
231+
if os.path.isdir(pretrained_model_name_or_path):
232+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
233+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
234+
state_dict = torch.load(archive_file, map_location="cpu")
235+
removed_keys = [
236+
key
237+
for key in state_dict
238+
if key.startswith("bert.encoder.layer.")
239+
and (key.endswith(".module.weight") or key.endswith(".module.bias"))
240+
]
241+
for key in removed_keys:
242+
new_key = key.replace(".module", "")
243+
state_dict[new_key] = state_dict[key]
244+
state_dict.pop(key)
245+
return state_dict
246+
247+
248+
def load_recipe(pretrained_model_name_or_path):
249+
"""
250+
Load recipe from the model directory
251+
"""
252+
recipe = None
253+
if pretrained_model_name_or_path is not None:
254+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
255+
if os.path.isdir(pretrained_model_name_or_path):
256+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)):
257+
with open(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)) as fp:
258+
recipe = json.load(fp)
259+
recipe = recipe["recipe"]
260+
return recipe
261+

src/transformers/file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@
220220
CONFIG_NAME = "config.json"
221221
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
222222
MODEL_CARD_NAME = "modelcard.json"
223+
RECIPE_NAME = "recipe.json"
223224

224225
SENTENCEPIECE_UNDERLINE = "▁"
225226
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility

0 commit comments

Comments
 (0)