Skip to content

Commit

Permalink
allow for setting in max feature maps, after learning @aydao raised h…
Browse files Browse the repository at this point in the history
…is to 1024 for superb results
  • Loading branch information
lucidrains committed Nov 15, 2020
1 parent 20c2d68 commit e225255
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
2 changes: 2 additions & 0 deletions stylegan2_pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def train_from_folder(
load_from = -1,
image_size = 128,
network_capacity = 16,
fmap_max = 512,
transparent = False,
batch_size = 5,
gradient_accumulate_every = 6,
Expand Down Expand Up @@ -117,6 +118,7 @@ def train_from_folder(
gradient_accumulate_every = gradient_accumulate_every,
image_size = image_size,
network_capacity = network_capacity,
fmap_max = fmap_max,
transparent = transparent,
lr = learning_rate,
lr_mlp = lr_mlp,
Expand Down
5 changes: 4 additions & 1 deletion stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ def __init__(
base_dir = './',
image_size = 128,
network_capacity = 16,
fmap_max = 512,
transparent = False,
batch_size = 4,
mixed_prob = 0.9,
Expand Down Expand Up @@ -742,6 +743,7 @@ def __init__(
assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
self.image_size = image_size
self.network_capacity = network_capacity
self.fmap_max = fmap_max
self.transparent = transparent

self.fq_layers = cast_list(fq_layers)
Expand Down Expand Up @@ -819,7 +821,7 @@ def hparams(self):

def init_GAN(self):
args, kwargs = self.GAN_params
self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)
self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)

if self.is_ddp:
ddp_kwargs = {'device_ids': [self.rank]}
Expand All @@ -841,6 +843,7 @@ def load_config(self):
self.transparent = config['transparent']
self.fq_layers = config['fq_layers']
self.fq_dict_size = config['fq_dict_size']
self.fmap_max = config.pop('fmap_max', 512)
self.attn_layers = config.pop('attn_layers', [])
self.no_const = config.pop('no_const', False)
del self.GAN
Expand Down
2 changes: 1 addition & 1 deletion stylegan2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.0'
__version__ = '1.5.1'

0 comments on commit e225255

Please sign in to comment.