diff --git a/xformers/components/attention/compositional.py b/xformers/components/attention/compositional.py index e7f871e4f..a68dec942 100644 --- a/xformers/components/attention/compositional.py +++ b/xformers/components/attention/compositional.py @@ -22,17 +22,13 @@ import torch.nn.functional as F from torch import Tensor, nn -from xformers import _is_triton_available from xformers.components.attention import ( Attention, AttentionConfig, AttentionMask, register_attention, ) - -if _is_triton_available: - from xformers.triton.softmax import softmax - +from xformers.components.attention.core import _softmax from xformers.components.in_proj_container import InProjContainer, InProjParams @@ -286,10 +282,7 @@ def forward( if att_mask_additive is not None: attn_weights += att_mask_additive.values - if _is_triton_available: - attn_weights = softmax(attn_weights, causal=self.causal) - else: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = _softmax(attn_weights, causal=self.causal) attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) attn_probs = self.dropout_module(attn_weights)