From d041958e5b087c745132429c4572d0e299cc637a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 30 Dec 2022 10:55:59 -0800 Subject: [PATCH] fix tests --- .../memory_efficient_attention.py | 10 +++++----- setup.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/memory_efficient_attention_pytorch/memory_efficient_attention.py b/memory_efficient_attention_pytorch/memory_efficient_attention.py index 2062f86..a138294 100644 --- a/memory_efficient_attention_pytorch/memory_efficient_attention.py +++ b/memory_efficient_attention_pytorch/memory_efficient_attention.py @@ -50,7 +50,7 @@ def attention( # memory efficient attention -def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout=0., training=False): +def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout): q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device weight = einsum('b h i d, b h j d -> b h i j', q, k) @@ -72,8 +72,9 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices weight = weight - weight_max exp_weight = weight.exp() - if training: - exp_weight = F.dropout(exp_weight, p=dropout, training=training) + + exp_weight = F.dropout(exp_weight, p = dropout) + weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v) return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') @@ -137,8 +138,7 @@ def memory_efficient_attention( attn_bias_chunk, causal, (q_start_index, k_start_index), - dropout = dropout, - training = training + dropout if training else 0. ) exp_weights.append(exp_weight_chunk) diff --git a/setup.py b/setup.py index 2eb83cb..2effd4a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'memory-efficient-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'Memory Efficient Attention - Pytorch', long_description_content_type = 'text/markdown',