Skip to content

[train_sana_sprint] Custom save_hook overwrites trained transformer with frozen model during checkpointing #12444

@riteshrm

Description

@riteshrm

Describe the bug

In the train_sana_sprint_diffusers.py example script, the custom save_model_hook incorrectly saves the wrong model. The isinstance() check is too broad because both the trained transformer and the frozen pretrained_model are instances of the same class. This causes the hook to save the trained model and then immediately overwrite it with the frozen, untrained reference model in the same checkpoint directory.

Reproduction

# In train_sana_sprint_diffusers.py, this save hook causes the issue:
def save_model_hook(models, weights, output_dir):
    if accelerator.is_main_process:
        for model in models:
            unwrapped_model = unwrap_model(model)
            # This check is too broad and matches both the trained and frozen models
            if isinstance(unwrapped_model, type(unwrap_model(transformer))):
                model = unwrapped_model
                model.save_pretrained(os.path.join(output_dir, "transformer"))
            # ... rest of the function ...

Logs

System Info

  • OS: Windows 11
  • Python: 3.10.18
  • diffusers version: 0.35.1
  • transformers version: 4.57.0
  • torch version: 2.8.0
  • accelerate version: 1.10.1
  • huggingface-hub version: 0.35.3
  • safetensors version: 0.6.2

Who can help?

Suggested Fix

def save_model_hook(models, weights, output_dir):
    if accelerator.is_main_process:
        for model in models:
            unwrapped_model = unwrap_model(model)
            # Handle transformer model
            if isinstance(unwrapped_model, type(unwrap_model(transformer))):
                model = unwrapped_model
                if model.config.guidance_embeds:
                    model.save_pretrained(os.path.join(output_dir, "transformer"))
            # Handle discriminator model (only save heads)
            elif isinstance(unwrapped_model, type(unwrap_model(disc))):
                # Save only the heads
                torch.save(unwrapped_model.heads.state_dict(), os.path.join(output_dir, "disc_heads.pt"))
            else:
                raise ValueError(f"unexpected save model: {unwrapped_model.__class__}")

            # make sure to pop weight so that corresponding model is not saved again
            if weights:
                weights.pop()

def load_model_hook(models, input_dir):
    transformer_ = None
    disc_ = None

    if not accelerator.distributed_type == DistributedType.DEEPSPEED:
        while len(models) > 0:
            model = models.pop()
            unwrapped_model = unwrap_model(model)

            if isinstance(unwrapped_model, type(unwrap_model(transformer))):
                if unwrapped_model.config.guidance_embeds:
                    transformer_ = model  # noqa: F841
            elif isinstance(unwrapped_model, type(unwrap_model(disc))):
                # Load only the heads
                heads_state_dict = torch.load(os.path.join(input_dir, "disc_heads.pt"))
                unwrapped_model.heads.load_state_dict(heads_state_dict)
                disc_ = model  # noqa: F841
            else:
                raise ValueError(f"unexpected save model: {unwrapped_model.__class__}")

    else:
        # DeepSpeed case
        transformer_ = SanaTransformer2DModel.from_pretrained(input_dir, subfolder="transformer")  # noqa: F841
        disc_heads_state_dict = torch.load(os.path.join(input_dir, "disc_heads.pt"))  # noqa: F841
        # You'll need to handle how to load the heads in DeepSpeed case


Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions