Skip to content

Commit

Permalink
fix bug with local attention migration
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 6, 2020
1 parent b8dba1c commit f6ce3bb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion routing_transformer/routing_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def __init__(self, dim, depth, max_seq_len, heads, local_attn_heads, window_siz
num_clusters = max_seq_len // window_size

if self.local_attn_heads > 0:
self.local_attn = LocalAttention(local_attn_window_size, causal = True, dropout = attn_dropout, rel_pos_emb_config = (dim // heads, local_attn_heads), shared_qk = shared_qk)
rel_pos_emb_config = (dim // heads, local_attn_heads) if rel_pos_emb is not None else None
self.local_attn = LocalAttention(local_attn_window_size, causal = True, dropout = attn_dropout, rel_pos_emb_config = rel_pos_emb_config, shared_qk = shared_qk)

if self.global_attn_heads > 0:
self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, head_dim, causal = causal, dropout = attn_dropout, ema_decay = kmeans_ema_decay, commitment = commitment_factor, receives_context = receives_context, num_mem_kv = num_mem_kv, shared_qk = shared_qk)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'routing_transformer',
packages = find_packages(exclude=['examples']),
version = '0.8.6',
version = '0.8.7',
license='MIT',
description = 'Routing Transformer (Pytorch)',
author = 'Phil Wang, Aran Komatsuzaki',
Expand Down

0 comments on commit f6ce3bb

Please sign in to comment.