Skip to content

Commit

Permalink
dropouts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 8, 2021
1 parent 13be747 commit 92ebcd3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions long_short_transformer/long_short_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def forward(self, x, **kwargs):
return self.fn(x, **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)

Expand All @@ -73,7 +74,8 @@ def __init__(
window_size = 128,
pos_emb = None,
segment_size = 16,
r = 1
r = 1,
dropout = 0.
):
super().__init__()
assert not (causal and r >= segment_size), 'r should be less than segment size, if autoregressive'
Expand All @@ -94,6 +96,8 @@ def __init__(

self.pos_emb = default(pos_emb, RotaryEmbedding(dim_head))

self.attn_dropout = nn.Dropout(dropout)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
Expand Down Expand Up @@ -228,6 +232,7 @@ def forward(self, x, mask = None):
# attention

attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)

# aggregate values (same as keys, since tied) and project out

Expand All @@ -252,7 +257,9 @@ def __init__(
heads = 8,
ff_mult = 4,
segment_size = None,
r = None
r = None,
ff_dropout = 0.,
attn_dropout = 0.
):
super().__init__()
self.max_seq_len = max_seq_len
Expand All @@ -271,8 +278,8 @@ def __init__(
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, LongShortAttention(dim = dim, heads = heads, dim_head = dim_head, window_size = window_size, causal = causal, pos_emb = pos_emb, segment_size = segment_size, r = r)),
PreNorm(dim, FeedForward(dim = dim, mult = ff_mult))
PreNorm(dim, LongShortAttention(dim = dim, heads = heads, dim_head = dim_head, window_size = window_size, causal = causal, pos_emb = pos_emb, segment_size = segment_size, r = r, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout))
]))

self.to_logits = nn.Sequential(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'long-short-transformer',
packages = find_packages(),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Long Short Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 92ebcd3

Please sign in to comment.