From 1003624e392fb2884440d95606724f7cbc0683c3 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 30 Jul 2025 20:37:32 +0800 Subject: [PATCH] Improves version comparison using packaging library Replaces string-based version comparison with proper semantic version parsing to ensure accurate comparison of PyTorch versions. Uses packaging.version.parse() instead of direct string comparison to handle edge cases and version formats correctly. --- flash_dmattn/flash_dmattn_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 83d4ee1..95fce09 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, Jingze Shi. from typing import Optional, Sequence, Tuple, Union, Any - +from packaging import version import torch import torch.nn as nn import os @@ -46,7 +46,7 @@ def round_multiple(x, m): # torch.compile() support is only enabled for pytorch >= 2.4 # The reason for this is that we are using the new custom_op and register_fake # APIs, which support inplace modification of inputs in the function itself -if torch.__version__ >= "2.4.0": +if version.parse(torch.__version__) >= version.parse("2.4.0"): _torch_custom_op_wrapper = torch.library.custom_op _torch_register_fake_wrapper = torch.library.register_fake else: