Skip to content

Commit

Permalink
do partial rotary dimensions, clamped at dimension of 32 at minimum
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 19, 2021
1 parent 8723359 commit 4b395ab
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.11.2',
version = '0.11.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 8 additions & 2 deletions x_transformers/x_transformers.py
Expand Up @@ -358,7 +358,11 @@ def forward(
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

if exists(rotary_pos_emb):
q, k = apply_rotary_pos_emb(q, k, rotary_pos_emb)
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k))
ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb)
q = torch.cat((ql, qr), dim = -1)
k = torch.cat((kl, kr), dim = -1)

input_mask = None
if any(map(exists, (mask, context_mask))):
Expand Down Expand Up @@ -444,6 +448,7 @@ def __init__(
rel_pos_max_distance = 128,
position_infused_attn = False,
rotary_pos_emb = False,
rotary_emb_dim = None,
custom_layers = None,
sandwich_coef = None,
par_ratio = None,
Expand All @@ -467,7 +472,8 @@ def __init__(
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None

self.rotary_pos_emb = RotaryEmbedding(dim_head) if rotary_pos_emb else always(None)
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else always(None)

assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = RelativePositionBias(causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) if rel_pos_bias else None
Expand Down

0 comments on commit 4b395ab

Please sign in to comment.