Skip to content

Commit

Permalink
add cap to number of channels with fmap_max as in original repo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 27, 2020
1 parent c74e436 commit d3207f4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'stylegan2_pytorch',
packages = find_packages(),
scripts=['bin/stylegan2_pytorch'],
version = '0.16.2',
version = '0.17.0',
license='GPLv3+',
description = 'StyleGan2 in Pytorch',
author = 'Phil Wang',
Expand Down
27 changes: 17 additions & 10 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,22 +395,26 @@ def forward(self, x):
return x

class Generator(nn.Module):
def __init__(self, image_size, latent_dim, network_capacity = 16, transparent = False, attn_layers = [], no_const = False):
def __init__(self, image_size, latent_dim, network_capacity = 16, transparent = False, attn_layers = [], no_const = False, fmap_max = 512):
super().__init__()
self.image_size = image_size
self.latent_dim = latent_dim
self.num_layers = int(log2(image_size) - 1)

init_channels = 4 * network_capacity
filters = [init_channels] + [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
in_out_pairs = zip(filters[0:-1], filters[1:])
filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]

set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
init_channels = filters[0]
filters = [init_channels, *filters]

in_out_pairs = zip(filters[:-1], filters[1:])
self.no_const = no_const

if no_const:
self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False)
else:
self.initial_block = nn.Parameter(torch.randn((init_channels, 4, 4)))
self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4)))

self.blocks = nn.ModuleList([])
self.attns = nn.ModuleList([])
Expand Down Expand Up @@ -457,14 +461,17 @@ def forward(self, styles, input_noise):
return rgb

class Discriminator(nn.Module):
def __init__(self, image_size, network_capacity = 16, fq_layers = [], fq_dict_size = 256, attn_layers = [], transparent = False):
def __init__(self, image_size, network_capacity = 16, fq_layers = [], fq_dict_size = 256, attn_layers = [], transparent = False, fmap_max = 512):
super().__init__()
num_layers = int(log2(image_size) - 1)
num_init_filters = 3 if not transparent else 4

blocks = []
filters = [num_init_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)]
chan_in_out = list(zip(filters[0:-1], filters[1:]))

set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
chan_in_out = list(zip(filters[:-1], filters[1:]))

blocks = []
quantize_blocks = []
Expand Down Expand Up @@ -515,15 +522,15 @@ def forward(self, x):
return x.squeeze(), quantize_loss

class StyleGAN2(nn.Module):
def __init__(self, image_size, latent_dim = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False):
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False):
super().__init__()
self.lr = lr
self.steps = steps
self.ema_updater = EMA(0.995)

self.S = StyleVectorizer(latent_dim, style_depth)
self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const)
self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent)
self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const, fmap_max = fmap_max)
self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent, fmap_max = fmap_max)

self.SE = StyleVectorizer(latent_dim, style_depth)
self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const)
Expand Down

0 comments on commit d3207f4

Please sign in to comment.