Skip to content

Commit efdd6a5

Browse files
authored
Merge pull request #892 from modelscope/dev2-dzj
refine training framework
2 parents 8c13362 + 42ec7b0 commit efdd6a5

File tree

6 files changed

+105
-285
lines changed

6 files changed

+105
-285
lines changed

diffsynth/trainers/utils.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import imageio, os, torch, warnings, torchvision, argparse, json
2+
from ..utils import ModelConfig
3+
from ..models.utils import load_state_dict
24
from peft import LoraConfig, inject_adapter_in_model
35
from PIL import Image
46
import pandas as pd
@@ -424,7 +426,53 @@ def transfer_data_to_device(self, data, device):
424426
if isinstance(data[key], torch.Tensor):
425427
data[key] = data[key].to(device)
426428
return data
427-
429+
430+
431+
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
432+
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
433+
model_configs = []
434+
if model_paths is not None:
435+
model_paths = json.loads(model_paths)
436+
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
437+
if model_id_with_origin_paths is not None:
438+
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
439+
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
440+
return model_configs
441+
442+
443+
def switch_pipe_to_training_mode(
444+
self,
445+
pipe,
446+
trainable_models,
447+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
448+
enable_fp8_training=False,
449+
):
450+
# Scheduler
451+
pipe.scheduler.set_timesteps(1000, training=True)
452+
453+
# Freeze untrainable models
454+
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
455+
456+
# Enable FP8 if pipeline supports
457+
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
458+
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
459+
460+
# Add LoRA to the base models
461+
if lora_base_model is not None:
462+
model = self.add_lora_to_model(
463+
getattr(pipe, lora_base_model),
464+
target_modules=lora_target_modules.split(","),
465+
lora_rank=lora_rank,
466+
upcast_dtype=pipe.torch_dtype,
467+
)
468+
if lora_checkpoint is not None:
469+
state_dict = load_state_dict(lora_checkpoint)
470+
state_dict = self.mapping_lora_state_dict(state_dict)
471+
load_result = model.load_state_dict(state_dict, strict=False)
472+
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
473+
if len(load_result[1]) > 0:
474+
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
475+
setattr(pipe, lora_base_model, model)
428476

429477

