|
22 | 22 | from comfy.cli_args import args, PerformanceFeature
|
23 | 23 | import torch
|
24 | 24 | import sys
|
| 25 | +import importlib |
25 | 26 | import platform
|
26 | 27 | import weakref
|
27 | 28 | import gc
|
@@ -336,12 +337,13 @@ def amd_min_version(device=None, min_rdna_version=0):
|
336 | 337 | logging.info("AMD arch: {}".format(arch))
|
337 | 338 | logging.info("ROCm version: {}".format(rocm_version))
|
338 | 339 | if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
339 |
| - if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much |
340 |
| - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 |
341 |
| - ENABLE_PYTORCH_ATTENTION = True |
342 |
| -# if torch_version_numeric >= (2, 8): |
343 |
| -# if any((a in arch) for a in ["gfx1201"]): |
344 |
| -# ENABLE_PYTORCH_ATTENTION = True |
| 340 | + if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. |
| 341 | + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much |
| 342 | + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 |
| 343 | + ENABLE_PYTORCH_ATTENTION = True |
| 344 | +# if torch_version_numeric >= (2, 8): |
| 345 | +# if any((a in arch) for a in ["gfx1201"]): |
| 346 | +# ENABLE_PYTORCH_ATTENTION = True |
345 | 347 | if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
346 | 348 | if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
347 | 349 | SUPPORT_FP8_OPS = True
|
|
0 commit comments