diff --git a/bin/stylegan2_pytorch b/bin/stylegan2_pytorch index 3b3016c..b79ee79 100644 --- a/bin/stylegan2_pytorch +++ b/bin/stylegan2_pytorch @@ -32,7 +32,8 @@ def train_from_folder( fq_dict_size = 256, attn_layers = [], no_const = False, - aug_prob = 0. + aug_prob = 0., + dataset_aug_prob = 0., ): model = Trainer( name, @@ -53,7 +54,8 @@ def train_from_folder( fq_dict_size = fq_dict_size, attn_layers = attn_layers, no_const = no_const, - aug_prob = aug_prob + aug_prob = aug_prob, + dataset_aug_prob = dataset_aug_prob ) if not new: diff --git a/setup.py b/setup.py index 7d594ba..29ca8e5 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'stylegan2_pytorch', packages = find_packages(), scripts=['bin/stylegan2_pytorch'], - version = '0.17.4', + version = '0.17.5', license='GPLv3+', description = 'StyleGan2 in Pytorch', author = 'Phil Wang', diff --git a/stylegan2_pytorch/stylegan2_pytorch.py b/stylegan2_pytorch/stylegan2_pytorch.py index 1fc7b7c..3096f70 100644 --- a/stylegan2_pytorch/stylegan2_pytorch.py +++ b/stylegan2_pytorch/stylegan2_pytorch.py @@ -62,6 +62,16 @@ class Flatten(nn.Module): def forward(self, x): return x.reshape(x.shape[0], -1) +class RandomApply(nn.Module): + def __init__(self, prob, fn, fn_else = lambda x: x): + super().__init__() + self.fn = fn + self.fn_else = fn_else + self.prob = prob + def forward(self, x): + fn = self.fn if random() < self.prob else self.fn_else + return fn(x) + class Residual(nn.Module): def __init__(self, fn): super().__init__() @@ -208,7 +218,7 @@ def resize_to_minimum_size(min_size, image): return image class Dataset(data.Dataset): - def __init__(self, folder, image_size, transparent = False): + def __init__(self, folder, image_size, transparent = False, aug_prob = 0.): super().__init__() self.folder = folder self.image_size = image_size @@ -221,7 +231,7 @@ def __init__(self, folder, image_size, transparent = False): transforms.Lambda(convert_image_fn), transforms.Lambda(partial(resize_to_minimum_size, image_size)), transforms.Resize(image_size), - transforms.CenterCrop(image_size), + RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)), transforms.ToTensor(), transforms.Lambda(expand_greyscale(num_channels)) ]) @@ -246,7 +256,7 @@ def random_crop_and_resize(tensor, scale): h_delta = int(random() * delta) w_delta = int(random() * delta) cropped = tensor[:, :, h_delta:(h_delta + new_width), w_delta:(w_delta + new_width)].clone() - return F.interpolate(cropped, size=(h, h)) + return F.interpolate(cropped, size=(h, h), mode='bilinear') def random_hflip(tensor, prob): if prob > random(): @@ -588,7 +598,7 @@ def forward(self, x): return x class Trainer(): - def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., *args, **kwargs): + def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., dataset_aug_prob = 0., *args, **kwargs): self.GAN_params = [args, kwargs] self.GAN = None @@ -638,6 +648,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity, self.init_folders() self.loader = None + self.dataset_aug_prob = dataset_aug_prob def init_GAN(self): args, kwargs = self.GAN_params @@ -662,7 +673,7 @@ def config(self): return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const} def set_data_src(self, folder): - self.dataset = Dataset(folder, self.image_size, transparent = self.transparent) + self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob) self.loader = cycle(data.DataLoader(self.dataset, num_workers = default(self.num_workers, num_cores), batch_size = self.batch_size, drop_last = True, shuffle=True, pin_memory=True)) def train(self):