|
50 | 50 | from diffusers.training_utils import EMAModel, compute_snr |
51 | 51 | from diffusers.utils import check_min_version, is_wandb_available |
52 | 52 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
53 | | -from diffusers.utils.import_utils import is_xformers_available |
| 53 | +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available |
54 | 54 | from diffusers.utils.torch_utils import is_compiled_module |
55 | 55 |
|
56 | 56 |
|
57 | 57 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. |
58 | 58 | check_min_version("0.28.0.dev0") |
59 | 59 |
|
60 | 60 | logger = get_logger(__name__) |
61 | | - |
| 61 | +if is_torch_npu_available(): |
| 62 | + torch.npu.config.allow_internal_format = False |
62 | 63 |
|
63 | 64 | DATASET_NAME_MAPPING = { |
64 | 65 | "lambdalabs/naruto-blip-captions": ("image", "text"), |
@@ -460,6 +461,9 @@ def parse_args(input_args=None): |
460 | 461 | ), |
461 | 462 | ) |
462 | 463 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
| 464 | + parser.add_argument( |
| 465 | + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." |
| 466 | + ) |
463 | 467 | parser.add_argument( |
464 | 468 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." |
465 | 469 | ) |
@@ -716,7 +720,12 @@ def main(args): |
716 | 720 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant |
717 | 721 | ) |
718 | 722 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) |
719 | | - |
| 723 | + if args.enable_npu_flash_attention: |
| 724 | + if is_torch_npu_available(): |
| 725 | + logger.info("npu flash attention enabled.") |
| 726 | + unet.enable_npu_flash_attention() |
| 727 | + else: |
| 728 | + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") |
720 | 729 | if args.enable_xformers_memory_efficient_attention: |
721 | 730 | if is_xformers_available(): |
722 | 731 | import xformers |
|
0 commit comments