Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions diffsynth/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,35 @@ def training_loss(self, **inputs):
return loss


def _enable_fp8_lora_training(self, dtype):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
from ..models.qwen_image_dit import RMSNorm
from ..models.qwen_image_vae import QwenImageRMS_norm
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
QwenImageRMS_norm: AutoWrappedModule,
}
model_config = dict(
offload_dtype=dtype,
offload_device="cuda",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device="cuda",
Comment on lines +170 to +175
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device is hardcoded to "cuda". It's better to use self.device to be consistent with the rest of the class and allow users to specify a different device (e.g., cuda:1) during pipeline initialization.

Suggested change
offload_dtype=dtype,
offload_device="cuda",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device="cuda",
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,

)
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)


def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
self.vram_management_enabled = True
if vram_limit is None:
Expand Down
7 changes: 6 additions & 1 deletion diffsynth/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,15 @@ def trainable_param_names(self):
return trainable_param_names


def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
if lora_alpha is None:
lora_alpha = lora_rank
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
if upcast_dtype is not None:
for param in model.parameters():
if param.requires_grad:
param.data = param.to(upcast_dtype)
return model


Expand Down Expand Up @@ -555,4 +559,5 @@ def qwen_image_parser():
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
return parser
14 changes: 11 additions & 3 deletions examples/qwen_image/model_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@ def __init__(
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
enable_fp8_training=False,
):
super().__init__()
# Load models
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path) for path in model_paths]
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
Comment on lines 26 to +31
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for parsing model_id_with_origin_paths is not robust. Using item.split(':', 1) is safer than item.split(':') as it prevents errors if the origin_file_pattern contains colons (e.g., in a Windows path). Additionally, explicitly checking that the split results in two parts before unpacking will prevent IndexError if an entry is malformed.

This refactoring improves robustness and readability.

Suggested change
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path) for path in model_paths]
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs.extend(ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths)
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for item in model_id_with_origin_paths:
parts = item.split(":", 1)
if len(parts) == 2:
model_configs.append(ModelConfig(model_id=parts[0], origin_file_pattern=parts[1], offload_dtype=offload_dtype))


tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)

# Enable FP8
if enable_fp8_training:
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)

# Reset training scheduler (do it in each training step)
self.pipe.scheduler.set_timesteps(1000, training=True)

Expand All @@ -43,7 +49,8 @@ def __init__(
model = self.add_lora_to_model(
getattr(self.pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
lora_rank=lora_rank,
upcast_dtype=self.pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
Expand Down Expand Up @@ -126,6 +133,7 @@ def forward(self, data, inputs=None):
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
enable_fp8_training=args.enable_fp8_training,
)
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
Expand Down