Skip to content

Commit

Permalink
add a feedforward after the self-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 29, 2020
1 parent e40eb6f commit 8a24fb6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'stylegan2_pytorch',
packages = find_packages(),
scripts=['bin/stylegan2_pytorch'],
version = '0.17.2',
version = '0.17.3',
license='GPLv3+',
description = 'StyleGan2 in Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 10 additions & 4 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def forward(self, x):
out = out.permute(0, 3, 1, 2)
return out, loss

# one layer of self-attention and feedforward, for images

attn_and_ff = lambda chan: nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(chan))),
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
])

# helpers

def default(value, d):
Expand Down Expand Up @@ -424,9 +431,7 @@ def __init__(self, image_size, latent_dim, network_capacity = 16, transparent =
not_last = ind != (self.num_layers - 1)
num_layer = self.num_layers - ind

attn_fn = nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(in_chan))) for _ in range(2)
]) if num_layer in attn_layers else None
attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None

self.attns.append(attn_fn)

Expand Down Expand Up @@ -485,7 +490,8 @@ def __init__(self, image_size, network_capacity = 16, fq_layers = [], fq_dict_si
blocks.append(block)

attn_fn = nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(out_chan))) for _ in range(2)
Residual(Rezero(ImageLinearAttention(out_chan))),
Residual(Rezero(nn.Sequential(nn.Conv2d(out_chan, out_chan, 1), leaky_relu(), nn.Conv2d(out_chan, out_chan, 1))))
]) if num_layer in attn_layers else None

attn_blocks.append(attn_fn)
Expand Down

0 comments on commit 8a24fb6

Please sign in to comment.