From e14a17cf58dcfd37e2773569055f9fa1ccac5f69 Mon Sep 17 00:00:00 2001 From: Aishwarya0811 Date: Tue, 7 Oct 2025 22:41:30 +0500 Subject: [PATCH] Fix checkpoint overwrite bug in train_sana_sprint_diffusers --- .../research_projects/sana/train_sana_sprint_diffusers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.py b/examples/research_projects/sana/train_sana_sprint_diffusers.py index d127fee5fd0d..2890ce19184c 100644 --- a/examples/research_projects/sana/train_sana_sprint_diffusers.py +++ b/examples/research_projects/sana/train_sana_sprint_diffusers.py @@ -1145,6 +1145,11 @@ def save_model_hook(models, weights, output_dir): elif isinstance(unwrapped_model, type(unwrap_model(disc))): # Save only the heads torch.save(unwrapped_model.heads.state_dict(), os.path.join(output_dir, "disc_heads.pt")) + + # Skip frozen pretrained_model + elif isinstance(unwrapped_model, type(unwrap_model(transformer))): + pass + else: raise ValueError(f"unexpected save model: {unwrapped_model.__class__}") @@ -1161,7 +1166,7 @@ def load_model_hook(models, input_dir): model = models.pop() unwrapped_model = unwrap_model(model) - if isinstance(unwrapped_model, type(unwrap_model(transformer))): + if isinstance(unwrapped_model, type(unwrap_model(transformer))) and getattr(unwrapped_model, 'guidance', False): transformer_ = model # noqa: F841 elif isinstance(unwrapped_model, type(unwrap_model(disc))): # Load only the heads