From f04ee1d1af2ceda9269b5a5cd4419b72df314d24 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 1 Nov 2024 00:25:52 +0000 Subject: [PATCH 01/13] update ptxla example --- .../research_projects/pytorch_xla/README.md | 16 +- .../pytorch_xla/train_text_to_image_xla.py | 179 +++++++++--------- src/diffusers/models/attention_processor.py | 24 ++- 3 files changed, 117 insertions(+), 102 deletions(-) diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/README.md index a6901d5ada9d..2a00fb6598cd 100644 --- a/examples/research_projects/pytorch_xla/README.md +++ b/examples/research_projects/pytorch_xla/README.md @@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on This script implements Distributed Data Parallel using GSPMD feature in XLA compiler where we shard the input batches over the TPU devices. -As of 9-11-2024, these are some expected step times. +As of 10-31-2024, these are some expected step times. | accelerator | global batch size | step time (seconds) | | ----------- | ----------------- | --------- | -| v5p-128 | 1024 | 0.245 | -| v5p-256 | 2048 | 0.234 | -| v5p-512 | 4096 | 0.2498 | +| v5p-512 | 16384 | 1.01 | +| v5p-256 | 8192 | 1.01 | +| v5p-128 | 4096 | 1.0 | +| v5p-64 | 2048 | 1.01 | ## Create TPU @@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions: gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} --zone=${ZONE} --worker=all \ --command=' -pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu -pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html +pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html +pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html ' ``` @@ -88,7 +90,7 @@ are fixed. gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} --zone=${ZONE} --worker=all \ --command=' -export XLA_DISABLE_FUNCTIONALIZATION=1 +export XLA_DISABLE_FUNCTIONALIZATION=0 export PROFILE_DIR=/tmp/ export CACHE_DIR=/tmp/ export DATASET_NAME=lambdalabs/naruto-blip-captions diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 5d9d8c540f11..751624b9c238 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -1,6 +1,7 @@ import argparse import os import random + import time from pathlib import Path @@ -28,12 +29,11 @@ from diffusers.utils import is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card - if is_wandb_available(): pass -PROFILE_DIR = os.environ.get("PROFILE_DIR", None) -CACHE_DIR = os.environ.get("CACHE_DIR", None) +PROFILE_DIR=os.environ.get('PROFILE_DIR', None) +CACHE_DIR = os.environ.get('CACHE_DIR', None) if CACHE_DIR: xr.initialize_cache(CACHE_DIR, readonly=False) xr.use_spmd() @@ -140,39 +140,37 @@ def run_optimizer(self): self.optimizer.step() def start_training(self): - times = [] - last_time = time.time() - step = 0 - while True: - if self.global_step >= self.args.max_train_steps: - xm.mark_step() - break - if step == 4 and PROFILE_DIR is not None: - xm.wait_device_ops() - xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration) + dataloader_exception = False + measure_start_step = 10 + assert measure_start_step < self.args.max_train_steps + total_time = 0 + for step in range(0, self.args.max_train_steps): try: batch = next(self.dataloader) except Exception as e: + dataloader_exception = True print(e) break + if step == measure_start_step and PROFILE_DIR is not None: + xm.wait_device_ops() + xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) + last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) - step_time = time.time() - last_time - if step >= 10: - times.append(step_time) - print(f"step: {step}, step_time: {step_time}") - if step % 5 == 0: - print(f"step: {step}, loss: {loss}") - last_time = time.time() self.global_step += 1 - step += 1 - # print(f"Average step time: {sum(times)/len(times)}") - xm.wait_device_ops() + xm.mark_step() + if not dataloader_exception: + xm.wait_device_ops() + total_time = time.time() - last_time + print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") + else: + print("dataloader exception happen, skip result") + return def step_fn( self, pixel_values, input_ids, - ): + ): with xp.Trace("model.forward"): self.optimizer.zero_grad() latents = self.vae.encode(pixel_values).latent_dist.sample() @@ -180,7 +178,10 @@ def step_fn( noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype) bsz = latents.shape[0] timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, ) timesteps = timesteps.long() @@ -195,8 +196,12 @@ def step_fn( elif self.noise_scheduler.config.prediction_type == "v_prediction": target = self.noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") - model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + raise ValueError( + f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" + ) + model_pred = self.unet( + noisy_latents, timesteps, encoder_hidden_states, return_dict=False + )[0] with xp.Trace("model.backward"): if self.args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -205,9 +210,9 @@ def step_fn( # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(self.noise_scheduler, timesteps) - mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] + mse_loss_weights = torch.stack( + [snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] if self.noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr elif self.noise_scheduler.config.prediction_type == "v_prediction": @@ -221,13 +226,11 @@ def step_fn( self.run_optimizer() return loss - def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( - "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + "--profile_duration", type=int, default=10000, help="Profile duration in ms" ) - parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -258,12 +261,6 @@ def parse_args(): " or to a folder containing files that 🤗 Datasets can understand." ), ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The config of the Dataset, leave as None if there's only one config.", - ) parser.add_argument( "--train_data_dir", type=str, @@ -283,15 +280,6 @@ def parse_args(): default="text", help="The column of the dataset containing a caption or a list of captions.", ) - parser.add_argument( - "--max_train_samples", - type=int, - default=None, - help=( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ), - ) parser.add_argument( "--output_dir", type=str, @@ -304,7 +292,6 @@ def parse_args(): default=None, help="The directory where the downloaded models and datasets will be stored.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -372,13 +359,17 @@ def parse_args(): "--loader_prefetch_size", type=int, default=1, - help=("Number of subprocesses to use for data loading to cpu."), + help=( + "Number of subprocesses to use for data loading to cpu." + ), ) parser.add_argument( "--device_prefetch_size", type=int, default=1, - help=("Number of subprocesses to use for data loading to tpu from cpu. "), + help=( + "Number of subprocesses to use for data loading to tpu from cpu. " + ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") @@ -394,12 +385,11 @@ def parse_args(): "--mixed_precision", type=str, default=None, - choices=["no", "fp16", "bf16"], + choices=["no", "bf16"], help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + "Whether to use mixed precision. Bf16 requires PyTorch >= 1.10" ), + ) parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") @@ -418,7 +408,6 @@ def parse_args(): return args - def setup_optimizer(unet, args): optimizer_cls = torch.optim.AdamW return optimizer_cls( @@ -430,13 +419,11 @@ def setup_optimizer(unet, args): foreach=True, ) - def load_dataset(args): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = datasets.load_dataset( args.dataset_name, - args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir, ) @@ -451,7 +438,6 @@ def load_dataset(args): ) return dataset - def get_column_names(dataset, args): column_names = dataset["train"].column_names @@ -476,14 +462,13 @@ def get_column_names(dataset, args): def main(args): + args = parse_args() - _ = xp.start_server(PORT) + server = xp.start_server(9012) num_devices = xr.global_runtime_device_count() - device_ids = np.arange(num_devices) - mesh_shape = (num_devices, 1) - mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) + mesh = xs.get_1d_mesh('data') xs.set_global_mesh(mesh) text_encoder = CLIPTextModel.from_pretrained( @@ -518,7 +503,6 @@ def main(args): ) from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear - unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) vae.requires_grad_(False) @@ -530,15 +514,12 @@ def main(args): # as these weights are only used for inference, keeping weights in full # precision is not required. weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": + if args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 device = xm.xla_device() - print("device: ", device) - print("weight_dtype: ", weight_dtype) + # Move text_encode and vae to device and cast to weight_dtype text_encoder = text_encoder.to(device, dtype=weight_dtype) vae = vae.to(device, dtype=weight_dtype) unet = unet.to(device, dtype=weight_dtype) @@ -573,9 +554,19 @@ def tokenize_captions(examples, is_train=True): train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - (transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)), - (transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)), + transforms.Resize( + args.resolution, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(args.resolution) + if args.center_crop + else transforms.RandomCrop(args.resolution) + ), + ( + transforms.RandomHorizontalFlip() + if args.random_flip + else transforms.Lambda(lambda x: x) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -588,18 +579,22 @@ def preprocess_train(examples): return examples train_dataset = dataset["train"] - train_dataset.set_format("torch") + train_dataset.set_format('torch') train_dataset.set_transform(preprocess_train) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to( + weight_dtype + ) input_ids = torch.stack([example["input_ids"] for example in examples]) return {"pixel_values": pixel_values, "input_ids": input_ids} g = torch.Generator() g.manual_seed(xr.host_index()) - sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g) + sampler = torch.utils.data.RandomSampler( + train_dataset, replacement=True, num_samples=int(1e10), generator=g + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, sampler=sampler, @@ -612,32 +607,34 @@ def collate_fn(examples): train_dataloader, device, input_sharding={ - "pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True), - "input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True), + "pixel_values": xs.ShardingSpec( + mesh, ("data", None, None, None), minibatch=True + ), + "input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True), }, loader_prefetch_size=args.loader_prefetch_size, device_prefetch_size=args.device_prefetch_size, ) + num_hosts = xr.process_count() + num_devices_per_host = num_devices // num_hosts if xm.is_master_ordinal(): print("***** Running training *****") - print(f"Instantaneous batch size per device = {args.train_batch_size}") + print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }") print( - f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}" + f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}" ) print(f" Total optimization steps = {args.max_train_steps}") - trainer = TrainSD( - vae=vae, - weight_dtype=weight_dtype, - device=device, - noise_scheduler=noise_scheduler, - unet=unet, - optimizer=optimizer, - text_encoder=text_encoder, - dataloader=train_dataloader, - args=args, - ) + trainer = TrainSD(vae=vae, + weight_dtype=weight_dtype, + device=device, + noise_scheduler=noise_scheduler, + unet=unet, + optimizer=optimizer, + text_encoder=text_encoder, + dataloader=train_dataloader, + args=args) trainer.start_training() unet = trainer.unet.to("cpu") @@ -666,4 +663,4 @@ def collate_fn(examples): if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index db88ecbbb9d3..6e5854087f2a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,7 +20,7 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import deprecate, logging +from ..utils import deprecate, logging, is_torch_xla_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph @@ -36,6 +36,11 @@ else: xformers = None +if is_torch_xla_available(): + from torch_xla.experimental.custom_kernel import flash_attention + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False @maybe_allow_in_graph class Attention(nn.Module): @@ -2474,9 +2479,20 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) + # Convert mask to float and replace 0s with -inf and 1s with 0 + attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0)) + + # Apply attention mask to key + key = key + attention_mask + query /= math.sqrt(query.shape[3]) + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None)) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 96af06e84355ac00fa8e051d75916c4d969eea2f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 1 Nov 2024 23:36:58 +0000 Subject: [PATCH 02/13] update ptxla example based on Pei's comments. --- examples/research_projects/pytorch_xla/README.md | 2 +- .../pytorch_xla/train_text_to_image_xla.py | 13 +++++++++++-- src/diffusers/models/attention_processor.py | 4 ++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/README.md index 2a00fb6598cd..c8bbe05e9da4 100644 --- a/examples/research_projects/pytorch_xla/README.md +++ b/examples/research_projects/pytorch_xla/README.md @@ -97,7 +97,7 @@ export DATASET_NAME=lambdalabs/naruto-blip-captions export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p export TRAIN_STEPS=50 export OUTPUT_DIR=/tmp/trained-model/ -python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4' +python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4' ``` diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 751624b9c238..c4c174814f95 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -32,7 +32,7 @@ if is_wandb_available(): pass -PROFILE_DIR=os.environ.get('PROFILE_DIR', None) +PROFILE_DIR = os.environ.get('PROFILE_DIR', None) CACHE_DIR = os.environ.get('CACHE_DIR', None) if CACHE_DIR: xr.initialize_cache(CACHE_DIR, readonly=False) @@ -363,6 +363,14 @@ def parse_args(): "Number of subprocesses to use for data loading to cpu." ), ) + parser.add_argument( + "--loader_prefetch_factor", + type=int, + default=2, + help=( + "Number of batches loaded in advance by each worker." + ), + ) parser.add_argument( "--device_prefetch_size", type=int, @@ -579,7 +587,7 @@ def preprocess_train(examples): return examples train_dataset = dataset["train"] - train_dataset.set_format('torch') + train_dataset.set_format("torch") train_dataset.set_transform(preprocess_train) def collate_fn(examples): @@ -601,6 +609,7 @@ def collate_fn(examples): collate_fn=collate_fn, num_workers=args.dataloader_num_workers, batch_size=args.train_batch_size, + prefetch_factor=args.loader_prefetch_factor, ) train_dataloader = pl.MpDeviceLoader( diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6e5854087f2a..a4c3c870f277 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,7 +20,7 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import deprecate, logging, is_torch_xla_available +from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph @@ -2484,7 +2484,7 @@ def __call__( attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) # Convert mask to float and replace 0s with -inf and 1s with 0 attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0)) - + # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) From 6234a37b6b553c4e0c3b99f38ed2a4ae47c5da56 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 4 Nov 2024 20:03:29 +0000 Subject: [PATCH 03/13] add print loss cli argument. Run make style and quality. --- .../research_projects/pytorch_xla/README.md | 3 +- .../pytorch_xla/train_text_to_image_xla.py | 127 +++++++++--------- src/diffusers/models/attention_processor.py | 10 +- 3 files changed, 70 insertions(+), 70 deletions(-) diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/README.md index c8bbe05e9da4..06013b8a61e0 100644 --- a/examples/research_projects/pytorch_xla/README.md +++ b/examples/research_projects/pytorch_xla/README.md @@ -98,9 +98,10 @@ export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to export TRAIN_STEPS=50 export OUTPUT_DIR=/tmp/trained-model/ python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4' - ``` +Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer. + ### Environment Envs Explained * `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer. diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index c4c174814f95..78449ef5b26a 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -1,7 +1,6 @@ import argparse import os import random - import time from pathlib import Path @@ -29,11 +28,12 @@ from diffusers.utils import is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card + if is_wandb_available(): pass -PROFILE_DIR = os.environ.get('PROFILE_DIR', None) -CACHE_DIR = os.environ.get('CACHE_DIR', None) +PROFILE_DIR = os.environ.get("PROFILE_DIR", None) +CACHE_DIR = os.environ.get("CACHE_DIR", None) if CACHE_DIR: xr.initialize_cache(CACHE_DIR, readonly=False) xr.use_spmd() @@ -151,12 +151,24 @@ def start_training(self): dataloader_exception = True print(e) break - if step == measure_start_step and PROFILE_DIR is not None: + if step == measure_start_step and PROFILE_DIR is not None: xm.wait_device_ops() - xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) - last_time = time.time() + xp.trace_detached("localhost:9012", PROFILE_DIR, duration_ms=args.profile_duration) + last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) self.global_step += 1 + + def print_loss_closure(step, loss): + print(f"Step: {step}, Loss: {loss}") + + if args.print_loss: + xm.add_step_closure( + print_loss_closure, + args=( + self.global_step, + loss, + ), + ) xm.mark_step() if not dataloader_exception: xm.wait_device_ops() @@ -170,7 +182,7 @@ def step_fn( self, pixel_values, input_ids, - ): + ): with xp.Trace("model.forward"): self.optimizer.zero_grad() latents = self.vae.encode(pixel_values).latent_dist.sample() @@ -196,12 +208,8 @@ def step_fn( elif self.noise_scheduler.config.prediction_type == "v_prediction": target = self.noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError( - f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" - ) - model_pred = self.unet( - noisy_latents, timesteps, encoder_hidden_states, return_dict=False - )[0] + raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") + model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] with xp.Trace("model.backward"): if self.args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -210,9 +218,9 @@ def step_fn( # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(self.noise_scheduler, timesteps) - mse_loss_weights = torch.stack( - [snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] + mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if self.noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr elif self.noise_scheduler.config.prediction_type == "v_prediction": @@ -226,11 +234,10 @@ def step_fn( self.run_optimizer() return loss + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--profile_duration", type=int, default=10000, help="Profile duration in ms" - ) + parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -359,25 +366,19 @@ def parse_args(): "--loader_prefetch_size", type=int, default=1, - help=( - "Number of subprocesses to use for data loading to cpu." - ), + help=("Number of subprocesses to use for data loading to cpu."), ) parser.add_argument( "--loader_prefetch_factor", type=int, default=2, - help=( - "Number of batches loaded in advance by each worker." - ), + help=("Number of batches loaded in advance by each worker."), ) parser.add_argument( "--device_prefetch_size", type=int, default=1, - help=( - "Number of subprocesses to use for data loading to tpu from cpu. " - ), + help=("Number of subprocesses to use for data loading to tpu from cpu. "), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") @@ -394,10 +395,7 @@ def parse_args(): type=str, default=None, choices=["no", "bf16"], - help=( - "Whether to use mixed precision. Bf16 requires PyTorch >= 1.10" - ), - + help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"), ) parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") @@ -407,6 +405,12 @@ def parse_args(): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) + parser.add_argument( + "--print_loss", + default=False, + action="store_true", + help=("Print loss at every step."), + ) args = parser.parse_args() @@ -416,6 +420,7 @@ def parse_args(): return args + def setup_optimizer(unet, args): optimizer_cls = torch.optim.AdamW return optimizer_cls( @@ -427,6 +432,7 @@ def setup_optimizer(unet, args): foreach=True, ) + def load_dataset(args): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. @@ -446,6 +452,7 @@ def load_dataset(args): ) return dataset + def get_column_names(dataset, args): column_names = dataset["train"].column_names @@ -470,13 +477,12 @@ def get_column_names(dataset, args): def main(args): - args = parse_args() - server = xp.start_server(9012) + _ = xp.start_server(9012) num_devices = xr.global_runtime_device_count() - mesh = xs.get_1d_mesh('data') + mesh = xs.get_1d_mesh("data") xs.set_global_mesh(mesh) text_encoder = CLIPTextModel.from_pretrained( @@ -511,6 +517,7 @@ def main(args): ) from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear + unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) vae.requires_grad_(False) @@ -562,19 +569,9 @@ def tokenize_captions(examples, is_train=True): train_transforms = transforms.Compose( [ - transforms.Resize( - args.resolution, interpolation=transforms.InterpolationMode.BILINEAR - ), - ( - transforms.CenterCrop(args.resolution) - if args.center_crop - else transforms.RandomCrop(args.resolution) - ), - ( - transforms.RandomHorizontalFlip() - if args.random_flip - else transforms.Lambda(lambda x: x) - ), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + (transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)), + (transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -592,17 +589,13 @@ def preprocess_train(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to( - weight_dtype - ) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype) input_ids = torch.stack([example["input_ids"] for example in examples]) return {"pixel_values": pixel_values, "input_ids": input_ids} g = torch.Generator() g.manual_seed(xr.host_index()) - sampler = torch.utils.data.RandomSampler( - train_dataset, replacement=True, num_samples=int(1e10), generator=g - ) + sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g) train_dataloader = torch.utils.data.DataLoader( train_dataset, sampler=sampler, @@ -616,9 +609,7 @@ def collate_fn(examples): train_dataloader, device, input_sharding={ - "pixel_values": xs.ShardingSpec( - mesh, ("data", None, None, None), minibatch=True - ), + "pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True), "input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True), }, loader_prefetch_size=args.loader_prefetch_size, @@ -635,15 +626,17 @@ def collate_fn(examples): ) print(f" Total optimization steps = {args.max_train_steps}") - trainer = TrainSD(vae=vae, - weight_dtype=weight_dtype, - device=device, - noise_scheduler=noise_scheduler, - unet=unet, - optimizer=optimizer, - text_encoder=text_encoder, - dataloader=train_dataloader, - args=args) + trainer = TrainSD( + vae=vae, + weight_dtype=weight_dtype, + device=device, + noise_scheduler=noise_scheduler, + unet=unet, + optimizer=optimizer, + text_encoder=text_encoder, + dataloader=train_dataloader, + args=args, + ) trainer.start_training() unet = trainer.unet.to("cpu") @@ -672,4 +665,4 @@ def collate_fn(examples): if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a4c3c870f277..b5247368001c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -38,10 +38,12 @@ if is_torch_xla_available(): from torch_xla.experimental.custom_kernel import flash_attention + XLA_AVAILABLE = True else: XLA_AVAILABLE = False + @maybe_allow_in_graph class Attention(nn.Module): r""" @@ -2483,12 +2485,16 @@ def __call__( if attention_mask is not None: attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) # Convert mask to float and replace 0s with -inf and 1s with 0 - attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0)) + attention_mask = ( + attention_mask.float() + .masked_fill(attention_mask == 0, float("-inf")) + .masked_fill(attention_mask == 1, float(0.0)) + ) # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) - hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None)) + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=("data", None, None, None)) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From b51613439974f30768b993f651c6f821b8fa5eeb Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 4 Nov 2024 21:47:54 +0000 Subject: [PATCH 04/13] make measure_start_step an argument. --- .../research_projects/pytorch_xla/train_text_to_image_xla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 78449ef5b26a..3bee979d62bf 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -141,7 +141,7 @@ def run_optimizer(self): def start_training(self): dataloader_exception = False - measure_start_step = 10 + measure_start_step = args.measure_start_step assert measure_start_step < self.args.max_train_steps total_time = 0 for step in range(0, self.args.max_train_steps): @@ -380,6 +380,7 @@ def parse_args(): default=1, help=("Number of subprocesses to use for data loading to tpu from cpu. "), ) + parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") From 8c47f35d8e18b1c0244392f8cb7ce87d68ed3d19 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 4 Nov 2024 21:57:45 +0000 Subject: [PATCH 05/13] use PORT variable across the script. --- .../research_projects/pytorch_xla/train_text_to_image_xla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 3bee979d62bf..068bc8bdd570 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -153,7 +153,7 @@ def start_training(self): break if step == measure_start_step and PROFILE_DIR is not None: xm.wait_device_ops() - xp.trace_detached("localhost:9012", PROFILE_DIR, duration_ms=args.profile_duration) + xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration) last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) self.global_step += 1 @@ -480,7 +480,7 @@ def get_column_names(dataset, args): def main(args): args = parse_args() - _ = xp.start_server(9012) + _ = xp.start_server(PORT) num_devices = xr.global_runtime_device_count() mesh = xs.get_1d_mesh("data") From 8b3cbb165c98354b2e21b0bc2a30d9436337f87c Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Thu, 28 Nov 2024 15:15:10 -0800 Subject: [PATCH 06/13] split out xla flash attention from base AttnProcessor --- src/diffusers/models/attention_processor.py | 108 +++++++++++++++++++- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a66a85b2200a..b4d489a77a0f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -276,10 +276,16 @@ def __init__( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + # If torch_xla is available, we use pallas flash attention kernel to improve the performance. if processor is None: - processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() - ) + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk: + if is_torch_xla_available: + processor = XLAFlashAttnProcessor2_0() + else: + processor = AttnProcessor2_0() + else: + processor = AttnProcessor() + self.set_processor(processor) def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: @@ -2644,6 +2650,102 @@ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XLAFlashAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using torch_xla). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + if not is_torch_xla_available: + raise ImportError("XLAFlashAttnProcessor2_0 required torch_xla package.") + def __call__( self, attn: Attention, From 2c00cbd01dc4e4b1cf73cb1f7203ce6c388fd5a3 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Fri, 29 Nov 2024 00:23:08 -0800 Subject: [PATCH 07/13] use version check for torch_xla --- src/diffusers/models/attention_processor.py | 34 ++++++++++++--------- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 13 ++++++++ 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 98cac4df574b..31db8dea7147 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -21,7 +21,7 @@ from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, is_torch_xla_available, logging -from ..utils.import_utils import is_torch_npu_available, is_xformers_available +from ..utils.import_utils import is_torch_npu_available, is_xformers_available, is_torch_xla_version from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph @@ -37,8 +37,10 @@ xformers = None if is_torch_xla_available(): - from torch_xla.experimental.custom_kernel import flash_attention - + # flash attention pallas kernel is introduced in the torch_xla 2.3 release. + if is_torch_xla_version(">", "2.2"): + from torch_xla.runtime import is_spmd + from torch_xla.experimental.custom_kernel import flash_attention XLA_AVAILABLE = True else: XLA_AVAILABLE = False @@ -276,16 +278,21 @@ def __init__( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - # If torch_xla is available, we use pallas flash attention kernel to improve the performance. + # If torch_xla is available with the correct version, we use pallas flash attention kernel to improve + # the performance. if processor is None: if hasattr(F, "scaled_dot_product_attention") and self.scale_qk: - if is_torch_xla_available: + if ( + is_torch_xla_available + and is_torch_xla_version('>', '2.2') + and (not is_spmd() or is_torch_xla_version('>', '2.3')) + ): processor = XLAFlashAttnProcessor2_0() else: processor = AttnProcessor2_0() else: processor = AttnProcessor() - + self.set_processor(processor) def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: @@ -2771,8 +2778,10 @@ class XLAFlashAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - if not is_torch_xla_available: - raise ImportError("XLAFlashAttnProcessor2_0 required torch_xla package.") + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + if is_spmd() and is_torch_xla_version("<", "2.4"): + raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") def __call__( self, @@ -2784,10 +2793,6 @@ def __call__( *args, **kwargs, ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2836,7 +2841,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): + if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): if attention_mask is not None: attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) # Convert mask to float and replace 0s with -inf and 1s with 0 @@ -2849,7 +2854,8 @@ def __call__( # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) - hidden_states = flash_attention(query, key, value, causal=False, partition_spec=("data", None, None, None)) + partition_spec = ("data", None, None, None) if is_spmd() else None + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c8f64adf3e8a..f91cee8113f2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -86,6 +86,7 @@ is_torch_npu_available, is_torch_version, is_torch_xla_available, + is_torch_xla_version, is_torchsde_available, is_torchvision_available, is_transformers_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f1323bf00ea4..3d90342b8b30 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -700,6 +700,19 @@ def is_torch_version(operation: str, version: str): return compare_versions(parse(_torch_version), operation, version) +def is_torch_xla_version(operation: str, version: str): + """ + Compares the current torch_xla version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of torch_xla + """ + return compare_versions(parse(_torch_xla_version), operation, version) + + def is_transformers_version(operation: str, version: str): """ Compares the current Transformers version to a given reference with an operation. From fb29e3724d002d4b07a6b141896bc4f6137e9120 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Tue, 3 Dec 2024 14:27:25 -0800 Subject: [PATCH 08/13] setup the option to use xla flash attention or not --- .../pytorch_xla/train_text_to_image_xla.py | 1 + src/diffusers/models/attention_processor.py | 46 ++++++++++++------- src/diffusers/models/modeling_utils.py | 29 ++++++++++++ 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 068bc8bdd570..4eaaccd3f91d 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -520,6 +520,7 @@ def main(args): from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) + unet.enable_use_xla_flash_attention(partition_spec=("data", None, None, None)) vae.requires_grad_(False) text_encoder.requires_grad_(False) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 31db8dea7147..2af9b5989513 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -278,21 +278,33 @@ def __init__( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - # If torch_xla is available with the correct version, we use pallas flash attention kernel to improve - # the performance. if processor is None: - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk: - if ( - is_torch_xla_available - and is_torch_xla_version('>', '2.2') - and (not is_spmd() or is_torch_xla_version('>', '2.3')) - ): - processor = XLAFlashAttnProcessor2_0() - else: - processor = AttnProcessor2_0() - else: - processor = AttnProcessor() + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None) -> None: + r""" + Set whether to use xla flash attention from `torch_xla` or not. + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + """ + if ( + use_xla_flash_attention + and is_torch_xla_available + and is_torch_xla_version('>', '2.2') + and (not is_spmd() or is_torch_xla_version('>', '2.3')) + ): + processor = XLAFlashAttnProcessor2_0(partition_spec) + else: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) self.set_processor(processor) def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: @@ -2772,16 +2784,17 @@ def __call__( class XLAFlashAttnProcessor2_0: r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using torch_xla). + Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. """ - def __init__(self): + def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") if is_torch_xla_version("<", "2.3"): raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") if is_spmd() and is_torch_xla_version("<", "2.4"): raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") + self.partition_spec=partition_spec def __call__( self, @@ -2854,7 +2867,7 @@ def __call__( # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) - partition_spec = ("data", None, None, None) if is_spmd() else None + partition_spec = self.partition_spec if is_spmd() else None hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) else: hidden_states = F.scaled_dot_product_attention( @@ -5201,6 +5214,7 @@ def __init__(self): FusedCogVideoXAttnProcessor2_0, XFormersAttnAddedKVProcessor, XFormersAttnProcessor, + XLAFlashAttnProcessor2_0, AttnProcessorNPU, AttnProcessor2_0, MochiVaeAttnProcessor2_0, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4a486fd4ce40..e864a8493766 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -208,6 +208,35 @@ def disable_npu_flash_attention(self) -> None: """ self.set_use_npu_flash_attention(False) + def set_use_xla_flash_attention( + self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_xla_flash_attention method + # gets the message + def fn_recursive_set_flash_attention(module: torch.nn.Module): + if hasattr(module, "set_use_xla_flash_attention"): + module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec) + + for child in module.children(): + fn_recursive_set_flash_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_flash_attention(module) + + def enable_use_xla_flash_attention(self, partition_spec: Optional[Callable] = None): + r""" + Enable the flash attention pallals kernel for torch_xla. + """ + self.set_use_xla_flash_attention(True, partition_spec) + + def disable_use_xla_flash_attention(self): + r""" + Disable the flash attention pallals kernel for torch_xla. + """ + self.set_use_xla_flash_attention(False) + def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: From df31c9dfa691e229895138cefb4411758e15e021 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Tue, 3 Dec 2024 14:40:19 -0800 Subject: [PATCH 09/13] naming nit --- .../research_projects/pytorch_xla/train_text_to_image_xla.py | 2 +- src/diffusers/models/modeling_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 4eaaccd3f91d..9719585d3dfb 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -520,7 +520,7 @@ def main(args): from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) - unet.enable_use_xla_flash_attention(partition_spec=("data", None, None, None)) + unet.enable_xla_flash_attention(partition_spec=("data", None, None, None)) vae.requires_grad_(False) text_encoder.requires_grad_(False) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 872a4458b3bb..b643d7e3bba0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -225,13 +225,13 @@ def fn_recursive_set_flash_attention(module: torch.nn.Module): if isinstance(module, torch.nn.Module): fn_recursive_set_flash_attention(module) - def enable_use_xla_flash_attention(self, partition_spec: Optional[Callable] = None): + def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None): r""" Enable the flash attention pallals kernel for torch_xla. """ self.set_use_xla_flash_attention(True, partition_spec) - def disable_use_xla_flash_attention(self): + def disable_xla_flash_attention(self): r""" Disable the flash attention pallals kernel for torch_xla. """ From ff332e686a61c6e20ddca688b0fe62c8d06abfac Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Tue, 3 Dec 2024 22:29:19 -0800 Subject: [PATCH 10/13] format fix with ruff cmd --- src/diffusers/models/attention_processor.py | 4 ++-- src/diffusers/models/modeling_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 36b0032537cd..dea9d9482712 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -21,7 +21,7 @@ from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, is_torch_xla_available, logging -from ..utils.import_utils import is_torch_npu_available, is_xformers_available, is_torch_xla_version +from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph @@ -39,8 +39,8 @@ if is_torch_xla_available(): # flash attention pallas kernel is introduced in the torch_xla 2.3 release. if is_torch_xla_version(">", "2.2"): - from torch_xla.runtime import is_spmd from torch_xla.experimental.custom_kernel import flash_attention + from torch_xla.runtime import is_spmd XLA_AVAILABLE = True else: XLA_AVAILABLE = False diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index b643d7e3bba0..60a0e6df0087 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -226,13 +226,13 @@ def fn_recursive_set_flash_attention(module: torch.nn.Module): fn_recursive_set_flash_attention(module) def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None): - r""" + r""" Enable the flash attention pallals kernel for torch_xla. """ self.set_use_xla_flash_attention(True, partition_spec) def disable_xla_flash_attention(self): - r""" + r""" Disable the flash attention pallals kernel for torch_xla. """ self.set_use_xla_flash_attention(False) From dbe4725f3d2ba4196efe8124b216065b7a7a84b5 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Thu, 5 Dec 2024 21:50:46 -0800 Subject: [PATCH 11/13] adding warning message --- src/diffusers/models/attention_processor.py | 17 ++++++++++------- src/diffusers/utils/import_utils.py | 2 ++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index dea9d9482712..f8c12affde54 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -294,13 +294,15 @@ def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_s partition_spec (`Tuple[]`, *optional*): Specify the partition specification if using SPMD. Otherwise None. """ - if ( - use_xla_flash_attention - and is_torch_xla_available - and is_torch_xla_version('>', '2.2') - and (not is_spmd() or is_torch_xla_version('>', '2.3')) - ): - processor = XLAFlashAttnProcessor2_0(partition_spec) + if use_xla_flash_attention: + if not is_torch_xla_available: + raise "torch_xla is not available" + elif is_torch_xla_version("<", "2.3"): + raise "flash attention pallas kernel is supported from torch_xla version 2.3" + elif is_spmd() and is_torch_xla_version("<", "2.4"): + raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" + else: + processor = XLAFlashAttnProcessor2_0(partition_spec) else: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() @@ -2871,6 +2873,7 @@ def __call__( partition_spec = self.partition_spec if is_spmd() else None hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) else: + logger.warning(f"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096.") hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 3d90342b8b30..e3b7655737a8 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -710,6 +710,8 @@ def is_torch_xla_version(operation: str, version: str): version (`str`): A string version of torch_xla """ + if not is_torch_xla_available: + return False return compare_versions(parse(_torch_xla_version), operation, version) From c012fafe91e7a8359705b1eeeae92632be6f7099 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Thu, 5 Dec 2024 21:53:04 -0800 Subject: [PATCH 12/13] format fix with ruff cmd --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f8c12affde54..69120bb5f09c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2873,7 +2873,7 @@ def __call__( partition_spec = self.partition_spec if is_spmd() else None hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) else: - logger.warning(f"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096.") + logger.warning("Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096.") hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) From 03089f516740b6bad757e68ffb61d6d590b4a337 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 6 Dec 2024 14:30:26 +0000 Subject: [PATCH 13/13] make style --- src/diffusers/models/attention_processor.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7e3ae786b876..444f201f6376 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -284,7 +284,9 @@ def __init__( ) self.set_processor(processor) - def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None) -> None: + def set_use_xla_flash_attention( + self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None + ) -> None: r""" Set whether to use xla flash attention from `torch_xla` or not. @@ -296,7 +298,7 @@ def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_s """ if use_xla_flash_attention: if not is_torch_xla_available: - raise "torch_xla is not available" + raise "torch_xla is not available" elif is_torch_xla_version("<", "2.3"): raise "flash attention pallas kernel is supported from torch_xla version 2.3" elif is_spmd() and is_torch_xla_version("<", "2.4"): @@ -2794,12 +2796,14 @@ class XLAFlashAttnProcessor2_0: def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) if is_torch_xla_version("<", "2.3"): raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") if is_spmd() and is_torch_xla_version("<", "2.4"): raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") - self.partition_spec=partition_spec + self.partition_spec = partition_spec def __call__( self, @@ -2875,7 +2879,9 @@ def __call__( partition_spec = self.partition_spec if is_spmd() else None hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) else: - logger.warning("Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096.") + logger.warning( + "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096." + ) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False )