diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 83d4ee1..95fce09 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, Jingze Shi. from typing import Optional, Sequence, Tuple, Union, Any - +from packaging import version import torch import torch.nn as nn import os @@ -46,7 +46,7 @@ def round_multiple(x, m): # torch.compile() support is only enabled for pytorch >= 2.4 # The reason for this is that we are using the new custom_op and register_fake # APIs, which support inplace modification of inputs in the function itself -if torch.__version__ >= "2.4.0": +if version.parse(torch.__version__) >= version.parse("2.4.0"): _torch_custom_op_wrapper = torch.library.custom_op _torch_register_fake_wrapper = torch.library.register_fake else: