From 084041fd3a4270869a6cad689a29023ab07bea01 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Sat, 14 Mar 2020 23:25:49 +0800 Subject: [PATCH] Added code --- .gitignore | 1 + README.md | 15 ++- celeba.sh | 37 ++++++ celeba128.sh | 40 +++++++ cifar10.sh | 36 ++++++ data/.gitkeep | 0 download.py | 180 +++++++++++++++++++++++++++++ mnist.sh | 36 ++++++ src/cfgan.py | 180 +++++++++++++++++++++++++++++ src/datasets.py | 150 ++++++++++++++++++++++++ src/ecfd.py | 185 +++++++++++++++++++++++++++++ src/gan.py | 272 +++++++++++++++++++++++++++++++++++++++++++ src/gen_samples.py | 58 ++++++++++ src/main.py | 82 +++++++++++++ src/networks.py | 282 +++++++++++++++++++++++++++++++++++++++++++++ src/resnet.py | 83 +++++++++++++ src/util.py | 49 ++++++++ stl10.sh | 36 ++++++ 18 files changed, 1718 insertions(+), 4 deletions(-) create mode 100644 celeba.sh create mode 100644 celeba128.sh create mode 100644 cifar10.sh create mode 100644 data/.gitkeep create mode 100644 download.py create mode 100644 mnist.sh create mode 100644 src/cfgan.py create mode 100644 src/datasets.py create mode 100644 src/ecfd.py create mode 100644 src/gan.py create mode 100644 src/gen_samples.py create mode 100644 src/main.py create mode 100644 src/networks.py create mode 100644 src/resnet.py create mode 100644 src/util.py create mode 100644 stl10.sh diff --git a/.gitignore b/.gitignore index e43b0f9..6a4dae1 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .DS_Store +__pycache__ diff --git a/README.md b/README.md index 4c7e12c..6c2465c 100644 --- a/README.md +++ b/README.md @@ -30,10 +30,10 @@ and the ECFD between two distributions is given by ### Generating samples from pre-trained models -* Download the pre-trained models from releases. +* Download the pre-trained generators from releases. * Run the following command to generate an 8x8 grid of samples from a model trained on CIFAR10 dataset: ```bash -python gen_samples.py\ +python src/gen_samples.py\ --png\ --imsize 32\ --noise_dim 32\ @@ -46,7 +46,7 @@ python gen_samples.py\ * **Downloading Datasets**: All the datasets will download by themselves when the code is run, except CelebA. CelebA can be downloaded by executing `python download.py celebA`. Rename the directory `./data/img_align_celeba` to `./data/celebA` after the script finishes execution. * Run `python src/main.py --help` to see a description of all the available command-line arguments. -* Run the following command to train OCFGAN-GP on the CIFAR10 dataset: +* **Example**: run the following command to train on the CIFAR10 dataset: ```bash python src/main.py\ --dataset cifar10\ @@ -80,4 +80,11 @@ For any questions regarding the code or the paper, please email me at [abdulfati booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, year={2020} } -``` \ No newline at end of file +``` + +#### Acknowledgements +Parts of the code/network structures in this repository have been adapted from the following repos: + +* [ozanciga/gans-with-pytorch](https://github.com/ozanciga/gans-with-pytorch) +* [OctoberChang/MMD-GAN](https://github.com/OctoberChang/MMD-GAN) +* [mbinkowski/MMD-GAN](https://github.com/mbinkowski/MMD-GAN) \ No newline at end of file diff --git a/celeba.sh b/celeba.sh new file mode 100644 index 0000000..ccefe52 --- /dev/null +++ b/celeba.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +BS=64 +GPU_ID=0 +MAX_GITER=125000 +DATA_PATH=./data +DATASET=celeba +DATAROOT=${DATA_PATH}/celebA +ISIZE=32 +NC=3 +NOISE_DIM=64 +MODEL=cfgangp +DOUT_DIM=32 +NUM_FREQS=8 +WEIGHT=gaussian_ecfd +SIGMA=0. + +cmd="python src/main.py\ + --dataset ${DATASET}\ + --dataroot ${DATAROOT}\ + --model ${MODEL}\ + --batch_size ${BS}\ + --image_size ${ISIZE}\ + --nc ${NC}\ + --noise_dim ${NOISE_DIM}\ + --dout_dim ${DOUT_DIM}\ + --max_giter ${MAX_GITER}\ + --resultsroot ./out + --gpu_device ${GPU_ID}" + +if [ ${MODEL} == 'cfgangp' ]; then + cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}" +fi + +echo $cmd +eval $cmd + diff --git a/celeba128.sh b/celeba128.sh new file mode 100644 index 0000000..b7037b6 --- /dev/null +++ b/celeba128.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +BS=64 +GPU_ID=2 +MAX_GITER=125000 +DATA_PATH=./data +DATASET=celeba128 +DATAROOT=${DATA_PATH}/celebA +ISIZE=128 +NC=3 +NOISE_DIM=100 + +MODEL=cfgangp +DOUT_DIM=1 +NUM_FREQS=8 +WEIGHT=gaussian_ecfd +SIGMA=0. + +cmd="python src/main.py\ + --dataset ${DATASET}\ + --dataroot ${DATAROOT}\ + --model ${MODEL}\ + --gen resnet + --disc dcgan5 + --batch_size ${BS}\ + --image_size ${ISIZE}\ + --nc ${NC}\ + --noise_dim ${NOISE_DIM}\ + --dout_dim ${DOUT_DIM}\ + --max_giter ${MAX_GITER}\ + --resultsroot ./out + --gpu_device ${GPU_ID}" + +if [ ${MODEL} == 'cfgangp' ]; then + cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}" +fi + +echo $cmd +eval $cmd + diff --git a/cifar10.sh b/cifar10.sh new file mode 100644 index 0000000..0f6d587 --- /dev/null +++ b/cifar10.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +BS=64 +GPU_ID=0 +MAX_GITER=125000 +DATA_PATH=./data +DATASET=cifar10 +DATAROOT=${DATA_PATH}/cifar10 +ISIZE=32 +NC=3 +NOISE_DIM=32 +MODEL=cfgangp +DOUT_DIM=${NOISE_DIM} +NUM_FREQS=8 +WEIGHT=gaussian_ecfd +SIGMA=0. + +cmd="python src/main.py\ + --dataset ${DATASET}\ + --dataroot ${DATAROOT}\ + --model ${MODEL}\ + --batch_size ${BS}\ + --image_size ${ISIZE}\ + --nc ${NC}\ + --noise_dim ${NOISE_DIM}\ + --dout_dim ${DOUT_DIM}\ + --max_giter ${MAX_GITER}\ + --resultsroot ./out + --gpu_device ${GPU_ID}" + +if [ ${MODEL} == 'cfgangp' ]; then + cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}" +fi + +echo $cmd +eval $cmd diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/download.py b/download.py new file mode 100644 index 0000000..cdd416e --- /dev/null +++ b/download.py @@ -0,0 +1,180 @@ +""" +Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py + +Downloads the following: +- Celeb-A dataset +- LSUN dataset +- MNIST dataset +""" + +from __future__ import print_function +import os +import sys +import gzip +import json +import shutil +import zipfile +import argparse +import requests +import subprocess +from tqdm import tqdm +from six.moves import urllib + +parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') +parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], + help='name of dataset to download [celebA, lsun, mnist]') + +def download(url, dirpath): + filename = url.split('/')[-1] + filepath = os.path.join(dirpath, filename) + u = urllib.request.urlopen(url) + f = open(filepath, 'wb') + filesize = int(u.headers["Content-Length"]) + print("Downloading: %s Bytes: %s" % (filename, filesize)) + + downloaded = 0 + block_sz = 8192 + status_width = 70 + while True: + buf = u.read(block_sz) + if not buf: + print('') + break + else: + print('', end='\r') + downloaded += len(buf) + f.write(buf) + status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % + ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) + print(status, end='') + sys.stdout.flush() + f.close() + return filepath + +def download_file_from_google_drive(id, destination): + URL = "https://docs.google.com/uc?export=download" + session = requests.Session() + + response = session.get(URL, params={ 'id': id }, stream=True) + token = get_confirm_token(response) + + if token: + params = { 'id' : id, 'confirm' : token } + response = session.get(URL, params=params, stream=True) + + save_response_content(response, destination) + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + +def save_response_content(response, destination, chunk_size=32*1024): + total_size = int(response.headers.get('content-length', 0)) + with open(destination, "wb") as f: + for chunk in tqdm(response.iter_content(chunk_size), total=total_size, + unit='B', unit_scale=True, desc=destination): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + +def unzip(filepath): + print("Extracting: " + filepath) + dirpath = os.path.dirname(filepath) + with zipfile.ZipFile(filepath) as zf: + zf.extractall(dirpath) + os.remove(filepath) + +def download_celeb_a(dirpath): + data_dir = 'celebA' + if os.path.exists(os.path.join(dirpath, data_dir)): + print('Found Celeb-A - skip') + return + + filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" + save_path = os.path.join(dirpath, filename) + + if os.path.exists(save_path): + print('[*] {} already exists'.format(save_path)) + else: + download_file_from_google_drive(drive_id, save_path) + + zip_dir = '' + with zipfile.ZipFile(save_path) as zf: + zip_dir = zf.namelist()[0] + zf.extractall(dirpath) + os.remove(save_path) + os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) + +def _list_categories(tag): + url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag + f = urllib.request.urlopen(url) + return json.loads(f.read()) + +def _download_lsun(out_dir, category, set_name, tag): + url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ + '&category={category}&set={set_name}'.format(**locals()) + print(url) + if set_name == 'test': + out_name = 'test_lmdb.zip' + else: + out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) + out_path = os.path.join(out_dir, out_name) + cmd = ['curl', url, '-o', out_path] + print('Downloading', category, set_name, 'set') + subprocess.call(cmd) + +def download_lsun(dirpath): + data_dir = os.path.join(dirpath, 'lsun') + if os.path.exists(data_dir): + print('Found LSUN - skip') + return + else: + os.mkdir(data_dir) + + tag = 'latest' + #categories = _list_categories(tag) + categories = ['bedroom'] + + for category in categories: + _download_lsun(data_dir, category, 'train', tag) + _download_lsun(data_dir, category, 'val', tag) + _download_lsun(data_dir, '', 'test', tag) + +def download_mnist(dirpath): + data_dir = os.path.join(dirpath, 'mnist') + if os.path.exists(data_dir): + print('Found MNIST - skip') + return + else: + os.mkdir(data_dir) + url_base = 'http://yann.lecun.com/exdb/mnist/' + file_names = ['train-images-idx3-ubyte.gz', + 'train-labels-idx1-ubyte.gz', + 't10k-images-idx3-ubyte.gz', + 't10k-labels-idx1-ubyte.gz'] + for file_name in file_names: + url = (url_base+file_name).format(**locals()) + print(url) + out_path = os.path.join(data_dir,file_name) + cmd = ['curl', url, '-o', out_path] + print('Downloading ', file_name) + subprocess.call(cmd) + cmd = ['gzip', '-d', out_path] + print('Decompressing ', file_name) + subprocess.call(cmd) + +def prepare_data_dir(path = './data'): + if not os.path.exists(path): + os.mkdir(path) + +if __name__ == '__main__': + args = parser.parse_args() + prepare_data_dir() + + if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']): + download_celeb_a('./data') + if 'lsun' in args.datasets: + download_lsun('./data') + if 'mnist' in args.datasets: + download_mnist('./data') diff --git a/mnist.sh b/mnist.sh new file mode 100644 index 0000000..394abe5 --- /dev/null +++ b/mnist.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +BS=64 +GPU_ID=0 +MAX_GITER=50000 +DATA_PATH=./data +DATASET=mnist +DATAROOT=${DATA_PATH}/mnist +ISIZE=32 +NC=1 +NOISE_DIM=10 +MODEL=cfgangp +DOUT_DIM=${NOISE_DIM} +NUM_FREQS=8 +WEIGHT=gaussian_ecfd +SIGMA=0. + +cmd="python src/main.py\ + --dataset ${DATASET}\ + --dataroot ${DATAROOT}\ + --model ${MODEL}\ + --batch_size ${BS}\ + --image_size ${ISIZE}\ + --nc ${NC}\ + --noise_dim ${NOISE_DIM}\ + --dout_dim ${DOUT_DIM}\ + --max_giter ${MAX_GITER}\ + --resultsroot ./out + --gpu_device ${GPU_ID}" + +if [ ${MODEL} == 'cfgangp' ]; then + cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}" +fi + +echo $cmd +eval $cmd diff --git a/src/cfgan.py b/src/cfgan.py new file mode 100644 index 0000000..352a466 --- /dev/null +++ b/src/cfgan.py @@ -0,0 +1,180 @@ +# Provides the CFGAN-GP model. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import torch +import ecfd +import os +from gan import GAN + +class CFGANGP(GAN): + def __init__(self, + dset_name, + imsize, + nc, + data_root='./data', + results_root='./results', + noise_dim=100, + dout_dim=1, + batch_size=64, + max_giters=50000, + lr=1e-4, + clip_disc=False, + disc_size=64, + gp_lambda=10., + ecfd_type='gaussian_ecfd', + sigmas=[1.0], + num_freqs=8, + optimize_sigma=False, + disc_net='flexible-dcgan', + gen_net='flexible-dcgan'): + """Intializer for a CFGANGP model. + + Arguments: + dset_name {str} -- Name of the dataset. + imsize {int} -- Size of the image. + nc {int} -- Number of channels. + + Keyword Arguments: + data_root {str} -- Directory where datasets are stored (default: {'./data'}). + results_root {str} -- Directory where results will be saved (default: {'./results'}). + noise_dim {int} -- Dimension of noise input to generator (default: {100}). + dout_dim {int} -- Dimension of output from discriminator (default: {1}). + batch_size {int} -- Batch size (default: {64}). + max_giters {int} -- Maximum number of generator iterations (default: {50000}). + lr {[type]} -- Learning rate (default: {1e-4}). + clip_disc {bool} -- Whether to clip the parameters of discriminator in [-0.01, 0.01]. + This should be True when gradient penalty is not used (default: {True}). + disc_size {int} -- Number of filters in the first Conv layer of critic. (default: {64}). + gp_lambda {float} -- Trade-off for gradient penalty (default: {10.0}). + ecfd_type {str} -- Weighting distribution for ECFD (default: {'gaussian_ecfd'}). + sigmas {list} -- A list of sigmas (default: {[1.0]}). + num_freqs {int} -- Number of random frequencies for ECFD (default: {8}). + optimize_sigma {bool} -- Whether to optimize sigma (default: {False}). + disc_net {str} -- Discriminator network type (default: {'flexible-dcgan'}). + gen_net {str} -- Generator network type (default: {'flexible-dcgan'}). + """ + GAN.__init__(self, dset_name, imsize, nc, + data_root=data_root, results_root=results_root, + noise_dim=noise_dim, dout_dim=dout_dim, + batch_size=batch_size, clip_disc=gp_lambda == 0., + max_giters=max_giters, lr=lr, disc_size=disc_size, + batch_norm=False, gen_net=gen_net, disc_net=disc_net) + self.ecfd_fn = getattr(ecfd, ecfd_type) + self.optimize_sigma = optimize_sigma + self.num_freqs = num_freqs + self.reg_lambda = 16.0 + cls_name = self.__class__.__name__.lower() + if optimize_sigma: + cls_name = 'o' + cls_name + self.results_root = os.path.join(results_root, dset_name, cls_name) + if optimize_sigma: + self.lg_sigmas = torch.zeros((1, dout_dim)).cuda() + self.lg_sigmas.requires_grad = True + self.d_optim = torch.optim.RMSprop( + list(self.discriminator.parameters()) + [self.lg_sigmas], lr=lr) + self.results_root += '_{:s}'.format(ecfd_type) + else: + self.sigmas = sigmas + self.results_root += '_{:s}_{:s}'.format( + ecfd_type, '_'.join(map(str, sigmas))) + self.gp_lambda = gp_lambda + self.results_root = os.path.join(self.results_root, self.gen_net) + self.ensure_dirs() + + def _reset_grad(self): + """Resets the gradient of discriminator and lg_sigmas + (if sigma is being optimized) to zero. + """ + super()._reset_grad() + if self.optimize_sigma and self.lg_sigmas.grad is not None: + self.lg_sigmas.grad.data.zero_() + + def disc_loss(self, reals, fakes): + """Computes the discriminator loss -ECFD + GP + OneSideErr. + + Arguments: + reals {torch.Tensor} -- A batch of real images. + fakes {torch.Tensor} -- A batch of fake images. + + Returns: + torch.Tensor -- The discriminator loss. + """ + d_real = self.discriminator(reals) + d_fake = self.discriminator(fakes) + if self.optimize_sigma: + sigmas = torch.exp(self.lg_sigmas) + else: + sigmas = self.sigmas + ecfd_loss = self.ecfd_fn( + d_real, d_fake, sigmas, num_freqs=self.num_freqs, + optimize_sigma=self.optimize_sigma) + if self.gp_lambda > 0.0: + gp = self.gradient_penalty(reals, fakes) + else: + gp = 0.0 + reg = self.one_side_error(d_real, d_fake) + loss = -torch.sqrt(ecfd_loss) + self.gp_lambda * gp + self.reg_lambda * reg + return loss + + def gen_loss(self, reals, fakes): + """Computes the generator loss ECFD - OneSideErr. + + Arguments: + reals {torch.Tensor} -- A batch of real images. + fakes {torch.Tensor} -- A batch of fake images. + Returns: + torch.Tensor -- The generator loss. + """ + d_real = self.discriminator(reals) + d_fake = self.discriminator(fakes) + if self.optimize_sigma: + sigmas = torch.exp(self.lg_sigmas) + else: + sigmas = self.sigmas + ecfd_loss = self.ecfd_fn( + d_real, d_fake, sigmas, num_freqs=self.num_freqs, + optimize_sigma=self.optimize_sigma) + reg = self.one_side_error(d_real, d_fake) + return torch.sqrt(ecfd_loss) - self.reg_lambda * reg + + def one_side_error(self, d_real, d_fake): + """Computes one sided penalty. + Adapted from: https://github.com/OctoberChang/MMD-GAN/blob/b15c98/mmd_gan.py#L57 + + Arguments: + reals {torch.Tensor} -- A batch of d(real images). + fakes {torch.Tensor} -- A batch of d(fake images). + + Returns: + torch.Tensor -- The one sided penalty. + """ + diff = d_real.mean(0) - d_fake.mean(0) + err = torch.relu(-diff) + return err.mean() + + def gradient_penalty(self, reals, fakes): + """Computes gradient penalty. + + Arguments: + reals {torch.Tensor} -- A batch of real images. + fakes {torch.Tensor} -- A batch of fake images. + + Returns: + torch.Tensor -- The gradient penalty. + """ + batch_size = reals.size(0) + alpha = torch.rand(batch_size, 1, 1, 1).cuda() + interpolations = alpha * reals + (1 - alpha) * fakes + interpolations.requires_grad = True + d_interpolations = self.discriminator(interpolations) + gradients = torch.autograd.grad( + d_interpolations, interpolations, + grad_outputs=torch.ones(d_interpolations.size()).cuda(), + create_graph=True, + retain_graph=True)[0] + gradients = gradients.view(batch_size, -1) + gradients_norm = torch.sum(gradients ** 2 + 1e-7, 1).sqrt() + return torch.mean((gradients_norm - 1) ** 2) diff --git a/src/datasets.py b/src/datasets.py new file mode 100644 index 0000000..34836ec --- /dev/null +++ b/src/datasets.py @@ -0,0 +1,150 @@ +# Provides torch Datasets. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import os +import numpy as np +import torch +import torchvision.datasets as dset +import torchvision.transforms as transforms +from torch.utils.data import Dataset +from glob import glob +from PIL import Image + + +class ImageFolderDataset(Dataset): + # Contructs a dataset from a folder with images + def __init__(self, root, input_transform=None): + self.image_filenames = [x for x in glob(root + '/*') if is_image_file(x.lower())] + + self.input_transform = input_transform + + def __getitem__(self, index): + x = load_img(self.image_filenames[index]) + if self.input_transform: + x = self.input_transform(x) + return x, 0 + + def __len__(self): + return len(self.image_filenames) + +class PKLDataset(Dataset): + # Construct a dataset from a .pkl file + def __init__(self, pkl_file): + print('[*] Loading dataset from %s' % pkl_file) + import pickle + with open(pkl_file, 'rb') as fobj: + self.images = pickle.load(fobj) + print('[*] Dataset loaded') + + def __getitem__(self, index): + x = self.images[index] + return x, 0 + + def __len__(self): + return len(self.images) + +class PTDataset(Dataset): + # Construct a dataset from a .pt file + def __init__(self, pt_file): + print('[*] Loading dataset from %s' % pt_file) + self.images = torch.load(pt_file) + print('[*] Dataset loaded') + + def __getitem__(self, index): + x = self.images[index] + return x, 0 + + def __len__(self): + return len(self.images) + + +def is_image_file(filename): + """Checks if a file is an image. + + Arguments: + filename {str} -- File path. + + Returns: + bool -- True if the path is PNG or JPG image. + """ + return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) + + +def load_img(filepath): + """Loads an image from file. + + Arguments: + filepath {str} -- Path of the image. + + Returns: + PIL.Image.Image -- A PIL Image object. + """ + img = Image.open(filepath).convert('RGB') + return img + + +def get_dataset(dset_name, data_root='./data', imsize=None, train=True): + """Creates and returns a torch dataset. + + Arguments: + dset_name {str} -- Name of the dataset. + + Keyword Arguments: + data_root {str} -- Directory where datasets are stored (default: {'./data'}). + imsize {int} -- Size of the image (default: {None}). + train {bool} -- Whether to load the train split (default: {True}). + + Returns: + Dataset -- A torch dataset, + """ + sizes = {'mnist': 32, + 'cifar10': 32, + 'stl10': 32, + 'celeba': 32, + 'celeba128': 128} + assert dset_name in sizes.keys(), 'Unknown dataset {0}'.format(dset_name) + if imsize is None: + imsize = sizes[dset_name] + # Resize, Center-crop, and normalize to [-1, 1] + transform = transforms.Compose([ + transforms.Resize(imsize), + transforms.CenterCrop(imsize), + transforms.ToTensor(), + transforms.Normalize( + (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + transform_with_antialias = transforms.Compose([ + transforms.Resize(imsize, Image.ANTIALIAS), + transforms.CenterCrop(imsize), + transforms.ToTensor(), + transforms.Normalize( + (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + if dset_name == 'mnist': + transform = transforms.Compose([ + transforms.Resize(imsize), + transforms.CenterCrop(imsize), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5])]) + dataset = dset.MNIST(root=data_root, download=True, + train=train, transform=transform) + elif dset_name == 'cifar10': + dataset = dset.CIFAR10(root=data_root, download=True, + train=train, transform=transform) + elif dset_name == 'stl10': + dataset = dset.STL10(root=data_root, download=True, + split='unlabeled', transform=transform) + elif dset_name == 'celeba': + dataset = ImageFolderDataset(data_root, input_transform=transform_with_antialias) + path = os.path.join(data_root, 'celeb.pkl') + if os.path.exists(path): + dataset = PKLDataset(path) + elif dset_name == 'celeba128': + dataset = ImageFolderDataset(data_root, input_transform=transform_with_antialias) + path = os.path.join(data_root, 'celeba128.pt') + if os.path.exists(path): + dataset = PTDataset(path) + return dataset diff --git a/src/ecfd.py b/src/ecfd.py new file mode 100644 index 0000000..fd63400 --- /dev/null +++ b/src/ecfd.py @@ -0,0 +1,185 @@ +# Provides functions to compute +# Empirical Characteristic Function Distance (ECFD) +# with different weighting distributions. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import torch +import numpy as np + +def gaussian_ecfd(X, Y, sigmas, num_freqs=8, optimize_sigma=False): + """Computes ECFD with Gaussian weighting distribution. + + Arguments: + X {torch.Tensor} -- Samples from distribution P of shape [B x D]. + Y {torch.Tensor} -- Samples from distribution Q of shape [B x D]. + sigmas {list} or {torch.Tensor} -- A list of floats or a torch Tensor of + shape [1 x D] if optimize_sigma is True. + + Keyword Arguments: + num_freqs {int} -- Number of random frequencies to use (default: {8}). + optimize_sigma {bool} -- Whether to optimize sigma (default: {False}). + + Returns: + torch.Tensor -- The ECFD. + """ + total_loss = 0.0 + if not optimize_sigma: + for sigma in sigmas: + batch_loss = _gaussian_ecfd(X, Y, sigma, num_freqs=num_freqs) + total_loss += batch_loss + else: + batch_loss = _gaussian_ecfd(X, Y, sigmas, num_freqs=num_freqs) + total_loss += batch_loss / torch.norm(sigmas, p=2) + return total_loss + + +def _gaussian_ecfd(X, Y, sigma, num_freqs=8): + wX, wY = 1.0, 1.0 + X, Y = X.view(X.size(0), -1), Y.view(Y.size(0), -1) + batch_size, dim = X.size() + t = torch.randn((num_freqs, dim)).cuda() * sigma + X_reshaped = X.view((batch_size, dim)) + tX = torch.matmul(t, X_reshaped.t()) + cos_tX = (torch.cos(tX) * wX).mean(1) + sin_tX = (torch.sin(tX) * wX).mean(1) + Y_reshaped = Y.view((batch_size, dim)) + tY = torch.matmul(t, Y_reshaped.t()) + cos_tY = (torch.cos(tY) * wY).mean(1) + sin_tY = (torch.sin(tY) * wY).mean(1) + loss = (cos_tX - cos_tY) ** 2 + (sin_tX - sin_tY) ** 2 + return loss.mean() + + +def laplace_ecfd(X, Y, sigmas, num_freqs=8, optimize_sigma=False): + """Computes ECFD with Laplace weighting distribution. + + Arguments: + X {torch.Tensor} -- Samples from distribution P of shape [B x D]. + Y {torch.Tensor} -- Samples from distribution Q of shape [B x D]. + sigmas {list} or {torch.Tensor} -- A list of floats or a torch Tensor of + shape [1 x D] if optimize_sigma is True. + + Keyword Arguments: + num_freqs {int} -- Number of random frequencies to use (default: {8}). + optimize_sigma {bool} -- Whether to optimize sigma (default: {False}). + + Returns: + torch.Tensor -- The ECFD. + """ + total_loss = 0.0 + if not optimize_sigma: + for sigma in sigmas: + batch_loss = _laplace_ecfd(X, Y, sigma, num_freqs=num_freqs) + total_loss += batch_loss + else: + batch_loss = _laplace_ecfd(X, Y, sigmas, num_freqs=num_freqs) + total_loss += batch_loss / torch.norm(sigmas, p=2) + return total_loss + + +def _laplace_ecfd(X, Y, sigma, num_freqs=8): + X, Y = X.view(X.size(0), -1), Y.view(Y.size(0), -1) + batch_size, dim = X.size() + t = torch.cuda.FloatTensor( + np.random.laplace(size=(num_freqs, dim))) * sigma + X_reshaped = X.view((batch_size, dim)) + tX = torch.matmul(t, X_reshaped.t()) + cos_tX = torch.cos(tX).mean(1) + sin_tX = torch.sin(tX).mean(1) + Y_reshaped = Y.view((batch_size, dim)) + tY = torch.matmul(t, Y_reshaped.t()) + cos_tY = torch.cos(tY).mean(1) + sin_tY = torch.sin(tY).mean(1) + loss = (cos_tX - cos_tY) ** 2 + (sin_tX - sin_tY) ** 2 + return loss.mean() + +def studentT_ecfd(X, Y, sigmas, num_freqs=8, optimize_sigma=False, dof=2.0): + """Computes ECFD with Student's-t weighting distribution with dof = 2. + + Arguments: + X {torch.Tensor} -- Samples from distribution P of shape [B x D]. + Y {torch.Tensor} -- Samples from distribution Q of shape [B x D]. + sigmas {list} or {torch.Tensor} -- A list of floats or a torch Tensor of + shape [1 x D] if optimize_sigma is True. + + Keyword Arguments: + num_freqs {int} -- Number of random frequencies to use (default: {8}). + optimize_sigma {bool} -- Whether to optimize sigma (default: {False}). + dof {float} -- Degrees of freedom. + + Returns: + torch.Tensor -- The ECFD. + """ + total_loss = 0.0 + if not optimize_sigma: + for sigma in sigmas: + batch_loss = _studentT_ecfd( + X, Y, sigma, num_freqs=num_freqs, dof=dof) + total_loss += batch_loss + else: + batch_loss = _studentT_ecfd(X, Y, sigmas, num_freqs=num_freqs, dof=dof) + total_loss += batch_loss / torch.norm(sigmas, p=2) + return total_loss + + +def _studentT_ecfd(X, Y, sigma, num_freqs=8, dof=2.0): + X, Y = X.view(X.size(0), -1), Y.view(Y.size(0), -1) + batch_size, dim = X.size() + t = torch.cuda.FloatTensor( + np.random.standard_t(dof, (num_freqs, dim))) * sigma + X_reshaped = X.view((batch_size, dim)) + tX = torch.matmul(t, X_reshaped.t()) + cos_tX = torch.cos(tX).mean(1) + sin_tX = torch.sin(tX).mean(1) + Y_reshaped = Y.view((batch_size, dim)) + tY = torch.matmul(t, Y_reshaped.t()) + cos_tY = torch.cos(tY).mean(1) + sin_tY = torch.sin(tY).mean(1) + loss = (cos_tX - cos_tY) ** 2 + (sin_tX - sin_tY) ** 2 + return loss.mean() + + +def uniform_ecfd(X, Y, sigmas, num_freqs=8, optimize_sigma=False): + """Computes ECFD with Uniform weighting distribution [-sigma, sigma]. + + Arguments: + X {torch.Tensor} -- Samples from distribution P of shape [B x D]. + Y {torch.Tensor} -- Samples from distribution Q of shape [B x D]. + sigmas {list} or {torch.Tensor} -- A list of floats or a torch Tensor of + shape [1 x D] if optimize_sigma is True. + + Keyword Arguments: + num_freqs {int} -- Number of random frequencies to use (default: {8}). + optimize_sigma {bool} -- Whether to optimize sigma (default: {False}). + + Returns: + torch.Tensor -- The ECFD. + """ + total_loss = 0.0 + if not optimize_sigma: + for sigma in sigmas: + batch_loss = _uniform_ecfd(X, Y, sigma, num_freqs=num_freqs) + total_loss += batch_loss + else: + batch_loss = _uniform_ecfd(X, Y, sigmas, num_freqs=num_freqs) + total_loss += batch_loss / torch.norm(sigmas, p=2) + return total_loss + + +def _uniform_ecfd(X, Y, sigma, num_freqs=8): + X, Y = X.view(X.size(0), -1), Y.view(Y.size(0), -1) + batch_size, dim = X.size() + t = (2 * torch.rand((num_freqs, dim)).cuda() - 1.0) * sigma + X_reshaped = X.view((batch_size, dim)) + tX = torch.matmul(t, X_reshaped.t()) + cos_tX = torch.cos(tX).mean(1) + sin_tX = torch.sin(tX).mean(1) + Y_reshaped = Y.view((batch_size, dim)) + tY = torch.matmul(t, Y_reshaped.t()) + cos_tY = torch.cos(tY).mean(1) + sin_tY = torch.sin(tY).mean(1) + loss = (cos_tX - cos_tY) ** 2 + (sin_tX - sin_tY) ** 2 + return loss.mean() diff --git a/src/gan.py b/src/gan.py new file mode 100644 index 0000000..e08bb30 --- /dev/null +++ b/src/gan.py @@ -0,0 +1,272 @@ +# Abstract class for GAN models. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import os +import timeit +import torch +import torch.utils.data as tdata +import torch.backends.cudnn as cudnn +import numpy as np +from abc import ABC, abstractmethod + +import networks +from datasets import get_dataset +from util import im2grid + + +class GAN(ABC): + def __init__(self, + dset_name, + imsize, + nc, + data_root='./data', + results_root='./results', + noise_dim=100, + dout_dim=1, + batch_size=64, + clip_disc=True, + max_giters=50000, + lr=1e-4, + disc_size=64, + batch_norm=True, + disc_net='flexible-dcgan', + gen_net='flexible-dcgan'): + """Intializer for base GAN model. + + Arguments: + dset_name {str} -- Name of the dataset. + imsize {int} -- Size of the image. + nc {int} -- Number of channels. + + Keyword Arguments: + data_root {str} -- Directory where datasets are stored (default: {'./data'}). + results_root {str} -- Directory where results will be saved (default: {'./results'}). + noise_dim {int} -- Dimension of noise input to generator (default: {100}). + dout_dim {int} -- Dimension of output from discriminator (default: {1}). + batch_size {int} -- Batch size (default: {64}). + clip_disc {bool} -- Whether to clip the parameters of discriminator in [-0.01, 0.01]. + This should be True when gradient penalty is not used (default: {True}). + max_giters {int} -- Maximum number of generator iterations (default: {50000}). + lr {[type]} -- Learning rate (default: {1e-4}). + disc_size {int} -- Number of filters in the first Conv layer of critic. (default: {64}) + batch_norm {bool} -- Whether to use batch norm in discriminator. This should be + False when gradient penalty is used (default: {True}). + disc_net {str} -- Discriminator network type. (default: {'flexible-dcgan'}) + gen_net {str} -- Generator network type. (default: {'flexible-dcgan'}) + """ + self.imsize = imsize + self.nc = nc + self.noise_dim = noise_dim + self.dout_dim = dout_dim + self.disc_size = disc_size + self.batch_norm = batch_norm + self.disc_net = disc_net + self.gen_net = gen_net + self._build_model() + self.g_optim = torch.optim.RMSprop( + self.generator.parameters(), lr=lr) + self.d_optim = torch.optim.RMSprop( + self.discriminator.parameters(), lr=lr) + self.giters = 1 + self.diters = 5 + self.max_giters = max_giters + self.data_root = data_root + suffix = self.__class__.__name__.lower() + suffix += '_' + str(self.disc_size) if self.disc_size != 64 else '' + self.results_root = os.path.join( + results_root, dset_name, suffix) + self.clip_disc = clip_disc + self.model_save_interval = 1000 + self.fixed_im_interval = 100 + self.fixed_noise = torch.cuda.FloatTensor( + batch_size, self.noise_dim, 1, 1).normal_(0, 1) + self.noise_tensor = torch.cuda.FloatTensor(batch_size, self.noise_dim, 1, 1) + train_dataset = get_dataset( + dset_name, data_root=self.data_root, imsize=self.imsize) + self.train_dataloader = tdata.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True) + self.real_data = self.get_real_batch() + + def ensure_dirs(self): + """Creates directories for saving results if they don't exist. + """ + dirs = ['checkpoints', 'samples'] + for d in dirs: + path = os.path.join(self.results_root, d) + print(f'[*] {d} will be saved in {path}') + if not os.path.exists(path): + os.makedirs(path) + + def _build_model(self): + """Initializes the generator and discriminator networks. + """ + self.generator, self.discriminator = networks.build_networks(gen=self.gen_net, + disc=self.disc_net, + ngf=64, + ndf=self.disc_size, + imsize=self.imsize, + nc=self.nc, + k=self.noise_dim, + z=self.dout_dim, + bn=self.batch_norm) + print('Generator', self.generator) + print('Discriminator', self.discriminator) + self.generator.apply(networks.weights_init) + self.discriminator.apply(networks.weights_init) + self.generator.cuda() + self.discriminator.cuda() + cudnn.benchmark = True + + def get_real_batch(self): + """Infinite generator for real images. + + Yields: + torch.Tensor -- A batch of real images. + """ + while True: + iterator = iter(self.train_dataloader) + i = 0 + while i < len(iterator): + i += 1 + yield next(iterator) + + def save_checkpoint(self, g_iter): + """Saves the current model state. + + Arguments: + g_iter {int} -- Generator iteration. + """ + torch.save(self.generator.state_dict(), + os.path.join(self.results_root, 'checkpoints', + 'netG_iter_{0}.pth'.format(g_iter))) + torch.save(self.discriminator.state_dict(), + os.path.join(self.results_root, 'checkpoints', + 'netD_iter_{0}.pth'.format(g_iter))) + if g_iter == self.model_save_interval: + self.model_save_interval = 10000 + + def render_fixed_noise_image(self, g_iter): + """Save the image generated by the generator for a fixed + latent vector over training iterations. + + Arguments: + g_iter {int} -- Generator iteration. + """ + with torch.no_grad(): + fake_data = self.generator(self.fixed_noise).cpu().numpy() + path = os.path.join(self.results_root, 'samples', + 'fixed_{0}.png'.format(g_iter)) + im2grid(fake_data, path, shuffle=False) + if g_iter == self.fixed_im_interval: + self.fixed_im_interval = 1000 + + @abstractmethod + def disc_loss(self, reals, fakes): + """The discriminator loss. + + Arguments: + reals {torch.Tensor} -- A batch of real images. + fakes {torch.Tensor} -- A batch of fake images. + """ + pass + + @abstractmethod + def gen_loss(self, reals, fakes): + """The generator loss. + + Arguments: + reals {torch.Tensor} -- A batch of real images. + fakes {torch.Tensor} -- A batch of fake images. + """ + pass + + def _reset_grad(self): + """Resets the gradient of discriminator to zero. + """ + self.discriminator.zero_grad() + + def _disc_iter(self): + """An iteration of discriminator update. + + Returns: + torch.Tensor -- Discriminator loss. + """ + for p in self.discriminator.parameters(): + p.requires_grad = True + if self.clip_disc: + for p in self.discriminator.parameters(): + p.data.clamp_(-0.01, 0.01) + self._reset_grad() + real_data, _ = next(self.real_data) + real_data = real_data.cuda() + batch_size = real_data.size(0) + noise = self.noise_tensor.normal_(0, 1) + with torch.no_grad(): + fake_data = self.generator(noise) + err_disc = self.disc_loss(real_data, fake_data) + err_disc.backward() + self.d_optim.step() + return err_disc.data + + def _gen_iter(self): + """An iteration of generator update. + + Returns: + torch.Tensor -- Generator loss. + """ + for p in self.discriminator.parameters(): + p.requires_grad = False + self.generator.zero_grad() + real_data, _ = next(self.real_data) + real_data = real_data.cuda() + batch_size = real_data.size(0) + noise = self.noise_tensor.normal_(0, 1) + fake_data = self.generator(noise) + err_gen = self.gen_loss(real_data, fake_data) + err_gen.backward() + self.g_optim.step() + return err_gen.data + + def train(self): + """The training loop. Runs for self.max_giters generator + iterations. + """ + start_time = timeit.default_timer() + g_iter = 1 + while g_iter <= self.max_giters: + for i in range(self.diters): + err_disc = self._disc_iter() + for j in range(self.giters): + err_gen = self._gen_iter() + if g_iter % self.model_save_interval == 0: + self.save_checkpoint(g_iter) + if g_iter % self.fixed_im_interval == 0: + self.render_fixed_noise_image(g_iter) + time_elapsed = (timeit.default_timer() - start_time) / 60 + print('[{:d}] <{:06.2f}m> d_loss: {:.6f}, g_loss: {:.6f}'.format( + g_iter, time_elapsed, err_disc, err_gen)) + g_iter += 1 + self.save_checkpoint(self.max_giters) + + def generate_samples(self, num_samples=50000, batch_size=100): + """Generates random samples from the generator and saves them as a + npy file. + + Keyword Arguments: + num_samples {int} -- Number of samples to generate (default: {50000}). + batch_size {int} -- Batch size (default: {100}). + """ + n_batches = num_samples // batch_size + generated_images = [] + noise = torch.cuda.FloatTensor(batch_size, self.noise_dim, 1, 1) + for b in range(n_batches): + noise.normal_(0, 1) + with torch.no_grad(): + fake_data = self.generator(noise).cpu().numpy() + generated_images.append(fake_data) + generated_images = np.vstack(generated_images) + path = os.path.join(self.results_root, 'samples', 'generated.npy') + np.save(path, generated_images) diff --git a/src/gen_samples.py b/src/gen_samples.py new file mode 100644 index 0000000..fa5aae5 --- /dev/null +++ b/src/gen_samples.py @@ -0,0 +1,58 @@ +# Generate samples from trained models using this script. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# +import argparse +import torch +import os +import random +import string +import numpy as np +from tqdm import tqdm + +import networks +from util import im2grid + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('ckpt', type=str, + help='Generator weights checkpoint.') + parser.add_argument('--gen', type=str, default='flexible-dcgan', + choices=['flexible-dcgan', 'dcgan32', 'dcgan5', 'resnet'], + help='Generator network type.') + parser.add_argument('--n', '--num_samples', type=int, + default=50000, help='Number of samples to generate.') + parser.add_argument('--k', '--noise_dim', type=int, + default=32, help='Dimension of noise input to generator.') + parser.add_argument('--imsize', type=int, + default=32, help='Size of the image.') + parser.add_argument('--png', action='store_true', help='Whether to generate a png ' + '(overrides num_samples and generates a 8x8 grid of images).') + parser.add_argument('--o', '--out_dir', type=str, + default='./', help='Output directory.') + + args = parser.parse_args() + + generator, _ = networks.build_networks(gen=args.gen, imsize=args.imsize, k=args.k) + generator.load_state_dict(torch.load(args.ckpt)) + print('[*] Generator loaded') + if not os.path.exists(args.o): + os.makedirs(args.o) + generator.cuda() + batch_size = 64 if args.png else 256 + noise = torch.cuda.FloatTensor(batch_size, args.k, 1, 1) + n_loops = 1 if args.png else args.n // batch_size + 1 + samples = [] + for i in tqdm(range(n_loops)): + noise.normal_(0, 1) + with torch.no_grad(): + images = generator(noise).detach().cpu().numpy() + if args.png: # Generate 8x8 if png + path = os.path.join(args.o, 'samples.png') + im2grid(images, path, shuffle=False) + else: + samples.append(images) + if not args.png: # Save images to samples.npy + samples = np.vstack(samples)[:args.n] + np.save('samples.npy', samples) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..9590cd3 --- /dev/null +++ b/src/main.py @@ -0,0 +1,82 @@ +# Train CFGAN-GP models using this script. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import argparse +import torch + + +def set_args(parser): + parser.add_argument('--dataset', required=True, + choices=['mnist', 'cifar10', 'celeba', 'stl10', 'celeba128'], + help='Dataset name.') + parser.add_argument('--model', required=True, + choices=['cfgangp'], help='GAN Model') + parser.add_argument('--dataroot', default='./data', help='Path to dataset.') + parser.add_argument('--disc', default='flexible-dcgan', + choices=['flexible-dcgan', 'dcgan32', 'dcgan5'], + help='Discriminator network type.') + parser.add_argument('--gen', default='flexible-dcgan', + choices=['flexible-dcgan', 'dcgan32', 'dcgan5', 'resnet'], + help='Generator network type.') + parser.add_argument('--resultsroot', default='./results', + help='Path where results will be saved.') + parser.add_argument('--batch_size', type=int, + default=64, help='Input batch size.') + parser.add_argument('--image_size', type=int, default=32, + help='The size of the input image to network.') + parser.add_argument('--nc', type=int, default=3, help='Number of channels.') + parser.add_argument('--dout_dim', type=int, default=1, + help='Output dim of discriminator.') + parser.add_argument('--noise_dim', type=int, default=100, + help='Dim of noise input to generator.') + parser.add_argument('--disc_size', default=64, type=int, + help='Number of filters in first Conv layer of critic.') + parser.add_argument('--max_giters', type=int, default=50000, + help='Number of generator iterations to train for.') + parser.add_argument('--lr', type=float, default=0.00005, + help='Learning rate, default=0.00005.') + parser.add_argument('--gpu_device', type=int, + default=0, help='GPU device id.') + # CFGAN specific arguments + parser.add_argument('--num_freqs', type=int, default=8, + help='Number of random frequencies.') + parser.add_argument('--sigmas', type=float, nargs='+', + help='Value of sigma for gaussian, student-t weight,\ + 0 means that sigma will be optimized.') + parser.add_argument('--weight', default='gaussian_ecfd', + type=str, choices=['gaussian_ecfd', 'studentT_ecfd', + 'laplace_ecfd', 'uniform_ecfd'], help='Weighting distribution' + ' for ECFD.') + return parser + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + set_args(parser) + args = parser.parse_args() + assert torch.cuda.is_available(), 'Error: Requires CUDA' + torch.cuda.set_device(args.gpu_device) + print("Using GPU device", torch.cuda.current_device()) + if args.model == 'cfgangp': + from cfgan import CFGANGP + gan = CFGANGP(args.dataset, args.image_size, args.nc, + data_root=args.dataroot, + results_root=args.resultsroot, + dout_dim=args.dout_dim, + noise_dim=args.noise_dim, + batch_size=args.batch_size, + max_giters=args.max_giters, + lr=args.lr, + ecfd_type=args.weight, + num_freqs=args.num_freqs, + sigmas=args.sigmas, + optimize_sigma=args.sigmas[0] == 0., + disc_size=args.disc_size, + disc_net=args.disc, + gen_net=args.gen) + + gan.train() + gan.generate_samples() diff --git a/src/networks.py b/src/networks.py new file mode 100644 index 0000000..de676e6 --- /dev/null +++ b/src/networks.py @@ -0,0 +1,282 @@ +# Provides different generator and discriminator network architectures. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import torch.nn as nn +import torch +import numpy as np + +from resnet import ResidualBlock + +# DCGAN-like Discriminator +class Encoder(nn.Module): + # Source: https://github.com/OctoberChang/MMD-GAN/blob/master/base_module.py + def __init__(self, isize, nc, k=100, ndf=64, bn=True): + super(Encoder, self).__init__() + assert isize % 16 == 0, "isize has to be a multiple of 16" + main = nn.Sequential() + main.add_module('initial-conv-{0}-{1}'.format(nc, ndf), + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) + main.add_module('initial-relu-{0}'.format(ndf), + nn.LeakyReLU(0.2, inplace=True)) + csize, cndf = isize / 2, ndf + + while csize > 4: + in_feat = cndf + out_feat = cndf * 2 + main.add_module('pyramid-{0}-{1}-conv'.format(in_feat, out_feat), + nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) + if bn: + main.add_module('pyramid-{0}-batchnorm'.format(out_feat), + nn.BatchNorm2d(out_feat)) + main.add_module('pyramid-{0}-relu'.format(out_feat), + nn.LeakyReLU(0.2, inplace=True)) + cndf = cndf * 2 + csize = csize / 2 + + main.add_module('final-{0}-{1}-conv'.format(cndf, k), + nn.Conv2d(cndf, k, 4, 1, 0, bias=False)) + + self.main = main + + def forward(self, input, return_layers=False): + if not return_layers: + output = self.main(input) + return output + else: + h = [self.main[1](self.main[0](input))] + i = 2 + while i < len(self.main) - 1: + hi = self.main[i + 2](self.main[i + 1](self.main[i](h[-1]))) + h.append(hi) + i += 3 + h.append(self.main[i](h[-1])) + return h + +class DCGANDiscriminator(nn.Module): + # Image size is fixed to 32 x 32 + def __init__(self, nz=10, ndf=64, nc=3): + super().__init__() + self.l1 = nn.Conv2d(nc, ndf, 4, 2, 1, bias=False) + self.l2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False) + self.l3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False) + self.l4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False) + self.l5 = nn.Conv2d(ndf * 8, nz, 4, 2, 1, bias=False) + + def forward(self, x): + x = nn.LeakyReLU(0.2, inplace=True)(self.l1(x)) + x = nn.LeakyReLU(0.2, inplace=True)(self.l2(x)) + x = nn.LeakyReLU(0.2, inplace=True)(self.l3(x)) + x = nn.LeakyReLU(0.2, inplace=True)(self.l4(x)) + x = self.l5(x) + return x + +class Decoder(nn.Module): + # Source: https://github.com/OctoberChang/MMD-GAN/blob/master/base_module.py + def __init__(self, isize, nc, k=100, ngf=64): + super(Decoder, self).__init__() + assert isize % 16 == 0, "isize has to be a multiple of 16" + + cngf, tisize = ngf // 2, 4 + while tisize != isize: + cngf = cngf * 2 + tisize = tisize * 2 + + main = nn.Sequential() + main.add_module('initial-{0}-{1}-convt'.format(k, cngf), + nn.ConvTranspose2d(k, cngf, 4, 1, 0, bias=False)) + main.add_module( + 'initial-{0}-batchnorm'.format(cngf), nn.BatchNorm2d(cngf)) + main.add_module('initial-{0}-relu'.format(cngf), nn.ReLU(True)) + + csize = 4 + while csize < isize // 2: + main.add_module('pyramid-{0}-{1}-convt'.format(cngf, cngf // 2), + nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False)) + main.add_module('pyramid-{0}-batchnorm'.format(cngf // 2), + nn.BatchNorm2d(cngf // 2)) + main.add_module('pyramid-{0}-relu'.format(cngf // 2), + nn.ReLU(True)) + cngf = cngf // 2 + csize = csize * 2 + + main.add_module('final-{0}-{1}-convt'.format(cngf, nc), + nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) + main.add_module('final-{0}-tanh'.format(nc), + nn.Tanh()) + + self.main = main + + def forward(self, input): + output = self.main(input) + return output + +class DCGANGenerator(nn.Module): + # Image size is fixed to 32 x 32 + def __init__(self, ngf=64, nz=10, nc=3): + super().__init__() + self.main = nn.Sequential( + nn.ConvTranspose2d(nz, ngf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 8), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), + nn.Tanh() + ) + + def forward(self, input): + output = self.main(input) + return output + +class DCGAN5Generator(nn.Module): + def __init__(self, imsize=128, ngf=64, nz=10, nc=3): + super().__init__() + self.ngf = ngf + self.nz = nz + s1, s2, s4, s8, s16, self.s32 = conv_sizes(imsize, layers=5, stride=2) + self.linear1 = nn.Linear(nz, ngf * 16 * self.s32 * self.s32) + self.relu = nn.ReLU(True) + self.bn0 = nn.BatchNorm2d(ngf * 16) + self.main = nn.Sequential( + nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 8), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), + nn.Tanh() + ) + + def forward(self, x): + x = x.view(-1, self.nz) + x = self.linear1(x).view(-1, self.ngf * 16, self.s32, self.s32) + x = self.relu(self.bn0(x)) + output = self.main(x) + return output + +class DCGAN5Discriminator(nn.Module): + def __init__(self, imsize=128, nz=10, ndf=64, nc=3): + super().__init__() + self.nc = nc + self.imsize = imsize + self.ndf = ndf + self.nz = nz + tmp = torch.randn(2, nc, imsize, imsize) + self.l1 = nn.Conv2d(nc, ndf, 4, 2, 0, bias=False) + self.l2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 0, bias=False) + self.l3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 0, bias=False) + self.l4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 0, bias=False) + self.l5 = nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 0, bias=False) + with torch.no_grad(): + outshape = self.l5(self.l4(self.l3(self.l2(self.l1(tmp))))).size() + self.linear = nn.Linear(np.prod(outshape[1:]), nz) + + def forward(self, x): + x = nn.LeakyReLU(0.2, inplace=True)(self.l1(x)) + x = nn.LeakyReLU(0.2, inplace=True)(self.l2(x)) + x = nn.LeakyReLU(0.2, inplace=True)(self.l3(x)) + x = nn.LeakyReLU(0.2, inplace=True)(self.l4(x)) + x = nn.LeakyReLU(0.2, inplace=True)(self.l5(x)) + x = x.view(x.size(0), -1) + x = self.linear(x) + output = x.view(-1, self.nz, 1, 1) + return output + +class ResNetGenerator(nn.Module): + def __init__(self, imsize=128, ngf=64, nz=10, nc=3): + super().__init__() + self.ngf = ngf + self.nz = nz + s1, s2, s4, s8, s16, self.s32 = conv_sizes(imsize, layers=5, stride=2) + self.linear1 = nn.Linear(nz, ngf * 16 * self.s32 * self.s32) + self.relu = nn.ReLU(True) + self.bn0 = nn.BatchNorm2d(ngf * 16) + self.model = nn.Sequential( + ResidualBlock(ngf * 16, ngf * 8, 3, resample='up'), + ResidualBlock(ngf * 8, ngf * 4, 3, resample='up'), + ResidualBlock(ngf * 4, ngf * 2, 3, resample='up'), + ResidualBlock(ngf * 2, ngf * 1, 3, resample='up'), + nn.BatchNorm2d(ngf), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(ngf, nc, 4, stride=2, padding=(3-1)//2), + nn.Tanh() + ) + + def forward(self, x): + x = x.view(-1, self.nz) + x = self.linear1(x).view(-1, self.ngf * 16, self.s32, self.s32) + output = self.model(x) + return output + +def grad_norm(m, norm_type=2): + total_norm = 0.0 + for p in m.parameters(): + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm ** norm_type + total_norm = total_norm ** (1. / norm_type) + return total_norm + +def conv_sizes(imsize, layers=5, stride=2): + s = [imsize] + for i in range(layers): + s.append(s[-1]//stride) + return s + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and classname != 'UpsampleConv': + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + elif classname.find('Linear') != -1: + m.weight.data.normal_(0.0, 0.1) + m.bias.data.fill_(0) + +def build_networks(gen='flexible-dcgan', disc='flexible-dcgan', + ngf=64, + ndf=64, + imsize=32, # Image dim imsize x imsize + nc=3, # Number of channels + k=100, # Dim. of noise input to disc + z=10, # Out dim. of disc + bn=True): # Use batch-norm in disc + if gen == 'flexible-dcgan': + generator = Decoder(imsize, nc, k=k, ngf=ngf) + elif gen == 'dcgan32': + generator = DCGANGenerator(nc=nc, nz=k, ngf=ngf) + elif gen == 'dcgan5': + generator = DCGAN5Generator(imsize=imsize, ngf=ngf, nz=k, nc=nc) + elif gen == 'resnet': + generator = ResNetGenerator(imsize=imsize, ngf=ngf, nz=k, nc=nc) + else: + raise ValueError('Unknown generator') + + if disc == 'flexible-dcgan': + discriminator = Encoder(imsize, nc, k=z, ndf=ndf, bn=bn) + elif disc == 'dcgan32': + discriminator = DCGANDiscriminator(nz=z, ndf=ndf, nc=nc) + elif disc == 'dcgan5': + discriminator = DCGAN5Discriminator(imsize=imsize, nz=z, ndf=ndf, nc=nc) + else: + raise ValueError('Unknown discriminator') + + return generator, discriminator diff --git a/src/resnet.py b/src/resnet.py new file mode 100644 index 0000000..baf16e6 --- /dev/null +++ b/src/resnet.py @@ -0,0 +1,83 @@ +# Provides residual blocks for ResNet. +# Based on: https://github.com/ozanciga/gans-with-pytorch/blob/master/wgan-gp/models.py +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import torch.nn as nn +import torch +import numpy as np + +class MeanPoolConv(nn.Module): + def __init__(self, n_input, n_output, k_size): + super(MeanPoolConv, self).__init__() + conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) + self.model = nn.Sequential(conv1) + def forward(self, x): + out = (x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]) / 4.0 + out = self.model(out) + return out + +class ConvMeanPool(nn.Module): + def __init__(self, n_input, n_output, k_size): + super(ConvMeanPool, self).__init__() + conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) + self.model = nn.Sequential(conv1) + def forward(self, x): + out = self.model(x) + out = (out[:,:,::2,::2] + out[:,:,1::2,::2] + out[:,:,::2,1::2] + out[:,:,1::2,1::2]) / 4.0 + return out + +class UpsampleConv(nn.Module): + def __init__(self, n_input, n_output, k_size): + super(UpsampleConv, self).__init__() + + self.model = nn.Sequential( + nn.PixelShuffle(2), + nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) + ) + def forward(self, x): + x = x.repeat((1, 4, 1, 1)) + out = self.model(x) + return out + +class ResidualBlock(nn.Module): + def __init__(self, n_input, n_output, k_size, resample='up', bn=True, spatial_dim=None): + super(ResidualBlock, self).__init__() + + self.resample = resample + + if resample == 'up': + self.conv1 = UpsampleConv(n_input, n_output, k_size) + self.conv2 = nn.Conv2d(n_output, n_output, k_size, padding=(k_size-1)//2) + self.conv_shortcut = UpsampleConv(n_input, n_output, k_size) + self.out_dim = n_output + elif resample == 'down': + self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) + self.conv2 = ConvMeanPool(n_input, n_output, k_size) + self.conv_shortcut = ConvMeanPool(n_input, n_output, k_size) + self.out_dim = n_output + self.ln_dims = [n_input, spatial_dim, spatial_dim] + else: + self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) + self.conv2 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) + self.conv_shortcut = None + self.out_dim = n_input + self.ln_dims = [n_input, spatial_dim, spatial_dim] + + self.model = nn.Sequential( + nn.BatchNorm2d(n_input) if bn else nn.LayerNorm(self.ln_dims), + nn.ReLU(inplace=True), + self.conv1, + nn.BatchNorm2d(self.out_dim) if bn else nn.LayerNorm(self.ln_dims), + nn.ReLU(inplace=True), + self.conv2, + ) + + def forward(self, x): + if self.conv_shortcut is None: + return x + self.model(x) + else: + return self.conv_shortcut(x) + self.model(x) + diff --git a/src/util.py b/src/util.py new file mode 100644 index 0000000..a1bf0db --- /dev/null +++ b/src/util.py @@ -0,0 +1,49 @@ +# Provides some utility functions for images. +# +# Copyright (c) 2020 Abdul Fatir Ansari. All rights reserved. +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +import numpy as np +import cv2 +from math import sin, cos + + +def sanitize_images(imgs): + if len(imgs.shape) == 3: + imgs = imgs[:, :, :, None] + if imgs.shape[-1] > 3: + imgs = imgs.transpose((0, 2, 3, 1)) + if imgs[0].min() < -0.0001: + imgs = (imgs + 1) / 2.0 + if imgs[0].max() <= 1.0: + imgs *= 255.0 + return imgs.astype(np.uint8) + +''' +Adapted from: https://stackoverflow.com/questions/42040747/more-idomatic-way-to-display-images-in-a-grid-with-numpy +''' + +def gallery(array, ncols=3): + nindex, height, width, intensity = array.shape + nrows = nindex // ncols + assert nindex == nrows * ncols + # want result.shape = (height*nrows, width*ncols, intensity) + result = (array.reshape(nrows, ncols, height, width, intensity) + .swapaxes(1, 2) + .reshape(height * nrows, width * ncols, intensity)) + return result + + +def im2grid(imgs, out_file='image.png', shuffle=True, num_imgs=None): + if num_imgs is None: + num_imgs = imgs.shape[0] + imgs = sanitize_images(imgs) + if shuffle: + imgs = imgs[np.random.permutation(imgs.shape[0])] + imgs = imgs[:num_imgs] + grid_image = gallery(imgs, ncols=int(np.sqrt(num_imgs))) + if grid_image.shape[-1] == 1: + cv2.imwrite(out_file, grid_image) + else: + cv2.imwrite(out_file, cv2.cvtColor(grid_image, cv2.COLOR_RGB2BGR)) diff --git a/stl10.sh b/stl10.sh new file mode 100644 index 0000000..99b833f --- /dev/null +++ b/stl10.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +BS=64 +GPU_ID=0 +MAX_GITER=125000 +DATA_PATH=./data +DATASET=stl10 +DATAROOT=${DATA_PATH}/stl10 +ISIZE=32 +NC=3 +NOISE_DIM=32 +MODEL=cfgangp +DOUT_DIM=${NOISE_DIM} +NUM_FREQS=8 +WEIGHT=gaussian_ecfd +SIGMA=0. + +cmd="python src/main.py\ + --dataset ${DATASET}\ + --dataroot ${DATAROOT}\ + --model ${MODEL}\ + --batch_size ${BS}\ + --image_size ${ISIZE}\ + --nc ${NC}\ + --noise_dim ${NOISE_DIM}\ + --dout_dim ${DOUT_DIM}\ + --max_giter ${MAX_GITER}\ + --resultsroot ./out + --gpu_device ${GPU_ID}" + +if [ ${MODEL} == 'cfgangp' ]; then + cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}" +fi + +echo $cmd +eval $cmd