Skip to content

Commit 768d0ea

Browse files
committed
try to make dreambooth script work; accelerator backward not playing well
1 parent cca5381 commit 768d0ea

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
FlowMatchEulerDiscreteScheduler,
6868
FluxPipeline,
6969
FluxTransformer2DModel,
70+
ParallelConfig,
71+
enable_parallelism,
7072
)
7173
from diffusers.optimization import get_scheduler
7274
from diffusers.training_utils import (
@@ -805,6 +807,8 @@ def parse_args(input_args=None):
805807
],
806808
help="The image interpolation method to use for resizing images.",
807809
)
810+
parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.")
811+
parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.")
808812

809813
if input_args is not None:
810814
args = parser.parse_args(input_args)
@@ -1347,15 +1351,28 @@ def main(args):
13471351

13481352
logging_dir = Path(args.output_dir, args.logging_dir)
13491353

1354+
cp_degree = args.context_parallel_degree
13501355
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
1351-
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
1356+
if cp_degree > 1:
1357+
kwargs = []
1358+
else:
1359+
kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)]
13521360
accelerator = Accelerator(
13531361
gradient_accumulation_steps=args.gradient_accumulation_steps,
13541362
mixed_precision=args.mixed_precision,
13551363
log_with=args.report_to,
13561364
project_config=accelerator_project_config,
1357-
kwargs_handlers=[kwargs],
1358-
)
1365+
kwargs_handlers=kwargs,
1366+
)
1367+
if cp_degree > 1 and not torch.distributed.is_initialized():
1368+
if not torch.cuda.is_available():
1369+
raise ValueError("Context parallelism is only tested on CUDA devices.")
1370+
if os.environ.get("WORLD_SIZE", None) is None:
1371+
raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.")
1372+
torch.distributed.init_process_group("nccl")
1373+
rank = torch.distributed.get_rank()
1374+
rank = accelerator.process_index
1375+
torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count()))
13591376

13601377
# Disable AMP for MPS.
13611378
if torch.backends.mps.is_available():
@@ -1977,6 +1994,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19771994
power=args.lr_power,
19781995
)
19791996

1997+
# Enable context parallelism
1998+
if cp_degree > 1:
1999+
ring_degree = cp_degree if args.context_parallel_type == "ring" else None
2000+
ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None
2001+
transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree))
2002+
transformer.set_attention_backend("_native_cudnn")
2003+
parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext()
2004+
19802005
# Prepare everything with our `accelerator`.
19812006
if not freeze_text_encoder:
19822007
if args.enable_t5_ti:
@@ -2131,7 +2156,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21312156
logger.info(f"PIVOT TRANSFORMER {epoch}")
21322157
optimizer.param_groups[0]["lr"] = 0.0
21332158

2134-
with accelerator.accumulate(models_to_accumulate):
2159+
with accelerator.accumulate(models_to_accumulate), parallel_context:
21352160
prompts = batch["prompts"]
21362161

21372162
# encode batch prompts when custom prompts are provided for each image -

0 commit comments

Comments
 (0)