Skip to content

Commit

Permalink
code review, thanks @DianaML
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 20, 2022
1 parent 582722e commit ca0a693
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions xformers/components/attention/compositional.py
Expand Up @@ -22,17 +22,13 @@
import torch.nn.functional as F
from torch import Tensor, nn

from xformers import _is_triton_available
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)

if _is_triton_available:
from xformers.triton.softmax import softmax

from xformers.components.attention.core import _softmax
from xformers.components.in_proj_container import InProjContainer, InProjParams


Expand Down Expand Up @@ -286,10 +282,7 @@ def forward(
if att_mask_additive is not None:
attn_weights += att_mask_additive.values

if _is_triton_available:
attn_weights = softmax(attn_weights, causal=self.causal)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = _softmax(attn_weights, causal=self.causal)

attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk)
attn_probs = self.dropout_module(attn_weights)
Expand Down

0 comments on commit ca0a693

Please sign in to comment.