1
1
import imageio , os , torch , warnings , torchvision , argparse , json
2
+ from ..utils import ModelConfig
3
+ from ..models .utils import load_state_dict
2
4
from peft import LoraConfig , inject_adapter_in_model
3
5
from PIL import Image
4
6
import pandas as pd
@@ -424,7 +426,53 @@ def transfer_data_to_device(self, data, device):
424
426
if isinstance (data [key ], torch .Tensor ):
425
427
data [key ] = data [key ].to (device )
426
428
return data
427
-
429
+
430
+
431
+ def parse_model_configs (self , model_paths , model_id_with_origin_paths , enable_fp8_training = False ):
432
+ offload_dtype = torch .float8_e4m3fn if enable_fp8_training else None
433
+ model_configs = []
434
+ if model_paths is not None :
435
+ model_paths = json .loads (model_paths )
436
+ model_configs += [ModelConfig (path = path , offload_dtype = offload_dtype ) for path in model_paths ]
437
+ if model_id_with_origin_paths is not None :
438
+ model_id_with_origin_paths = model_id_with_origin_paths .split ("," )
439
+ model_configs += [ModelConfig (model_id = i .split (":" )[0 ], origin_file_pattern = i .split (":" )[1 ], offload_dtype = offload_dtype ) for i in model_id_with_origin_paths ]
440
+ return model_configs
441
+
442
+
443
+ def switch_pipe_to_training_mode (
444
+ self ,
445
+ pipe ,
446
+ trainable_models ,
447
+ lora_base_model , lora_target_modules , lora_rank , lora_checkpoint = None ,
448
+ enable_fp8_training = False ,
449
+ ):
450
+ # Scheduler
451
+ pipe .scheduler .set_timesteps (1000 , training = True )
452
+
453
+ # Freeze untrainable models
454
+ pipe .freeze_except ([] if trainable_models is None else trainable_models .split ("," ))
455
+
456
+ # Enable FP8 if pipeline supports
457
+ if enable_fp8_training and hasattr (pipe , "_enable_fp8_lora_training" ):
458
+ pipe ._enable_fp8_lora_training (torch .float8_e4m3fn )
459
+
460
+ # Add LoRA to the base models
461
+ if lora_base_model is not None :
462
+ model = self .add_lora_to_model (
463
+ getattr (pipe , lora_base_model ),
464
+ target_modules = lora_target_modules .split ("," ),
465
+ lora_rank = lora_rank ,
466
+ upcast_dtype = pipe .torch_dtype ,
467
+ )
468
+ if lora_checkpoint is not None :
469
+ state_dict = load_state_dict (lora_checkpoint )
470
+ state_dict = self .mapping_lora_state_dict (state_dict )
471
+ load_result = model .load_state_dict (state_dict , strict = False )
472
+ print (f"LoRA checkpoint loaded: { lora_checkpoint } , total { len (state_dict )} keys" )
473
+ if len (load_result [1 ]) > 0 :
474
+ print (f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: { load_result [1 ]} " )
475
+ setattr (pipe , lora_base_model , model )
428
476
429
477
430
478
class ModelLogger :
@@ -472,14 +520,26 @@ def launch_training_task(
472
520
dataset : torch .utils .data .Dataset ,
473
521
model : DiffusionTrainingModule ,
474
522
model_logger : ModelLogger ,
475
- optimizer : torch . optim . Optimizer ,
476
- scheduler : torch . optim . lr_scheduler . LRScheduler ,
523
+ learning_rate : float = 1e-5 ,
524
+ weight_decay : float = 1e-2 ,
477
525
num_workers : int = 8 ,
478
526
save_steps : int = None ,
479
527
num_epochs : int = 1 ,
480
528
gradient_accumulation_steps : int = 1 ,
481
529
find_unused_parameters : bool = False ,
530
+ args = None ,
482
531
):
532
+ if args is not None :
533
+ learning_rate = args .learning_rate
534
+ weight_decay = args .weight_decay
535
+ num_workers = args .dataset_num_workers
536
+ save_steps = args .save_steps
537
+ num_epochs = args .num_epochs
538
+ gradient_accumulation_steps = args .gradient_accumulation_steps
539
+ find_unused_parameters = args .find_unused_parameters
540
+
541
+ optimizer = torch .optim .AdamW (model .trainable_modules (), lr = learning_rate , weight_decay = weight_decay )
542
+ scheduler = torch .optim .lr_scheduler .ConstantLR (optimizer )
483
543
dataloader = torch .utils .data .DataLoader (dataset , shuffle = True , collate_fn = lambda x : x [0 ], num_workers = num_workers )
484
544
accelerator = Accelerator (
485
545
gradient_accumulation_steps = gradient_accumulation_steps ,
@@ -509,8 +569,12 @@ def launch_data_process_task(
509
569
model : DiffusionTrainingModule ,
510
570
model_logger : ModelLogger ,
511
571
num_workers : int = 8 ,
572
+ args = None ,
512
573
):
513
- dataloader = torch .utils .data .DataLoader (dataset , shuffle = True , collate_fn = lambda x : x [0 ], num_workers = num_workers )
574
+ if args is not None :
575
+ num_workers = args .dataset_num_workers
576
+
577
+ dataloader = torch .utils .data .DataLoader (dataset , shuffle = False , collate_fn = lambda x : x [0 ], num_workers = num_workers )
514
578
accelerator = Accelerator ()
515
579
model , dataloader = accelerator .prepare (model , dataloader )
516
580
@@ -520,7 +584,7 @@ def launch_data_process_task(
520
584
folder = os .path .join (model_logger .output_path , str (accelerator .process_index ))
521
585
os .makedirs (folder , exist_ok = True )
522
586
save_path = os .path .join (model_logger .output_path , str (accelerator .process_index ), f"{ data_id } .pth" )
523
- data = model (data )
587
+ data = model (data , return_inputs = True )
524
588
torch .save (data , save_path )
525
589
526
590
@@ -623,4 +687,5 @@ def qwen_image_parser():
623
687
parser .add_argument ("--weight_decay" , type = float , default = 0.01 , help = "Weight decay." )
624
688
parser .add_argument ("--processor_path" , type = str , default = None , help = "Path to the processor. If provided, the processor will be used for image editing." )
625
689
parser .add_argument ("--enable_fp8_training" , default = False , action = "store_true" , help = "Whether to enable FP8 training. Only available for LoRA training on a single GPU." )
690
+ parser .add_argument ("--task" , type = str , default = "sft" , required = False , help = "Task type." )
626
691
return parser
0 commit comments