Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 30, 2022
1 parent d76fba8 commit d041958
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions memory_efficient_attention_pytorch/memory_efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def attention(

# memory efficient attention

def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout=0., training=False):
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device

weight = einsum('b h i d, b h j d -> b h i j', q, k)
Expand All @@ -72,8 +72,9 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
weight = weight - weight_max

exp_weight = weight.exp()
if training:
exp_weight = F.dropout(exp_weight, p=dropout, training=training)

exp_weight = F.dropout(exp_weight, p = dropout)

weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)

return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
Expand Down Expand Up @@ -137,8 +138,7 @@ def memory_efficient_attention(
attn_bias_chunk,
causal,
(q_start_index, k_start_index),
dropout = dropout,
training = training
dropout if training else 0.
)

exp_weights.append(exp_weight_chunk)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'memory-efficient-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'Memory Efficient Attention - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit d041958

Please sign in to comment.