From c6f1091f0640491d58961e9f8e239de77240aa08 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 8 Jul 2021 09:53:39 -0700 Subject: [PATCH] move rotary embeddings to separate package --- .../long_short_transformer.py | 13 +++--- long_short_transformer/rotary.py | 40 ------------------- setup.py | 3 +- 3 files changed, 9 insertions(+), 47 deletions(-) delete mode 100644 long_short_transformer/rotary.py diff --git a/long_short_transformer/long_short_transformer.py b/long_short_transformer/long_short_transformer.py index 3bbc9be..85be9d6 100644 --- a/long_short_transformer/long_short_transformer.py +++ b/long_short_transformer/long_short_transformer.py @@ -5,7 +5,7 @@ from torch import nn, einsum import torch.nn.functional as F -from long_short_transformer.rotary import RotaryEmbedding, apply_rotary_emb +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb from einops import rearrange, repeat @@ -131,15 +131,16 @@ def forward(self, x, mask = None): seq_range = torch.arange(padded_len, device = device) + # split heads + + q, kv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) + # rotary embedding if exists(self.pos_emb): rotary_emb = self.pos_emb(seq_range, cache_key = padded_len) - qkv = map(lambda t: apply_rotary_emb(rotary_emb, t), (qkv)) - - # split heads - - q, kv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) + rotary_emb = rearrange(rotary_emb, 'n d -> () n d') + q, kv = map(lambda t: apply_rotary_emb(rotary_emb, t), (q, kv)) # scale queries diff --git a/long_short_transformer/rotary.py b/long_short_transformer/rotary.py deleted file mode 100644 index 51e4e5a..0000000 --- a/long_short_transformer/rotary.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from torch import nn, einsum - -from einops import rearrange - -def exists(val): - return val is not None - -def rotate_half(x): - x = rearrange(x, 'b n (r d) -> b n r d', r = 2) - x1, x2 = x.unbind(dim = -2) - return torch.cat((-x2, x1), dim = -1) - -def apply_rotary_emb(freqs, t): - cos, sin = freqs - rot_dim = cos.shape[-1] - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - t = (t * cos) + (rotate_half(t) * sin) - return torch.cat((t, t_pass), dim = -1) - -class RotaryEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.cache = dict() - self.register_buffer('inv_freq', inv_freq) - - def forward(self, t, cache_key = None): - if exists(cache_key) and cache_key in self.cache: - return self.cache[cache_key] - - t = t.type(self.inv_freq.dtype) - freqs = torch.einsum('i, j -> i j', t, self.inv_freq) - freqs = torch.cat((freqs, freqs), dim = -1) - emb = (freqs.cos(), freqs.sin()) - - if exists(cache_key): - self.cache[cache_key] = emb - - return emb diff --git a/setup.py b/setup.py index d92da07..d4681c8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'long-short-transformer', packages = find_packages(), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'Long Short Transformer - Pytorch', author = 'Phil Wang', @@ -17,6 +17,7 @@ ], install_requires=[ 'einops>=0.3', + 'rotary-embedding-torch', 'torch>=1.6' ], classifiers=[