Skip to content

Conversation

Aishwarya0811
Copy link
Contributor

@Aishwarya0811 Aishwarya0811 commented Oct 7, 2025

What does this PR do?

Fixes #12444

This PR fixes a critical bug in the train_sana_sprint_diffusers.py script where the save_model_hook incorrectly saves both the trained transformer and the frozen pretrained_model to the same checkpoint location, causing the frozen model to overwrite the trained weights.

The Problem

The original isinstance() check was too broad:
python
if isinstance(unwrapped_model, type(unwrap_model(transformer))):
model.save_pretrained(os.path.join(output_dir, "transformer"))
Both transformer (trained, with guidance=True) and pretrained_model (frozen reference, with guidance=False) are instances of SanaTrigFlow, so this condition matched both models. The result:

First iteration: Saves trained model ✓
Second iteration: Overwrites with frozen model ✗

Users lose all their training progress during checkpointing.

The Solution

Added a discriminator to distinguish between the models:
pythonif isinstance(unwrapped_model, type(unwrap_model(transformer))) and getattr(unwrapped_model, 'guidance', False):
model.save_pretrained(os.path.join(output_dir, "transformer"))
The trained model has guidance=True, while the frozen reference has guidance=False. Now only the trained model is saved.
Also added an explicit skip clause for the frozen model to prevent it from triggering the error handler:
pythonelif isinstance(unwrapped_model, type(unwrap_model(transformer))):
pass # Skip frozen pretrained_model
Testing
Created a test script (test.py) that verifies:

✓ Trained transformer is saved correctly
✓ Frozen pretrained_model is skipped
✓ Discriminator heads are saved correctly

Test results:
ALL CHECKS PASSED! The fix is working correctly!

  • Trained transformer: SAVED
  • Frozen model: SKIPPED
  • Discriminator heads: SAVED

@Aishwarya0811 Aishwarya0811 force-pushed the fix-sana-checkpoint-overwrite branch from 2ba5169 to e14a17c Compare October 7, 2025 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[train_sana_sprint] Custom save_hook overwrites trained transformer with frozen model during checkpointing
1 participant