Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash_dmattn/flash_dmattn_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2025, Jingze Shi.

from typing import Optional, Sequence, Tuple, Union, Any

from packaging import version
import torch
import torch.nn as nn
import os
Expand Down Expand Up @@ -46,7 +46,7 @@ def round_multiple(x, m):
# torch.compile() support is only enabled for pytorch >= 2.4
# The reason for this is that we are using the new custom_op and register_fake
# APIs, which support inplace modification of inputs in the function itself
if torch.__version__ >= "2.4.0":
if version.parse(torch.__version__) >= version.parse("2.4.0"):
_torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake
else:
Expand Down