430478
class ModelLogger:
@@ -472,14 +520,26 @@ def launch_training_task(
472520
dataset: torch.utils.data.Dataset,
473521
model: DiffusionTrainingModule,
474522
model_logger: ModelLogger,
475-
optimizer: torch.optim.Optimizer,
476-
scheduler: torch.optim.lr_scheduler.LRScheduler,
523+
learning_rate: float = 1e-5,
524+
weight_decay: float = 1e-2,
477525
num_workers: int = 8,
478526
save_steps: int = None,
479527
num_epochs: int = 1,
480528
gradient_accumulation_steps: int = 1,
481529
find_unused_parameters: bool = False,
530+
args = None,
482531
):
532+
if args is not None:
533+
learning_rate = args.learning_rate
534+
weight_decay = args.weight_decay
535+
num_workers = args.dataset_num_workers
536+
save_steps = args.save_steps
537+
num_epochs = args.num_epochs
538+
gradient_accumulation_steps = args.gradient_accumulation_steps
539+
find_unused_parameters = args.find_unused_parameters
540+
541+
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
542+
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
483543
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
484544
accelerator = Accelerator(
485545
gradient_accumulation_steps=gradient_accumulation_steps,
@@ -509,8 +569,12 @@ def launch_data_process_task(
509569
model: DiffusionTrainingModule,
510570
model_logger: ModelLogger,
511571
num_workers: int = 8,
572+
args = None,
512573
):
513-
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
574+
if args is not None:
575+
num_workers = args.dataset_num_workers
576+
577+
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
514578
accelerator = Accelerator()
515579
model, dataloader = accelerator.prepare(model, dataloader)
516580

@@ -520,7 +584,7 @@ def launch_data_process_task(
520584
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
521585
os.makedirs(folder, exist_ok=True)
522586
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
523-
data = model(data)
587+
data = model(data, return_inputs=True)
524588
torch.save(data, save_path)
525589

526590

@@ -623,4 +687,5 @@ def qwen_image_parser():
623687
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
624688
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
625689
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
690+
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
626691
return parser

examples/flux/model_training/train.py

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,16 @@ def __init__(
2020
):
2121
super().__init__()
2222
# Load models
23-
model_configs = []
24-
if model_paths is not None:
25-
model_paths = json.loads(model_paths)
26-
model_configs += [ModelConfig(path=path) for path in model_paths]
27-
if model_id_with_origin_paths is not None:
28-
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
29-
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
23+
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
3024
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
3125

32-
# Reset training scheduler
33-
self.pipe.scheduler.set_timesteps(1000, training=True)
34-
35-
# Freeze untrainable models
36-
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
26+
# Training mode
27+
self.switch_pipe_to_training_mode(
28+
self.pipe, trainable_models,
29+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
30+
enable_fp8_training=False,
31+
)
3732

38-
# Add LoRA to the base models
39-
if lora_base_model is not None:
40-
model = self.add_lora_to_model(
41-
getattr(self.pipe, lora_base_model),
42-
target_modules=lora_target_modules.split(","),
43-
lora_rank=lora_rank
44-
)
45-
if lora_checkpoint is not None:
46-
state_dict = load_state_dict(lora_checkpoint)
47-
state_dict = self.mapping_lora_state_dict(state_dict)
48-
load_result = model.load_state_dict(state_dict, strict=False)
49-
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
50-
if len(load_result[1]) > 0:
51-
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
52-
setattr(self.pipe, lora_base_model, model)
53-
5433
# Store other configs
5534
self.use_gradient_checkpointing = use_gradient_checkpointing
5635
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
@@ -138,13 +117,4 @@ def forward(self, data, inputs=None):
138117
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
139118
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
140119
)
141-
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
142-
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
143-
launch_training_task(
144-
dataset, model, model_logger, optimizer, scheduler,
145-
num_epochs=args.num_epochs,
146-
gradient_accumulation_steps=args.gradient_accumulation_steps,
147-
save_steps=args.save_steps,
148-
find_unused_parameters=args.find_unused_parameters,
149-
num_workers=args.dataset_num_workers,
150-
)
120+
launch_training_task(dataset, model, model_logger, args=args)

examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
accelerate launch examples/qwen_image/model_training/train_data_process.py \
1+
accelerate launch examples/qwen_image/model_training/train.py \
22
--dataset_base_path data/example_image_dataset \
33
--dataset_metadata_path data/example_image_dataset/metadata.csv \
44
--max_pixels 1048576 \
55
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
66
--output_path "./models/train/Qwen-Image_lora_cache" \
77
--use_gradient_checkpointing \
8-
--dataset_num_workers 8
8+
--dataset_num_workers 8 \
9+
--task data_process
910

1011
accelerate launch examples/qwen_image/model_training/train.py \
1112
--dataset_base_path models/train/Qwen-Image_lora_cache \

examples/qwen_image/model_training/train.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from diffsynth import load_state_dict
33
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
44
from diffsynth.pipelines.flux_image_new import ControlNetInput
5-
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser
5+
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
66
from diffsynth.trainers.unified_dataset import UnifiedDataset
77
os.environ["TOKENIZERS_PARALLELISM"] = "false"
88

@@ -22,46 +22,18 @@ def __init__(
2222
):
2323
super().__init__()
2424
# Load models
25-
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
26-
model_configs = []
27-
if model_paths is not None:
28-
model_paths = json.loads(model_paths)
29-
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
30-
if model_id_with_origin_paths is not None:
31-
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
32-
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
33-
25+
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
3426
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
3527
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
3628
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
29+
30+
# Training mode
31+
self.switch_pipe_to_training_mode(
32+
self.pipe, trainable_models,
33+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
34+
enable_fp8_training=enable_fp8_training,
35+
)
3736

38-
# Enable FP8
39-
if enable_fp8_training:
40-
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
41-
42-
# Reset training scheduler (do it in each training step)
43-
self.pipe.scheduler.set_timesteps(1000, training=True)
44-
45-
# Freeze untrainable models
46-
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
47-
48-
# Add LoRA to the base models
49-
if lora_base_model is not None:
50-
model = self.add_lora_to_model(
51-
getattr(self.pipe, lora_base_model),
52-
target_modules=lora_target_modules.split(","),
53-
lora_rank=lora_rank,
54-
upcast_dtype=self.pipe.torch_dtype,
55-
)
56-
if lora_checkpoint is not None:
57-
state_dict = load_state_dict(lora_checkpoint)
58-
state_dict = self.mapping_lora_state_dict(state_dict)
59-
load_result = model.load_state_dict(state_dict, strict=False)
60-
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
61-
if len(load_result[1]) > 0:
62-
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
63-
setattr(self.pipe, lora_base_model, model)
64-
6537
# Store other configs
6638
self.use_gradient_checkpointing = use_gradient_checkpointing
6739
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
@@ -109,9 +81,10 @@ def forward_preprocess(self, data):
10981
return {**inputs_shared, **inputs_posi}
11082

11183

112-
def forward(self, data, inputs=None):
84+
def forward(self, data, inputs=None, return_inputs=False):
11385
if inputs is None: inputs = self.forward_preprocess(data)
11486
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
87+
if return_inputs: return inputs
11588
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
11689
loss = self.pipe.training_loss(**models, **inputs)
11790
return loss
@@ -151,13 +124,8 @@ def forward(self, data, inputs=None):
151124
enable_fp8_training=args.enable_fp8_training,
152125
)
153126
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
154-
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
155-
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
156-
launch_training_task(
157-
dataset, model, model_logger, optimizer, scheduler,
158-
num_epochs=args.num_epochs,
159-
gradient_accumulation_steps=args.gradient_accumulation_steps,
160-
save_steps=args.save_steps,
161-
find_unused_parameters=args.find_unused_parameters,
162-
num_workers=args.dataset_num_workers,
163-
)
127+
launcher_map = {
128+
"sft": launch_training_task,
129+
"data_process": launch_data_process_task
130+
}
131+
launcher_map[args.task](dataset, model, model_logger, args=args)

0 commit comments

Comments
 (0)