22import collections
33import math
44import os
5- from typing import Any
5+ from typing import Any , Optional
6+ import json
67
78import numpy
89import torch
1314from sparseml .pytorch .optim .optimizer import ScheduledOptimizer
1415from sparseml .pytorch .utils import ModuleExporter , logger
1516from trainer_qa import QuestionAnsweringTrainer
17+
18+ from transformers .file_utils import RECIPE_NAME , WEIGHTS_NAME
1619from transformers .modeling_outputs import QuestionAnsweringModelOutput
1720from transformers .models .bert .modeling_bert import BertForQuestionAnswering
1821
@@ -28,36 +31,63 @@ class SparseMLQATrainer(QuestionAnsweringTrainer):
2831 :param args, kwargs: arguments passed into parent class
2932 """
3033
31- def __init__ (self , recipe , teacher = None , distill_hardness = 0.5 , distill_temperature = 2.0 , * args , ** kwargs ):
34+ def __init__ (
35+ self , model_name_or_path , recipes , teacher = None , distill_hardness = 0.5 , distill_temperature = 2.0 , * args , ** kwargs
36+ ):
3237 super ().__init__ (* args , ** kwargs )
33- self .recipe = recipe
38+ self .model_name_or_path = str (model_name_or_path )
39+ self .recipes = [recipe for recipe in recipes if recipe ]
3440 self .teacher = teacher
3541 self .distill_hardness = distill_hardness
3642 self .distill_temperature = distill_temperature
3743 self .criterion = torch .nn .CrossEntropyLoss ()
3844
39- self .manager = None
45+ manager = None
46+ modifiers = []
47+ for recipe in self .recipes :
48+ manager = ScheduledModifierManager .from_yaml (recipe , modifiers )
49+ modifiers = manager .modifiers
50+ self .manager = manager
51+
4052 self .loggers = None
41- if self .recipe is not None :
53+ if self .recipes is not None :
4254 loggers = []
4355 if "wandb" in self .args .report_to :
4456 loggers .append (logger .WANDBLogger ())
4557 self .loggers = loggers
4658
59+ def apply_recipes (self , epoch = 0.0 ):
60+ """
61+ Apply recipes and sparsification related parameters to the model
62+ """
63+ if self .manager is not None :
64+ self .manager .initialize (self .model , epoch = epoch , loggers = self .loggers )
65+ if os .path .isdir (self .model_name_or_path ):
66+ if os .path .isfile (os .path .join (self .model_name_or_path , WEIGHTS_NAME )):
67+ archive_file = os .path .join (self .model_name_or_path , WEIGHTS_NAME )
68+ state_dict = torch .load (archive_file , map_location = "cpu" )
69+ _ , missing_keys , unexpected_keys , _ = BertForQuestionAnswering ._load_state_dict_into_model (
70+ self .model , state_dict , self .model_name_or_path , _fast_init = False
71+ )
72+ if missing_keys or unexpected_keys :
73+ raise RuntimeError (
74+ "Unexpected or missing keys detected when applying recipes to models\n "
75+ f"Missing keys: { missing_keys } \n "
76+ f"Unexpected keys: { unexpected_keys } \n "
77+ )
78+
4779 def create_optimizer (self ):
4880 """
4981 Create optimizer customized using SparseML
5082 """
5183 super ().create_optimizer ()
52- if self .recipe is None :
84+ if not self .recipes :
5385 return
5486 steps_per_epoch = math .ceil (
5587 len (self .train_dataset ) / (self .args .per_device_train_batch_size * self .args ._n_gpu )
5688 )
57- self .manager = ScheduledModifierManager .from_yaml (self .recipe )
5889 self .args .num_train_epochs = float (self .manager .max_epochs )
5990 if hasattr (self , "scaler" ):
60- self .manager .initialize (self .model , epoch = 0.0 , loggers = self .loggers )
6191 self .scaler = self .manager .modify (
6292 self .model , self .optimizer , steps_per_epoch = steps_per_epoch , wrap_optim = self .scaler
6393 )
@@ -70,7 +100,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
70100 """
71101 Computing loss using teacher/student distillation
72102 """
73- if self .recipe is None or self .teacher is None :
103+ if not self .recipes or self .teacher is None :
74104 return super ().compute_loss (model , inputs , return_outputs = return_outputs )
75105
76106 outputs = model (** inputs )
@@ -114,6 +144,22 @@ def compute_loss(self, model, inputs, return_outputs=False):
114144 loss = ((1 - self .distill_hardness ) * label_loss ) + (self .distill_hardness * teacher_loss )
115145 return (loss , outputs ) if return_outputs else loss
116146
147+ def save_model (self , output_dir : Optional [str ] = None ):
148+ """
149+ Save model during or after training. The sparsification recipe will also be saved.
150+ """
151+ super ().save_model (output_dir = output_dir )
152+ self ._save_recipe (output_dir = output_dir )
153+
154+ def _save_recipe (self , output_dir : Optional [str ] = None ):
155+ if output_dir is None :
156+ output_dir = self .args .output_dir
157+ output_dir = output_dir if output_dir is not None else self .args .output_dir
158+ os .makedirs (output_dir , exist_ok = True )
159+ output_recipe_file = os .path .join (output_dir , RECIPE_NAME )
160+ with open (output_recipe_file , "w" ) as fp :
161+ json .dump ({"recipe" : str (self .manager ) if self .manager is not None else None }, fp )
162+
117163
118164class QuestionAnsweringModuleExporter (ModuleExporter ):
119165 """
@@ -173,3 +219,43 @@ def export_model(model, dataloader, output_dir, num_exported_samples):
173219 num_samples += 1
174220 if num_samples >= num_exported_samples :
175221 return
222+
223+
224+ def preprocess_state_dict (pretrained_model_name_or_path ):
225+ """
226+ Restore original parameter names that were changed by QAT process
227+ """
228+ state_dict = None
229+ if pretrained_model_name_or_path is not None :
230+ pretrained_model_name_or_path = str (pretrained_model_name_or_path )
231+ if os .path .isdir (pretrained_model_name_or_path ):
232+ if os .path .isfile (os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )):
233+ archive_file = os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )
234+ state_dict = torch .load (archive_file , map_location = "cpu" )
235+ removed_keys = [
236+ key
237+ for key in state_dict
238+ if key .startswith ("bert.encoder.layer." )
239+ and (key .endswith (".module.weight" ) or key .endswith (".module.bias" ))
240+ ]
241+ for key in removed_keys :
242+ new_key = key .replace (".module" , "" )
243+ state_dict [new_key ] = state_dict [key ]
244+ state_dict .pop (key )
245+ return state_dict
246+
247+
248+ def load_recipe (pretrained_model_name_or_path ):
249+ """
250+ Load recipe from the model directory
251+ """
252+ recipe = None
253+ if pretrained_model_name_or_path is not None :
254+ pretrained_model_name_or_path = str (pretrained_model_name_or_path )
255+ if os .path .isdir (pretrained_model_name_or_path ):
256+ if os .path .isfile (os .path .join (pretrained_model_name_or_path , RECIPE_NAME )):
257+ with open (os .path .join (pretrained_model_name_or_path , RECIPE_NAME )) as fp :
258+ recipe = json .load (fp )
259+ recipe = recipe ["recipe" ]
260+ return recipe
261+
0 commit comments