Skip to content

Commit

Permalink
change to be exactly like paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 6, 2022
1 parent d5b0335 commit b58b121
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions rela_transformer/rela_transformer.py
Expand Up @@ -19,13 +19,13 @@ def __init__(
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.to_gate = nn.Linear(dim, dim, bias = False)
self.w = nn.Parameter(torch.ones(dim))
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
normed_x = x / norm.clamp(min = self.eps) * self.g
return normed_x * self.to_gate(x).sigmoid()
return normed_x * (x * self.w).sigmoid()

def FeedForward(dim, mult = 4):
return nn.Sequential(
Expand Down Expand Up @@ -59,9 +59,9 @@ def __init__(
self.mem_k = nn.Parameter(torch.randn(num_memory_kv, inner_dim))
self.mem_v = nn.Parameter(torch.randn(num_memory_kv, inner_dim))

self.norm_values = GatedRMSNorm(dim_head)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
GatedRMSNorm(dim)
)

def forward(self, x, mask = None):
Expand Down Expand Up @@ -95,6 +95,8 @@ def forward(self, x, mask = None):
attn = attn.masked_fill(causal_mask, 0.)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = self.norm_values(out)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'rela-transformer',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'ReLA Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit b58b121

Please sign in to comment.