Skip to content

Commit

Permalink
move rotary embeddings to separate package
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 8, 2021
1 parent 92ebcd3 commit c6f1091
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 47 deletions.
13 changes: 7 additions & 6 deletions long_short_transformer/long_short_transformer.py
Expand Up @@ -5,7 +5,7 @@
from torch import nn, einsum
import torch.nn.functional as F

from long_short_transformer.rotary import RotaryEmbedding, apply_rotary_emb
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb

from einops import rearrange, repeat

Expand Down Expand Up @@ -131,15 +131,16 @@ def forward(self, x, mask = None):

seq_range = torch.arange(padded_len, device = device)

# split heads

q, kv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

# rotary embedding

if exists(self.pos_emb):
rotary_emb = self.pos_emb(seq_range, cache_key = padded_len)
qkv = map(lambda t: apply_rotary_emb(rotary_emb, t), (qkv))

# split heads

q, kv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
rotary_emb = rearrange(rotary_emb, 'n d -> () n d')
q, kv = map(lambda t: apply_rotary_emb(rotary_emb, t), (q, kv))

# scale queries

Expand Down
40 changes: 0 additions & 40 deletions long_short_transformer/rotary.py

This file was deleted.

3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'long-short-transformer',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Long Short Transformer - Pytorch',
author = 'Phil Wang',
Expand All @@ -17,6 +17,7 @@
],
install_requires=[
'einops>=0.3',
'rotary-embedding-torch',
'torch>=1.6'
],
classifiers=[
Expand Down

0 comments on commit c6f1091

Please sign in to comment.