Skip to content

Commit 98ba18b

Browse files
HelloWorldBeginnermhh001sayakpaul
authored
Add Ascend NPU support for SDXL. (#7916)
Co-authored-by: mhh001 <mahonghao1@huawei.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 5bb3858 commit 98ba18b

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,16 @@
5050
from diffusers.training_utils import EMAModel, compute_snr
5151
from diffusers.utils import check_min_version, is_wandb_available
5252
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
53-
from diffusers.utils.import_utils import is_xformers_available
53+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
5454
from diffusers.utils.torch_utils import is_compiled_module
5555

5656

5757
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5858
check_min_version("0.28.0.dev0")
5959

6060
logger = get_logger(__name__)
61-
61+
if is_torch_npu_available():
62+
torch.npu.config.allow_internal_format = False
6263

6364
DATASET_NAME_MAPPING = {
6465
"lambdalabs/naruto-blip-captions": ("image", "text"),
@@ -460,6 +461,9 @@ def parse_args(input_args=None):
460461
),
461462
)
462463
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
464+
parser.add_argument(
465+
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
466+
)
463467
parser.add_argument(
464468
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
465469
)
@@ -716,7 +720,12 @@ def main(args):
716720
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
717721
)
718722
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
719-
723+
if args.enable_npu_flash_attention:
724+
if is_torch_npu_available():
725+
logger.info("npu flash attention enabled.")
726+
unet.enable_npu_flash_attention()
727+
else:
728+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
720729
if args.enable_xformers_memory_efficient_attention:
721730
if is_xformers_available():
722731
import xformers

0 commit comments

Comments
 (0)