Skip to content

Commit

Permalink
add dataset augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 2, 2020
1 parent 5d00394 commit 8ef4b20
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
6 changes: 4 additions & 2 deletions bin/stylegan2_pytorch
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def train_from_folder(
fq_dict_size = 256,
attn_layers = [],
no_const = False,
aug_prob = 0.
aug_prob = 0.,
dataset_aug_prob = 0.,
):
model = Trainer(
name,
Expand All @@ -53,7 +54,8 @@ def train_from_folder(
fq_dict_size = fq_dict_size,
attn_layers = attn_layers,
no_const = no_const,
aug_prob = aug_prob
aug_prob = aug_prob,
dataset_aug_prob = dataset_aug_prob
)

if not new:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'stylegan2_pytorch',
packages = find_packages(),
scripts=['bin/stylegan2_pytorch'],
version = '0.17.4',
version = '0.17.5',
license='GPLv3+',
description = 'StyleGan2 in Pytorch',
author = 'Phil Wang',
Expand Down
21 changes: 16 additions & 5 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ class Flatten(nn.Module):
def forward(self, x):
return x.reshape(x.shape[0], -1)

class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else = lambda x: x):
super().__init__()
self.fn = fn
self.fn_else = fn_else
self.prob = prob
def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
Expand Down Expand Up @@ -208,7 +218,7 @@ def resize_to_minimum_size(min_size, image):
return image

class Dataset(data.Dataset):
def __init__(self, folder, image_size, transparent = False):
def __init__(self, folder, image_size, transparent = False, aug_prob = 0.):
super().__init__()
self.folder = folder
self.image_size = image_size
Expand All @@ -221,7 +231,7 @@ def __init__(self, folder, image_size, transparent = False):
transforms.Lambda(convert_image_fn),
transforms.Lambda(partial(resize_to_minimum_size, image_size)),
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)),
transforms.ToTensor(),
transforms.Lambda(expand_greyscale(num_channels))
])
Expand All @@ -246,7 +256,7 @@ def random_crop_and_resize(tensor, scale):
h_delta = int(random() * delta)
w_delta = int(random() * delta)
cropped = tensor[:, :, h_delta:(h_delta + new_width), w_delta:(w_delta + new_width)].clone()
return F.interpolate(cropped, size=(h, h))
return F.interpolate(cropped, size=(h, h), mode='bilinear')

def random_hflip(tensor, prob):
if prob > random():
Expand Down Expand Up @@ -588,7 +598,7 @@ def forward(self, x):
return x

class Trainer():
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., *args, **kwargs):
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., dataset_aug_prob = 0., *args, **kwargs):
self.GAN_params = [args, kwargs]
self.GAN = None

Expand Down Expand Up @@ -638,6 +648,7 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,
self.init_folders()

self.loader = None
self.dataset_aug_prob = dataset_aug_prob

def init_GAN(self):
args, kwargs = self.GAN_params
Expand All @@ -662,7 +673,7 @@ def config(self):
return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const}

def set_data_src(self, folder):
self.dataset = Dataset(folder, self.image_size, transparent = self.transparent)
self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
self.loader = cycle(data.DataLoader(self.dataset, num_workers = default(self.num_workers, num_cores), batch_size = self.batch_size, drop_last = True, shuffle=True, pin_memory=True))

def train(self):
Expand Down

0 comments on commit 8ef4b20

Please sign in to comment.