-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add conv_in_channels in def save_motion_modules #8411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
current implementation doesn't pass conv_in_channels to the MotionAdapter constructor. This would incorrectly result into "conv_in_channels": null in the saved config.json, disregarding the actual conv_in_channels.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @ernestchu sorry for the delay here. Yeah this is tricky since the We can't save We could add an arg to def save_motion_modules(
self,
save_directory: str,
is_main_process: bool = True,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
save_conv_in_channels=False, # Used to save conv_in weights of PIA-like models
**kwargs,
) -> None:
state_dict = self.state_dict()
# Extract all motion modules
motion_state_dict = {}
for k, v in state_dict.items():
if "motion_modules" in k:
motion_state_dict[k] = v
if save_conv_in_channels and ((k == "conv_in.weight") or (k == "conv_in.bias")):
motion_state_dict[k] = v
adapter = MotionAdapter(
block_out_channels=self.config["block_out_channels"],
motion_layers_per_block=self.config["layers_per_block"],
motion_norm_num_groups=self.config["norm_num_groups"],
motion_num_attention_heads=self.config["motion_num_attention_heads"],
motion_max_seq_length=self.config["motion_max_seq_length"],
use_motion_mid_block=self.config["use_motion_mid_block"],
conv_in_channels=self.config["in_channels"] if save_conv_in_channels else None,
)
adapter.load_state_dict(motion_state_dict)
adapter.save_pretrained(
save_directory=save_directory,
is_main_process=is_main_process,
safe_serialization=safe_serialization,
variant=variant,
push_to_hub=push_to_hub,
**kwargs,
) |
|
Yeah, I was referring to PIA. The code should generalize to models like it. Your proposal LGTM. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
hi @ernestchu |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
In save_motion_modules, current implementation doesn't pass
conv_in_channelsto the MotionAdapter constructor. This would incorrectly result into"conv_in_channels": nullin the savedconfig.json, disregarding the actualconv_in_channels. This commit fixes the issue.However, more actions may be required.
If a instance is constructed from
from_unet2dwith amotion_adapter, one can simply copymotion_adapter.configto theUNetMotionModelinstance (like this) and use the config to construct the MotionAdapter instance to be saved.If an instance is NOT constructed from
from_unet2d, there has to be a way to identify the correct config that describes the motion_adapter in the instance, but it is beyond my bandwidth for the moment. Maybe @DN6 can help?Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.