-
Notifications
You must be signed in to change notification settings - Fork 1.2k
support qwen-image fp8 lora training #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,21 +17,27 @@ def __init__( | |||||||||||||||||||||||||||||||||||
| use_gradient_checkpointing=True, | ||||||||||||||||||||||||||||||||||||
| use_gradient_checkpointing_offload=False, | ||||||||||||||||||||||||||||||||||||
| extra_inputs=None, | ||||||||||||||||||||||||||||||||||||
| enable_fp8_training=False, | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||
| # Load models | ||||||||||||||||||||||||||||||||||||
| offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None | ||||||||||||||||||||||||||||||||||||
| model_configs = [] | ||||||||||||||||||||||||||||||||||||
| if model_paths is not None: | ||||||||||||||||||||||||||||||||||||
| model_paths = json.loads(model_paths) | ||||||||||||||||||||||||||||||||||||
| model_configs += [ModelConfig(path=path) for path in model_paths] | ||||||||||||||||||||||||||||||||||||
| model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] | ||||||||||||||||||||||||||||||||||||
| if model_id_with_origin_paths is not None: | ||||||||||||||||||||||||||||||||||||
| model_id_with_origin_paths = model_id_with_origin_paths.split(",") | ||||||||||||||||||||||||||||||||||||
| model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] | ||||||||||||||||||||||||||||||||||||
| model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
26
to
+31
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for parsing This refactoring improves robustness and readability.
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) | ||||||||||||||||||||||||||||||||||||
| processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) | ||||||||||||||||||||||||||||||||||||
| self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Enable FP8 | ||||||||||||||||||||||||||||||||||||
| if enable_fp8_training: | ||||||||||||||||||||||||||||||||||||
| self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Reset training scheduler (do it in each training step) | ||||||||||||||||||||||||||||||||||||
| self.pipe.scheduler.set_timesteps(1000, training=True) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -43,7 +49,8 @@ def __init__( | |||||||||||||||||||||||||||||||||||
| model = self.add_lora_to_model( | ||||||||||||||||||||||||||||||||||||
| getattr(self.pipe, lora_base_model), | ||||||||||||||||||||||||||||||||||||
| target_modules=lora_target_modules.split(","), | ||||||||||||||||||||||||||||||||||||
| lora_rank=lora_rank | ||||||||||||||||||||||||||||||||||||
| lora_rank=lora_rank, | ||||||||||||||||||||||||||||||||||||
| upcast_dtype=self.pipe.torch_dtype, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| if lora_checkpoint is not None: | ||||||||||||||||||||||||||||||||||||
| state_dict = load_state_dict(lora_checkpoint) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -126,6 +133,7 @@ def forward(self, data, inputs=None): | |||||||||||||||||||||||||||||||||||
| use_gradient_checkpointing=args.use_gradient_checkpointing, | ||||||||||||||||||||||||||||||||||||
| use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, | ||||||||||||||||||||||||||||||||||||
| extra_inputs=args.extra_inputs, | ||||||||||||||||||||||||||||||||||||
| enable_fp8_training=args.enable_fp8_training, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) | ||||||||||||||||||||||||||||||||||||
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device is hardcoded to
"cuda". It's better to useself.deviceto be consistent with the rest of the class and allow users to specify a different device (e.g.,cuda:1) during pipeline initialization.