Skip to content

Commit

Permalink
option for relu squared activation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 22, 2022
1 parent 4a392d8 commit d5b0335
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions rela_transformer/rela_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ def __init__(
causal = True,
dim_head = 64,
heads = 8,
num_memory_kv = 0
num_memory_kv = 0,
relu_squared = False
):
super().__init__()
self.heads = heads
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.causal = causal
self.relu_squared = relu_squared
self.norm = GatedRMSNorm(dim)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
Expand Down Expand Up @@ -80,6 +82,9 @@ def forward(self, x, mask = None):

attn = F.relu(sim)

if self.relu_squared:
attn = attn ** 2

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
attn = attn.masked_fill(~mask, 0.)
Expand Down Expand Up @@ -107,6 +112,7 @@ def __init__(
num_memory_kv = 0,
no_ff = False,
ff_mult = 4,
relu_squared = False
):
super().__init__()
self.max_seq_len = max_seq_len
Expand All @@ -116,7 +122,7 @@ def __init__(
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ReLA(dim = dim, heads = heads, dim_head = dim_head, num_memory_kv = num_memory_kv, causal = causal),
ReLA(dim = dim, relu_squared = relu_squared, heads = heads, dim_head = dim_head, num_memory_kv = num_memory_kv, causal = causal),
FeedForward(dim = dim, mult = ff_mult) if not no_ff else None
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'rela-transformer',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'ReLA Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit d5b0335

Please sign in to comment.