Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @mi804, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly expands the capabilities of the video generation system by integrating the Wan2.2-S2V model. The core enhancement is the ability to generate videos directly from audio input, with optional guidance from a pose video. This involves adding new deep learning models for audio feature extraction and video generation, updating the main processing pipeline to handle these new input types, and incorporating robust audio-video merging utilities. The changes aim to provide a more versatile and powerful tool for creating dynamic, audio-synchronized video content.
Highlights
- New Model Integration: This PR introduces comprehensive support for the Wan2.2-S2V (Speech-to-Video) model, enabling the generation of videos driven by audio input.
- Audio-Driven Video Generation: The pipeline now includes an audio encoder (
WanS2VAudioEncoder) based on Wav2Vec2 for extracting features from audio, and a specialized DiT model (WanS2VModel) that processes these audio features alongside visual inputs to generate video frames. - Audio-Video Synchronization: New functionalities for merging generated video with audio using FFmpeg have been added, ensuring synchronized output. This includes
merge_video_audioandsave_video_with_audiofunctions. - Pose-Guided Video Generation: The system now supports optional pose video input (
s2v_pose_video) to guide the generated video's motion, allowing for more controlled and specific video outputs based on a reference pose. - Performance and Resource Optimization: The pipeline's VRAM management has been enhanced to accommodate the new audio encoder and related modules, optimizing resource usage for efficient inference, including support for low-VRAM environments and Unified Sequence Parallel (USP).
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds support for the Wan2.2-S2V model, including inference, low VRAM usage, and USP. The changes are extensive, adding new models, pipeline units, and utility functions. I've identified a critical issue in the WanS2VModel's forward pass that could affect training with a batch size greater than one. Additionally, there are several medium-severity issues, including documentation inconsistencies, potential runtime errors due to unchecked external dependencies, and stylistic improvements for better code quality and maintainability. Overall, the PR is a significant contribution, and addressing these points will enhance its robustness and clarity.
| use_reentrant=False, | ||
| ) | ||
| else: | ||
| x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) |
There was a problem hiding this comment.
There seems to be an inconsistency in how pre_compute_freqs is handled between this forward method (likely for training) and the model_fn_wans2v inference function. Here, the entire pre_compute_freqs tensor (with a batch dimension) is passed to the block. However, in model_fn_wans2v, pre_compute_freqs[0] is used, suggesting the underlying attention mechanism expects freqs without a batch dimension. This will likely cause a shape mismatch error during training if the batch size is greater than 1. You should probably loop over the batch or adjust the apply_rotary_pos_emb to be batch-aware.
| padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] | ||
|
|
||
| if add_last_motion < 2 and self.drop_mode != "drop": | ||
| zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() |
There was a problem hiding this comment.
It's more Pythonic to use the built-in len() function to get the length of an object rather than calling the __len__() dunder method directly.
| zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() | |
| zero_end_frame = self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1].sum() |
|
|
||
| |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| | ||
| |-|-|-|-|-|-|-| | ||
| |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| |
There was a problem hiding this comment.
| frame.save(os.path.join(save_path, f"{i}.png")) | ||
|
|
||
|
|
||
| def merge_video_audio(video_path: str, audio_path: str): |
There was a problem hiding this comment.
This function relies on ffmpeg being installed and available in the system's PATH. To make this more robust, it's a good practice to check if ffmpeg is available before attempting to use it. You can use shutil.which('ffmpeg') for this check and raise a more informative error if it's not found. Also, consider using the logging module instead of print for error messages to allow for better log management.
| b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 | ||
|
|
||
| # split freqs | ||
| if type(freqs) is list: |
There was a problem hiding this comment.
Using type(...) is ... for type checking is generally discouraged in favor of isinstance(). isinstance() is more robust as it correctly handles inheritance. Please use isinstance() here and in other similar checks on lines 39 and 42.
| if type(freqs) is list: | |
| if isinstance(freqs, list): |
| if required_duration > total_frames / original_fps: | ||
| raise ValueError("required_duration must be less than video length") | ||
|
|
||
| if not fixed_start is None and fixed_start >= 0: |
There was a problem hiding this comment.
| frame_audio_embed = \ | ||
| torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ | ||
| else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) |
There was a problem hiding this comment.
The backslash for line continuation is unnecessary here as the statement is already enclosed in parentheses. This also applies to a similar case on line 178.
| frame_audio_embed = \ | |
| torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ | |
| else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) | |
| frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) |
|
|
||
| | Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | | ||
| |-|-|-|-|-|-|-| | ||
| |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| |
There was a problem hiding this comment.
|
|
||
| |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| | ||
| |-|-|-|-|-|-|-| | ||
| |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| |
There was a problem hiding this comment.
| negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" | ||
| input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) | ||
| # s2v audio input, recommend 16kHz sampling rate | ||
| audio_path = 'data/example_video_dataset/wans2v/sing.MP3' |
There was a problem hiding this comment.
Using an uppercase file extension .MP3 might cause FileNotFoundError on case-sensitive file systems (like Linux) if the actual file has a lowercase extension. It's better to use lowercase for consistency.
| audio_path = 'data/example_video_dataset/wans2v/sing.MP3' | |
| audio_path = 'data/example_video_dataset/wans2v/sing.mp3' |
Support Wan-S2V
support wans2v, including inference, lowvram and usp