Skip to content

Commit

Permalink
add sub-groupnorm in multihead ema, for use in audio modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 26, 2023
1 parent ee349ab commit dc765fd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
25 changes: 22 additions & 3 deletions mega_pytorch/mega_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torch.fft import rfft, irfft

from einops import rearrange
from einops.layers.torch import Rearrange

from scipy.fftpack import next_fast_len

# functions
Expand Down Expand Up @@ -185,7 +187,7 @@ def __init__(
dim,
heads,
bidirectional = False,
dim_head = None
norm_mhesa_heads = False
):
super().__init__()
self.bidirectional = bidirectional
Expand All @@ -202,6 +204,19 @@ def __init__(
self.reverse_alphas = nn.Parameter(torch.randn(heads))
self.reverse_dampen_factors = nn.Parameter(torch.randn(heads))

self.heads = heads

self.norm_heads = nn.Identity()

if norm_mhesa_heads:
# https://arxiv.org/abs/2210.06423 - retnet used sub-ln with some success as groupnorm

self.norm_heads = nn.Sequential(
Rearrange('b n h d -> b (h d) n'),
nn.GroupNorm(heads, dim * heads),
Rearrange('b (h d) n -> b n h d', h = heads)
)

def forward(self, x):
device, seq_len = x.device, x.shape[1]

Expand Down Expand Up @@ -233,6 +248,10 @@ def apply_learned_ema_with_damping(x, alphas, dampen_factors):
x_reversed = torch.flip(x_reversed, dims = (1,))
x = torch.cat((x, x_reversed), dim = -2)

# maybe norm heads

x = self.norm_heads(x)

# combine heads and out

return einsum('... h d, h d -> ... d', x, self.reduction)
Expand All @@ -250,7 +269,7 @@ def __init__(
attn_dim_value = 256,
laplacian_attn_fn = False,
causal = True,
ema_dim_head = None
norm_mhesa_heads = False
):
super().__init__()

Expand All @@ -266,7 +285,7 @@ def __init__(
dim = dim,
heads = ema_heads,
bidirectional = not causal,
dim_head = ema_dim_head
norm_mhesa_heads = norm_mhesa_heads
)

self.to_reset_gate = nn.Sequential(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'Mega-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.15',
version = '0.1.0',
license='MIT',
description = 'Mega - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit dc765fd

Please sign in to comment.