add wan2.2-S2V-14B training#1006
Conversation
Summary of ChangesHello @lzws, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates comprehensive training support for the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds training support for the wan2.2-S2V-14B model, including new training scripts, validation scripts, and necessary code modifications to handle audio data. The changes are well-structured, covering documentation, code, and examples. I've provided a few suggestions to improve code robustness, style, and portability.
| def __init__(self, sr=16000): | ||
| self.sr = sr | ||
| def __call__(self, data: str): | ||
| import librosa |
There was a problem hiding this comment.
For better code organization and to avoid repeated import overhead, it's recommended to move the import librosa statement to the top of the file. This makes dependencies explicit and easier to manage. If librosa is an optional dependency, you can wrap the top-level import in a try...except ImportError block.
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Wan2.2-S2V-14B_full" \ | ||
| --extra_inputs "input_image,input_audio,s2v_pose_video" \ | ||
| --use_gradient_checkpointing_offload No newline at end of file |
| --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ | ||
| --lora_rank 32 \ | ||
| --extra_inputs "input_image,input_audio,s2v_pose_video" \ | ||
| --use_gradient_checkpointing_offload No newline at end of file |
| model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) | ||
| self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) | ||
| if audio_processor_config is not None: | ||
| audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1]) |
There was a problem hiding this comment.
The code audio_processor_config.split(":") assumes that the audio_processor_config string will always contain a colon. If it doesn't, this will raise an IndexError at runtime. It would be more robust to validate the format before splitting to prevent unexpected crashes.
| audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1]) | |
| if ":" not in audio_processor_config: | |
| raise ValueError(f"Invalid audio_processor_config format: {audio_processor_config}") | |
| model_id, origin_file_pattern = audio_processor_config.split(":", 1) | |
| audio_processor_config = ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern) |
| width = 832 | ||
|
|
||
| prompt = "a person is singing" | ||
| negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" |
There was a problem hiding this comment.
The negative prompt contains a duplicated phrase '画面模糊'. Removing the duplication would make it cleaner.
| negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" | |
| negative_prompt = "最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" |
|
|
||
| pipe = WanVideoPipeline.from_pretrained( | ||
| torch_dtype=torch.bfloat16, | ||
| device="cuda:0", |
There was a problem hiding this comment.
| width = 832 | ||
|
|
||
| prompt = "a person is singing" | ||
| negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" |
There was a problem hiding this comment.
The negative prompt contains a duplicated phrase '画面模糊'. Removing the duplication would make it cleaner.
| negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" | |
| negative_prompt = "最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" |
add wan2.2-S2V-14B training
add wan2.2-S2V-14B training