Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
def create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
pickle = not sparseml_wrapper.qat_active(math.inf if epoch <0 else epoch) # qat does not support pickled exports
ckpt_model = deepcopy(model.module if is_parallel(model) else model).float()
yaml = ckpt_model.yaml
Expand All @@ -445,7 +445,7 @@ def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
'yaml': yaml,
'hyp': model.hyp,
**ema.state_dict(pickle),
**sparseml_wrapper.state_dict(),
**sparseml_wrapper.state_dict(final_epoch),
**kwargs}

def load_checkpoint(
Expand All @@ -469,6 +469,10 @@ def load_checkpoint(
weights = attempt_download(weights) or check_download_sparsezoo_weights(weights)
ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple)
else weights, map_location="cpu") # load checkpoint

# temporary fix until SparseML and ZooModels are updated
ckpt['checkpoint_recipe'] = ckpt.get('recipe') or ckpt.get('checkpoint_recipe')

pickled = isinstance(ckpt['model'], nn.Module)
train_type = type_ == 'train'
ensemble_type = type_ == 'ensemble'
Expand Down Expand Up @@ -500,21 +504,22 @@ def load_checkpoint(
# load sparseml recipe for applying pruning and quantization
checkpoint_recipe = train_recipe = None
if resume:
train_recipe = ckpt.get('recipe')
elif recipe or ckpt.get('recipe'):
train_recipe, checkpoint_recipe = recipe, ckpt.get('recipe')
train_recipe, checkpoint_recipe = ckpt.get('train_recipe'), ckpt.get('checkpoint_recipe')
elif recipe or ckpt.get('checkpoint_recipe'):
train_recipe, checkpoint_recipe = recipe, ckpt.get('checkpoint_recipe')

sparseml_wrapper = SparseMLWrapper(
model.model if val_type else model,
checkpoint_recipe,
train_recipe,
train_mode=train_type,
epoch=ckpt['epoch'],
one_shot=one_shot,
steps_per_epoch=max_train_steps,
)
exclude_anchors = not ensemble_type and (cfg or hyp.get('anchors')) and not resume
loaded = False

sparseml_wrapper.apply_checkpoint_structure()
if train_type:
# intialize the recipe for training and restore the weights before if no quantized weights
quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()])
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model,
None,
opt.recipe,
train_mode=True,
steps_per_epoch=opt.max_train_steps,
one_shot=opt.one_shot,
)
Expand Down Expand Up @@ -314,7 +315,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
"date": datetime.now().isoformat(),
}
ckpt = create_checkpoint(
-1, model, optimizer, ema, sparseml_wrapper, **ckpt_extras
-1, True, model, optimizer, ema, sparseml_wrapper, **ckpt_extras
)
one_shot_checkpoint_name = w / "checkpoint-one-shot.pt"
torch.save(ckpt, one_shot_checkpoint_name)
Expand Down Expand Up @@ -486,7 +487,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
'best_fitness': best_fitness,
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
'date': datetime.now().isoformat()}
ckpt = create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras)
ckpt = create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras)

# Save last, best and delete
torch.save(ckpt, last)
Expand Down
69 changes: 48 additions & 21 deletions utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import SparsificationGroupLogger
from sparseml.pytorch.utils import GradSampler
from sparseml.pytorch.sparsification.quantization import QuantizationModifier
import torchvision.transforms.functional as F

from utils.torch_utils import is_parallel
Expand Down Expand Up @@ -51,7 +52,16 @@ def check_download_sparsezoo_weights(path):


class SparseMLWrapper(object):
def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, one_shot=False):
def __init__(
self,
model,
checkpoint_recipe,
train_recipe,
train_mode=False,
epoch=-1,
steps_per_epoch=-1,
one_shot=False,
):
self.enabled = bool(train_recipe)
self.model = model.module if is_parallel(model) else model
self.checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint_recipe) if checkpoint_recipe else None
Expand All @@ -62,21 +72,47 @@ def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, o
self.one_shot = one_shot
self.train_recipe = train_recipe

if self.one_shot:
self._apply_one_shot()

def state_dict(self):
manager = (ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)
if self.checkpoint_manager and self.enabled else self.manager)
self.apply_checkpoint_structure(train_mode, epoch, one_shot)

return {
'recipe': str(manager) if self.enabled else None,
}
def state_dict(self, final_epoch):
if self.enabled or self.checkpoint_manager:
compose_recipes = self.checkpoint_manager and self.enabled and final_epoch
return {
'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager))
if compose_recipes else str(self.checkpoint_manager),
'train_recipe': str(self.manager) if not final_epoch else None
}
else:
return {
'checkpoint_recipe': None,
'train_recipe': None
}

def apply_checkpoint_structure(self):
def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False):
if self.checkpoint_manager:
# if checkpoint recipe has a QAT modifier and this is a transfer learning
# run then remove the QAT modifier from the manager
if train_mode:
qat_idx = next((
idx for idx, mod in enumerate(self.checkpoint_manager.modifiers)
if isinstance(mod, QuantizationModifier)), -1
)
if qat_idx >= 0:
_ = self.checkpoint_manager.modifiers.pop(qat_idx)

self.checkpoint_manager.apply_structure(self.model, math.inf)

if train_mode and epoch > 0 and self.enabled:
self.manager.apply_structure(self.model, epoch)
elif one_shot:
if self.enabled:
self.manager.apply(self.model)
_LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner")
else:
_LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
f"{self.train_recipe}"
)

def initialize(
self,
start_epoch,
Expand Down Expand Up @@ -144,9 +180,9 @@ def check_lr_override(self, scheduler, rank):
def check_epoch_override(self, epochs, rank):
# Override num epochs if recipe explicitly modifies epoch range
if self.enabled and self.manager.epoch_modifiers and self.manager.max_epochs:
epochs = self.manager.max_epochs or epochs # override num_epochs
if rank in [0,-1]:
self.logger.info(f'Overriding number of epochs from SparseML manager to {epochs}')
epochs = self.manager.max_epochs + self.start_epoch or epochs # override num_epochs

return epochs

Expand Down Expand Up @@ -195,15 +231,6 @@ def dataloader():
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
yield [imgs], {}, targets
return dataloader

def _apply_one_shot(self):
if self.manager is not None:
self.manager.apply(self.model)
_LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner")
else:
_LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
f"{self.train_recipe}"
)

def save_sample_inputs_outputs(
self,
Expand Down