From ca0a693ac1757f2ad42877a243e9ea61255ab41f Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Wed, 19 Jan 2022 19:55:04 -0800 Subject: [PATCH] code review, thanks @dianaml --- xformers/components/attention/compositional.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/xformers/components/attention/compositional.py b/xformers/components/attention/compositional.py index e7f871e4f..a68dec942 100644 --- a/xformers/components/attention/compositional.py +++ b/xformers/components/attention/compositional.py @@ -22,17 +22,13 @@ import torch.nn.functional as F from torch import Tensor, nn -from xformers import _is_triton_available from xformers.components.attention import ( Attention, AttentionConfig, AttentionMask, register_attention, ) - -if _is_triton_available: - from xformers.triton.softmax import softmax - +from xformers.components.attention.core import _softmax from xformers.components.in_proj_container import InProjContainer, InProjParams @@ -286,10 +282,7 @@ def forward( if att_mask_additive is not None: attn_weights += att_mask_additive.values - if _is_triton_available: - attn_weights = softmax(attn_weights, causal=self.causal) - else: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = _softmax(attn_weights, causal=self.causal) attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) attn_probs = self.dropout_module(attn_weights)