|
67 | 67 | FlowMatchEulerDiscreteScheduler,
|
68 | 68 | FluxPipeline,
|
69 | 69 | FluxTransformer2DModel,
|
| 70 | + ParallelConfig, |
| 71 | + enable_parallelism, |
70 | 72 | )
|
71 | 73 | from diffusers.optimization import get_scheduler
|
72 | 74 | from diffusers.training_utils import (
|
@@ -805,6 +807,8 @@ def parse_args(input_args=None):
|
805 | 807 | ],
|
806 | 808 | help="The image interpolation method to use for resizing images.",
|
807 | 809 | )
|
| 810 | + parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.") |
| 811 | + parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.") |
808 | 812 |
|
809 | 813 | if input_args is not None:
|
810 | 814 | args = parser.parse_args(input_args)
|
@@ -1347,15 +1351,28 @@ def main(args):
|
1347 | 1351 |
|
1348 | 1352 | logging_dir = Path(args.output_dir, args.logging_dir)
|
1349 | 1353 |
|
| 1354 | + cp_degree = args.context_parallel_degree |
1350 | 1355 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
1351 |
| - kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| 1356 | + if cp_degree > 1: |
| 1357 | + kwargs = [] |
| 1358 | + else: |
| 1359 | + kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)] |
1352 | 1360 | accelerator = Accelerator(
|
1353 | 1361 | gradient_accumulation_steps=args.gradient_accumulation_steps,
|
1354 | 1362 | mixed_precision=args.mixed_precision,
|
1355 | 1363 | log_with=args.report_to,
|
1356 | 1364 | project_config=accelerator_project_config,
|
1357 |
| - kwargs_handlers=[kwargs], |
1358 |
| - ) |
| 1365 | + kwargs_handlers=kwargs, |
| 1366 | + ) |
| 1367 | + if cp_degree > 1 and not torch.distributed.is_initialized(): |
| 1368 | + if not torch.cuda.is_available(): |
| 1369 | + raise ValueError("Context parallelism is only tested on CUDA devices.") |
| 1370 | + if os.environ.get("WORLD_SIZE", None) is None: |
| 1371 | + raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.") |
| 1372 | + torch.distributed.init_process_group("nccl") |
| 1373 | + rank = torch.distributed.get_rank() |
| 1374 | + rank = accelerator.process_index |
| 1375 | + torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count())) |
1359 | 1376 |
|
1360 | 1377 | # Disable AMP for MPS.
|
1361 | 1378 | if torch.backends.mps.is_available():
|
@@ -1977,6 +1994,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
1977 | 1994 | power=args.lr_power,
|
1978 | 1995 | )
|
1979 | 1996 |
|
| 1997 | + # Enable context parallelism |
| 1998 | + if cp_degree > 1: |
| 1999 | + ring_degree = cp_degree if args.context_parallel_type == "ring" else None |
| 2000 | + ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None |
| 2001 | + transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree)) |
| 2002 | + transformer.set_attention_backend("_native_cudnn") |
| 2003 | + parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext() |
| 2004 | + |
1980 | 2005 | # Prepare everything with our `accelerator`.
|
1981 | 2006 | if not freeze_text_encoder:
|
1982 | 2007 | if args.enable_t5_ti:
|
@@ -2131,7 +2156,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
2131 | 2156 | logger.info(f"PIVOT TRANSFORMER {epoch}")
|
2132 | 2157 | optimizer.param_groups[0]["lr"] = 0.0
|
2133 | 2158 |
|
2134 |
| - with accelerator.accumulate(models_to_accumulate): |
| 2159 | + with accelerator.accumulate(models_to_accumulate), parallel_context: |
2135 | 2160 | prompts = batch["prompts"]
|
2136 | 2161 |
|
2137 | 2162 | # encode batch prompts when custom prompts are provided for each image -
|
|
0 commit comments