-
Notifications
You must be signed in to change notification settings - Fork 1.2k
add wan2.2-S2V-14B training #1006
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ | ||
| --dataset_base_path data/example_video_dataset/wans2v \ | ||
| --dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \ | ||
| --data_file_keys "video,input_audio,s2v_pose_video" \ | ||
| --height 448 \ | ||
| --width 832 \ | ||
| --num_frames 81 \ | ||
| --dataset_repeat 100 \ | ||
| --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ | ||
| --audio_processor_config "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \ | ||
| --learning_rate 1e-5 \ | ||
| --num_epochs 1 \ | ||
| --trainable_models "dit" \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Wan2.2-S2V-14B_full" \ | ||
| --extra_inputs "input_image,input_audio,s2v_pose_video" \ | ||
| --use_gradient_checkpointing_offload | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ | ||
| --dataset_base_path data/example_video_dataset/wans2v \ | ||
| --dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \ | ||
| --data_file_keys "video,input_audio,s2v_pose_video" \ | ||
| --height 448 \ | ||
| --width 832 \ | ||
| --num_frames 81 \ | ||
| --dataset_repeat 100 \ | ||
| --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ | ||
| --audio_processor_config "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \ | ||
| --learning_rate 1e-4 \ | ||
| --num_epochs 5 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Wan2.2-S2V-14B_lora" \ | ||
| --lora_base_model "dit" \ | ||
| --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ | ||
| --lora_rank 32 \ | ||
| --extra_inputs "input_image,input_audio,s2v_pose_video" \ | ||
| --use_gradient_checkpointing_offload | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,15 +2,15 @@ | |||||||||||
| from diffsynth import load_state_dict | ||||||||||||
| from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig | ||||||||||||
| from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser | ||||||||||||
| from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, ImageCropAndResize, ToAbsolutePath | ||||||||||||
| from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath | ||||||||||||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class WanTrainingModule(DiffusionTrainingModule): | ||||||||||||
| def __init__( | ||||||||||||
| self, | ||||||||||||
| model_paths=None, model_id_with_origin_paths=None, | ||||||||||||
| model_paths=None, model_id_with_origin_paths=None, audio_processor_config=None, | ||||||||||||
| trainable_models=None, | ||||||||||||
| lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None, | ||||||||||||
| use_gradient_checkpointing=True, | ||||||||||||
|
|
@@ -22,7 +22,9 @@ def __init__( | |||||||||||
| super().__init__() | ||||||||||||
| # Load models | ||||||||||||
| model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) | ||||||||||||
| self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) | ||||||||||||
| if audio_processor_config is not None: | ||||||||||||
| audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1]) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code
Suggested change
|
||||||||||||
| self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, audio_processor_config=audio_processor_config) | ||||||||||||
|
|
||||||||||||
| # Training mode | ||||||||||||
| self.switch_pipe_to_training_mode( | ||||||||||||
|
|
@@ -109,12 +111,14 @@ def forward(self, data, inputs=None): | |||||||||||
| time_division_remainder=1, | ||||||||||||
| ), | ||||||||||||
| special_operator_map={ | ||||||||||||
| "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)) | ||||||||||||
| "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), | ||||||||||||
| "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000), | ||||||||||||
| } | ||||||||||||
| ) | ||||||||||||
| model = WanTrainingModule( | ||||||||||||
| model_paths=args.model_paths, | ||||||||||||
| model_id_with_origin_paths=args.model_id_with_origin_paths, | ||||||||||||
| audio_processor_config=args.audio_processor_config, | ||||||||||||
| trainable_models=args.trainable_models, | ||||||||||||
| lora_base_model=args.lora_base_model, | ||||||||||||
| lora_target_modules=args.lora_target_modules, | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better code organization and to avoid repeated import overhead, it's recommended to move the
import librosastatement to the top of the file. This makes dependencies explicit and easier to manage. Iflibrosais an optional dependency, you can wrap the top-level import in atry...except ImportErrorblock.