Skip to content

Commit bcbd788

Browse files
Don't enable pytorch attention on AMD if triton isn't available. (#9747)
1 parent 27a0fcc commit bcbd788

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

comfy/model_management.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from comfy.cli_args import args, PerformanceFeature
2323
import torch
2424
import sys
25+
import importlib
2526
import platform
2627
import weakref
2728
import gc
@@ -336,12 +337,13 @@ def amd_min_version(device=None, min_rdna_version=0):
336337
logging.info("AMD arch: {}".format(arch))
337338
logging.info("ROCm version: {}".format(rocm_version))
338339
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
339-
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
340-
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
341-
ENABLE_PYTORCH_ATTENTION = True
342-
# if torch_version_numeric >= (2, 8):
343-
# if any((a in arch) for a in ["gfx1201"]):
344-
# ENABLE_PYTORCH_ATTENTION = True
340+
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
341+
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
342+
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
343+
ENABLE_PYTORCH_ATTENTION = True
344+
# if torch_version_numeric >= (2, 8):
345+
# if any((a in arch) for a in ["gfx1201"]):
346+
# ENABLE_PYTORCH_ATTENTION = True
345347
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
346348
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
347349
SUPPORT_FP8_OPS = True

0 commit comments

Comments
 (0)