Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")

logger = get_logger(__name__)

if is_torch_npu_available():
torch.npu.config.allow_internal_format = False

DATASET_NAME_MAPPING = {
"lambdalabs/naruto-blip-captions": ("image", "text"),
Expand Down Expand Up @@ -460,6 +461,9 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
Expand Down Expand Up @@ -716,7 +720,12 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
unet.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
Expand Down