Fix checkpoint overwrite bug in train_sana_sprint_diffusers #12448
+6
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #12444
This PR fixes a critical bug in the
train_sana_sprint_diffusers.py
script where thesave_model_hook
incorrectly saves both the trained transformer and the frozenpretrained_model
to the same checkpoint location, causing the frozen model to overwrite the trained weights.The Problem
The original
isinstance()
check was too broad:python
if isinstance(unwrapped_model, type(unwrap_model(transformer))):
model.save_pretrained(os.path.join(output_dir, "transformer"))
Both transformer (trained, with guidance=True) and pretrained_model (frozen reference, with guidance=False) are instances of SanaTrigFlow, so this condition matched both models. The result:
First iteration: Saves trained model ✓
Second iteration: Overwrites with frozen model ✗
Users lose all their training progress during checkpointing.
The Solution
Added a discriminator to distinguish between the models:
pythonif isinstance(unwrapped_model, type(unwrap_model(transformer))) and getattr(unwrapped_model, 'guidance', False):
model.save_pretrained(os.path.join(output_dir, "transformer"))
The trained model has guidance=True, while the frozen reference has guidance=False. Now only the trained model is saved.
Also added an explicit skip clause for the frozen model to prevent it from triggering the error handler:
pythonelif isinstance(unwrapped_model, type(unwrap_model(transformer))):
pass # Skip frozen pretrained_model
Testing
Created a test script (test.py) that verifies:
✓ Trained transformer is saved correctly
✓ Frozen pretrained_model is skipped
✓ Discriminator heads are saved correctly
Test results:
ALL CHECKS PASSED! The fix is working correctly!