From 09c2c8910cfc2ebf38a532f8d02ee6c4fd3f7c92 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Wed, 29 Jun 2022 13:46:26 -0400 Subject: [PATCH] Fix checkpoint and train recipe loading --- utils/sparse.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/utils/sparse.py b/utils/sparse.py index b03d44af1d41..d5cfb77b67e5 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -75,18 +75,20 @@ def __init__( self.apply_checkpoint_structure(train_mode, epoch, one_shot) def state_dict(self, final_epoch): - if self.enabled or self.checkpoint_manager: - compose_recipes = self.checkpoint_manager and self.enabled and final_epoch - return { - 'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)) - if compose_recipes else str(self.checkpoint_manager), - 'train_recipe': str(self.manager) if not final_epoch else None - } + if self.enabled and final_epoch: + checkpoint_recipe = ( + str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)) + if self.checkpoint_manager else str(self.manager) + ) + train_recipe = None else: - return { - 'checkpoint_recipe': None, - 'train_recipe': None - } + checkpoint_recipe = str(self.checkpoint_manager) if self.checkpoint_manager else None + train_recipe = str(self.manager) if self.manager else None + + return { + 'checkpoint_recipe': checkpoint_recipe, + 'train_recipe': train_recipe + } def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False): if self.checkpoint_manager: