From c31a35165dfe0e767d40c23bfb00a39245c60497 Mon Sep 17 00:00:00 2001 From: mhh001 Date: Sat, 11 May 2024 15:12:25 +0800 Subject: [PATCH] Add Ascend NPU support for SDXL. --- .../text_to_image/train_text_to_image_sdxl.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 90602ad597a9..ee9dfd63cbc4 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -50,7 +50,7 @@ from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -58,7 +58,8 @@ check_min_version("0.28.0.dev0") logger = get_logger(__name__) - +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False DATASET_NAME_MAPPING = { "lambdalabs/naruto-blip-captions": ("image", "text"), @@ -460,6 +461,9 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." + ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) @@ -716,7 +720,12 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) - + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + unet.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers