Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions src/sparseml/transformers/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,46 +137,51 @@ def apply_recipes(self, epoch=0.0):
Applies all recipes from checkpoint_recipes. Runs architecture changing
modifiers to prepare model for state dict loading
"""
# get state dict before recipe application
org_state_dict = self.model.state_dict()

# apply any checkpoint recipes
for checkpoint_recipe in self.checkpoint_recipes:
if checkpoint_recipe is not None:
ScheduledModifierManager.from_yaml(checkpoint_recipe).apply(self.model)

# init current training recipe
if self.manager is not None:
org_state_dict = self.model.state_dict()
self.manager.initialize(
self.model,
epoch=epoch,
distillation_teacher=self.teacher,
loggers=self.loggers,
)
new_state_dict = self.model.state_dict()
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]

if os.path.isdir(self.model_name_or_path):
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
state_dict = torch.load(archive_file, map_location="cpu")
new_params_to_init = [
p for p in new_params if p in state_dict.keys()
]
if new_params_to_init:
# parameters from dict are dependent on recipe
(
_,
missing_keys,
unexpected_keys,
_,
) = self.model._load_state_dict_into_model(
self.model,
state_dict,
self.model_name_or_path,
_fast_init=False,

# if model structure changed, load in new params from state dict
new_state_dict = self.model.state_dict()
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]

if os.path.isdir(self.model_name_or_path):
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
state_dict = torch.load(archive_file, map_location="cpu")
new_params_to_init = [p for p in new_params if p in state_dict.keys()]
if new_params_to_init:
# parameters from dict are dependent on recipe
(
_,
missing_keys,
unexpected_keys,
_,
) = self.model._load_state_dict_into_model(
self.model,
state_dict,
self.model_name_or_path,
_fast_init=False,
)
if missing_keys or unexpected_keys:
raise RuntimeError(
"Unexpected or missing keys detected when applying "
f"recipes to models\nMissing keys: {missing_keys}\n"
f"Unexpected keys: {unexpected_keys}\n"
)
if missing_keys or unexpected_keys:
raise RuntimeError(
"Unexpected or missing keys detected when applying "
f"recipes to models\nMissing keys: {missing_keys}\n"
f"Unexpected keys: {unexpected_keys}\n"
)

def create_optimizer(self):
"""
Expand Down