Skip to content

Conversation

@ernestchu
Copy link
Contributor

@ernestchu ernestchu commented Jun 5, 2024

What does this PR do?

In save_motion_modules, current implementation doesn't pass conv_in_channels to the MotionAdapter constructor. This would incorrectly result into "conv_in_channels": null in the saved config.json, disregarding the actual conv_in_channels. This commit fixes the issue.

However, more actions may be required.

If a instance is constructed from from_unet2d with a motion_adapter, one can simply copy motion_adapter.config to the UNetMotionModel instance (like this) and use the config to construct the MotionAdapter instance to be saved.

If an instance is NOT constructed from from_unet2d, there has to be a way to identify the correct config that describes the motion_adapter in the instance, but it is beyond my bandwidth for the moment. Maybe @DN6 can help?

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

ernestchu added 2 commits June 6, 2024 00:26
current implementation doesn't pass conv_in_channels to the MotionAdapter constructor. This would incorrectly result into "conv_in_channels": null in the saved config.json, disregarding the actual conv_in_channels.
@sayakpaul sayakpaul requested a review from DN6 June 6, 2024 06:47
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Jul 10, 2024

Hi @ernestchu sorry for the delay here. Yeah this is tricky since the conv_in channels for models like PIA are loaded directly into the UNet. I assume you're working with a model like that?

We can't save conv_in_channels by default because the conv_in weights of the UNet are not saved with the MotionModules, which is why we have that failing test.

We could add an arg to save_motion_modules such as save_conv_in_channels and save the config and weights appropriately if it is set to True. WDYT?

    def save_motion_modules(
        self,
        save_directory: str,
        is_main_process: bool = True,
        safe_serialization: bool = True,
        variant: Optional[str] = None,
        push_to_hub: bool = False,
        save_conv_in_channels=False,  # Used to save conv_in weights of PIA-like models
        **kwargs,
    ) -> None:
        state_dict = self.state_dict()

        # Extract all motion modules
        motion_state_dict = {}
        for k, v in state_dict.items():
            if "motion_modules" in k:
                motion_state_dict[k] = v
            if save_conv_in_channels and ((k == "conv_in.weight") or (k == "conv_in.bias")):
                motion_state_dict[k] = v

        adapter = MotionAdapter(
            block_out_channels=self.config["block_out_channels"],
            motion_layers_per_block=self.config["layers_per_block"],
            motion_norm_num_groups=self.config["norm_num_groups"],
            motion_num_attention_heads=self.config["motion_num_attention_heads"],
            motion_max_seq_length=self.config["motion_max_seq_length"],
            use_motion_mid_block=self.config["use_motion_mid_block"],
            conv_in_channels=self.config["in_channels"] if save_conv_in_channels else None,
        )
        adapter.load_state_dict(motion_state_dict)
        adapter.save_pretrained(
            save_directory=save_directory,
            is_main_process=is_main_process,
            safe_serialization=safe_serialization,
            variant=variant,
            push_to_hub=push_to_hub,
            **kwargs,
        )

@ernestchu
Copy link
Contributor Author

Yeah, I was referring to PIA. The code should generalize to models like it. Your proposal LGTM.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 3, 2024

hi @ernestchu
let us know if you'll have time to finish this PR:)

@github-actions github-actions bot removed the stale Issues that haven't received updates label Dec 3, 2024
@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants