Skip to content

Commit

Permalink
use glu module
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2020
1 parent 8780449 commit 97d99d0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
33 changes: 17 additions & 16 deletions aoa_pytorch/aoa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def __init__(
*,
dim,
dim_head = 64,
heads = 8
heads = 8,
dropout = 0.,
aoa_dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
Expand All @@ -26,37 +28,36 @@ def __init__(
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

self.q_out = nn.Linear(dim_head, dim_head, bias = False)
self.attn_out = nn.Linear(dim_head, dim_head, bias = False)
self.out_bias = nn.Parameter(torch.zeros(1, 1, dim_head))
self.dropout = nn.Dropout(dropout)

self.q_gate = nn.Linear(dim_head, dim_head, bias = False)
self.attn_gate = nn.Linear(dim_head, dim_head, bias = False)
self.gate_bias = nn.Parameter(torch.zeros(1, 1, dim_head))
self.aoa = nn.Sequential(
nn.Linear(2 * inner_dim, 2 * dim),
nn.GLU(),
nn.Dropout(aoa_dropout)
)

def forward(self, x, context = None):
h = self.heads

q = self.to_q(x)
q_ = self.to_q(x)

context = default(context, x)
kv = self.to_kv(context).chunk(2, dim = -1)

# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q_, *kv))
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

# attention
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)

# weighted average of values
attn_out = einsum('b h i j, b h j d -> b h i d', attn, v)

# as described in equations (2) and (3) in paper
I = self.q_out(q) + self.attn_out(attn_out) + self.out_bias
G = self.q_gate(q) + self.attn_gate(attn_out) + self.gate_bias
# concat heads
out = rearrange(attn_out, 'b h n d -> b n (h d)', h = h)

# attention on attention
out = I * G.sigmoid()

# merge heads
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
out = self.aoa(torch.cat((out, q_), dim = -1))
return out
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'aoa_pytorch',
packages = find_packages(exclude=['examples']),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Attention on Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 97d99d0

Please sign in to comment.