diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.py b/examples/research_projects/sana/train_sana_sprint_diffusers.py index d127fee5fd0d..2890ce19184c 100644 --- a/examples/research_projects/sana/train_sana_sprint_diffusers.py +++ b/examples/research_projects/sana/train_sana_sprint_diffusers.py @@ -1145,6 +1145,11 @@ def save_model_hook(models, weights, output_dir): elif isinstance(unwrapped_model, type(unwrap_model(disc))): # Save only the heads torch.save(unwrapped_model.heads.state_dict(), os.path.join(output_dir, "disc_heads.pt")) + + # Skip frozen pretrained_model + elif isinstance(unwrapped_model, type(unwrap_model(transformer))): + pass + else: raise ValueError(f"unexpected save model: {unwrapped_model.__class__}") @@ -1161,7 +1166,7 @@ def load_model_hook(models, input_dir): model = models.pop() unwrapped_model = unwrap_model(model) - if isinstance(unwrapped_model, type(unwrap_model(transformer))): + if isinstance(unwrapped_model, type(unwrap_model(transformer))) and getattr(unwrapped_model, 'guidance', False): transformer_ = model # noqa: F841 elif isinstance(unwrapped_model, type(unwrap_model(disc))): # Load only the heads