diff --git a/README.md b/README.md index 1297776..cd35167 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,165 @@ -Yu Deng, Jiaolong Yang, Dong Chen, Fang Wen, Xin Tong. Disentangled and Controllable Face Image Generation via 3D Imitative-Contrastive Learning. CVPR 2020 (Oral) [PDF](https://arxiv.org/pdf/2004.11660.pdf) -# Code will be released soon, stay tuned! +## Disentangled and Controllable Face Image Generation via 3D Imitative-Contrastive Learning ## +

+ +

+ + +This is a tensorflow implementation of the following paper: + +**Disentangled and Controllable Face Image Generation via 3D Imitative-Contrastive Learning**, CVPR 2020. (**_Oral_**) + +Yu Deng, Jiaolong Yang, Dong Chen, Fang Wen, and Xin Tong + +Paper: [https://arxiv.org/abs/2004.11660](https://arxiv.org/abs/2004.11660) + +Abstract: _We propose an approach for face image generation of virtual people with disentangled, precisely-controllable latent representations for identity of non-existing people, expression, pose, and illumination. We embed 3D priors into adversarial learning and train the network to imitate the image formation of an analytic 3D face deformation and rendering process. To deal with the generation freedom induced by the domain gap between real and rendered faces, we further introduce contrastive learning to promote disentanglement by comparing pairs of generated images. Experiments show that through our imitative-contrastive learning, the factor variations are very well disentangled and the properties of a generated face can be precisely controlled. We also analyze the learned latent space and present several meaningful properties supporting factor disentanglement. Our method can also be used to embed real images into the disentangled latent space. We hope our method could provide new understandings of the relationship between physical properties and deep image synthesis._ + +## Features + +### ● Factor disentanglement +When generating face images, we can freely change the four factors including identity, expression, lighting, and pose. The factor variations are highly disentangled: changing one factor does not affect others. + +

+ +

+ + +### ● Reference based generation +We achieve reference-based generation where we extract expression, pose and lighting from a given image and generate new identities with similar properties. + +

+ +

+ +### ● Real image pose manipulation +We can use our method to embed a real image into the disentangled latent space and edit it, such as pose manipulation. +

+ +

+ +### ● Real image lighting editing +We can edit the lighting of a real image. +

+ +

+ +### ● Real image expression transfer +We can also achieve expression transfer of real images. +

+ +

+ +## Requirements + +- Only Linux is supported. +- Python 3.6. We recommend Anaconda3 with numpy 1.14.3 or newer. +- Tensorflow 1.12 with GPU support (only supported version currently). +- CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer. +- One or more high-end NVIDIA GPUs. We recommend using at least 4 Tesla P100 GPUs for training. +- [Basel Face Model 2009 (BFM09)](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-0&id=basel_face_model). +- [Expression Basis](https://github.com/Juyong/3DFace). The original BFM09 model does not handle expression variations so extra expression basis are needed. +- [Facenet](https://github.com/davidsandberg/facenet). We use the open source face recognition network to extract identity features. +- [3D face reconstruction network](https://github.com/microsoft/Deep3DFaceReconstruction). We use the network to extract identity, expression, lighting, and pose coefficients. + +## Using pre-trained network +1. Clone the repository: + +``` +git clone https://github.com/microsoft/DisentangledFaceGAN.git +cd DisentangledFaceGAN +``` +2. Generate images using pre-trained network: + +``` +# Generate face images with random variations of expression, lighting, and pose +python generate_images.py + +# Generate face images with random variations of expression +python generate_images.py --factor 1 + +# Generate face images with random variations of lighting +python generate_images.py --factor 2 + +# Generate face images with random variations of pose +python generate_images.py --factor 3 +``` + +## Training preparation + +1. Download the Basel Face Model. Due to the license agreement of Basel Face Model, you have to submit an application on its [home page](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). After getting the access to BFM data, download "01_MorphableModel.mat" and put it in "./renderer/BFM face model". +2. Download the Expression Basis provided by [Guo et al.](https://github.com/Juyong/3DFace). You can find a link named "CoarseData" in the first row of Introduction part in their repository. Download and unzip the Coarse_Dataset.zip. Put "Exp_Pca.bin" in "./renderer/BFM face model". +3. Download the [pre-trained weights](https://drive.google.com/file/d/0B5MzpY9kBtDVZ2RpVDYwWmxoSUk/edit) of Facenet, unzip it and put all files in "./training/pretrained_weights/id_net". +4. Download the [pre-trained weights](https://drive.google.com/file/d/176LCdUDxAj7T2awQ5knPMPawq5Q2RUWM/view?usp=sharing) of 3D face reconstruction network, unzip it and put all files in "./training/pretrained_weights/recon_net". +5. Download the [pre-trained weights](https://drive.google.com/file/d/1YkvI_B-cPNo1NhTjiEk8O8FVnVpIypNd/view?usp=sharing) of face parser, unzip it and put all files in "./training/pretrained_weights/parsing_net". + +## Data pre-processing +1. Download [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset). Detect 5 facial landmarks for all images. We recommend using [dlib](http://dlib.net/) or [MTCNN](https://github.com/ipazc/mtcnn). Save all images in and corresponding landmarks in . Note that a image and its detected landmark file should have same name. +2. Align images and extract coefficients for VAE and GAN training: + +``` +python preprocess_data.py +``` +3. Convert the aligned images to multi-resolution TFRecords similar as in [StyleGAN](https://github.com/NVlabs/stylegan): + +``` +python dataset_tool.py create_from_images ./datasets/ffhq_align /img +``` + +## Training networks +1. We provide pre-trained VAEs for factors of identity, expression, lighting, and pose. To train new models from scratch, run: + +``` +cd vae + +# train VAE for identity coefficients +python demo.py --datapath /coeff --factor id + +# train VAE for expression coefficients +python demo.py --datapath /coeff --factor exp + +# train VAE for lighting coefficients +python demo.py --datapath /coeff --factor gamma + +# train VAE for pose coefficients +python demo.py --datapath /coeff --factor rot +``` +2. Train the Stylegan generator with imitative-contrastive learning scheme: + +``` +# Stage 1 with only imitative losses, training with 15000k images +python train.py + +# Stage 2 with both imitative losses and contrastive losses, training with another 5000k images +python train.py --stage 2 --run_id --snapshot --kimg +# For example +python train.py --stage 2 --run_id 0 --snapshot 14926 --kimg 14926 +``` + +After training, the network can be used similarly as the provided pre-trained model: +``` +# Generate face images with specific model +python generate_images.py --model +``` + +We have trained the model using a configuration of 4 Tesla P100 GPUs. It takes 6d 15h for stage 1 and 5d 8h for stage 2. + +## Contact +If you have any questions, please contact Yu Deng (t-yudeng@microsoft.com) and Jiaolong Yang (jiaoyan@microsoft.com) + +## License + +Copyright © Microsoft Corporation. + +Licensed under the MIT license. + +## Citation + +Please cite the following paper if this model helps your research: + + @inproceedings{deng2020disentangled, + title={Disentangled and Controllable Face Image Generation via 3D Imitative-Contrastive Learning}, + author={Yu Deng and Jiaolong Yang and Dong Chen and Fang Wen and Xin Tong}, + booktitle={IEEE Computer Vision and Pattern Recognition}, + year={2020} + } diff --git a/config.py b/config.py new file mode 100644 index 0000000..ef146f1 --- /dev/null +++ b/config.py @@ -0,0 +1,18 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Global configuration.""" + +#---------------------------------------------------------------------------- +# Paths. + +result_dir = 'results' +data_dir = 'datasets' +cache_dir = 'cache' +run_dir_ignore = ['results', 'datasets', 'cache','dnnlib','metrics','vae','preprocess','training','renderer'] + +#---------------------------------------------------------------------------- diff --git a/dataset_tool.py b/dataset_tool.py new file mode 100644 index 0000000..4ddfe44 --- /dev/null +++ b/dataset_tool.py @@ -0,0 +1,645 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.""" + +# pylint: disable=too-many-lines +import os +import sys +import glob +import argparse +import threading +import six.moves.queue as Queue # pylint: disable=import-error +import traceback +import numpy as np +import tensorflow as tf +import PIL.Image +import dnnlib.tflib as tflib + +from training import dataset + +#---------------------------------------------------------------------------- + +def error(msg): + print('Error: ' + msg) + exit(1) + +#---------------------------------------------------------------------------- + +class TFRecordExporter: + def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10): + self.tfrecord_dir = tfrecord_dir + self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) + self.expected_images = expected_images + self.cur_images = 0 + self.shape = None + self.resolution_log2 = None + self.tfr_writers = [] + self.print_progress = print_progress + self.progress_interval = progress_interval + + if self.print_progress: + print('Creating dataset "%s"' % tfrecord_dir) + if not os.path.isdir(self.tfrecord_dir): + os.makedirs(self.tfrecord_dir) + assert os.path.isdir(self.tfrecord_dir) + + def close(self): + if self.print_progress: + print('%-40s\r' % 'Flushing data...', end='', flush=True) + for tfr_writer in self.tfr_writers: + tfr_writer.close() + self.tfr_writers = [] + if self.print_progress: + print('%-40s\r' % '', end='', flush=True) + print('Added %d images.' % self.cur_images) + + def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order. + order = np.arange(self.expected_images) + np.random.RandomState(123).shuffle(order) + return order + + def add_image(self, img): + if self.print_progress and self.cur_images % self.progress_interval == 0: + print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True) + if self.shape is None: + self.shape = img.shape + self.resolution_log2 = int(np.log2(self.shape[1])) + assert self.shape[0] in [1, 3] + assert self.shape[1] == self.shape[2] + assert self.shape[1] == 2**self.resolution_log2 + tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) + for lod in range(self.resolution_log2 - 1): + tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) + self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) + assert img.shape == self.shape + for lod, tfr_writer in enumerate(self.tfr_writers): + if lod: + img = img.astype(np.float32) + img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25 + quant = np.rint(img).clip(0, 255).astype(np.uint8) + ex = tf.train.Example(features=tf.train.Features(feature={ + 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)), + 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))})) + tfr_writer.write(ex.SerializeToString()) + self.cur_images += 1 + + def add_labels(self, labels): + if self.print_progress: + print('%-40s\r' % 'Saving labels...', end='', flush=True) + assert labels.shape[0] == self.cur_images + with open(self.tfr_prefix + '-rxx.labels', 'wb') as f: + np.save(f, labels.astype(np.float32)) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + +#---------------------------------------------------------------------------- + +class ExceptionInfo(object): + def __init__(self): + self.value = sys.exc_info()[1] + self.traceback = traceback.format_exc() + +#---------------------------------------------------------------------------- + +class WorkerThread(threading.Thread): + def __init__(self, task_queue): + threading.Thread.__init__(self) + self.task_queue = task_queue + + def run(self): + while True: + func, args, result_queue = self.task_queue.get() + if func is None: + break + try: + result = func(*args) + except: + result = ExceptionInfo() + result_queue.put((result, args)) + +#---------------------------------------------------------------------------- + +class ThreadPool(object): + def __init__(self, num_threads): + assert num_threads >= 1 + self.task_queue = Queue.Queue() + self.result_queues = dict() + self.num_threads = num_threads + for _idx in range(self.num_threads): + thread = WorkerThread(self.task_queue) + thread.daemon = True + thread.start() + + def add_task(self, func, args=()): + assert hasattr(func, '__call__') # must be a function + if func not in self.result_queues: + self.result_queues[func] = Queue.Queue() + self.task_queue.put((func, args, self.result_queues[func])) + + def get_result(self, func): # returns (result, args) + result, args = self.result_queues[func].get() + if isinstance(result, ExceptionInfo): + print('\n\nWorker thread caught an exception:\n' + result.traceback) + raise result.value + return result, args + + def finish(self): + for _idx in range(self.num_threads): + self.task_queue.put((None, (), None)) + + def __enter__(self): # for 'with' statement + return self + + def __exit__(self, *excinfo): + self.finish() + + def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None): + if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4 + assert max_items_in_flight >= 1 + results = [] + retire_idx = [0] + + def task_func(prepared, _idx): + return process_func(prepared) + + def retire_result(): + processed, (_prepared, idx) = self.get_result(task_func) + results[idx] = processed + while retire_idx[0] < len(results) and results[retire_idx[0]] is not None: + yield post_func(results[retire_idx[0]]) + results[retire_idx[0]] = None + retire_idx[0] += 1 + + for idx, item in enumerate(item_iterator): + prepared = pre_func(item) + results.append(None) + self.add_task(func=task_func, args=(prepared, idx)) + while retire_idx[0] < idx - max_items_in_flight + 2: + for res in retire_result(): yield res + while retire_idx[0] < len(results): + for res in retire_result(): yield res + +#---------------------------------------------------------------------------- + +def display(tfrecord_dir): + print('Loading dataset "%s"' % tfrecord_dir) + tflib.init_tf({'gpu_options.allow_growth': True}) + dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0) + tflib.init_uninitialized_vars() + import cv2 # pip install opencv-python + + idx = 0 + while True: + try: + images, labels = dset.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + break + if idx == 0: + print('Displaying images') + cv2.namedWindow('dataset_tool') + print('Press SPACE or ENTER to advance, ESC to exit') + print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist())) + cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR + idx += 1 + if cv2.waitKey() == 27: + break + print('\nDisplayed %d images.' % idx) + +#---------------------------------------------------------------------------- + +def extract(tfrecord_dir, output_dir): + print('Loading dataset "%s"' % tfrecord_dir) + tflib.init_tf({'gpu_options.allow_growth': True}) + dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0) + tflib.init_uninitialized_vars() + + print('Extracting images to "%s"' % output_dir) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + idx = 0 + while True: + if idx % 10 == 0: + print('%d\r' % idx, end='', flush=True) + try: + images, _labels = dset.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + break + if images.shape[1] == 1: + img = PIL.Image.fromarray(images[0][0], 'L') + else: + img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB') + img.save(os.path.join(output_dir, 'img%08d.png' % idx)) + idx += 1 + print('Extracted %d images.' % idx) + +#---------------------------------------------------------------------------- + +def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels): + max_label_size = 0 if ignore_labels else 'full' + print('Loading dataset "%s"' % tfrecord_dir_a) + tflib.init_tf({'gpu_options.allow_growth': True}) + dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0) + print('Loading dataset "%s"' % tfrecord_dir_b) + dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0) + tflib.init_uninitialized_vars() + + print('Comparing datasets') + idx = 0 + identical_images = 0 + identical_labels = 0 + while True: + if idx % 100 == 0: + print('%d\r' % idx, end='', flush=True) + try: + images_a, labels_a = dset_a.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + images_a, labels_a = None, None + try: + images_b, labels_b = dset_b.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + images_b, labels_b = None, None + if images_a is None or images_b is None: + if images_a is not None or images_b is not None: + print('Datasets contain different number of images') + break + if images_a.shape == images_b.shape and np.all(images_a == images_b): + identical_images += 1 + else: + print('Image %d is different' % idx) + if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b): + identical_labels += 1 + else: + print('Label %d is different' % idx) + idx += 1 + print('Identical images: %d / %d' % (identical_images, idx)) + if not ignore_labels: + print('Identical labels: %d / %d' % (identical_labels, idx)) + +#---------------------------------------------------------------------------- + +def create_mnist(tfrecord_dir, mnist_dir): + print('Loading MNIST from "%s"' % mnist_dir) + import gzip + with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: + images = np.frombuffer(file.read(), np.uint8, offset=16) + with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file: + labels = np.frombuffer(file.read(), np.uint8, offset=8) + images = images.reshape(-1, 1, 28, 28) + images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (60000,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123): + print('Loading MNIST from "%s"' % mnist_dir) + import gzip + with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: + images = np.frombuffer(file.read(), np.uint8, offset=16) + images = images.reshape(-1, 28, 28) + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + + with TFRecordExporter(tfrecord_dir, num_images) as tfr: + rnd = np.random.RandomState(random_seed) + for _idx in range(num_images): + tfr.add_image(images[rnd.randint(images.shape[0], size=3)]) + +#---------------------------------------------------------------------------- + +def create_cifar10(tfrecord_dir, cifar10_dir): + print('Loading CIFAR-10 from "%s"' % cifar10_dir) + import pickle + images = [] + labels = [] + for batch in range(1, 6): + with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file: + data = pickle.load(file, encoding='latin1') + images.append(data['data'].reshape(-1, 3, 32, 32)) + labels.append(data['labels']) + images = np.concatenate(images) + labels = np.concatenate(labels) + assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype == np.int32 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_cifar100(tfrecord_dir, cifar100_dir): + print('Loading CIFAR-100 from "%s"' % cifar100_dir) + import pickle + with open(os.path.join(cifar100_dir, 'train'), 'rb') as file: + data = pickle.load(file, encoding='latin1') + images = data['data'].reshape(-1, 3, 32, 32) + labels = np.array(data['fine_labels']) + assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype == np.int32 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 99 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_svhn(tfrecord_dir, svhn_dir): + print('Loading SVHN from "%s"' % svhn_dir) + import pickle + images = [] + labels = [] + for batch in range(1, 4): + with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file: + data = pickle.load(file, encoding='latin1') + images.append(data[0]) + labels.append(data[1]) + images = np.concatenate(images) + labels = np.concatenate(labels) + assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (73257,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None): + print('Loading LSUN dataset from "%s"' % lmdb_dir) + import lmdb # pip install lmdb # pylint: disable=import-error + import cv2 # pip install opencv-python + import io + with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: + total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter + if max_images is None: + max_images = total_images + with TFRecordExporter(tfrecord_dir, max_images) as tfr: + for _idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) + if img is None: + raise IOError('cv2.imdecode failed') + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.asarray(PIL.Image.open(io.BytesIO(value))) + crop = np.min(img.shape[:2]) + img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS) + img = np.asarray(img) + img = img.transpose([2, 0, 1]) # HWC => CHW + tfr.add_image(img) + except: + print(sys.exc_info()[1]) + if tfr.cur_images == max_images: + break + +#---------------------------------------------------------------------------- + +def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None): + assert width == 2 ** int(np.round(np.log2(width))) + assert height <= width + print('Loading LSUN dataset from "%s"' % lmdb_dir) + import lmdb # pip install lmdb # pylint: disable=import-error + import cv2 # pip install opencv-python + import io + with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: + total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter + if max_images is None: + max_images = total_images + with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr: + for idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) + if img is None: + raise IOError('cv2.imdecode failed') + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.asarray(PIL.Image.open(io.BytesIO(value))) + + ch = int(np.round(width * img.shape[0] / img.shape[1])) + if img.shape[1] < width or ch < height: + continue + + img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.ANTIALIAS) + img = np.asarray(img) + img = img.transpose([2, 0, 1]) # HWC => CHW + + canvas = np.zeros([3, width, width], dtype=np.uint8) + canvas[:, (width - height) // 2 : (width + height) // 2] = img + tfr.add_image(canvas) + print('\r%d / %d => %d ' % (idx + 1, total_images, tfr.cur_images), end='') + + except: + print(sys.exc_info()[1]) + if tfr.cur_images == max_images: + break + print() + +#---------------------------------------------------------------------------- + +def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121): + print('Loading CelebA from "%s"' % celeba_dir) + glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png') + image_filenames = sorted(glob.glob(glob_pattern)) + expected_images = 202599 + if len(image_filenames) != expected_images: + error('Expected to find %d images' % expected_images) + + with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) + assert img.shape == (218, 178, 3) + img = img[cy - 64 : cy + 64, cx - 64 : cx + 64] + img = img.transpose(2, 0, 1) # HWC => CHW + tfr.add_image(img) + +#---------------------------------------------------------------------------- + +def create_from_images(tfrecord_dir, image_dir, shuffle): + print('Loading images from "%s"' % image_dir) + image_filenames = sorted(glob.glob(os.path.join(image_dir, '*'))) + if len(image_filenames) == 0: + error('No input images found') + + img = np.asarray(PIL.Image.open(image_filenames[0])) + resolution = img.shape[0] + channels = img.shape[2] if img.ndim == 3 else 1 + if img.shape[1] != resolution: + error('Input images must have the same width and height') + if resolution != 2 ** int(np.floor(np.log2(resolution))): + error('Input image resolution must be a power-of-two') + if channels not in [1, 3]: + error('Input images must be stored as RGB or grayscale') + + with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: + order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames)) + for idx in range(order.size): + img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) + if channels == 1: + img = img[np.newaxis, :, :] # HW => CHW + else: + img = img.transpose([2, 0, 1]) # HWC => CHW + tfr.add_image(img) + +#---------------------------------------------------------------------------- + +def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle): + print('Loading HDF5 archive from "%s"' % hdf5_filename) + import h5py # conda install h5py + with h5py.File(hdf5_filename, 'r') as hdf5_file: + hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3]) + with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr: + order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0]) + for idx in range(order.size): + tfr.add_image(hdf5_data[order[idx]]) + npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy' + if os.path.isfile(npy_filename): + tfr.add_labels(np.load(npy_filename)[order]) + +#---------------------------------------------------------------------------- + +def execute_cmdline(argv): + prog = argv[0] + parser = argparse.ArgumentParser( + prog = prog, + description = 'Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.', + epilog = 'Type "%s -h" for more information.' % prog) + + subparsers = parser.add_subparsers(dest='command') + subparsers.required = True + def add_command(cmd, desc, example=None): + epilog = 'Example: %s %s' % (prog, example) if example is not None else None + return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog) + + p = add_command( 'display', 'Display images in dataset.', + 'display datasets/mnist') + p.add_argument( 'tfrecord_dir', help='Directory containing dataset') + + p = add_command( 'extract', 'Extract images from dataset.', + 'extract datasets/mnist mnist-images') + p.add_argument( 'tfrecord_dir', help='Directory containing dataset') + p.add_argument( 'output_dir', help='Directory to extract the images into') + + p = add_command( 'compare', 'Compare two datasets.', + 'compare datasets/mydataset datasets/mnist') + p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset') + p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset') + p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0) + + p = add_command( 'create_mnist', 'Create dataset for MNIST.', + 'create_mnist datasets/mnist ~/downloads/mnist') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'mnist_dir', help='Directory containing MNIST') + + p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.', + 'create_mnistrgb datasets/mnistrgb ~/downloads/mnist') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'mnist_dir', help='Directory containing MNIST') + p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000) + p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123) + + p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.', + 'create_cifar10 datasets/cifar10 ~/downloads/cifar10') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10') + + p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.', + 'create_cifar100 datasets/cifar100 ~/downloads/cifar100') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100') + + p = add_command( 'create_svhn', 'Create dataset for SVHN.', + 'create_svhn datasets/svhn ~/downloads/svhn') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'svhn_dir', help='Directory containing SVHN') + + p = add_command( 'create_lsun', 'Create dataset for single LSUN category.', + 'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'lmdb_dir', help='Directory containing LMDB database') + p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256) + p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) + + p = add_command( 'create_lsun_wide', 'Create LSUN dataset with non-square aspect ratio.', + 'create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'lmdb_dir', help='Directory containing LMDB database') + p.add_argument( '--width', help='Output width (default: 512)', type=int, default=512) + p.add_argument( '--height', help='Output height (default: 384)', type=int, default=384) + p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) + + p = add_command( 'create_celeba', 'Create dataset for CelebA.', + 'create_celeba datasets/celeba ~/downloads/celeba') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'celeba_dir', help='Directory containing CelebA') + p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89) + p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121) + + p = add_command( 'create_from_images', 'Create dataset from a directory full of images.', + 'create_from_images datasets/mydataset myimagedir') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'image_dir', help='Directory containing the images') + p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) + + p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.', + 'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images') + p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) + + args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h']) + func = globals()[args.command] + del args.command + func(**vars(args)) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + execute_cmdline(sys.argv) + +#---------------------------------------------------------------------------- diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py new file mode 100644 index 0000000..ad43827 --- /dev/null +++ b/dnnlib/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +from . import submission + +from .submission.run_context import RunContext + +from .submission.submit import SubmitTarget +from .submission.submit import PathType +from .submission.submit import SubmitConfig +from .submission.submit import get_path_from_template +from .submission.submit import submit_run + +from .util import EasyDict + +submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. diff --git a/dnnlib/submission/__init__.py b/dnnlib/submission/__init__.py new file mode 100644 index 0000000..5385612 --- /dev/null +++ b/dnnlib/submission/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +from . import run_context +from . import submit diff --git a/dnnlib/submission/_internal/run.py b/dnnlib/submission/_internal/run.py new file mode 100644 index 0000000..18f830d --- /dev/null +++ b/dnnlib/submission/_internal/run.py @@ -0,0 +1,45 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper for launching run functions in computing clusters. + +During the submit process, this file is copied to the appropriate run dir. +When the job is launched in the cluster, this module is the first thing that +is run inside the docker container. +""" + +import os +import pickle +import sys + +# PYTHONPATH should have been set so that the run_dir/src is in it +import dnnlib + +def main(): + if not len(sys.argv) >= 4: + raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") + + run_dir = str(sys.argv[1]) + task_name = str(sys.argv[2]) + host_name = str(sys.argv[3]) + + submit_config_path = os.path.join(run_dir, "submit_config.pkl") + + # SubmitConfig should have been pickled to the run dir + if not os.path.exists(submit_config_path): + raise RuntimeError("SubmitConfig pickle file does not exist!") + + submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) + dnnlib.submission.submit.set_user_name_override(submit_config.user_name) + + submit_config.task_name = task_name + submit_config.host_name = host_name + + dnnlib.submission.submit.run_wrapper(submit_config) + +if __name__ == "__main__": + main() diff --git a/dnnlib/submission/run_context.py b/dnnlib/submission/run_context.py new file mode 100644 index 0000000..932320e --- /dev/null +++ b/dnnlib/submission/run_context.py @@ -0,0 +1,99 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helpers for managing the run/training loop.""" + +import datetime +import json +import os +import pprint +import time +import types + +from typing import Any + +from . import submit + + +class RunContext(object): + """Helper class for managing the run/training loop. + + The context will hide the implementation details of a basic run/training loop. + It will set things up properly, tell if run should be stopped, and then cleans up. + User should call update periodically and use should_stop to determine if run should be stopped. + + Args: + submit_config: The SubmitConfig that is used for the current run. + config_module: The whole config module that is used for the current run. + max_epoch: Optional cached value for the max_epoch variable used in update. + """ + + def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): + self.submit_config = submit_config + self.should_stop_flag = False + self.has_closed = False + self.start_time = time.time() + self.last_update_time = time.time() + self.last_update_interval = 0.0 + self.max_epoch = max_epoch + + # pretty print the all the relevant content of the config module to a text file + if config_module is not None: + with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: + filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} + pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) + + # write out details about the run to a text file + self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} + with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: + pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) + + def __enter__(self) -> "RunContext": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: + """Do general housekeeping and keep the state of the context up-to-date. + Should be called often enough but not in a tight loop.""" + assert not self.has_closed + + self.last_update_interval = time.time() - self.last_update_time + self.last_update_time = time.time() + + if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): + self.should_stop_flag = True + + max_epoch_val = self.max_epoch if max_epoch is None else max_epoch + + def should_stop(self) -> bool: + """Tell whether a stopping condition has been triggered one way or another.""" + return self.should_stop_flag + + def get_time_since_start(self) -> float: + """How much time has passed since the creation of the context.""" + return time.time() - self.start_time + + def get_time_since_last_update(self) -> float: + """How much time has passed since the last call to update.""" + return time.time() - self.last_update_time + + def get_last_update_interval(self) -> float: + """How much time passed between the previous two calls to update.""" + return self.last_update_interval + + def close(self) -> None: + """Close the context and clean up. + Should only be called once.""" + if not self.has_closed: + # update the run.txt with stopping time + self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") + with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: + pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) + + self.has_closed = True diff --git a/dnnlib/submission/submit.py b/dnnlib/submission/submit.py new file mode 100644 index 0000000..60ff428 --- /dev/null +++ b/dnnlib/submission/submit.py @@ -0,0 +1,290 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Submit a function to be run either locally or in a computing cluster.""" + +import copy +import io +import os +import pathlib +import pickle +import platform +import pprint +import re +import shutil +import time +import traceback + +import zipfile + +from enum import Enum + +from .. import util +from ..util import EasyDict + + +class SubmitTarget(Enum): + """The target where the function should be run. + + LOCAL: Run it locally. + """ + LOCAL = 1 + + +class PathType(Enum): + """Determines in which format should a path be formatted. + + WINDOWS: Format with Windows style. + LINUX: Format with Linux/Posix style. + AUTO: Use current OS type to select either WINDOWS or LINUX. + """ + WINDOWS = 1 + LINUX = 2 + AUTO = 3 + + +_user_name_override = None + + +class SubmitConfig(util.EasyDict): + """Strongly typed config dict needed to submit runs. + + Attributes: + run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. + run_desc: Description of the run. Will be used in the run dir and task name. + run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. + run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. + submit_target: Submit target enum value. Used to select where the run is actually launched. + num_gpus: Number of GPUs used/requested for the run. + print_info: Whether to print debug information when submitting. + ask_confirmation: Whether to ask a confirmation before submitting. + run_id: Automatically populated value during submit. + run_name: Automatically populated value during submit. + run_dir: Automatically populated value during submit. + run_func_name: Automatically populated value during submit. + run_func_kwargs: Automatically populated value during submit. + user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. + task_name: Automatically populated value during submit. + host_name: Automatically populated value during submit. + """ + + def __init__(self): + super().__init__() + + # run (set these) + self.run_dir_root = "" # should always be passed through get_path_from_template + self.run_desc = "" + self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] + self.run_dir_extra_files = None + + # submit (set these) + self.submit_target = SubmitTarget.LOCAL + self.num_gpus = 1 + self.print_info = False + self.ask_confirmation = False + + # (automatically populated) + self.run_id = None + self.run_name = None + self.run_dir = None + self.run_func_name = None + self.run_func_kwargs = None + self.user_name = None + self.task_name = None + self.host_name = "localhost" + + +def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: + """Replace tags in the given path template and return either Windows or Linux formatted path.""" + # automatically select path type depending on running OS + if path_type == PathType.AUTO: + if platform.system() == "Windows": + path_type = PathType.WINDOWS + elif platform.system() == "Linux": + path_type = PathType.LINUX + else: + raise RuntimeError("Unknown platform") + + path_template = path_template.replace("", get_user_name()) + + # return correctly formatted path + if path_type == PathType.WINDOWS: + return str(pathlib.PureWindowsPath(path_template)) + elif path_type == PathType.LINUX: + return str(pathlib.PurePosixPath(path_template)) + else: + raise RuntimeError("Unknown platform") + + +def get_template_from_path(path: str) -> str: + """Convert a normal path back to its template representation.""" + # replace all path parts with the template tags + path = path.replace("\\", "/") + return path + + +def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: + """Convert a normal path to template and the convert it back to a normal path with given path type.""" + path_template = get_template_from_path(path) + path = get_path_from_template(path_template, path_type) + return path + + +def set_user_name_override(name: str) -> None: + """Set the global username override value.""" + global _user_name_override + _user_name_override = name + + +def get_user_name(): + """Get the current user name.""" + if _user_name_override is not None: + return _user_name_override + elif platform.system() == "Windows": + return os.getlogin() + elif platform.system() == "Linux": + try: + import pwd # pylint: disable=import-error + return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member + except: + return "unknown" + else: + raise RuntimeError("Unknown platform") + + +def _create_run_dir_local(submit_config: SubmitConfig) -> str: + """Create a new run dir with increasing ID number at the start.""" + run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) + + if not os.path.exists(run_dir_root): + print("Creating the run dir root: {}".format(run_dir_root)) + os.makedirs(run_dir_root) + + submit_config.run_id = _get_next_run_id_local(run_dir_root) + submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) + run_dir = os.path.join(run_dir_root, submit_config.run_name) + + if os.path.exists(run_dir): + raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) + + print("Creating the run dir: {}".format(run_dir)) + os.makedirs(run_dir) + + return run_dir + + +def _get_next_run_id_local(run_dir_root: str) -> int: + """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" + dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] + r = re.compile("^\\d+") # match one or more digits at the start of the string + run_id = 0 + + for dir_name in dir_names: + m = r.match(dir_name) + + if m is not None: + i = int(m.group()) + run_id = max(run_id, i + 1) + + return run_id + + +def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: + """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" + print("Copying files to the run dir") + files = [] + + run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) + assert '.' in submit_config.run_func_name + for _idx in range(submit_config.run_func_name.count('.') - 1): + run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) + files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) + + dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") + files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) + + if submit_config.run_dir_extra_files is not None: + files += submit_config.run_dir_extra_files + + files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] + files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] + + util.copy_files_and_create_dirs(files) + + pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) + + with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: + pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) + + +def run_wrapper(submit_config: SubmitConfig) -> None: + """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" + is_local = submit_config.submit_target == SubmitTarget.LOCAL + + checker = None + + # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing + if is_local: + logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) + else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) + logger = util.Logger(file_name=None, should_flush=True) + + import dnnlib + dnnlib.submit_config = submit_config + + try: + print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) + start_time = time.time() + util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) + print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) + except: + if is_local: + raise + else: + traceback.print_exc() + + log_src = os.path.join(submit_config.run_dir, "log.txt") + log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) + shutil.copyfile(log_src, log_dst) + finally: + open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() + + dnnlib.submit_config = None + logger.close() + + if checker is not None: + checker.stop() + + +def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: + """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" + submit_config = copy.copy(submit_config) + + if submit_config.user_name is None: + submit_config.user_name = get_user_name() + + submit_config.run_func_name = run_func_name + submit_config.run_func_kwargs = run_func_kwargs + + assert submit_config.submit_target == SubmitTarget.LOCAL + if submit_config.submit_target in {SubmitTarget.LOCAL}: + run_dir = _create_run_dir_local(submit_config) + + submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) + submit_config.run_dir = run_dir + _populate_run_dir(run_dir, submit_config) + + if submit_config.print_info: + print("\nSubmit config:\n") + pprint.pprint(submit_config, indent=4, width=200, compact=False) + print() + + if submit_config.ask_confirmation: + if not util.ask_yes_no("Continue submitting the job?"): + return + + run_wrapper(submit_config) diff --git a/dnnlib/tflib/__init__.py b/dnnlib/tflib/__init__.py new file mode 100644 index 0000000..f054a39 --- /dev/null +++ b/dnnlib/tflib/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +from . import autosummary +from . import network +from . import optimizer +from . import tfutil + +from .tfutil import * +from .network import Network + +from .optimizer import Optimizer diff --git a/dnnlib/tflib/autosummary.py b/dnnlib/tflib/autosummary.py new file mode 100644 index 0000000..43154f7 --- /dev/null +++ b/dnnlib/tflib/autosummary.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper for adding automatically tracked values to Tensorboard. + +Autosummary creates an identity op that internally keeps track of the input +values and automatically shows up in TensorBoard. The reported value +represents an average over input components. The average is accumulated +constantly over time and flushed when save_summaries() is called. + +Notes: +- The output tensor must be used as an input for something else in the + graph. Otherwise, the autosummary op will not get executed, and the average + value will not get accumulated. +- It is perfectly fine to include autosummaries with the same name in + several places throughout the graph, even if they are executed concurrently. +- It is ok to also pass in a python scalar or numpy array. In this case, it + is added to the average immediately. +""" + +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorboard import summary as summary_lib +from tensorboard.plugins.custom_scalar import layout_pb2 + +from . import tfutil +from .tfutil import TfExpression +from .tfutil import TfExpressionEx + +_dtype = tf.float64 +_vars = OrderedDict() # name => [var, ...] +_immediate = OrderedDict() # name => update_op, update_value +_finalized = False +_merge_op = None + + +def _create_var(name: str, value_expr: TfExpression) -> TfExpression: + """Internal helper for creating autosummary accumulators.""" + assert not _finalized + name_id = name.replace("/", "_") + v = tf.cast(value_expr, _dtype) + + if v.shape.is_fully_defined(): + size = np.prod(tfutil.shape_to_list(v.shape)) + size_expr = tf.constant(size, dtype=_dtype) + else: + size = None + size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) + + if size == 1: + if v.shape.ndims != 0: + v = tf.reshape(v, []) + v = [size_expr, v, tf.square(v)] + else: + v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] + v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) + + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): + var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] + update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) + + if name in _vars: + _vars[name].append(var) + else: + _vars[name] = [var] + return update_op + + +def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: + """Create a new autosummary. + + Args: + name: Name to use in TensorBoard + value: TensorFlow expression or python value to track + passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. + + Example use of the passthru mechanism: + + n = autosummary('l2loss', loss, passthru=n) + + This is a shorthand for the following code: + + with tf.control_dependencies([autosummary('l2loss', loss)]): + n = tf.identity(n) + """ + tfutil.assert_tf_initialized() + name_id = name.replace("/", "_") + + if tfutil.is_tf_expression(value): + with tf.name_scope("summary_" + name_id), tf.device(value.device): + update_op = _create_var(name, value) + with tf.control_dependencies([update_op]): + return tf.identity(value if passthru is None else passthru) + + else: # python scalar or numpy array + if name not in _immediate: + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): + update_value = tf.placeholder(_dtype) + update_op = _create_var(name, update_value) + _immediate[name] = update_op, update_value + + update_op, update_value = _immediate[name] + tfutil.run(update_op, {update_value: value}) + return value if passthru is None else passthru + + +def finalize_autosummaries() -> None: + """Create the necessary ops to include autosummaries in TensorBoard report. + Note: This should be done only once per graph. + """ + global _finalized + tfutil.assert_tf_initialized() + + if _finalized: + return None + + _finalized = True + tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) + + # Create summary ops. + with tf.device(None), tf.control_dependencies(None): + for name, vars_list in _vars.items(): + name_id = name.replace("/", "_") + with tfutil.absolute_name_scope("Autosummary/" + name_id): + moments = tf.add_n(vars_list) + moments /= moments[0] + with tf.control_dependencies([moments]): # read before resetting + reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] + with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting + mean = moments[1] + std = tf.sqrt(moments[2] - tf.square(moments[1])) + tf.summary.scalar(name, mean) + tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) + tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) + + # Group by category and chart name. + cat_dict = OrderedDict() + for series_name in sorted(_vars.keys()): + p = series_name.split("/") + cat = p[0] if len(p) >= 2 else "" + chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] + if cat not in cat_dict: + cat_dict[cat] = OrderedDict() + if chart not in cat_dict[cat]: + cat_dict[cat][chart] = [] + cat_dict[cat][chart].append(series_name) + + # Setup custom_scalar layout. + categories = [] + for cat_name, chart_dict in cat_dict.items(): + charts = [] + for chart_name, series_names in chart_dict.items(): + series = [] + for series_name in series_names: + series.append(layout_pb2.MarginChartContent.Series( + value=series_name, + lower="xCustomScalars/" + series_name + "/margin_lo", + upper="xCustomScalars/" + series_name + "/margin_hi")) + margin = layout_pb2.MarginChartContent(series=series) + charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) + categories.append(layout_pb2.Category(title=cat_name, chart=charts)) + layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) + return layout + +def save_summaries(file_writer, global_step=None): + """Call FileWriter.add_summary() with all summaries in the default graph, + automatically finalizing and merging them on the first call. + """ + global _merge_op + tfutil.assert_tf_initialized() + + if _merge_op is None: + layout = finalize_autosummaries() + if layout is not None: + file_writer.add_summary(layout) + with tf.device(None), tf.control_dependencies(None): + _merge_op = tf.summary.merge_all() + + file_writer.add_summary(_merge_op.eval(), global_step) diff --git a/dnnlib/tflib/network.py b/dnnlib/tflib/network.py new file mode 100644 index 0000000..d888a90 --- /dev/null +++ b/dnnlib/tflib/network.py @@ -0,0 +1,591 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper for managing networks.""" + +import types +import inspect +import re +import uuid +import sys +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import Any, List, Tuple, Union + +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. +_import_module_src = dict() # Source code for temporary modules created during pickle import. + + +def import_handler(handler_func): + """Function decorator for declaring custom import handlers.""" + _import_handlers.append(handler_func) + return handler_func + + +class Network: + """Generic network abstraction. + + Acts as a convenience wrapper for a parameterized network construction + function, providing several utility methods and convenient access to + the inputs/outputs/weights. + + Network objects can be safely pickled and unpickled for long-term + archival purposes. The pickling works reliably as long as the underlying + network construction function is defined in a standalone Python module + that has no side effects or application-specific imports. + + Args: + name: Network name. Used to select TensorFlow name and variable scopes. + func_name: Fully qualified name of the underlying network construction function, or a top-level function object. + static_kwargs: Keyword arguments to be passed in to the network construction function. + + Attributes: + name: User-specified name, defaults to build func name if None. + scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name. + static_kwargs: Arguments passed to the user-supplied build func. + components: Container for sub-networks. Passed to the build func, and retained between calls. + num_inputs: Number of input tensors. + num_outputs: Number of output tensors. + input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension. + output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension. + input_shape: Short-hand for input_shapes[0]. + output_shape: Short-hand for output_shapes[0]. + input_templates: Input placeholders in the template graph. + output_templates: Output tensors in the template graph. + input_names: Name string for each input. + output_names: Name string for each output. + own_vars: Variables defined by this network (local_name => var), excluding sub-networks. + vars: All variables (local_name => var). + trainables: All trainable variables (local_name => var). + var_global_to_local: Mapping from variable global names to local names. + """ + + def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): + tfutil.assert_tf_initialized() + assert isinstance(name, str) or name is None + assert func_name is not None + assert isinstance(func_name, str) or util.is_top_level_function(func_name) + assert util.is_pickleable(static_kwargs) + + self._init_fields() + self.name = name + self.static_kwargs = util.EasyDict(static_kwargs) + + # Locate the user-specified network build function. + if util.is_top_level_function(func_name): + func_name = util.get_top_level_function_name(func_name) + module, self._build_func_name = util.get_module_from_obj_name(func_name) + self._build_func = util.get_obj_from_module(module, self._build_func_name) + assert callable(self._build_func) + + # Dig up source code for the module containing the build function. + self._build_module_src = _import_module_src.get(module, None) + if self._build_module_src is None: + self._build_module_src = inspect.getsource(module) + + # Init TensorFlow graph. + self._init_graph() + self.reset_own_vars() + + def _init_fields(self) -> None: + self.name = None + self.scope = None + self.static_kwargs = util.EasyDict() + self.components = util.EasyDict() + self.num_inputs = 0 + self.num_outputs = 0 + self.input_shapes = [[]] + self.output_shapes = [[]] + self.input_shape = [] + self.output_shape = [] + self.input_templates = [] + self.output_templates = [] + self.input_names = [] + self.output_names = [] + self.own_vars = OrderedDict() + self.vars = OrderedDict() + self.trainables = OrderedDict() + self.var_global_to_local = OrderedDict() + + self._build_func = None # User-supplied build function that constructs the network. + self._build_func_name = None # Name of the build function. + self._build_module_src = None # Full source code of the module containing the build function. + self._run_cache = dict() # Cached graph data for Network.run(). + + def _init_graph(self) -> None: + # Collect inputs. + self.input_names = [] + + for param in inspect.signature(self._build_func).parameters.values(): + if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: + self.input_names.append(param.name) + + self.num_inputs = len(self.input_names) + assert self.num_inputs >= 1 + + # Choose name and scope. + if self.name is None: + self.name = self._build_func_name + assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) + with tf.name_scope(None): + self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) + + # Finalize build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs["is_template_graph"] = True + build_kwargs["components"] = self.components + + # Build template graph. + with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes + assert tf.get_variable_scope().name == self.scope + assert tf.get_default_graph().get_name_scope() == self.scope + with tf.control_dependencies(None): # ignore surrounding control dependencies + self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] + out_expr = self._build_func(*self.input_templates, **build_kwargs) + + # Collect outputs. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + self.num_outputs = len(self.output_templates) + assert self.num_outputs >= 1 + assert all(tfutil.is_tf_expression(t) for t in self.output_templates) + + # Perform sanity checks. + if any(t.shape.ndims is None for t in self.input_templates): + raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") + if any(t.shape.ndims is None for t in self.output_templates): + raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") + if any(not isinstance(comp, Network) for comp in self.components.values()): + raise ValueError("Components of a Network must be Networks themselves.") + if len(self.components) != len(set(comp.name for comp in self.components.values())): + raise ValueError("Components of a Network must have unique names.") + + # List inputs and outputs. + self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates] + self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates] + self.input_shape = self.input_shapes[0] + self.output_shape = self.output_shapes[0] + self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] + + # List variables. + self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) + self.vars = OrderedDict(self.own_vars) + self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) + self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) + self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) + + def reset_own_vars(self) -> None: + """Re-initialize all variables of this network, excluding sub-networks.""" + tfutil.run([var.initializer for var in self.own_vars.values()]) + + def reset_vars(self) -> None: + """Re-initialize all variables of this network, including sub-networks.""" + tfutil.run([var.initializer for var in self.vars.values()]) + + def reset_trainables(self) -> None: + """Re-initialize all trainable variables of this network, including sub-networks.""" + tfutil.run([var.initializer for var in self.trainables.values()]) + + def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: + """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" + assert len(in_expr) == self.num_inputs + assert not all(expr is None for expr in in_expr) + + # Finalize build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs.update(dynamic_kwargs) + build_kwargs["is_template_graph"] = False + build_kwargs["components"] = self.components + + # Build TensorFlow graph to evaluate the network. + with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): + assert tf.get_variable_scope().name == self.scope + valid_inputs = [expr for expr in in_expr if expr is not None] + final_inputs = [] + for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): + if expr is not None: + expr = tf.identity(expr, name=name) + else: + expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) + final_inputs.append(expr) + out_expr = self._build_func(*final_inputs, **build_kwargs) + + # Propagate input shapes back to the user-specified expressions. + for expr, final in zip(in_expr, final_inputs): + if isinstance(expr, tf.Tensor): + expr.set_shape(final.shape) + + # Express outputs in the desired format. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + if return_as_list: + out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + return out_expr + + def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: + """Get the local name of a given variable, without any surrounding name scopes.""" + assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) + global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name + return self.var_global_to_local[global_name] + + def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: + """Find variable by local or global name.""" + assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) + return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name + + def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: + """Get the value of a given variable as NumPy array. + Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" + return self.find_var(var_or_local_name).eval() + + def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: + """Set the value of a given variable based on the given NumPy array. + Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" + tfutil.set_vars({self.find_var(var_or_local_name): new_value}) + + def __getstate__(self) -> dict: + """Pickle export.""" + state = dict() + state["version"] = 3 + state["name"] = self.name + state["static_kwargs"] = dict(self.static_kwargs) + state["components"] = dict(self.components) + state["build_module_src"] = self._build_module_src + state["build_func_name"] = self._build_func_name + state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values())))) + return state + + def __setstate__(self, state: dict) -> None: + """Pickle import.""" + # pylint: disable=attribute-defined-outside-init + tfutil.assert_tf_initialized() + self._init_fields() + + # Execute custom import handlers. + for handler in _import_handlers: + state = handler(state) + + # Set basic fields. + assert state["version"] in [2, 3] + self.name = state["name"] + self.static_kwargs = util.EasyDict(state["static_kwargs"]) + self.components = util.EasyDict(state.get("components", {})) + self._build_module_src = state["build_module_src"] + self._build_func_name = state["build_func_name"] + + # Create temporary module from the imported source code. + module_name = "_tflib_network_import_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _import_module_src[module] = self._build_module_src + exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used + + # Locate network build function in the temporary module. + self._build_func = util.get_obj_from_module(module, self._build_func_name) + assert callable(self._build_func) + + # Init TensorFlow graph. + self._init_graph() + self.reset_own_vars() + tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]}) + + def clone(self, name: str = None, **new_static_kwargs) -> "Network": + """Create a clone of this network with its own copy of the variables.""" + # pylint: disable=protected-access + net = object.__new__(Network) + net._init_fields() + net.name = name if name is not None else self.name + net.static_kwargs = util.EasyDict(self.static_kwargs) + net.static_kwargs.update(new_static_kwargs) + net._build_module_src = self._build_module_src + net._build_func_name = self._build_func_name + net._build_func = self._build_func + net._init_graph() + net.copy_vars_from(self) + return net + + def copy_own_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, excluding sub-networks.""" + names = [name for name in self.own_vars.keys() if name in src_net.own_vars] + tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) + + def copy_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, including sub-networks.""" + names = [name for name in self.vars.keys() if name in src_net.vars] + tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) + + def copy_trainables_from(self, src_net: "Network") -> None: + """Copy the values of all trainable variables from the given network, including sub-networks.""" + names = [name for name in self.trainables.keys() if name in src_net.trainables] + tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) + + def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": + """Create new network with the given parameters, and copy all variables from this network.""" + if new_name is None: + new_name = self.name + static_kwargs = dict(self.static_kwargs) + static_kwargs.update(new_static_kwargs) + net = Network(name=new_name, func_name=new_func_name, **static_kwargs) + net.copy_vars_from(self) + return net + + def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: + """Construct a TensorFlow op that updates the variables of this network + to be slightly closer to those of the given network.""" + with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): + ops = [] + for name, var in self.vars.items(): + if name in src_net.vars: + cur_beta = beta if name in self.trainables else beta_nontrainable + new_value = tfutil.lerp(src_net.vars[name], var, cur_beta) + ops.append(var.assign(new_value)) + return tf.group(*ops) + + def run(self, + *in_arrays: Tuple[Union[np.ndarray, None], ...], + input_transform: dict = None, + output_transform: dict = None, + return_as_list: bool = False, + print_progress: bool = False, + minibatch_size: int = None, + num_gpus: int = 1, + assume_frozen: bool = False, + **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: + """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). + + Args: + input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the input + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the output + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. + print_progress: Print progress to the console? Useful for very large input arrays. + minibatch_size: Maximum minibatch size to use, None = disable batching. + num_gpus: Number of GPUs to use. + assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. + dynamic_kwargs: Additional keyword arguments to be passed into the network build function. + """ + assert len(in_arrays) == self.num_inputs + assert not all(arr is None for arr in in_arrays) + assert input_transform is None or util.is_top_level_function(input_transform["func"]) + assert output_transform is None or util.is_top_level_function(output_transform["func"]) + output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) + num_items = in_arrays[0].shape[0] + if minibatch_size is None: + minibatch_size = num_items + + # Construct unique hash key from all arguments that affect the TensorFlow graph. + key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) + def unwind_key(obj): + if isinstance(obj, dict): + return [(key, unwind_key(value)) for key, value in sorted(obj.items())] + if callable(obj): + return util.get_top_level_function_name(obj) + return obj + key = repr(unwind_key(key)) + + # Build graph. + if key not in self._run_cache: + with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): + with tf.device("/cpu:0"): + in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] + in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) + + out_split = [] + for gpu in range(num_gpus): + with tf.device("/gpu:%d" % gpu): + net_gpu = self.clone() if assume_frozen else self + in_gpu = in_split[gpu] + + if input_transform is not None: + in_kwargs = dict(input_transform) + in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) + in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) + + assert len(in_gpu) == self.num_inputs + out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) + + if output_transform is not None: + out_kwargs = dict(output_transform) + out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) + out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) + + assert len(out_gpu) == self.num_outputs + out_split.append(out_gpu) + + with tf.device("/cpu:0"): + out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] + self._run_cache[key] = in_expr, out_expr + + # Run minibatches. + in_expr, out_expr = self._run_cache[key] + out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr] + + for mb_begin in range(0, num_items, minibatch_size): + if print_progress: + print("\r%d / %d" % (mb_begin, num_items), end="") + + mb_end = min(mb_begin + minibatch_size, num_items) + mb_num = mb_end - mb_begin + mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] + mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) + + for dst, src in zip(out_arrays, mb_out): + dst[mb_begin: mb_end] = src + + # Done. + if print_progress: + print("\r%d / %d" % (num_items, num_items)) + + if not return_as_list: + out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) + return out_arrays + + def list_ops(self) -> List[TfExpression]: + include_prefix = self.scope + "/" + exclude_prefix = include_prefix + "_" + ops = tf.get_default_graph().get_operations() + ops = [op for op in ops if op.name.startswith(include_prefix)] + ops = [op for op in ops if not op.name.startswith(exclude_prefix)] + return ops + + def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: + """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to + individual layers of the network. Mainly intended to be used for reporting.""" + layers = [] + + def recurse(scope, parent_ops, parent_vars, level): + # Ignore specific patterns. + if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): + return + + # Filter ops and vars by scope. + global_prefix = scope + "/" + local_prefix = global_prefix[len(self.scope) + 1:] + cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] + cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] + if not cur_ops and not cur_vars: + return + + # Filter out all ops related to variables. + for var in [op for op in cur_ops if op.type.startswith("Variable")]: + var_prefix = var.name + "/" + cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] + + # Scope does not contain ops as immediate children => recurse deeper. + contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops) + if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1: + visited = set() + for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: + token = rel_name.split("/")[0] + if token not in visited: + recurse(global_prefix + token, cur_ops, cur_vars, level + 1) + visited.add(token) + return + + # Report layer. + layer_name = scope[len(self.scope) + 1:] + layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] + layer_trainables = [var for _name, var in cur_vars if var.trainable] + layers.append((layer_name, layer_output, layer_trainables)) + + recurse(self.scope, self.list_ops(), list(self.vars.items()), 0) + return layers + + def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: + """Print a summary table of the network structure.""" + rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] + rows += [["---"] * 4] + total_params = 0 + + for layer_name, layer_output, layer_trainables in self.list_layers(): + num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables) + weights = [var for var in layer_trainables if var.name.endswith("/weight:0")] + weights.sort(key=lambda x: len(x.name)) + if len(weights) == 0 and len(layer_trainables) == 1: + weights = layer_trainables + total_params += num_params + + if not hide_layers_with_no_params or num_params != 0: + num_params_str = str(num_params) if num_params > 0 else "-" + output_shape_str = str(layer_output.shape) + weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" + rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] + + rows += [["---"] * 4] + rows += [["Total", str(total_params), "", ""]] + + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) + print() + + def setup_weight_histograms(self, title: str = None) -> None: + """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" + if title is None: + title = self.name + + with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): + for local_name, var in self.trainables.items(): + if "/" in local_name: + p = local_name.split("/") + name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) + else: + name = title + "_toplevel/" + local_name + + tf.summary.histogram(name, var) + +#---------------------------------------------------------------------------- +# Backwards-compatible emulation of legacy output transformation in Network.run(). + +_print_legacy_warning = True + +def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): + global _print_legacy_warning + legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] + if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): + return output_transform, dynamic_kwargs + + if _print_legacy_warning: + _print_legacy_warning = False + print() + print("WARNING: Old-style output transformations in Network.run() are deprecated.") + print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") + print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") + print() + assert output_transform is None + + new_kwargs = dict(dynamic_kwargs) + new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} + new_transform["func"] = _legacy_output_transform_func + return new_transform, new_kwargs + +def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): + if out_mul != 1.0: + expr = [x * out_mul for x in expr] + + if out_add != 0.0: + expr = [x + out_add for x in expr] + + if out_shrink > 1: + ksize = [1, 1, out_shrink, out_shrink] + expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] + + if out_dtype is not None: + if tf.as_dtype(out_dtype).is_integer: + expr = [tf.round(x) for x in expr] + expr = [tf.saturate_cast(x, out_dtype) for x in expr] + return expr diff --git a/dnnlib/tflib/optimizer.py b/dnnlib/tflib/optimizer.py new file mode 100644 index 0000000..6ed88cb --- /dev/null +++ b/dnnlib/tflib/optimizer.py @@ -0,0 +1,214 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper wrapper for a Tensorflow optimizer.""" + +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import List, Union + +from . import autosummary +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +try: + # TensorFlow 1.13 + from tensorflow.python.ops import nccl_ops +except: + # Older TensorFlow versions + import tensorflow.contrib.nccl as nccl_ops + +class Optimizer: + """A Wrapper for tf.train.Optimizer. + + Automatically takes care of: + - Gradient averaging for multi-GPU training. + - Dynamic loss scaling and typecasts for FP16 training. + - Ignoring corrupted gradients that contain NaNs/Infs. + - Reporting statistics. + - Well-chosen default settings. + """ + + def __init__(self, + name: str = "Train", + tf_optimizer: str = "tf.train.AdamOptimizer", + learning_rate: TfExpressionEx = 0.001, + use_loss_scaling: bool = False, + loss_scaling_init: float = 64.0, + loss_scaling_inc: float = 0.0005, + loss_scaling_dec: float = 1.0, + **kwargs): + + # Init fields. + self.name = name + self.learning_rate = tf.convert_to_tensor(learning_rate) + self.id = self.name.replace("/", ".") + self.scope = tf.get_default_graph().unique_name(self.id) + self.optimizer_class = util.get_obj_by_name(tf_optimizer) + self.optimizer_kwargs = dict(kwargs) + self.use_loss_scaling = use_loss_scaling + self.loss_scaling_init = loss_scaling_init + self.loss_scaling_inc = loss_scaling_inc + self.loss_scaling_dec = loss_scaling_dec + self._grad_shapes = None # [shape, ...] + self._dev_opt = OrderedDict() # device => optimizer + self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] + self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) + self._updates_applied = False + + def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: + """Register the gradients of the given loss function with respect to the given variables. + Intended to be called once per GPU.""" + assert not self._updates_applied + + # Validate arguments. + if isinstance(trainable_vars, dict): + trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars + + assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 + assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) + + if self._grad_shapes is None: + self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] + + assert len(trainable_vars) == len(self._grad_shapes) + assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) + + dev = loss.device + + assert all(var.device == dev for var in trainable_vars) + + # Register device and compute gradients. + with tf.name_scope(self.id + "_grad"), tf.device(dev): + if dev not in self._dev_opt: + opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) + assert callable(self.optimizer_class) + self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) + self._dev_grads[dev] = [] + + loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) + grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage + grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros + self._dev_grads[dev].append(grads) + + def apply_updates(self) -> tf.Operation: + """Construct training op to update the registered variables based on their gradients.""" + tfutil.assert_tf_initialized() + assert not self._updates_applied + self._updates_applied = True + devices = list(self._dev_grads.keys()) + total_grads = sum(len(grads) for grads in self._dev_grads.values()) + assert len(devices) >= 1 and total_grads >= 1 + ops = [] + + with tfutil.absolute_name_scope(self.scope): + # Cast gradients to FP32 and calculate partial sum within each device. + dev_grads = OrderedDict() # device => [(grad, var), ...] + + for dev_idx, dev in enumerate(devices): + with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): + sums = [] + + for gv in zip(*self._dev_grads[dev]): + assert all(v is gv[0][1] for g, v in gv) + g = [tf.cast(g, tf.float32) for g, v in gv] + g = g[0] if len(g) == 1 else tf.add_n(g) + sums.append((g, gv[0][1])) + + dev_grads[dev] = sums + + # Sum gradients across devices. + if len(devices) > 1: + with tf.name_scope("SumAcrossGPUs"), tf.device(None): + for var_idx, grad_shape in enumerate(self._grad_shapes): + g = [dev_grads[dev][var_idx][0] for dev in devices] + + if np.prod(grad_shape): # nccl does not support zero-sized tensors + g = nccl_ops.all_sum(g) + + for dev, gg in zip(devices, g): + dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) + + # Apply updates separately on each device. + for dev_idx, (dev, grads) in enumerate(dev_grads.items()): + with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): + # Scale gradients as needed. + if self.use_loss_scaling or total_grads > 1: + with tf.name_scope("Scale"): + coef = tf.constant(np.float32(1.0 / total_grads), name="coef") + coef = self.undo_loss_scaling(coef) + grads = [(g * coef, v) for g, v in grads] + + # Check for overflows. + with tf.name_scope("CheckOverflow"): + grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) + + # Update weights and adjust loss scaling. + with tf.name_scope("UpdateWeights"): + # pylint: disable=cell-var-from-loop + opt = self._dev_opt[dev] + ls_var = self.get_loss_scaling_var(dev) + + if not self.use_loss_scaling: + ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) + else: + ops.append(tf.cond(grad_ok, + lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), + lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) + + # Report statistics on the last device. + if dev == devices[-1]: + with tf.name_scope("Statistics"): + ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) + ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) + + if self.use_loss_scaling: + ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) + + # Initialize variables and group everything into a single op. + self.reset_optimizer_state() + tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) + + return tf.group(*ops, name="TrainingOp") + + def reset_optimizer_state(self) -> None: + """Reset internal state of the underlying optimizer.""" + tfutil.assert_tf_initialized() + tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) + + def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: + """Get or create variable representing log2 of the current dynamic loss scaling factor.""" + if not self.use_loss_scaling: + return None + + if device not in self._dev_ls_var: + with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): + self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") + + return self._dev_ls_var[device] + + def apply_loss_scaling(self, value: TfExpression) -> TfExpression: + """Apply dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + + if not self.use_loss_scaling: + return value + + return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) + + def undo_loss_scaling(self, value: TfExpression) -> TfExpression: + """Undo the effect of dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + + if not self.use_loss_scaling: + return value + + return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type diff --git a/dnnlib/tflib/tfutil.py b/dnnlib/tflib/tfutil.py new file mode 100644 index 0000000..a431a4d --- /dev/null +++ b/dnnlib/tflib/tfutil.py @@ -0,0 +1,240 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Miscellaneous helper utils for Tensorflow.""" + +import os +import numpy as np +import tensorflow as tf + +from typing import Any, Iterable, List, Union + +TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] +"""A type that represents a valid Tensorflow expression.""" + +TfExpressionEx = Union[TfExpression, int, float, np.ndarray] +"""A type that can be converted to a valid Tensorflow expression.""" + + +def run(*args, **kwargs) -> Any: + """Run the specified ops in the default session.""" + assert_tf_initialized() + return tf.get_default_session().run(*args, **kwargs) + + +def is_tf_expression(x: Any) -> bool: + """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" + return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) + + +def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: + """Convert a Tensorflow shape to a list of ints.""" + return [dim.value for dim in shape] + + +def flatten(x: TfExpressionEx) -> TfExpression: + """Shortcut function for flattening a tensor.""" + with tf.name_scope("Flatten"): + return tf.reshape(x, [-1]) + + +def log2(x: TfExpressionEx) -> TfExpression: + """Logarithm in base 2.""" + with tf.name_scope("Log2"): + return tf.log(x) * np.float32(1.0 / np.log(2.0)) + + +def exp2(x: TfExpressionEx) -> TfExpression: + """Exponent in base 2.""" + with tf.name_scope("Exp2"): + return tf.exp(x * np.float32(np.log(2.0))) + + +def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: + """Linear interpolation.""" + with tf.name_scope("Lerp"): + return a + (b - a) * t + + +def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: + """Linear interpolation with clip.""" + with tf.name_scope("LerpClip"): + return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) + + +def absolute_name_scope(scope: str) -> tf.name_scope: + """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" + return tf.name_scope(scope + "/") + + +def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: + """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" + return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) + + +def _sanitize_tf_config(config_dict: dict = None) -> dict: + # Defaults. + cfg = dict() + cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. + cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. + cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. + cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. + cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. + + # User overrides. + if config_dict is not None: + cfg.update(config_dict) + return cfg + + +def init_tf(config_dict: dict = None) -> None: + """Initialize TensorFlow session using good default settings.""" + # Skip if already initialized. + if tf.get_default_session() is not None: + return + + # Setup config dict and random seeds. + cfg = _sanitize_tf_config(config_dict) + np_random_seed = cfg["rnd.np_random_seed"] + if np_random_seed is not None: + np.random.seed(np_random_seed) + tf_random_seed = cfg["rnd.tf_random_seed"] + if tf_random_seed == "auto": + tf_random_seed = np.random.randint(1 << 31) + if tf_random_seed is not None: + tf.set_random_seed(tf_random_seed) + + # Setup environment variables. + for key, value in list(cfg.items()): + fields = key.split(".") + if fields[0] == "env": + assert len(fields) == 2 + os.environ[fields[1]] = str(value) + + # Create default TensorFlow session. + create_session(cfg, force_as_default=True) + + +def assert_tf_initialized(): + """Check that TensorFlow session has been initialized.""" + if tf.get_default_session() is None: + raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") + + +def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: + """Create tf.Session based on config dict.""" + # Setup TensorFlow config proto. + cfg = _sanitize_tf_config(config_dict) + config_proto = tf.ConfigProto() + for key, value in cfg.items(): + fields = key.split(".") + if fields[0] not in ["rnd", "env"]: + obj = config_proto + for field in fields[:-1]: + obj = getattr(obj, field) + setattr(obj, fields[-1], value) + + # Create session. + session = tf.Session(config=config_proto) + if force_as_default: + # pylint: disable=protected-access + session._default_session = session.as_default() + session._default_session.enforce_nesting = False + session._default_session.__enter__() # pylint: disable=no-member + + return session + + +def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: + """Initialize all tf.Variables that have not already been initialized. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tf.variables_initializer(tf.report_uninitialized_variables()).run() + """ + assert_tf_initialized() + if target_vars is None: + target_vars = tf.global_variables() + + test_vars = [] + test_ops = [] + + with tf.control_dependencies(None): # ignore surrounding control_dependencies + for var in target_vars: + assert is_tf_expression(var) + + try: + tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) + except KeyError: + # Op does not exist => variable may be uninitialized. + test_vars.append(var) + + with absolute_name_scope(var.name.split(":")[0]): + test_ops.append(tf.is_variable_initialized(var)) + + init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] + run([var.initializer for var in init_vars]) + + +def set_vars(var_to_value_dict: dict) -> None: + """Set the values of given tf.Variables. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] + """ + assert_tf_initialized() + ops = [] + feed_dict = {} + + for var, value in var_to_value_dict.items(): + assert is_tf_expression(var) + + try: + setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op + except KeyError: + with absolute_name_scope(var.name.split(":")[0]): + with tf.control_dependencies(None): # ignore surrounding control_dependencies + setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter + + ops.append(setter) + feed_dict[setter.op.inputs[1]] = value + + run(ops, feed_dict) + + +def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): + """Create tf.Variable with large initial value without bloating the tf graph.""" + assert_tf_initialized() + assert isinstance(initial_value, np.ndarray) + zeros = tf.zeros(initial_value.shape, initial_value.dtype) + var = tf.Variable(zeros, *args, **kwargs) + set_vars({var: initial_value}) + return var + + +def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): + """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. + Can be used as an input transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if nhwc_to_nchw: + images = tf.transpose(images, [0, 3, 1, 2]) + return (images - drange[0]) * ((drange[1] - drange[0]) / 255) + + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if shrink > 1: + ksize = [1, 1, shrink, shrink] + images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") + if nchw_to_nhwc: + images = tf.transpose(images, [0, 2, 3, 1]) + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + return tf.saturate_cast(images, tf.uint8) diff --git a/dnnlib/util.py b/dnnlib/util.py new file mode 100644 index 0000000..133ef76 --- /dev/null +++ b/dnnlib/util.py @@ -0,0 +1,405 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: str) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + return obj.__module__ + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert is_url(url) + assert num_attempts >= 1 + + # Lookup from cache. + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache_dir is not None: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + return open(cache_files[0], "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive quota exceeded") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache_dir is not None: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + + # Return data as file object. + return io.BytesIO(url_data) diff --git a/generate_images.py b/generate_images.py new file mode 100644 index 0000000..173eac6 --- /dev/null +++ b/generate_images.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Script for generating an image using pre-trained generator.""" + +import os +import pickle +import numpy as np +import PIL.Image +from training import misc +import dnnlib +import dnnlib.tflib as tflib +import config +import tensorflow as tf +import argparse + +# define mapping network from z space to lambda space +def CoeffDecoder(z,ch_depth = 3, ch_dim = 512, coeff_length = 128): + with tf.variable_scope('stage1'): + with tf.variable_scope('decoder'): + y = z + for i in range(ch_depth): + y = tf.layers.dense(y, ch_dim, tf.nn.relu, name='fc'+str(i)) + + x_hat = tf.layers.dense(y, coeff_length, name='x_hat') + x_hat = tf.stop_gradient(x_hat) + + return x_hat + +# restore pre-trained weights +def restore_weights_and_initialize(): + var_list = tf.trainable_variables() + g_list = tf.global_variables() + + # add batch normalization params into trainable variables + bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] + bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] + var_list +=bn_moving_vars + + var_id_list = [v for v in var_list if 'id' in v.name and 'stage1' in v.name] + var_exp_list = [v for v in var_list if 'exp' in v.name and 'stage1' in v.name] + var_gamma_list = [v for v in var_list if 'gamma' in v.name and 'stage1' in v.name] + var_rot_list = [v for v in var_list if 'rot' in v.name and 'stage1' in v.name] + + saver_id = tf.train.Saver(var_list = var_id_list,max_to_keep = 100) + saver_exp = tf.train.Saver(var_list = var_exp_list,max_to_keep = 100) + saver_gamma = tf.train.Saver(var_list = var_gamma_list,max_to_keep = 100) + saver_rot = tf.train.Saver(var_list = var_rot_list,max_to_keep = 100) + + saver_id.restore(tf.get_default_session(),'./vae/weights/id/stage1_epoch_395.ckpt') + saver_exp.restore(tf.get_default_session(),'./vae/weights/exp/stage1_epoch_395.ckpt') + saver_gamma.restore(tf.get_default_session(),'./vae/weights/gamma/stage1_epoch_395.ckpt') + saver_rot.restore(tf.get_default_session(),'./vae/weights/rot/stage1_epoch_395.ckpt') + +def z_to_lambda_mapping(latents): + with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): + with tf.variable_scope('id'): + IDcoeff = CoeffDecoder(z = latents[:,:128],coeff_length = 160,ch_dim = 512, ch_depth = 3) + with tf.variable_scope('exp'): + EXPcoeff = CoeffDecoder(z = latents[:,128:128+32],coeff_length = 64,ch_dim = 256, ch_depth = 3) + with tf.variable_scope('gamma'): + GAMMAcoeff = CoeffDecoder(z = latents[:,128+32:128+32+16],coeff_length = 27,ch_dim = 128, ch_depth = 3) + with tf.variable_scope('rot'): + Rotcoeff = CoeffDecoder(z = latents[:,128+32+16:128+32+16+3],coeff_length = 3,ch_dim = 32, ch_depth = 3) + + INPUTcoeff = tf.concat([IDcoeff,EXPcoeff,Rotcoeff,GAMMAcoeff], axis = 1) + + return INPUTcoeff + +# generate images using attribute-preserving truncation trick +def truncate_generation(Gs,inputcoeff,rate=0.7,dlatent_average_id=None): + + if dlatent_average_id is None: + url_pretrained_model_ffhq_average_w_id = 'https://drive.google.com/uc?id=17L6-ENX3NbMsS3MSCshychZETLPtJnbS' + with dnnlib.util.open_url(url_pretrained_model_ffhq_average_w_id, cache_dir=config.cache_dir) as f: + dlatent_average_id = np.loadtxt(f) + dlatent_average_id = np.reshape(dlatent_average_id,[1,14,512]).astype(np.float32) + dlatent_average_id = tf.constant(dlatent_average_id) + + inputcoeff_id = tf.concat([inputcoeff[:,:160],tf.zeros([1,126])],axis=1) + dlatent_out = Gs.components.mapping.get_output_for(inputcoeff, None ,is_training=False, is_validation = True) # original w space output + dlatent_out_id = Gs.components.mapping.get_output_for(inputcoeff_id, None ,is_training=False, is_validation = True) + + dlatent_out_trun = dlatent_out + (dlatent_average_id - dlatent_out_id)*(1-rate) + dlatent_out_final = tf.concat([dlatent_out_trun[:,:8,:],dlatent_out[:,8:,:]],axis = 1) # w space latent vector with truncation trick + + fake_images_out = Gs.components.synthesis.get_output_for(dlatent_out_final, randomize_noise = False) + fake_images_out = tf.clip_by_value((fake_images_out+1)*127.5,0,255) + fake_images_out = tf.transpose(fake_images_out,perm = [0,2,3,1]) + + return fake_images_out + +# calculate average w space latent vector with zero expression, lighting, and pose. +def get_model_and_average_w_id(model_name): + G, D, Gs = misc.load_pkl(model_name) + average_w_name = model_name.replace('.pkl','-average_w_id.txt') + if not os.path.isfile(average_w_name): + print('Calculating average w id...\n') + latents = tf.placeholder(tf.float32, name='latents', shape=[1,128+32+16+3]) + noise = tf.placeholder(tf.float32, name='noise', shape=[1,32]) + INPUTcoeff = z_to_lambda_mapping(latents) + INPUTcoeff_id = INPUTcoeff[:,:160] + INPUTcoeff_w_noise = tf.concat([INPUTcoeff_id,tf.zeros([1,64+27+3]),noise],axis = 1) + dlatent_out = Gs.components.mapping.get_output_for(INPUTcoeff_w_noise, None ,is_training=False, is_validation = True) + restore_weights_and_initialize() + np.random.seed(1) + average_w_id = [] + for i in range(50000): + lats = np.random.normal(size=[1,128+32+16+3]) + noise_ = np.random.normal(size=[1,32]) + w_out = tflib.run(dlatent_out,{latents:lats,noise:noise_}) + average_w_id.append(w_out) + + average_w_id = np.concatenate(average_w_id,axis = 0) + average_w_id = np.mean(average_w_id,axis = 0) + np.savetxt(average_w_name,average_w_id) + else: + average_w_id = np.loadtxt(average_w_name) + + return Gs,average_w_id + +def parse_args(): + desc = "Disentangled face image generation" + parser = argparse.ArgumentParser(description=desc) + + parser.add_argument('--factor', type=int, default=0, help='factor variation mode. 0 = all, 1 = expression, 2 = lighting, 3 = pose.') + parser.add_argument('--subject', type=int, default=20, help='how many subjects to generate.') + parser.add_argument('--variation', type=int, default=5, help='how many images to generate per subject.') + parser.add_argument('--model',type=str,default=None,help='pkl file name of the generator. If None, use the default pre-trained model.') + + return parser.parse_args() + + +def load_Gs(url): + with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: + _G, _D, Gs = pickle.load(f) + return Gs + +def main(): + + args = parse_args() + if args is None: + exit() + + # save path for generated images + save_path = 'generate_images' + if not os.path.exists(save_path): + os.makedirs(save_path) + resume_pkl = '' + + tflib.init_tf() + + with tf.device('/gpu:0'): + + # Use default pre-trained model + if args.model is None: + url_pretrained_model_ffhq = 'https://drive.google.com/uc?id=1nT_cf610q5mxD_jACvV43w4SYBxsPUBq' + Gs = load_Gs(url_pretrained_model_ffhq) + average_w_id = None + + else: + Gs,average_w_id = get_model_and_average_w_id(args.model) + # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. + # average_w_id = average w space latent vector with zero expression, lighting, and pose. + + # Print network details. + Gs.print_layers() + + # Pick latent vector. + latents = tf.placeholder(tf.float32, name='latents', shape=[1,128+32+16+3]) + noise = tf.placeholder(tf.float32, name='noise', shape=[1,32]) + INPUTcoeff = z_to_lambda_mapping(latents) + INPUTcoeff_w_noise = tf.concat([INPUTcoeff,noise],axis = 1) + + # Generate images + fake_images_out = truncate_generation(Gs,INPUTcoeff_w_noise,dlatent_average_id=average_w_id) + + restore_weights_and_initialize() + + np.random.seed(1) + for i in range(args.subject): + print(i) + lats1 = np.random.normal(size=[1,128+32+16+3]) + noise_ = np.random.normal(size=[1,32]) + for j in range(args.variation): + lats2 = np.random.normal(size=[1,32+16+3]) + if args.factor == 0: # change all factors + lats = np.concatenate([lats1[:,:128],lats2],axis = 1) + elif args.factor == 1: # change expression only + lats = np.concatenate([lats1[:,:128],lats2[:,:32],lats1[:,128+32:]],axis = 1) + elif args.factor == 2: # change lighting only + lats = np.concatenate([lats1[:,:128+32],lats2[:,32:32+16],lats1[:,128+32+16:]],axis = 1) + elif args.factor == 3: # change pose only + lats = np.concatenate([lats1[:,:128+32+16],lats2[:,32+16:32+16+3]],axis = 1) + fake = tflib.run(fake_images_out, {latents:lats,noise:noise_}) + PIL.Image.fromarray(fake[0].astype(np.uint8), 'RGB').save(os.path.join(save_path,'%03d_%02d.png'%(i,j))) + +if __name__ == "__main__": + main() diff --git a/images/disentangled.png b/images/disentangled.png new file mode 100644 index 0000000..2fd1fe4 Binary files /dev/null and b/images/disentangled.png differ diff --git a/images/expression.png b/images/expression.png new file mode 100644 index 0000000..c8ce08c Binary files /dev/null and b/images/expression.png differ diff --git a/images/light.png b/images/light.png new file mode 100644 index 0000000..fc6b353 Binary files /dev/null and b/images/light.png differ diff --git a/images/pose.png b/images/pose.png new file mode 100644 index 0000000..a5b7988 Binary files /dev/null and b/images/pose.png differ diff --git a/images/reference.png b/images/reference.png new file mode 100644 index 0000000..efed086 Binary files /dev/null and b/images/reference.png differ diff --git a/images/teaser.gif b/images/teaser.gif new file mode 100644 index 0000000..da75044 Binary files /dev/null and b/images/teaser.gif differ diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..db8124b --- /dev/null +++ b/metrics/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +# empty diff --git a/metrics/frechet_inception_distance.py b/metrics/frechet_inception_distance.py new file mode 100644 index 0000000..b84ad57 --- /dev/null +++ b/metrics/frechet_inception_distance.py @@ -0,0 +1,112 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Frechet Inception Distance (FID).""" + +import os +import numpy as np +import scipy +import tensorflow as tf +import dnnlib.tflib as tflib + +from metrics import metric_base +from training import misc +from training.networks_stylegan import CoeffDecoder +from training.training_loop import z_to_lambda_mapping + +#---------------------------------------------------------------------------- +# Modified by Deng et al. +def restore_weights_and_initialize(): + var_list = tf.trainable_variables() + g_list = tf.global_variables() + + # add batch normalization params into trainable variables + bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] + bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] + var_list +=bn_moving_vars + + var_id_list = [v for v in var_list if 'id' in v.name and 'stage1' in v.name] + var_exp_list = [v for v in var_list if 'exp' in v.name and 'stage1' in v.name] + var_gamma_list = [v for v in var_list if 'gamma' in v.name and 'stage1' in v.name] + var_rot_list = [v for v in var_list if 'rot' in v.name and 'stage1' in v.name] + + saver_id = tf.train.Saver(var_list = var_id_list,max_to_keep = 100) + saver_exp = tf.train.Saver(var_list = var_exp_list,max_to_keep = 100) + saver_gamma = tf.train.Saver(var_list = var_gamma_list,max_to_keep = 100) + saver_rot = tf.train.Saver(var_list = var_rot_list,max_to_keep = 100) + + saver_id.restore(tf.get_default_session(),'./vae/weights/id/stage1_epoch_395.ckpt') + saver_exp.restore(tf.get_default_session(),'./vae/weights/exp/stage1_epoch_395.ckpt') + saver_gamma.restore(tf.get_default_session(),'./vae/weights/gamma/stage1_epoch_395.ckpt') + saver_rot.restore(tf.get_default_session(),'./vae/weights/rot/stage1_epoch_395.ckpt') +#---------------------------------------------------------------------------- + +class FID(metric_base.MetricBase): + def __init__(self, num_images, minibatch_per_gpu, **kwargs): + super().__init__(**kwargs) + self.num_images = num_images + self.minibatch_per_gpu = minibatch_per_gpu + + def _evaluate(self, Gs, num_gpus): + minibatch_size = num_gpus * self.minibatch_per_gpu + inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl + activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) + + # Calculate statistics for reals. + cache_file = self._get_cache_file_for_reals(num_images=self.num_images) + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + if os.path.isfile(cache_file): + mu_real, sigma_real = misc.load_pkl(cache_file) + else: + for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)): + begin = idx * minibatch_size + end = min(begin + minibatch_size, self.num_images) + activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True) + if end == self.num_images: + break + mu_real = np.mean(activations, axis=0) + sigma_real = np.cov(activations, rowvar=False) + misc.save_pkl((mu_real, sigma_real), cache_file) + + # Construct TensorFlow graph. + result_expr = [] + for gpu_idx in range(num_gpus): + with tf.device('/gpu:%d' % gpu_idx): + Gs_clone = Gs.clone() + inception_clone = inception.clone() + + #---------------------------------------------------------------------------- + # Modified by Deng et al. + latents = tf.random_normal([self.minibatch_per_gpu,128+32+16+3]) + INPUTcoeff = z_to_lambda_mapping(latents) + + if Gs_clone.input_shape[1] == 254: + INPUTcoeff_w_noise = INPUTcoeff + else: + noise_coeff = tf.random_normal([self.minibatch_per_gpu,Gs_clone.input_shape[1]-254]) + INPUTcoeff_w_noise = tf.concat([INPUTcoeff,noise_coeff], axis = 1) + images = Gs_clone.get_output_for(INPUTcoeff_w_noise, None, is_validation=True, randomize_noise=True) + images = tflib.convert_images_to_uint8(images) + result_expr.append(inception_clone.get_output_for(images)) + + restore_weights_and_initialize() + #---------------------------------------------------------------------------- + + # Calculate statistics for fakes. + for begin in range(0, self.num_images, minibatch_size): + end = min(begin + minibatch_size, self.num_images) + activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] + mu_fake = np.mean(activations, axis=0) + sigma_fake = np.cov(activations, rowvar=False) + + # Calculate FID. + m = np.square(mu_fake - mu_real).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member + dist = m + np.trace(sigma_fake + sigma_real - 2*s) + self._report_result(np.real(dist)) + +#---------------------------------------------------------------------------- diff --git a/metrics/linear_separability.py b/metrics/linear_separability.py new file mode 100644 index 0000000..e50be5a --- /dev/null +++ b/metrics/linear_separability.py @@ -0,0 +1,177 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Linear Separability (LS).""" + +from collections import defaultdict +import numpy as np +import sklearn.svm +import tensorflow as tf +import dnnlib.tflib as tflib + +from metrics import metric_base +from training import misc + +#---------------------------------------------------------------------------- + +classifier_urls = [ + 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl + 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl + 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl + 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl + 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl + 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl + 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl + 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl + 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl + 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl + 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl + 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl + 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl + 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl + 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl + 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl + 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl + 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl + 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl + 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl + 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl + 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl + 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl + 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl + 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl + 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl + 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl + 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl + 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl + 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl + 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl + 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl + 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl + 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl + 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl + 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl + 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl + 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl + 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl + 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl +] + +#---------------------------------------------------------------------------- + +def prob_normalize(p): + p = np.asarray(p).astype(np.float32) + assert len(p.shape) == 2 + return p / np.sum(p) + +def mutual_information(p): + p = prob_normalize(p) + px = np.sum(p, axis=1) + py = np.sum(p, axis=0) + result = 0.0 + for x in range(p.shape[0]): + p_x = px[x] + for y in range(p.shape[1]): + p_xy = p[x][y] + p_y = py[y] + if p_xy > 0.0: + result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output + return result + +def entropy(p): + p = prob_normalize(p) + result = 0.0 + for x in range(p.shape[0]): + for y in range(p.shape[1]): + p_xy = p[x][y] + if p_xy > 0.0: + result -= p_xy * np.log2(p_xy) + return result + +def conditional_entropy(p): + # H(Y|X) where X corresponds to axis 0, Y to axis 1 + # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? + p = prob_normalize(p) + y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) + return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. + +#---------------------------------------------------------------------------- + +class LS(metric_base.MetricBase): + def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): + assert num_keep <= num_samples + super().__init__(**kwargs) + self.num_samples = num_samples + self.num_keep = num_keep + self.attrib_indices = attrib_indices + self.minibatch_per_gpu = minibatch_per_gpu + + def _evaluate(self, Gs, num_gpus): + minibatch_size = num_gpus * self.minibatch_per_gpu + + # Construct TensorFlow graph for each GPU. + result_expr = [] + for gpu_idx in range(num_gpus): + with tf.device('/gpu:%d' % gpu_idx): + Gs_clone = Gs.clone() + + # Generate images. + latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) + dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True) + images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True) + + # Downsample to 256x256. The attribute classifiers were built for 256x256. + if images.shape[2] > 256: + factor = images.shape[2] // 256 + images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) + images = tf.reduce_mean(images, axis=[3, 5]) + + # Run classifier for each attribute. + result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) + for attrib_idx in self.attrib_indices: + classifier = misc.load_pkl(classifier_urls[attrib_idx]) + logits = classifier.get_output_for(images, None) + predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) + result_dict[attrib_idx] = predictions + result_expr.append(result_dict) + + # Sampling loop. + results = [] + for _ in range(0, self.num_samples, minibatch_size): + results += tflib.run(result_expr) + results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} + + # Calculate conditional entropy for each attribute. + conditional_entropies = defaultdict(list) + for attrib_idx in self.attrib_indices: + # Prune the least confident samples. + pruned_indices = list(range(self.num_samples)) + pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) + pruned_indices = pruned_indices[:self.num_keep] + + # Fit SVM to the remaining samples. + svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) + for space in ['latents', 'dlatents']: + svm_inputs = results[space][pruned_indices] + try: + svm = sklearn.svm.LinearSVC() + svm.fit(svm_inputs, svm_targets) + svm.score(svm_inputs, svm_targets) + svm_outputs = svm.predict(svm_inputs) + except: + svm_outputs = svm_targets # assume perfect prediction + + # Calculate conditional entropy. + p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] + conditional_entropies[space].append(conditional_entropy(p)) + + # Calculate separability scores. + scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} + self._report_result(scores['latents'], suffix='_z') + self._report_result(scores['dlatents'], suffix='_w') + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_base.py b/metrics/metric_base.py new file mode 100644 index 0000000..0db82ad --- /dev/null +++ b/metrics/metric_base.py @@ -0,0 +1,142 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Common definitions for GAN metrics.""" + +import os +import time +import hashlib +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib + +import config +from training import misc +from training import dataset + +#---------------------------------------------------------------------------- +# Standard metrics. + +fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) +ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) +ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16) +ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16) +ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16) +ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4) +dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging + +#---------------------------------------------------------------------------- +# Base class for metrics. + +class MetricBase: + def __init__(self, name): + self.name = name + self._network_pkl = None + self._dataset_args = None + self._mirror_augment = None + self._results = [] + self._eval_time = None + + def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): + self._network_pkl = network_pkl + self._dataset_args = dataset_args + self._mirror_augment = mirror_augment + self._results = [] + + if (dataset_args is None or mirror_augment is None) and run_dir is not None: + run_config = misc.parse_config_for_previous_run(run_dir) + self._dataset_args = dict(run_config['dataset']) + self._dataset_args['shuffle_mb'] = 0 + self._mirror_augment = run_config['train'].get('mirror_augment', False) + + time_begin = time.time() + with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager + _G, _D, Gs = misc.load_pkl(self._network_pkl) + self._evaluate(Gs, num_gpus=num_gpus) + self._eval_time = time.time() - time_begin + + if log_results: + result_str = self.get_result_str() + if run_dir is not None: + log = os.path.join(run_dir, 'metric-%s.txt' % self.name) + with dnnlib.util.Logger(log, 'a'): + print(result_str) + else: + print(result_str) + + def get_result_str(self): + network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] + if len(network_name) > 29: + network_name = '...' + network_name[-26:] + result_str = '%-30s' % network_name + result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) + for res in self._results: + result_str += ' ' + self.name + res.suffix + ' ' + result_str += res.fmt % res.value + return result_str + + def update_autosummaries(self): + for res in self._results: + tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) + + def _evaluate(self, Gs, num_gpus): + raise NotImplementedError # to be overridden by subclasses + + def _report_result(self, value, suffix='', fmt='%-10.4f'): + self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] + + def _get_cache_file_for_reals(self, extension='pkl', **kwargs): + all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) + all_args.update(self._dataset_args) + all_args.update(kwargs) + md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) + dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1] + return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) + + def _iterate_reals(self, minibatch_size): + dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args) + while True: + images, _labels = dataset_obj.get_minibatch_np(minibatch_size) + if self._mirror_augment: + images = misc.apply_mirror_augment(images) + yield images + + def _iterate_fakes(self, Gs, minibatch_size, num_gpus): + while True: + latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) + fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) + yield images + +#---------------------------------------------------------------------------- +# Group of multiple metrics. + +class MetricGroup: + def __init__(self, metric_kwarg_list): + self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] + + def run(self, *args, **kwargs): + for metric in self.metrics: + metric.run(*args, **kwargs) + + def get_result_str(self): + return ' '.join(metric.get_result_str() for metric in self.metrics) + + def update_autosummaries(self): + for metric in self.metrics: + metric.update_autosummaries() + +#---------------------------------------------------------------------------- +# Dummy metric for debugging purposes. + +class DummyMetric(MetricBase): + def _evaluate(self, Gs, num_gpus): + _ = Gs, num_gpus + self._report_result(0.0) + +#---------------------------------------------------------------------------- diff --git a/metrics/perceptual_path_length.py b/metrics/perceptual_path_length.py new file mode 100644 index 0000000..17271cf --- /dev/null +++ b/metrics/perceptual_path_length.py @@ -0,0 +1,108 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Perceptual Path Length (PPL).""" + +import numpy as np +import tensorflow as tf +import dnnlib.tflib as tflib + +from metrics import metric_base +from training import misc + +#---------------------------------------------------------------------------- + +# Normalize batch of vectors. +def normalize(v): + return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) + +# Spherical interpolation of a batch of vectors. +def slerp(a, b, t): + a = normalize(a) + b = normalize(b) + d = tf.reduce_sum(a * b, axis=-1, keepdims=True) + p = t * tf.math.acos(d) + c = normalize(b - d * a) + d = a * tf.math.cos(p) + c * tf.math.sin(p) + return normalize(d) + +#---------------------------------------------------------------------------- + +class PPL(metric_base.MetricBase): + def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs): + assert space in ['z', 'w'] + assert sampling in ['full', 'end'] + super().__init__(**kwargs) + self.num_samples = num_samples + self.epsilon = epsilon + self.space = space + self.sampling = sampling + self.minibatch_per_gpu = minibatch_per_gpu + + def _evaluate(self, Gs, num_gpus): + minibatch_size = num_gpus * self.minibatch_per_gpu + + # Construct TensorFlow graph. + distance_expr = [] + for gpu_idx in range(num_gpus): + with tf.device('/gpu:%d' % gpu_idx): + Gs_clone = Gs.clone() + noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')] + + # Generate random latents and interpolation t-values. + lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:]) + lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0) + + # Interpolate in W or Z. + if self.space == 'w': + dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True) + dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2] + dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis]) + dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon) + dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape) + else: # space == 'z' + lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2] + lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis]) + lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon) + lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape) + dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True) + + # Synthesize images. + with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch + images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False) + + # Crop only the face region. + c = int(images.shape[2] // 8) + images = images[:, :, c*3 : c*7, c*2 : c*6] + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + if images.shape[2] > 256: + factor = images.shape[2] // 256 + images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) + images = tf.reduce_mean(images, axis=[3,5]) + + # Scale dynamic range from [-1,1] to [0,255] for VGG. + images = (images + 1) * (255 / 2) + + # Evaluate perceptual distance. + img_e0, img_e1 = images[0::2], images[1::2] + distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl + distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2)) + + # Sampling loop. + all_distances = [] + for _ in range(0, self.num_samples, minibatch_size): + all_distances += tflib.run(distance_expr) + all_distances = np.concatenate(all_distances, axis=0) + + # Reject outliers. + lo = np.percentile(all_distances, 1, interpolation='lower') + hi = np.percentile(all_distances, 99, interpolation='higher') + filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances) + self._report_result(np.mean(filtered_distances)) + +#---------------------------------------------------------------------------- diff --git a/preprocess/__init__.py b/preprocess/__init__.py new file mode 100644 index 0000000..0eca642 --- /dev/null +++ b/preprocess/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. \ No newline at end of file diff --git a/preprocess/preprocess_utils.py b/preprocess/preprocess_utils.py new file mode 100644 index 0000000..b729761 --- /dev/null +++ b/preprocess/preprocess_utils.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np +import os +from scipy.io import loadmat,savemat +from PIL import Image,ImageOps +from array import array +import cv2 + + +# Load expression basis provided by Guo et al., +# https://github.com/Juyong/3DFace. +def LoadExpBasis(): + n_vertex = 53215 + Expbin = open('./renderer/BFM face model/Exp_Pca.bin','rb') + exp_dim = array('i') + exp_dim.fromfile(Expbin,1) + expMU = array('f') + expPC = array('f') + expMU.fromfile(Expbin,3*n_vertex) + expPC.fromfile(Expbin,3*exp_dim[0]*n_vertex) + + expPC = np.array(expPC) + expPC = np.reshape(expPC,[exp_dim[0],-1]) + expPC = np.transpose(expPC) + + expEV = np.loadtxt('./renderer/BFM face model/std_exp.txt') + + return expPC,expEV + +# Load BFM09 face model and transfer it to our face model +def transferBFM09(): + original_BFM = loadmat('./renderer/BFM face model/01_MorphableModel.mat') + shapePC = original_BFM['shapePC'] # shape basis + shapeEV = original_BFM['shapeEV'] # corresponding eigen value + shapeMU = original_BFM['shapeMU'] # mean face + texPC = original_BFM['texPC'] # texture basis + texEV = original_BFM['texEV'] # eigen value + texMU = original_BFM['texMU'] # mean texture + + expPC,expEV = LoadExpBasis() # expression basis and eigen value + + idBase = shapePC*np.reshape(shapeEV,[-1,199]) + idBase = idBase/1e5 # unify the scale to decimeter + idBase = idBase[:,:80] # use only first 80 basis + + exBase = expPC*np.reshape(expEV,[-1,79]) + exBase = exBase/1e5 # unify the scale to decimeter + exBase = exBase[:,:64] # use only first 64 basis + + texBase = texPC*np.reshape(texEV,[-1,199]) + texBase = texBase[:,:80] # use only first 80 basis + + # Our face model is cropped along face landmarks which contains only 35709 vertex. + # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. + # thus we select corresponding vertex to get our face model. + + index_exp = loadmat('./renderer/BFM face model/BFM_front_idx.mat') + index_exp = index_exp['idx'].astype(np.int32) - 1 #starts from 0 (to 53215) + + index_shape = loadmat('./renderer/BFM face model/BFM_exp_idx.mat') + index_shape = index_shape['trimIndex'].astype(np.int32) - 1 #starts from 0 (to 53490) + index_shape = index_shape[index_exp] + + + idBase = np.reshape(idBase,[-1,3,80]) + idBase = idBase[index_shape,:,:] + idBase = np.reshape(idBase,[-1,80]) + + texBase = np.reshape(texBase,[-1,3,80]) + texBase = texBase[index_shape,:,:] + texBase = np.reshape(texBase,[-1,80]) + + exBase = np.reshape(exBase,[-1,3,64]) + exBase = exBase[index_exp,:,:] + exBase = np.reshape(exBase,[-1,64]) + + meanshape = np.reshape(shapeMU,[-1,3])/1e5 + meanshape = meanshape[index_shape,:] + meanshape = np.reshape(meanshape,[1,-1]) + + meantex = np.reshape(texMU,[-1,3]) + meantex = meantex[index_shape,:] + meantex = np.reshape(meantex,[1,-1]) + + # region used for image rendering, and 68 landmarks index etc. + gan_tl = loadmat('./renderer/BFM face model/gan_tl.mat') + gan_tl = gan_tl['f'] + + gan_mask = loadmat('./renderer/BFM face model/gan_mask.mat') + gan_mask = gan_mask['idx'] + + other_info = loadmat('./renderer/BFM face model/facemodel_info.mat') + keypoints = other_info['keypoints'] + point_buf = other_info['point_buf'] + tri = other_info['tri'] + + # save our face model + savemat('./renderer/BFM face model/BFM_model_front_gan.mat',{'meanshape':meanshape,'meantex':meantex,'idBase':idBase,'exBase':exBase,'texBase':texBase,\ + 'tri':tri,'point_buf':point_buf,'keypoints':keypoints,'gan_mask':gan_mask,'gan_tl':gan_tl}) + +#calculating least sqaures problem +def POS(xp,x): + npts = xp.shape[1] + + A = np.zeros([2*npts,8]) + + A[0:2*npts-1:2,0:3] = x.transpose() + A[0:2*npts-1:2,3] = 1 + + A[1:2*npts:2,4:7] = x.transpose() + A[1:2*npts:2,7] = 1; + + b = np.reshape(xp.transpose(),[2*npts,1]) + + k,_,_,_ = np.linalg.lstsq(A,b) + + R1 = k[0:3] + R2 = k[4:7] + sTx = k[3] + sTy = k[7] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 + t = np.stack([sTx,sTy],axis = 0) + + return t,s + +# align image for 3D face reconstruction +def process_img(img,lm,t,s,target_size = 512.): + w0,h0 = img.size + w = (w0/s*102).astype(np.int32) + h = (h0/s*102).astype(np.int32) + img = img.resize((w,h),resample = Image.BICUBIC) + + left = (w/2 - target_size/2 + float((t[0] - w0/2)*102/s)).astype(np.int32) + right = left + target_size + up = (h/2 - target_size/2 + float((h0/2 - t[1])*102/s)).astype(np.int32) + below = up + target_size + + img = img.crop((left,up,right,below)) + lm = np.stack([lm[:,0] - t[0] + w0/2,lm[:,1] - t[1] + h0/2],axis = 1)/s*102 + lm = lm - np.reshape(np.array([(w/2 - target_size/2),(h/2-target_size/2)]),[1,2]) + + return img,lm + +def Preprocess(img,lm,lm3D,target_size = 512.): + + w0,h0 = img.size + + # change from image plane coordinates to 3D sapce coordinates(X-Y plane) + lm = np.stack([lm[:,0],h0 - 1 - lm[:,1]], axis = 1) + + # calculate translation and scale factors using 5 facial landmarks and standard landmarks + t,s = POS(lm.transpose(),lm3D.transpose()) + s = s*224./target_size + + # processing the image + img_new,lm_new = process_img(img,lm,t,s,target_size = target_size) + lm_new = np.stack([lm_new[:,0],target_size - lm_new[:,1]], axis = 1) + trans_params = np.array([w0,h0,102.0/s,t[0],t[1]]) + + return img_new,lm_new,trans_params + + +def load_lm3d(): + + Lm3D = loadmat('preprocess/similarity_Lm3D_all.mat') + Lm3D = Lm3D['lm'] + + # calculate 5 facial landmarks using 68 landmarks + lm_idx = np.array([31,37,40,43,46,49,55]) - 1 + Lm3D = np.stack([Lm3D[lm_idx[0],:],np.mean(Lm3D[lm_idx[[1,2]],:],0),np.mean(Lm3D[lm_idx[[3,4]],:],0),Lm3D[lm_idx[5],:],Lm3D[lm_idx[6],:]], axis = 0) + Lm3D = Lm3D[[1,2,0,3,4],:] + + return Lm3D + +# load input images and corresponding 5 landmarks +def load_img(img_path,lm_path): + + image = Image.open(img_path) + lm = np.loadtxt(lm_path) + + return image,lm + + +# Crop and rescale face region for GAN training +def crop_n_rescale_face_region(image,coeff): + tx = coeff[0,254] + ty = coeff[0,255] + tz = coeff[0,256] + f = 1015.*512/224 + cam_pos = 10. + scale = 1.22*224/512 + + # cancel translation and rescale face size + M = np.float32([[1,0,-f*tx/(cam_pos - tz)],[0,1,f*ty/(cam_pos - tz)]]) + (rows, cols) = image.shape[:2] + img_shift = cv2.warpAffine(image,M,(cols,rows)) + + # crop image to 256*256 + scale_ = scale*(cam_pos - tz)/cam_pos + w = int(cols*scale_) + h = int(rows*scale_) + res = cv2.resize(img_shift,(w,h)) + res = Image.fromarray(res.astype(np.uint8),'RGB') + res = ImageOps.expand(res,border=10,fill = 'black') + res = res.crop((round(w/2)-128+10,round(h/2)-128+10,round(w/2)+128+10,round(h/2)+128+10)) + res = np.array(res) + res = res.astype(np.uint8) + + return res diff --git a/preprocess/similarity_Lm3D_all.mat b/preprocess/similarity_Lm3D_all.mat new file mode 100644 index 0000000..a0e2358 Binary files /dev/null and b/preprocess/similarity_Lm3D_all.mat differ diff --git a/preprocess_data.py b/preprocess_data.py new file mode 100644 index 0000000..b8c5f4e --- /dev/null +++ b/preprocess_data.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Script for data pre-processing.""" + +import tensorflow as tf +import numpy as np +import cv2 +from PIL import Image +import os +from scipy.io import loadmat,savemat +from renderer import face_decoder +from training.networks_recon import R_Net +from preprocess.preprocess_utils import * + +# Pretrained face reconstruction model from Deng et al. 19, +# https://github.com/microsoft/Deep3DFaceReconstruction +model_continue_path = 'training/pretrained_weights/recon_net' +R_net_weights = os.path.join(model_continue_path,'FaceReconModel.ckpt') +config = tf.ConfigProto() +config.gpu_options.visible_device_list = '0' + +def parse_args(): + desc = "Data Preprocess of DisentangledFaceGAN" + parser = argparse.ArgumentParser(description=desc) + + parser.add_argument('--image_path', type=str, help='Training image path.') + parser.add_argument('--lm_path', type=str, help='Deteced landmark path.') + parser.add_argument('--save_path', type=str, default='./data' ,help='Save path for aligned images and extracted coefficients.') + + return parser.parse_args() + +def main(): + args = parse_args() + image_path = args.image_path + lm_path = args.lm_path + # lm_path = os.path.join(args.image_path,'lm5p') # detected landmarks for training images should be saved in /lm5p subfolder + + # create save path for aligned images and extracted coefficients + save_path = args.save_path + if not os.path.exists(os.path.join(save_path,'img')): + os.makedirs(os.path.join(save_path,'img')) + if not os.path.exists(os.path.join(save_path,'coeff')): + os.makedirs(os.path.join(save_path,'coeff')) + + # Load BFM09 face model + if not os.path.isfile('./renderer/BFM face model/BFM_model_front_gan.mat'): + transferBFM09() + + # Load standard landmarks for alignment + lm3D = load_lm3d() + + + # Build reconstruction model + with tf.Graph().as_default() as graph: + + images = tf.placeholder(name = 'input_imgs', shape = [None,224,224,3], dtype = tf.float32) + Face3D = face_decoder.Face3D() # analytic 3D face formation process + coeff = R_Net(images,is_training=False) # 3D face reconstruction network + + with tf.Session(config = config) as sess: + + var_list = tf.trainable_variables() + g_list = tf.global_variables() + + # Add batch normalization params into trainable variables + bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] + bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] + var_list +=bn_moving_vars + + # Create saver to save and restore weights + resnet_vars = [v for v in var_list if 'resnet_v1_50' in v.name] + res_fc = [v for v in var_list if 'fc-id' in v.name or 'fc-ex' in v.name or 'fc-tex' in v.name or 'fc-angles' in v.name or 'fc-gamma' in v.name or 'fc-XY' in v.name or 'fc-Z' in v.name or 'fc-f' in v.name] + resnet_vars += res_fc + + saver = tf.train.Saver(var_list = var_list,max_to_keep = 100) + saver.restore(sess,R_net_weights) + + for file in os.listdir(os.path.join(image_path)): + if file.endswith('png'): + print(file) + + # load images and landmarks + image = Image.open(os.path.join(image_path,file)) + if not os.path.isfile(os.path.join(lm_path,file.replace('png','txt'))): + continue + lm = np.loadtxt(os.path.join(lm_path,file.replace('png','txt'))) + lm = np.reshape(lm,[5,2]) + + # align image for 3d face reconstruction + align_img,_,_ = Preprocess(image,lm,lm3D) # 512*512*3 RGB image + align_img = np.array(align_img) + + align_img_ = align_img[:,:,::-1] #RGBtoBGR + align_img_ = cv2.resize(align_img_,(224,224)) # input image to reconstruction network should be 224*224 + align_img_ = np.expand_dims(align_img_,0) + coef = sess.run(coeff,feed_dict = {images: align_img_}) + + # align image for GAN training + # eliminate translation and rescale face size to proper scale + rescale_img = crop_n_rescale_face_region(align_img,coef) # 256*256*3 RGB image + coef = np.squeeze(coef,0) + + # save aligned images and extracted coefficients + cv2.imwrite(os.path.join(save_path,'img',file),rescale_img[:,:,::-1]) + savemat(os.path.join(save_path,'coeff',file.replace('.png','.mat')),{'coeff':coef}) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/renderer/BFM face model/BFM_exp_idx.mat b/renderer/BFM face model/BFM_exp_idx.mat new file mode 100644 index 0000000..1146e4e Binary files /dev/null and b/renderer/BFM face model/BFM_exp_idx.mat differ diff --git a/renderer/BFM face model/BFM_front_idx.mat b/renderer/BFM face model/BFM_front_idx.mat new file mode 100644 index 0000000..b9d7b09 Binary files /dev/null and b/renderer/BFM face model/BFM_front_idx.mat differ diff --git a/renderer/BFM face model/facemodel_info.mat b/renderer/BFM face model/facemodel_info.mat new file mode 100644 index 0000000..3e516ec Binary files /dev/null and b/renderer/BFM face model/facemodel_info.mat differ diff --git a/renderer/BFM face model/gan_mask.mat b/renderer/BFM face model/gan_mask.mat new file mode 100644 index 0000000..8cc0054 Binary files /dev/null and b/renderer/BFM face model/gan_mask.mat differ diff --git a/renderer/BFM face model/gan_tl.mat b/renderer/BFM face model/gan_tl.mat new file mode 100644 index 0000000..ffe4652 Binary files /dev/null and b/renderer/BFM face model/gan_tl.mat differ diff --git a/renderer/BFM face model/select_vertex_id.mat b/renderer/BFM face model/select_vertex_id.mat new file mode 100644 index 0000000..5b8b220 Binary files /dev/null and b/renderer/BFM face model/select_vertex_id.mat differ diff --git a/renderer/BFM face model/similarity_Lm3D_all.mat b/renderer/BFM face model/similarity_Lm3D_all.mat new file mode 100644 index 0000000..a0e2358 Binary files /dev/null and b/renderer/BFM face model/similarity_Lm3D_all.mat differ diff --git a/renderer/BFM face model/std_exp.txt b/renderer/BFM face model/std_exp.txt new file mode 100644 index 0000000..767b8de --- /dev/null +++ b/renderer/BFM face model/std_exp.txt @@ -0,0 +1 @@ +453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19 \ No newline at end of file diff --git a/renderer/__init__.py b/renderer/__init__.py new file mode 100644 index 0000000..8be9837 --- /dev/null +++ b/renderer/__init__.py @@ -0,0 +1 @@ +#. \ No newline at end of file diff --git a/renderer/camera_utils.py b/renderer/camera_utils.py new file mode 100644 index 0000000..04c067a --- /dev/null +++ b/renderer/camera_utils.py @@ -0,0 +1,152 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collection of TF functions for managing 3D camera matrices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import tensorflow as tf + + +def perspective(aspect_ratio, fov_y, near_clip, far_clip): + """Computes perspective transformation matrices. + + Functionality mimes gluPerspective (third_party/GL/glu/include/GLU/glu.h). + + Args: + aspect_ratio: float value specifying the image aspect ratio (width/height). + fov_y: 1-D float32 Tensor with shape [batch_size] specifying output vertical + field of views in degrees. + near_clip: 1-D float32 Tensor with shape [batch_size] specifying near + clipping plane distance. + far_clip: 1-D float32 Tensor with shape [batch_size] specifying far clipping + plane distance. + + Returns: + A [batch_size, 4, 4] float tensor that maps from right-handed points in eye + space to left-handed points in clip space. + """ + # The multiplication of fov_y by pi/360.0 simultaneously converts to radians + # and adds the half-angle factor of .5. + focal_lengths_y = 1.0 / tf.tan(fov_y * (math.pi / 360.0)) + depth_range = far_clip - near_clip + p_22 = -(far_clip + near_clip) / depth_range + p_23 = -2.0 * (far_clip * near_clip / depth_range) + + zeros = tf.zeros_like(p_23, dtype=tf.float32) + # pyformat: disable + perspective_transform = tf.concat( + [ + focal_lengths_y / aspect_ratio, zeros, zeros, zeros, + zeros, focal_lengths_y, zeros, zeros, + zeros, zeros, p_22, p_23, + zeros, zeros, -tf.ones_like(p_23, dtype=tf.float32), zeros + ], axis=0) + # pyformat: enable + perspective_transform = tf.reshape(perspective_transform, [4, 4, -1]) + return tf.transpose(perspective_transform, [2, 0, 1]) + + +def look_at(eye, center, world_up): + """Computes camera viewing matrices. + + Functionality mimes gluLookAt (third_party/GL/glu/include/GLU/glu.h). + + Args: + eye: 2-D float32 tensor with shape [batch_size, 3] containing the XYZ world + space position of the camera. + center: 2-D float32 tensor with shape [batch_size, 3] containing a position + along the center of the camera's gaze. + world_up: 2-D float32 tensor with shape [batch_size, 3] specifying the + world's up direction; the output camera will have no tilt with respect + to this direction. + + Returns: + A [batch_size, 4, 4] float tensor containing a right-handed camera + extrinsics matrix that maps points from world space to points in eye space. + """ + batch_size = center.shape[0].value + vector_degeneracy_cutoff = 1e-6 + forward = center - eye + forward_norm = tf.norm(forward, ord='euclidean', axis=1, keep_dims=True) + tf.assert_greater( + forward_norm, + vector_degeneracy_cutoff, + message='Camera matrix is degenerate because eye and center are close.') + forward = tf.divide(forward, forward_norm) + + to_side = tf.cross(forward, world_up) + to_side_norm = tf.norm(to_side, ord='euclidean', axis=1, keep_dims=True) + tf.assert_greater( + to_side_norm, + vector_degeneracy_cutoff, + message='Camera matrix is degenerate because up and gaze are close or' + 'because up is degenerate.') + to_side = tf.divide(to_side, to_side_norm) + cam_up = tf.cross(to_side, forward) + + w_column = tf.constant( + batch_size * [[0., 0., 0., 1.]], dtype=tf.float32) # [batch_size, 4] + w_column = tf.reshape(w_column, [batch_size, 4, 1]) + view_rotation = tf.stack( + [to_side, cam_up, -forward, + tf.zeros_like(to_side, dtype=tf.float32)], + axis=1) # [batch_size, 4, 3] matrix + view_rotation = tf.concat( + [view_rotation, w_column], axis=2) # [batch_size, 4, 4] + + identity_batch = tf.tile(tf.expand_dims(tf.eye(3), 0), [batch_size, 1, 1]) + view_translation = tf.concat([identity_batch, tf.expand_dims(-eye, 2)], 2) + view_translation = tf.concat( + [view_translation, + tf.reshape(w_column, [batch_size, 1, 4])], 1) + camera_matrices = tf.matmul(view_rotation, view_translation) + return camera_matrices + + +def euler_matrices(angles): + """Computes a XYZ Tait-Bryan (improper Euler angle) rotation. + + Returns 4x4 matrices for convenient multiplication with other transformations. + + Args: + angles: a [batch_size, 3] tensor containing X, Y, and Z angles in radians. + + Returns: + a [batch_size, 4, 4] tensor of matrices. + """ + s = tf.sin(angles) + c = tf.cos(angles) + # Rename variables for readability in the matrix definition below. + c0, c1, c2 = (c[:, 0], c[:, 1], c[:, 2]) + s0, s1, s2 = (s[:, 0], s[:, 1], s[:, 2]) + + zeros = tf.zeros_like(s[:, 0]) + ones = tf.ones_like(s[:, 0]) + + # pyformat: disable + flattened = tf.concat( + [ + c2 * c1, c2 * s1 * s0 - c0 * s2, s2 * s0 + c2 * c0 * s1, zeros, + c1 * s2, c2 * c0 + s2 * s1 * s0, c0 * s2 * s1 - c2 * s0, zeros, + -s1, c1 * s0, c1 * c0, zeros, + zeros, zeros, zeros, ones + ], + axis=0) + # pyformat: enable + reshaped = tf.reshape(flattened, [4, 4, -1]) + return tf.transpose(reshaped, [2, 0, 1]) diff --git a/renderer/face_decoder.py b/renderer/face_decoder.py new file mode 100644 index 0000000..d16194e --- /dev/null +++ b/renderer/face_decoder.py @@ -0,0 +1,349 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import tensorflow as tf +import math as m +import numpy as np +from renderer import mesh_renderer +from scipy.io import loadmat + +# Reconstruct 3D face based on output coefficients and facemodel +#----------------------------------------------------------------------------------------- + +# BFM 3D face model +class BFM(): + def __init__(self,model_path = 'renderer/BFM face model/BFM_model_front_gan.mat'): + model = loadmat(model_path) + self.meanshape = tf.constant(model['meanshape']) # mean face shape. [3*N,1] + self.idBase = tf.constant(model['idBase']) # identity basis. [3*N,80] + self.exBase = tf.constant(model['exBase'].astype(np.float32)) # expression basis. [3*N,64] + self.meantex = tf.constant(model['meantex']) # mean face texture. [3*N,1] (0-255) + self.texBase = tf.constant(model['texBase']) # texture basis. [3*N,80] + self.point_buf = tf.constant(model['point_buf']) # face indices for each vertex that lies in. starts from 1. [N,8] + self.face_buf = tf.constant(model['tri']) # vertex indices for each face. starts from 1. [F,3] + self.front_mask_render = tf.squeeze(tf.constant(model['gan_mask'])) # vertex indices for small face region for rendering. starts from 1. + self.mask_face_buf = tf.constant(model['gan_tl']) # vertex indices for each face from small face region. starts from 1. [f,3] + self.keypoints = tf.squeeze(tf.constant(model['keypoints'])) # vertex indices for 68 landmarks. starts from 1. [68,1] + +# Analytic 3D face +class Face3D(): + def __init__(self): + facemodel = BFM() + self.facemodel = facemodel + + # analytic 3D face reconstructions with coefficients from R-Net + def Reconstruction_Block(self,coeff,res,batchsize,progressive=True): + #coeff: [batchsize,257] reconstruction coefficients + id_coeff,ex_coeff,tex_coeff,angles,translation,gamma = self.Split_coeff(coeff) + # [batchsize,N,3] canonical face shape in BFM space + face_shape = self.Shape_formation_block(id_coeff,ex_coeff,self.facemodel) + # [batchsize,N,3] vertex texture (in RGB order) + face_texture = self.Texture_formation_block(tex_coeff,self.facemodel) + # [batchsize,3,3] rotation matrix for face shape + rotation = self.Compute_rotation_matrix(angles) + # [batchsize,N,3] vertex normal + face_norm = self.Compute_norm(face_shape,self.facemodel) + norm_r = tf.matmul(face_norm,rotation) + + # do rigid transformation for face shape using predicted rotation and translation + face_shape_t = self.Rigid_transform_block(face_shape,rotation,translation) + # compute 2d landmark projections + # landmark_p: [batchsize,68,2] + face_landmark_t = self.Compute_landmark(face_shape_t,self.facemodel) + landmark_p = self.Projection_block(face_landmark_t) # 256*256 image + + # [batchsize,N,3] vertex color (in RGB order) + face_color = self.Illumination_block(face_texture, norm_r, gamma) + + # reconstruction images and region masks + if progressive: + render_imgs,img_mask = tf.cond(res<=8, lambda:self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,8,64), + lambda: + tf.cond(res<=16, lambda:self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,16,32), + lambda: + tf.cond(res<=32, lambda:self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,32,16), + lambda: + tf.cond(res<=64, lambda:self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,64,8), + lambda: + tf.cond(res<=128, lambda:self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,128,4), + lambda: + self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,256,4) + ))))) + else: + render_imgs,img_mask = self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,res,batchsize) + + + render_imgs = tf.clip_by_value(render_imgs,0,255) + render_imgs = tf.cast(render_imgs,tf.float32) + render_mask = tf.cast(img_mask,tf.float32) + + return render_imgs,render_mask,landmark_p,face_shape_t + + + def Get_landmark(self,coeff): + face_shape_t = self.Get_face_shape(coeff) + # compute 2d landmark projections + # landmark_p: [batchsize,68,2] + face_landmark_t = self.Compute_landmark(face_shape_t,self.facemodel) + landmark_p = self.Projection_block(face_landmark_t,focal=1015.,half_image_width=112.) # 224*224 image + + return landmark_p + + def Get_face_shape(self,coeff): + #coeff: [batchsize,257] reconstruction coefficients + + id_coeff,ex_coeff,tex_coeff,angles,translation,gamma = self.Split_coeff(coeff) + # [batchsize,N,3] canonical face shape in BFM space + face_shape = self.Shape_formation_block(id_coeff,ex_coeff,self.facemodel) + # [batchsize,3,3] rotation matrix for face shape + rotation = self.Compute_rotation_matrix(angles) + + # do rigid transformation for face shape using predicted rotation and translation + face_shape_t = self.Rigid_transform_block(face_shape,rotation,translation) + + return face_shape_t + + def Split_coeff(self,coeff): + + id_coeff = coeff[:,:80] + tex_coeff = coeff[:,80:160] + ex_coeff = coeff[:,160:224] + angles = coeff[:,224:227] + gamma = coeff[:,227:254] + translation = coeff[:,254:257] + + return id_coeff,ex_coeff,tex_coeff,angles,translation,gamma + + def Shape_formation_block(self,id_coeff,ex_coeff,facemodel): + face_shape = tf.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \ + tf.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + facemodel.meanshape + + # reshape face shape to [batchsize,N,3] + face_shape = tf.reshape(face_shape,[tf.shape(face_shape)[0],-1,3]) + # re-centering the face shape with mean shape + face_shape = face_shape - tf.reshape(tf.reduce_mean(tf.reshape(facemodel.meanshape,[-1,3]),0),[1,1,3]) + + return face_shape + + def Compute_norm(self,face_shape,facemodel): + shape = face_shape + face_id = facemodel.face_buf + point_id = facemodel.point_buf + + # face_id and point_id index starts from 1 + face_id = tf.cast(face_id - 1,tf.int32) + point_id = tf.cast(point_id - 1,tf.int32) + + #compute normal for each face + v1 = tf.gather(shape,face_id[:,0], axis = 1) + v2 = tf.gather(shape,face_id[:,1], axis = 1) + v3 = tf.gather(shape,face_id[:,2], axis = 1) + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = tf.cross(e1,e2) + + face_norm = tf.nn.l2_normalize(face_norm, dim = 2) # normalized face_norm first + face_norm = tf.concat([face_norm,tf.zeros([tf.shape(face_shape)[0],1,3])], axis = 1) + + #compute normal for each vertex using one-ring neighborhood + v_norm = tf.reduce_sum(tf.gather(face_norm, point_id, axis = 1), axis = 2) + v_norm = tf.nn.l2_normalize(v_norm, dim = 2) + + return v_norm + + def Texture_formation_block(self,tex_coeff,facemodel): + face_texture = tf.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex + + # reshape face texture to [batchsize,N,3], note that texture is in RGB order + face_texture = tf.reshape(face_texture,[tf.shape(face_texture)[0],-1,3]) + + return face_texture + + def Compute_rotation_matrix(self,angles): + n_data = tf.shape(angles)[0] + + # compute rotation matrix for X-axis, Y-axis, Z-axis respectively + rotation_X = tf.concat([tf.ones([n_data,1]), + tf.zeros([n_data,3]), + tf.reshape(tf.cos(angles[:,0]),[n_data,1]), + -tf.reshape(tf.sin(angles[:,0]),[n_data,1]), + tf.zeros([n_data,1]), + tf.reshape(tf.sin(angles[:,0]),[n_data,1]), + tf.reshape(tf.cos(angles[:,0]),[n_data,1])], + axis = 1 + ) + + rotation_Y = tf.concat([tf.reshape(tf.cos(angles[:,1]),[n_data,1]), + tf.zeros([n_data,1]), + tf.reshape(tf.sin(angles[:,1]),[n_data,1]), + tf.zeros([n_data,1]), + tf.ones([n_data,1]), + tf.zeros([n_data,1]), + -tf.reshape(tf.sin(angles[:,1]),[n_data,1]), + tf.zeros([n_data,1]), + tf.reshape(tf.cos(angles[:,1]),[n_data,1])], + axis = 1 + ) + + rotation_Z = tf.concat([tf.reshape(tf.cos(angles[:,2]),[n_data,1]), + -tf.reshape(tf.sin(angles[:,2]),[n_data,1]), + tf.zeros([n_data,1]), + tf.reshape(tf.sin(angles[:,2]),[n_data,1]), + tf.reshape(tf.cos(angles[:,2]),[n_data,1]), + tf.zeros([n_data,3]), + tf.ones([n_data,1])], + axis = 1 + ) + + rotation_X = tf.reshape(rotation_X,[n_data,3,3]) + rotation_Y = tf.reshape(rotation_Y,[n_data,3,3]) + rotation_Z = tf.reshape(rotation_Z,[n_data,3,3]) + + # R = RzRyRx + rotation = tf.matmul(tf.matmul(rotation_Z,rotation_Y),rotation_X) + + # because our face shape is N*3, so compute the transpose of R, so that rotation shapes can be calculated as face_shape*R + rotation = tf.transpose(rotation, perm = [0,2,1]) + + return rotation + + def Projection_block(self,face_shape,focal=1015.0*1.22,half_image_width=128.): + + # pre-defined camera focal for pespective projection + focal = tf.constant(focal) + # focal = tf.constant(400.0) + focal = tf.reshape(focal,[-1,1]) + batchsize = tf.shape(face_shape)[0] + # center = tf.constant(112.0) + + # define camera position + # camera_pos = tf.reshape(tf.constant([0.0,0.0,10.0]),[1,1,3]) + camera_pos = tf.reshape(tf.constant([0.0,0.0,10.0]),[1,1,3]) + # camera_pos = tf.reshape(tf.constant([0.0,0.0,4.0]),[1,1,3]) + reverse_z = tf.tile(tf.reshape(tf.constant([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3]),[tf.shape(face_shape)[0],1,1]) + + # compute projection matrix + # p_matrix = tf.concat([[focal],[0.0],[center],[0.0],[focal],[center],[0.0],[0.0],[1.0]],axis = 0) + p_matrix = tf.concat([focal*tf.ones([batchsize,1]),tf.zeros([batchsize,1]),half_image_width*tf.ones([batchsize,1]),tf.zeros([batchsize,1]),\ + focal*tf.ones([batchsize,1]),half_image_width*tf.ones([batchsize,1]),tf.zeros([batchsize,2]),tf.ones([batchsize,1])],axis = 1) + # p_matrix = tf.tile(tf.reshape(p_matrix,[1,3,3]),[tf.shape(face_shape)[0],1,1]) + p_matrix = tf.reshape(p_matrix,[-1,3,3]) + + # convert z in canonical space to the distance to camera + face_shape = tf.matmul(face_shape,reverse_z) + camera_pos + aug_projection = tf.matmul(face_shape,tf.transpose(p_matrix,[0,2,1])) + + # [batchsize, N,2] 2d face projection + face_projection = aug_projection[:,:,0:2]/tf.reshape(aug_projection[:,:,2],[tf.shape(face_shape)[0],tf.shape(aug_projection)[1],1]) + + + return face_projection + + + def Compute_landmark(self,face_shape,facemodel): + + # compute 3D landmark postitions with pre-computed 3D face shape + keypoints_idx = facemodel.keypoints + keypoints_idx = tf.cast(keypoints_idx - 1,tf.int32) + face_landmark = tf.gather(face_shape,keypoints_idx,axis = 1) + + return face_landmark + + def Illumination_block(self,face_texture,norm_r,gamma): + n_data = tf.shape(gamma)[0] + n_point = tf.shape(norm_r)[1] + gamma = tf.reshape(gamma,[n_data,3,9]) + # set initial lighting with an ambient lighting + init_lit = tf.constant([0.8,0,0,0,0,0,0,0,0]) + gamma = gamma + tf.reshape(init_lit,[1,1,9]) + + # compute vertex color using SH function approximation + a0 = m.pi + a1 = 2*m.pi/tf.sqrt(3.0) + a2 = 2*m.pi/tf.sqrt(8.0) + c0 = 1/tf.sqrt(4*m.pi) + c1 = tf.sqrt(3.0)/tf.sqrt(4*m.pi) + c2 = 3*tf.sqrt(5.0)/tf.sqrt(12*m.pi) + + Y = tf.concat([tf.tile(tf.reshape(a0*c0,[1,1,1]),[n_data,n_point,1]), + tf.expand_dims(-a1*c1*norm_r[:,:,1],2), + tf.expand_dims(a1*c1*norm_r[:,:,2],2), + tf.expand_dims(-a1*c1*norm_r[:,:,0],2), + tf.expand_dims(a2*c2*norm_r[:,:,0]*norm_r[:,:,1],2), + tf.expand_dims(-a2*c2*norm_r[:,:,1]*norm_r[:,:,2],2), + tf.expand_dims(a2*c2*0.5/tf.sqrt(3.0)*(3*tf.square(norm_r[:,:,2])-1),2), + tf.expand_dims(-a2*c2*norm_r[:,:,0]*norm_r[:,:,2],2), + tf.expand_dims(a2*c2*0.5*(tf.square(norm_r[:,:,0])-tf.square(norm_r[:,:,1])),2)],axis = 2) + + color_r = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,0,:],2)),axis = 2) + color_g = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,1,:],2)),axis = 2) + color_b = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,2,:],2)),axis = 2) + + #[batchsize,N,3] vertex color in RGB order + face_color = tf.stack([color_r*face_texture[:,:,0],color_g*face_texture[:,:,1],color_b*face_texture[:,:,2]],axis = 2) + + return face_color + + def Rigid_transform_block(self,face_shape,rotation,translation): + # do rigid transformation for 3D face shape + face_shape_r = tf.matmul(face_shape,rotation) + face_shape_t = face_shape_r + tf.reshape(translation,[tf.shape(face_shape)[0],1,3]) + + return face_shape_t + + def Render_block(self,face_shape,face_norm,face_color,facemodel,res,batchsize): + # render reconstruction images + n_vex = int(facemodel.idBase.shape[0].value/3) + fov_y = 2*tf.atan(128/(1015.*1.22))*180./m.pi + # full face region + face_shape = tf.reshape(face_shape,[batchsize,n_vex,3]) + face_norm = tf.reshape(face_norm,[batchsize,n_vex,3]) + face_color = tf.reshape(face_color,[batchsize,n_vex,3]) + + # pre-defined cropped face region + mask_face_shape = tf.gather(face_shape,tf.cast(facemodel.front_mask_render-1,tf.int32),axis = 1) + mask_face_norm = tf.gather(face_norm,tf.cast(facemodel.front_mask_render-1,tf.int32),axis = 1) + mask_face_color = tf.gather(face_color,tf.cast(facemodel.front_mask_render-1,tf.int32),axis = 1) + + # setting cammera settings + camera_position = tf.constant([[0,0,10.0]]) + tf.zeros([batchsize,3]) + camera_lookat = tf.constant([[0,0,0.0]]) + tf.zeros([batchsize,3]) + camera_up = tf.constant([[0,1.0,0]]) + tf.zeros([batchsize,3]) + + # setting light source position(intensities are set to 0 because we have computed the vertex color) + light_positions = tf.reshape(tf.constant([0,0,1e5]),[1,1,3]) + tf.zeros([batchsize,1,3]) + light_intensities = tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3])+tf.zeros([batchsize,1,3]) + ambient_color = tf.reshape(tf.constant([1.0,1,1]),[1,3])+ tf.zeros([batchsize,3]) + + near_clip = 0.01 + far_clip = 50. + + # using tf_mesh_renderer for rasterization, + # https://github.com/google/tf_mesh_renderer + + # img: [batchsize,224,224,3] images in RGB order (0-255) + # mask:[batchsize,224,224,1] transparency for img ({0,1} value) + with tf.device('/cpu:0'): + rgba_img = mesh_renderer.mesh_renderer(mask_face_shape, + tf.cast(facemodel.mask_face_buf-1,tf.int32), + mask_face_norm, + mask_face_color, + camera_position = camera_position, + camera_lookat = camera_lookat, + camera_up = camera_up, + light_positions = light_positions, + light_intensities = light_intensities, + image_width = res, + image_height = res, + # fov_y = 12.5936, + fov_y = fov_y, + ambient_color = ambient_color, + near_clip = near_clip, + far_clip = far_clip) + + img = rgba_img[:,:,:,:3] + mask = rgba_img[:,:,:,3:] + + return img,mask + + diff --git a/renderer/mesh_renderer.py b/renderer/mesh_renderer.py new file mode 100644 index 0000000..79b89fd --- /dev/null +++ b/renderer/mesh_renderer.py @@ -0,0 +1,404 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Differentiable 3-D rendering of a triangle mesh.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +from renderer import camera_utils +from renderer import rasterize_triangles + + +def phong_shader(normals, + alphas, + pixel_positions, + light_positions, + light_intensities, + diffuse_colors=None, + camera_position=None, + specular_colors=None, + shininess_coefficients=None, + ambient_color=None): + """Computes pixelwise lighting from rasterized buffers with the Phong model. + + Args: + normals: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the world space XYZ normal for + the corresponding pixel. Should be already normalized. + alphas: a 3D float32 tensor with shape [batch_size, image_height, + image_width]. The inner dimension is the alpha value (transparency) + for the corresponding pixel. + pixel_positions: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the world space XYZ position for + the corresponding pixel. + light_positions: a 3D tensor with shape [batch_size, light_count, 3]. The + XYZ position of each light in the scene. In the same coordinate space as + pixel_positions. + light_intensities: a 3D tensor with shape [batch_size, light_count, 3]. The + RGB intensity values for each light. Intensities may be above one. + diffuse_colors: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the diffuse RGB coefficients at + a pixel in the range [0, 1]. + camera_position: a 1D tensor with shape [batch_size, 3]. The XYZ camera + position in the scene. If supplied, specular reflections will be + computed. If not supplied, specular_colors and shininess_coefficients + are expected to be None. In the same coordinate space as + pixel_positions. + specular_colors: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the specular RGB coefficients at + a pixel in the range [0, 1]. If None, assumed to be tf.zeros() + shininess_coefficients: A 3D float32 tensor that is broadcasted to shape + [batch_size, image_height, image_width]. The inner dimension is the + shininess coefficient for the object at a pixel. Dimensions that are + constant can be given length 1, so [batch_size, 1, 1] and [1, 1, 1] are + also valid input shapes. + ambient_color: a 2D tensor with shape [batch_size, 3]. The RGB ambient + color, which is added to each pixel before tone mapping. If None, it is + assumed to be tf.zeros(). + Returns: + A 4D float32 tensor of shape [batch_size, image_height, image_width, 4] + containing the lit RGBA color values for each image at each pixel. Colors + are in the range [0,1]. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + batch_size, image_height, image_width = [s.value for s in normals.shape[:-1]] + light_count = light_positions.shape[1].value + pixel_count = image_height * image_width + # Reshape all values to easily do pixelwise computations: + normals = tf.reshape(normals, [batch_size, -1, 3]) + alphas = tf.reshape(alphas, [batch_size, -1, 1]) + diffuse_colors = tf.reshape(diffuse_colors, [batch_size, -1, 3]) + if camera_position is not None: + specular_colors = tf.reshape(specular_colors, [batch_size, -1, 3]) + + # Ambient component + output_colors = tf.zeros([batch_size, image_height * image_width, 3]) + if ambient_color is not None: + ambient_reshaped = tf.expand_dims(ambient_color, axis=1) + output_colors = tf.add(output_colors, ambient_reshaped * diffuse_colors) + + # Diffuse component + pixel_positions = tf.reshape(pixel_positions, [batch_size, -1, 3]) + per_light_pixel_positions = tf.stack( + [pixel_positions] * light_count, + axis=1) # [batch_size, light_count, pixel_count, 3] + directions_to_lights = tf.nn.l2_normalize( + tf.expand_dims(light_positions, axis=2) - per_light_pixel_positions, + dim=3) # [batch_size, light_count, pixel_count, 3] + # The specular component should only contribute when the light and normal + # face one another (i.e. the dot product is nonnegative): + normals_dot_lights = tf.clip_by_value( + tf.reduce_sum( + tf.expand_dims(normals, axis=1) * directions_to_lights, axis=3), 0.0, + 1.0) # [batch_size, light_count, pixel_count] + diffuse_output = tf.expand_dims( + diffuse_colors, axis=1) * tf.expand_dims( + normals_dot_lights, axis=3) * tf.expand_dims( + light_intensities, axis=2) + diffuse_output = tf.reduce_sum( + diffuse_output, axis=1) # [batch_size, pixel_count, 3] + output_colors = tf.add(output_colors, diffuse_output) + + # Specular component + if camera_position is not None: + camera_position = tf.reshape(camera_position, [batch_size, 1, 3]) + mirror_reflection_direction = tf.nn.l2_normalize( + 2.0 * tf.expand_dims(normals_dot_lights, axis=3) * tf.expand_dims( + normals, axis=1) - directions_to_lights, + dim=3) + direction_to_camera = tf.nn.l2_normalize( + camera_position - pixel_positions, dim=2) + reflection_direction_dot_camera_direction = tf.reduce_sum( + tf.expand_dims(direction_to_camera, axis=1) * + mirror_reflection_direction, + axis=3) + # The specular component should only contribute when the reflection is + # external: + reflection_direction_dot_camera_direction = tf.clip_by_value( + tf.nn.l2_normalize(reflection_direction_dot_camera_direction, dim=2), + 0.0, 1.0) + # The specular component should also only contribute when the diffuse + # component contributes: + reflection_direction_dot_camera_direction = tf.where( + normals_dot_lights != 0.0, reflection_direction_dot_camera_direction, + tf.zeros_like( + reflection_direction_dot_camera_direction, dtype=tf.float32)) + # Reshape to support broadcasting the shininess coefficient, which rarely + # varies per-vertex: + reflection_direction_dot_camera_direction = tf.reshape( + reflection_direction_dot_camera_direction, + [batch_size, light_count, image_height, image_width]) + shininess_coefficients = tf.expand_dims(shininess_coefficients, axis=1) + specularity = tf.reshape( + tf.pow(reflection_direction_dot_camera_direction, + shininess_coefficients), + [batch_size, light_count, pixel_count, 1]) + specular_output = tf.expand_dims( + specular_colors, axis=1) * specularity * tf.expand_dims( + light_intensities, axis=2) + specular_output = tf.reduce_sum(specular_output, axis=1) + output_colors = tf.add(output_colors, specular_output) + rgb_images = tf.reshape(output_colors, + [batch_size, image_height, image_width, 3]) + alpha_images = tf.reshape(alphas, [batch_size, image_height, image_width, 1]) + valid_rgb_values = tf.concat(3 * [alpha_images > 0.5], axis=3) + rgb_images = tf.where(valid_rgb_values, rgb_images, + tf.zeros_like(rgb_images, dtype=tf.float32)) + return tf.reverse(tf.concat([rgb_images, alpha_images], axis=3), axis=[1]) + + +def tone_mapper(image, gamma): + """Applies gamma correction to the input image. + + Tone maps the input image batch in order to make scenes with a high dynamic + range viewable. The gamma correction factor is computed separately per image, + but is shared between all provided channels. The exact function computed is: + + image_out = A*image_in^gamma, where A is an image-wide constant computed so + that the maximum image value is approximately 1. The correction is applied + to all channels. + + Args: + image: 4-D float32 tensor with shape [batch_size, image_height, + image_width, channel_count]. The batch of images to tone map. + gamma: 0-D float32 nonnegative tensor. Values of gamma below one compress + relative contrast in the image, and values above one increase it. A + value of 1 is equivalent to scaling the image to have a maximum value + of 1. + Returns: + 4-D float32 tensor with shape [batch_size, image_height, image_width, + channel_count]. Contains the gamma-corrected images, clipped to the range + [0, 1]. + """ + batch_size = image.shape[0].value + corrected_image = tf.pow(image, gamma) + image_max = tf.reduce_max( + tf.reshape(corrected_image, [batch_size, -1]), axis=1) + scaled_image = tf.divide(corrected_image, + tf.reshape(image_max, [batch_size, 1, 1, 1])) + return tf.clip_by_value(scaled_image, 0.0, 1.0) + + +def mesh_renderer(vertices, + triangles, + normals, + diffuse_colors, + camera_position, + camera_lookat, + camera_up, + light_positions, + light_intensities, + image_width, + image_height, + specular_colors=None, + shininess_coefficients=None, + ambient_color=None, + fov_y=40.0, + near_clip=0.01, + far_clip=50.0): + """Renders an input scene using phong shading, and returns an output image. + + Args: + vertices: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each + triplet is an xyz position in world space. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + normals: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each + triplet is the xyz vertex normal for its corresponding vertex. Each + vector is assumed to be already normalized. + diffuse_colors: 3-D float32 tensor with shape [batch_size, + vertex_count, 3]. The RGB diffuse reflection in the range [0,1] for + each vertex. + camera_position: 2-D tensor with shape [batch_size, 3] or 1-D tensor with + shape [3] specifying the XYZ world space camera position. + camera_lookat: 2-D tensor with shape [batch_size, 3] or 1-D tensor with + shape [3] containing an XYZ point along the center of the camera's gaze. + camera_up: 2-D tensor with shape [batch_size, 3] or 1-D tensor with shape + [3] containing the up direction for the camera. The camera will have no + tilt with respect to this direction. + light_positions: a 3-D tensor with shape [batch_size, light_count, 3]. The + XYZ position of each light in the scene. In the same coordinate space as + pixel_positions. + light_intensities: a 3-D tensor with shape [batch_size, light_count, 3]. The + RGB intensity values for each light. Intensities may be above one. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + specular_colors: 3-D float32 tensor with shape [batch_size, + vertex_count, 3]. The RGB specular reflection in the range [0, 1] for + each vertex. If supplied, specular reflections will be computed, and + both specular_colors and shininess_coefficients are expected. + shininess_coefficients: a 0D-2D float32 tensor with maximum shape + [batch_size, vertex_count]. The phong shininess coefficient of each + vertex. A 0D tensor or float gives a constant shininess coefficient + across all batches and images. A 1D tensor must have shape [batch_size], + and a single shininess coefficient per image is used. + ambient_color: a 2D tensor with shape [batch_size, 3]. The RGB ambient + color, which is added to each pixel in the scene. If None, it is + assumed to be black. + fov_y: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + desired output image y field of view in degrees. + near_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + near clipping plane distance. + far_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + far clipping plane distance. + + Returns: + A 4-D float32 tensor of shape [batch_size, image_height, image_width, 4] + containing the lit RGBA color values for each image at each pixel. RGB + colors are the intensity values before tonemapping and can be in the range + [0, infinity]. Clipping to the range [0,1] with tf.clip_by_value is likely + reasonable for both viewing and training most scenes. More complex scenes + with multiple lights should tone map color values for display only. One + simple tonemapping approach is to rescale color values as x/(1+x); gamma + compression is another common techinque. Alpha values are zero for + background pixels and near one for mesh pixels. + Raises: + ValueError: An invalid argument to the method is detected. + """ + if len(vertices.shape) != 3: + raise ValueError('Vertices must have shape [batch_size, vertex_count, 3].') + batch_size = vertices.shape[0].value + # print(batch_size) + if len(normals.shape) != 3: + raise ValueError('Normals must have shape [batch_size, vertex_count, 3].') + if len(light_positions.shape) != 3: + raise ValueError( + 'Light_positions must have shape [batch_size, light_count, 3].') + if len(light_intensities.shape) != 3: + raise ValueError( + 'Light_intensities must have shape [batch_size, light_count, 3].') + if len(diffuse_colors.shape) != 3: + raise ValueError( + 'vertex_diffuse_colors must have shape [batch_size, vertex_count, 3].') + if (ambient_color is not None and + ambient_color.get_shape().as_list() != [batch_size, 3]): + raise ValueError('Ambient_color must have shape [batch_size, 3].') + if camera_position.get_shape().as_list() == [3]: + camera_position = tf.tile( + tf.expand_dims(camera_position, axis=0), [batch_size, 1]) + elif camera_position.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_position must have shape [batch_size, 3]') + if camera_lookat.get_shape().as_list() == [3]: + camera_lookat = tf.tile( + tf.expand_dims(camera_lookat, axis=0), [batch_size, 1]) + elif camera_lookat.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_lookat must have shape [batch_size, 3]') + if camera_up.get_shape().as_list() == [3]: + camera_up = tf.tile(tf.expand_dims(camera_up, axis=0), [batch_size, 1]) + elif camera_up.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_up must have shape [batch_size, 3]') + if isinstance(fov_y, float): + fov_y = tf.constant(batch_size * [fov_y], dtype=tf.float32) + elif not fov_y.get_shape().as_list(): + fov_y = tf.tile(tf.expand_dims(fov_y, 0), [batch_size]) + elif fov_y.get_shape().as_list() != [batch_size]: + raise ValueError('Fov_y must be a float, a 0D tensor, or a 1D tensor with' + 'shape [batch_size]') + if isinstance(near_clip, float): + near_clip = tf.constant(batch_size * [near_clip], dtype=tf.float32) + elif not near_clip.get_shape().as_list(): + near_clip = tf.tile(tf.expand_dims(near_clip, 0), [batch_size]) + elif near_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Near_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + if isinstance(far_clip, float): + far_clip = tf.constant(batch_size * [far_clip], dtype=tf.float32) + elif not far_clip.get_shape().as_list(): + far_clip = tf.tile(tf.expand_dims(far_clip, 0), [batch_size]) + elif far_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Far_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + if specular_colors is not None and shininess_coefficients is None: + raise ValueError( + 'Specular colors were supplied without shininess coefficients.') + if shininess_coefficients is not None and specular_colors is None: + raise ValueError( + 'Shininess coefficients were supplied without specular colors.') + if specular_colors is not None: + # Since a 0-D float32 tensor is accepted, also accept a float. + if isinstance(shininess_coefficients, float): + shininess_coefficients = tf.constant( + shininess_coefficients, dtype=tf.float32) + if len(specular_colors.shape) != 3: + raise ValueError('The specular colors must have shape [batch_size, ' + 'vertex_count, 3].') + if len(shininess_coefficients.shape) > 2: + raise ValueError('The shininess coefficients must have shape at most' + '[batch_size, vertex_count].') + # If we don't have per-vertex coefficients, we can just reshape the + # input shininess to broadcast later, rather than interpolating an + # additional vertex attribute: + if len(shininess_coefficients.shape) < 2: + vertex_attributes = tf.concat( + [normals, vertices, diffuse_colors, specular_colors], axis=2) + else: + vertex_attributes = tf.concat( + [ + normals, vertices, diffuse_colors, specular_colors, + tf.expand_dims(shininess_coefficients, axis=2) + ], + axis=2) + else: + vertex_attributes = tf.concat([normals, vertices, diffuse_colors], axis=2) + + camera_matrices = camera_utils.look_at(camera_position, camera_lookat, + camera_up) + + perspective_transforms = camera_utils.perspective(image_width / image_height, + fov_y, near_clip, far_clip) + + clip_space_transforms = tf.matmul(perspective_transforms, camera_matrices) + + pixel_attributes,alphas = rasterize_triangles.rasterize_triangles( + vertices, vertex_attributes, triangles, clip_space_transforms, + image_width, image_height, [-1] * vertex_attributes.shape[2].value) + + # Extract the interpolated vertex attributes from the pixel buffer and + # supply them to the shader: + pixel_normals = tf.nn.l2_normalize(pixel_attributes[:, :, :, 0:3], dim=3) + pixel_positions = pixel_attributes[:, :, :, 3:6] + diffuse_colors = pixel_attributes[:, :, :, 6:9] + if specular_colors is not None: + specular_colors = pixel_attributes[:, :, :, 9:12] + # Retrieve the interpolated shininess coefficients if necessary, or just + # reshape our input for broadcasting: + if len(shininess_coefficients.shape) == 2: + shininess_coefficients = pixel_attributes[:, :, :, 12] + else: + shininess_coefficients = tf.reshape(shininess_coefficients, [-1, 1, 1]) + + # pixel_mask = tf.cast(tf.reduce_any(diffuse_colors >= 0, axis=3), tf.float32) + + renders = phong_shader( + normals=pixel_normals, + alphas=alphas, + pixel_positions=pixel_positions, + light_positions=light_positions, + light_intensities=light_intensities, + diffuse_colors=diffuse_colors, + camera_position=camera_position if specular_colors is not None else None, + specular_colors=specular_colors, + shininess_coefficients=shininess_coefficients, + ambient_color=ambient_color) + return renders diff --git a/renderer/rasterize_triangles.py b/renderer/rasterize_triangles.py new file mode 100644 index 0000000..d8bcc32 --- /dev/null +++ b/renderer/rasterize_triangles.py @@ -0,0 +1,190 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Differentiable triangle rasterizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tensorflow as tf + + +# rasterize_triangles_module = tf.load_op_library( +# os.path.join(os.environ['TEST_SRCDIR'], +# 'tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_kernel.so')) + + +rasterize_triangles_module = tf.load_op_library('./renderer/rasterize_triangles_kernel_1.so') + + +# This epsilon should be smaller than any valid barycentric reweighting factor +# (i.e. the per-pixel reweighting factor used to correct for the effects of +# perspective-incorrect barycentric interpolation). It is necessary primarily +# because the reweighting factor will be 0 for factors outside the mesh, and we +# need to ensure the image color and gradient outside the region of the mesh are +# 0. +_MINIMUM_REWEIGHTING_THRESHOLD = 1e-6 + +# This epsilon is the minimum absolute value of a homogenous coordinate before +# it is clipped. It should be sufficiently large such that the output of +# the perspective divide step with this denominator still has good working +# precision with 32 bit arithmetic, and sufficiently small so that in practice +# vertices are almost never close enough to a clipping plane to be thresholded. +_MINIMUM_PERSPECTIVE_DIVIDE_THRESHOLD = 1e-6 + + +def rasterize_triangles(vertices, attributes, triangles, projection_matrices, + image_width, image_height, background_value): + """Rasterizes the input scene and computes interpolated vertex attributes. + + NOTE: the rasterizer does no triangle clipping. Triangles that lie outside the + viewing frustum (esp. behind the camera) may be drawn incorrectly. + + Args: + vertices: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each + triplet is an xyz position in model space. + attributes: 3-D float32 tensor with shape [batch_size, vertex_count, + attribute_count]. Each vertex attribute is interpolated + across the triangle using barycentric interpolation. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + projection_matrices: 3-D float tensor with shape [batch_size, 4, 4] + containing model-view-perspective projection matrices. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + background_value: a 1-D float32 tensor with shape [attribute_count]. Pixels + that lie outside all triangles take this value. + + Returns: + A 4-D float32 tensor with shape [batch_size, image_height, image_width, + attribute_count], containing the interpolated vertex attributes at + each pixel. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + if not image_width > 0: + raise ValueError('Image width must be > 0.') + if not image_height > 0: + raise ValueError('Image height must be > 0.') + if len(vertices.shape) != 3: + raise ValueError('The vertex buffer must be 3D.') + batch_size = vertices.shape[0].value + vertex_count = vertices.shape[1].value + + # We map the coordinates to normalized device coordinates before passing + # the scene to the rendering kernel to keep as many ops in tensorflow as + # possible. + + homogeneous_coord = tf.ones([batch_size, vertex_count, 1], dtype=tf.float32) + vertices_homogeneous = tf.concat([vertices, homogeneous_coord], 2) + + # Vertices are given in row-major order, but the transformation pipeline is + # column major: + clip_space_points = tf.matmul( + vertices_homogeneous, projection_matrices, transpose_b=True) + + # Perspective divide, first thresholding the homogeneous coordinate to avoid + # the possibility of NaNs: + clip_space_points_w = tf.maximum( + tf.abs(clip_space_points[:, :, 3:4]), + _MINIMUM_PERSPECTIVE_DIVIDE_THRESHOLD) * tf.sign( + clip_space_points[:, :, 3:4]) + normalized_device_coordinates = ( + clip_space_points[:, :, 0:3] / clip_space_points_w) + + per_image_uncorrected_barycentric_coordinates = [] + per_image_vertex_ids = [] + for im in range(vertices.shape[0]): + barycentric_coords, triangle_ids, _ = ( + rasterize_triangles_module.rasterize_triangles( + normalized_device_coordinates[im, :, :], triangles, image_width, + image_height)) + per_image_uncorrected_barycentric_coordinates.append( + tf.reshape(barycentric_coords, [-1, 3])) + + # Gathers the vertex indices now because the indices don't contain a batch + # identifier, and reindexes the vertex ids to point to a (batch,vertex_id) + vertex_ids = tf.gather(triangles, tf.reshape(triangle_ids, [-1])) + reindexed_ids = tf.add(vertex_ids, im * vertices.shape[1].value) + per_image_vertex_ids.append(reindexed_ids) + + uncorrected_barycentric_coordinates = tf.concat( + per_image_uncorrected_barycentric_coordinates, axis=0) + vertex_ids = tf.concat(per_image_vertex_ids, axis=0) + + # Indexes with each pixel's clip-space triangle's extrema (the pixel's + # 'corner points') ids to get the relevant properties for deferred shading. + flattened_vertex_attributes = tf.reshape(attributes, + [batch_size * vertex_count, -1]) + corner_attributes = tf.gather(flattened_vertex_attributes, vertex_ids) + + # Barycentric interpolation is linear in the reciprocal of the homogeneous + # W coordinate, so we use these weights to correct for the effects of + # perspective distortion after rasterization. + perspective_distortion_weights = tf.reciprocal( + tf.reshape(clip_space_points_w, [-1])) + corner_distortion_weights = tf.gather(perspective_distortion_weights, + vertex_ids) + + # Apply perspective correction to the barycentric coordinates. This step is + # required since the rasterizer receives normalized-device coordinates (i.e., + # after perspective division), so it can't apply perspective correction to the + # interpolated values. + weighted_barycentric_coordinates = tf.multiply( + uncorrected_barycentric_coordinates, corner_distortion_weights) + barycentric_reweighting_factor = tf.reduce_sum( + weighted_barycentric_coordinates, axis=1) + + corrected_barycentric_coordinates = tf.divide( + weighted_barycentric_coordinates, + tf.expand_dims( + tf.maximum(barycentric_reweighting_factor, + _MINIMUM_REWEIGHTING_THRESHOLD), + axis=1)) + + # Computes the pixel attributes by interpolating the known attributes at the + # corner points of the triangle interpolated with the barycentric coordinates. + weighted_vertex_attributes = tf.multiply( + corner_attributes, + tf.expand_dims(corrected_barycentric_coordinates, axis=2)) + summed_attributes = tf.reduce_sum(weighted_vertex_attributes, axis=1) + attribute_images = tf.reshape(summed_attributes, + [batch_size, image_height, image_width, -1]) + + # Barycentric coordinates should approximately sum to one where there is + # rendered geometry, but be exactly zero where there is not. + alphas = tf.clip_by_value( + tf.reduce_sum(2.0 * corrected_barycentric_coordinates, axis=1), 0.0, 1.0) + alphas = tf.reshape(alphas, [batch_size, image_height, image_width, 1]) + + attributes_with_background = ( + alphas * attribute_images + (1.0 - alphas) * background_value) + + return attributes_with_background,alphas + + +@tf.RegisterGradient('RasterizeTriangles') +def _rasterize_triangles_grad(op, df_dbarys, df_dids, df_dz): + # Gradients are only supported for barycentric coordinates. Gradients for the + # z-buffer are possible as well but not currently implemented. + del df_dids, df_dz + return rasterize_triangles_module.rasterize_triangles_grad( + op.inputs[0], op.inputs[1], op.outputs[0], op.outputs[1], df_dbarys, + op.get_attr('image_width'), op.get_attr('image_height')), None diff --git a/train.py b/train.py new file mode 100644 index 0000000..35bc0ba --- /dev/null +++ b/train.py @@ -0,0 +1,123 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Main entry point for training StyleGAN networks.""" + +import copy +import dnnlib +from dnnlib import EasyDict +import argparse +import config +from metrics import metric_base + +#---------------------------------------------------------------------------- +# Official training configs for StyleGAN, targeted mainly for FFHQ. + +if 1: + desc = 'sgan' + train = EasyDict() # Description string included in result subdir name. + G = EasyDict(func_name='training.networks_stylegan.G_style') # Options for generator network. + D = EasyDict(func_name='training.networks_stylegan.D_basic') # Options for discriminator network. + G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. + D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. + G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating') # Options for generator loss. + D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # Options for discriminator loss. + dataset = EasyDict() # Options for load_dataset(). + sched = EasyDict() # Options for TrainingSchedule. + grid = EasyDict(size='1080p', layout='random') # Options for setup_snapshot_image_grid(). + metrics = [metric_base.fid50k] # Options for MetricGroup. + submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). + tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). + + # Dataset. + desc += '-ffhq256'; dataset = EasyDict(tfrecord_dir='ffhq_align', resolution=256); train.mirror_augment = True + + # Number of GPUs. + #desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4} + #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8} + desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16} + # desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32} + + # Default options. + train.total_kimg = 25000 + sched.lod_initial_resolution = 8 + sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} + sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) + +#---------------------------------------------------------------------------- +# Main entry point for training. +# Calls the function indicated by 'train' using the selected options. + + +#---------------------------------------------------------------------------- +# Modified by Deng et al. +def parse_args(): + desc = "Tensorflow implementation of DisentangledFaceGAN" + parser = argparse.ArgumentParser(description=desc) + + parser.add_argument('--w_id', type=float, default=3, help='weight for identity perceptual loss') + parser.add_argument('--w_lm', type=float, default=500, help='weight for landmark loss') + parser.add_argument('--w_gamma', type=float, default=10, help='weight for lighting loss') + parser.add_argument('--w_skin', type=float, default=20, help='weight for face region loss') + parser.add_argument('--w_exp_warp', type=float, default=10, help='weight for expression change loss') + parser.add_argument('--w_gamma_change', type=float, default=10, help='weight for lighting change loss') + parser.add_argument('--noise_dim', type=int, default=32, help='dimension of the additional noise factor') + parser.add_argument('--stage', type=int, default=1, help='training stage. 1 = imitative losses only; 2 = imitative losses and contrastive losses') + parser.add_argument('--run_id', type=int, default=0, help='run ID or network pkl to resume training from') + parser.add_argument('--snapshot', type=int, default=0, help='snapshot index to resume training from') + parser.add_argument('--kimg', type=float, default=0, help='assumed training progress at certain number of images') + + return parser.parse_args() +#---------------------------------------------------------------------------- + + +def main(): + + #------------------------------------------------------------------------ + # Modified by Deng et al. + args = parse_args() + if args is None: + exit() + + + weight_args = EasyDict() + weight_args.update(w_id=args.w_id,w_lm=args.w_lm,w_gamma=args.w_gamma,w_skin=args.w_skin, + w_exp_warp=args.w_exp_warp,w_gamma_change=args.w_gamma_change) + + train.update(run_func_name='training.training_loop.training_loop') + kwargs = EasyDict(train) + + # stage 1: training with only imitative losses with 15000k images. + if args.stage == 1: + train_stage = EasyDict(func_name='training.training_utils.training_stage1') + kwargs.update(total_kimg=15000) + + # stage 2: training with imitative losses and contrastive losses. + else: + train_stage = EasyDict(func_name='training.training_utils.training_stage2') + kwargs.update(resume_run_id=args.run_id,resume_snapshot=args.snapshot,resume_kimg=args.kimg) + kwargs.update(total_kimg=25000) + weight_args.update(w_lm=100) + + kwargs.update(train_stage_args=train_stage) + kwargs.update(weight_args = weight_args,noise_dim = args.noise_dim) + #------------------------------------------------------------------------ + + kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) + kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) + kwargs.submit_config = copy.deepcopy(submit_config) + kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) + kwargs.submit_config.run_dir_ignore += config.run_dir_ignore + kwargs.submit_config.run_desc = desc + dnnlib.submit_run(**kwargs) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..db8124b --- /dev/null +++ b/training/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +# empty diff --git a/training/dataset.py b/training/dataset.py new file mode 100644 index 0000000..cf14222 --- /dev/null +++ b/training/dataset.py @@ -0,0 +1,241 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Multi-resolution input data pipeline.""" + +import os +import glob +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib + +#---------------------------------------------------------------------------- +# Parse individual image from a tfrecords file. + +def parse_tfrecord_tf(record): + features = tf.parse_single_example(record, features={ + 'shape': tf.FixedLenFeature([3], tf.int64), + 'data': tf.FixedLenFeature([], tf.string)}) + data = tf.decode_raw(features['data'], tf.uint8) + return tf.reshape(data, features['shape']) + +def parse_tfrecord_np(record): + ex = tf.train.Example() + ex.ParseFromString(record) + shape = ex.features.feature['shape'].int64_list.value # temporary pylint workaround # pylint: disable=no-member + data = ex.features.feature['data'].bytes_list.value[0] # temporary pylint workaround # pylint: disable=no-member + return np.fromstring(data, np.uint8).reshape(shape) + +#---------------------------------------------------------------------------- +# Dataset class that loads data from tfrecords files. + +class TFRecordDataset: + def __init__(self, + tfrecord_dir, # Directory containing a collection of tfrecords files. + resolution = None, # Dataset resolution, None = autodetect. + label_file = None, # Relative path of the labels file, None = autodetect. + max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. + repeat = True, # Repeat dataset indefinitely. + shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. + prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. + buffer_mb = 256, # Read buffer size (megabytes). + num_threads = 2): # Number of concurrent threads. + + self.tfrecord_dir = tfrecord_dir + self.resolution = None + self.resolution_log2 = None + self.shape = [] # [channel, height, width] + self.dtype = 'uint8' + self.dynamic_range = [0, 255] + self.label_file = label_file + self.label_size = None # [component] + self.label_dtype = None + self._np_labels = None + self._tf_minibatch_in = None + self._tf_labels_var = None + self._tf_labels_dataset = None + self._tf_datasets = dict() + self._tf_iterator = None + self._tf_init_ops = dict() + self._tf_minibatch_np = None + self._cur_minibatch = -1 + self._cur_lod = -1 + + # List tfrecords files and inspect their shapes. + assert os.path.isdir(self.tfrecord_dir) + tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords'))) + assert len(tfr_files) >= 1 + tfr_shapes = [] + for tfr_file in tfr_files: + tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) + for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): + tfr_shapes.append(parse_tfrecord_np(record).shape) + break + + # Autodetect label filename. + if self.label_file is None: + guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels'))) + if len(guess): + self.label_file = guess[0] + elif not os.path.isfile(self.label_file): + guess = os.path.join(self.tfrecord_dir, self.label_file) + if os.path.isfile(guess): + self.label_file = guess + + # Determine shape and resolution. + max_shape = max(tfr_shapes, key=np.prod) + self.resolution = resolution if resolution is not None else max_shape[1] + self.resolution_log2 = int(np.log2(self.resolution)) + self.shape = [max_shape[0], self.resolution, self.resolution] + tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] + assert all(shape[0] == max_shape[0] for shape in tfr_shapes) + assert all(shape[1] == shape[2] for shape in tfr_shapes) + assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) + assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1)) + + # Load labels. + assert max_label_size == 'full' or max_label_size >= 0 + self._np_labels = np.zeros([1<<20, 0], dtype=np.float32) + if self.label_file is not None and max_label_size != 0: + self._np_labels = np.load(self.label_file) + assert self._np_labels.ndim == 2 + if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: + self._np_labels = self._np_labels[:, :max_label_size] + self.label_size = self._np_labels.shape[1] + self.label_dtype = self._np_labels.dtype.name + + # Build TF expressions. + with tf.name_scope('Dataset'), tf.device('/cpu:0'): + self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) + self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') + self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) + for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): + if tfr_lod < 0: + continue + dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20) + dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads) + dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) + bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize + if shuffle_mb > 0: + dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) + if repeat: + dset = dset.repeat() + if prefetch_mb > 0: + dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) + dset = dset.batch(self._tf_minibatch_in) + self._tf_datasets[tfr_lod] = dset + self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) + self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} + + # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). + def configure(self, minibatch_size, lod=0): + lod = int(np.floor(lod)) + assert minibatch_size >= 1 and lod in self._tf_datasets + if self._cur_minibatch != minibatch_size or self._cur_lod != lod: + self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) + self._cur_minibatch = minibatch_size + self._cur_lod = lod + + # Get next minibatch as TensorFlow expressions. + def get_minibatch_tf(self): # => images, labels + return self._tf_iterator.get_next() + + # Get next minibatch as NumPy arrays. + def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels + self.configure(minibatch_size, lod) + if self._tf_minibatch_np is None: + self._tf_minibatch_np = self.get_minibatch_tf() + return tflib.run(self._tf_minibatch_np) + + # Get random labels as TensorFlow expression. + def get_random_labels_tf(self, minibatch_size): # => labels + if self.label_size > 0: + with tf.device('/cpu:0'): + return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) + return tf.zeros([minibatch_size, 0], self.label_dtype) + + # Get random labels as NumPy array. + def get_random_labels_np(self, minibatch_size): # => labels + if self.label_size > 0: + return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] + return np.zeros([minibatch_size, 0], self.label_dtype) + +#---------------------------------------------------------------------------- +# Base class for datasets that are generated on the fly. + +class SyntheticDataset: + def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dynamic_range=[0,255], label_size=0, label_dtype='float32'): + self.resolution = resolution + self.resolution_log2 = int(np.log2(resolution)) + self.shape = [num_channels, resolution, resolution] + self.dtype = dtype + self.dynamic_range = dynamic_range + self.label_size = label_size + self.label_dtype = label_dtype + self._tf_minibatch_var = None + self._tf_lod_var = None + self._tf_minibatch_np = None + self._tf_labels_np = None + + assert self.resolution == 2 ** self.resolution_log2 + with tf.name_scope('Dataset'): + self._tf_minibatch_var = tf.Variable(np.int32(0), name='minibatch_var') + self._tf_lod_var = tf.Variable(np.int32(0), name='lod_var') + + def configure(self, minibatch_size, lod=0): + lod = int(np.floor(lod)) + assert minibatch_size >= 1 and 0 <= lod <= self.resolution_log2 + tflib.set_vars({self._tf_minibatch_var: minibatch_size, self._tf_lod_var: lod}) + + def get_minibatch_tf(self): # => images, labels + with tf.name_scope('SyntheticDataset'): + shrink = tf.cast(2.0 ** tf.cast(self._tf_lod_var, tf.float32), tf.int32) + shape = [self.shape[0], self.shape[1] // shrink, self.shape[2] // shrink] + images = self._generate_images(self._tf_minibatch_var, self._tf_lod_var, shape) + labels = self._generate_labels(self._tf_minibatch_var) + return images, labels + + def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels + self.configure(minibatch_size, lod) + if self._tf_minibatch_np is None: + self._tf_minibatch_np = self.get_minibatch_tf() + return tflib.run(self._tf_minibatch_np) + + def get_random_labels_tf(self, minibatch_size): # => labels + with tf.name_scope('SyntheticDataset'): + return self._generate_labels(minibatch_size) + + def get_random_labels_np(self, minibatch_size): # => labels + self.configure(minibatch_size) + if self._tf_labels_np is None: + self._tf_labels_np = self.get_random_labels_tf(minibatch_size) + return tflib.run(self._tf_labels_np) + + def _generate_images(self, minibatch, lod, shape): # to be overridden by subclasses # pylint: disable=unused-argument + return tf.zeros([minibatch] + shape, self.dtype) + + def _generate_labels(self, minibatch): # to be overridden by subclasses + return tf.zeros([minibatch, self.label_size], self.label_dtype) + +#---------------------------------------------------------------------------- +# Helper func for constructing a dataset object using the given options. + +def load_dataset(class_name='training.dataset.TFRecordDataset', data_dir=None, verbose=False, **kwargs): + adjusted_kwargs = dict(kwargs) + if 'tfrecord_dir' in adjusted_kwargs and data_dir is not None: + adjusted_kwargs['tfrecord_dir'] = os.path.join(data_dir, adjusted_kwargs['tfrecord_dir']) + if verbose: + print('Streaming data using %s...' % class_name) + dataset = dnnlib.util.get_obj_by_name(class_name)(**adjusted_kwargs) + if verbose: + print('Dataset shape =', np.int32(dataset.shape).tolist()) + print('Dynamic range =', dataset.dynamic_range) + print('Label size =', dataset.label_size) + return dataset + +#---------------------------------------------------------------------------- diff --git a/training/inception_resnet_v1.py b/training/inception_resnet_v1.py new file mode 100644 index 0000000..80ec889 --- /dev/null +++ b/training/inception_resnet_v1.py @@ -0,0 +1,247 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contains the definition of the Inception Resnet V1 architecture. +As described in http://arxiv.org/abs/1602.07261. + Inception-v4, Inception-ResNet and the Impact of Residual Connections + on Learning + Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.slim as slim + + +# Inception-Resnet-A +def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): + """Builds the 35x35 resnet block.""" + with tf.variable_scope(scope, 'Block35', [net], reuse=reuse): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1') + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') + tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3') + with tf.variable_scope('Branch_2'): + tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') + tower_conv2_1 = slim.conv2d(tower_conv2_0, 32, 3, scope='Conv2d_0b_3x3') + tower_conv2_2 = slim.conv2d(tower_conv2_1, 32, 3, scope='Conv2d_0c_3x3') + mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3) + up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, + activation_fn=None, scope='Conv2d_1x1') + net += scale * up + if activation_fn: + net = activation_fn(net) + return net + +# Inception-Resnet-B +def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): + """Builds the 17x17 resnet block.""" + with tf.variable_scope(scope, 'Block17', [net], reuse=reuse): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d(net, 128, 1, scope='Conv2d_1x1') + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1') + tower_conv1_1 = slim.conv2d(tower_conv1_0, 128, [1, 7], + scope='Conv2d_0b_1x7') + tower_conv1_2 = slim.conv2d(tower_conv1_1, 128, [7, 1], + scope='Conv2d_0c_7x1') + mixed = tf.concat([tower_conv, tower_conv1_2], 3) + up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, + activation_fn=None, scope='Conv2d_1x1') + net += scale * up + if activation_fn: + net = activation_fn(net) + return net + + +# Inception-Resnet-C +def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): + """Builds the 8x8 resnet block.""" + with tf.variable_scope(scope, 'Block8', [net], reuse=reuse): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1') + tower_conv1_1 = slim.conv2d(tower_conv1_0, 192, [1, 3], + scope='Conv2d_0b_1x3') + tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1], + scope='Conv2d_0c_3x1') + mixed = tf.concat([tower_conv, tower_conv1_2], 3) + up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, + activation_fn=None, scope='Conv2d_1x1') + net += scale * up + if activation_fn: + net = activation_fn(net) + return net + +def reduction_a(net, k, l, m, n): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID', + scope='Conv2d_1a_3x3') + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d(net, k, 1, scope='Conv2d_0a_1x1') + tower_conv1_1 = slim.conv2d(tower_conv1_0, l, 3, + scope='Conv2d_0b_3x3') + tower_conv1_2 = slim.conv2d(tower_conv1_1, m, 3, + stride=2, padding='VALID', + scope='Conv2d_1a_3x3') + with tf.variable_scope('Branch_2'): + tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) + return net + +def reduction_b(net): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') + tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2, + padding='VALID', scope='Conv2d_1a_3x3') + with tf.variable_scope('Branch_1'): + tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') + tower_conv1_1 = slim.conv2d(tower_conv1, 256, 3, stride=2, + padding='VALID', scope='Conv2d_1a_3x3') + with tf.variable_scope('Branch_2'): + tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') + tower_conv2_1 = slim.conv2d(tower_conv2, 256, 3, + scope='Conv2d_0b_3x3') + tower_conv2_2 = slim.conv2d(tower_conv2_1, 256, 3, stride=2, + padding='VALID', scope='Conv2d_1a_3x3') + with tf.variable_scope('Branch_3'): + tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([tower_conv_1, tower_conv1_1, + tower_conv2_2, tower_pool], 3) + return net + +def inference(images, keep_probability, phase_train=True, + bottleneck_layer_size=128, weight_decay=0.0, reuse=None): + batch_norm_params = { + # Decay for the moving averages. + 'decay': 0.995, + # epsilon to prevent 0s in variance. + 'epsilon': 0.001, + # force in-place updates of mean and variance estimates + 'updates_collections': None, + # Moving averages ends up in the trainable variables collection + 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], + } + + with slim.arg_scope([slim.conv2d, slim.fully_connected], + weights_initializer=slim.initializers.xavier_initializer(), + weights_regularizer=slim.l2_regularizer(weight_decay), + normalizer_fn=slim.batch_norm, + normalizer_params=batch_norm_params): + return inception_resnet_v1(images, is_training=phase_train, + dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse) + + +def inception_resnet_v1(inputs, is_training=True, + dropout_keep_prob=0.8, + bottleneck_layer_size=128, + reuse=None, + scope='InceptionResnetV1'): + """Creates the Inception Resnet V1 model. + Args: + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + num_classes: number of predicted classes. + is_training: whether is training or not. + dropout_keep_prob: float, the fraction to keep before final layer. + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + scope: Optional variable_scope. + Returns: + logits: the logits outputs of the model. + end_points: the set of end_points from the inception model. + """ + end_points = {} + + with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse): + with slim.arg_scope([slim.batch_norm, slim.dropout], + is_training=is_training): + with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], + stride=1, padding='SAME'): + + # 149 x 149 x 32 + net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID', + scope='Conv2d_1a_3x3') + end_points['Conv2d_1a_3x3'] = net + # 147 x 147 x 32 + net = slim.conv2d(net, 32, 3, padding='VALID', + scope='Conv2d_2a_3x3') + end_points['Conv2d_2a_3x3'] = net + # 147 x 147 x 64 + net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3') + end_points['Conv2d_2b_3x3'] = net + # 73 x 73 x 64 + net = slim.max_pool2d(net, 3, stride=2, padding='VALID', + scope='MaxPool_3a_3x3') + end_points['MaxPool_3a_3x3'] = net + # 73 x 73 x 80 + net = slim.conv2d(net, 80, 1, padding='VALID', + scope='Conv2d_3b_1x1') + end_points['Conv2d_3b_1x1'] = net + # 71 x 71 x 192 + net = slim.conv2d(net, 192, 3, padding='VALID', + scope='Conv2d_4a_3x3') + end_points['Conv2d_4a_3x3'] = net + # 35 x 35 x 256 + net = slim.conv2d(net, 256, 3, stride=2, padding='VALID', + scope='Conv2d_4b_3x3') + end_points['Conv2d_4b_3x3'] = net + + # 5 x Inception-resnet-A + net = slim.repeat(net, 5, block35, scale=0.17) + end_points['Mixed_5a'] = net + + # Reduction-A + with tf.variable_scope('Mixed_6a'): + net = reduction_a(net, 192, 192, 256, 384) + end_points['Mixed_6a'] = net + + # 10 x Inception-Resnet-B + net = slim.repeat(net, 10, block17, scale=0.10) + end_points['Mixed_6b'] = net + + # Reduction-B + with tf.variable_scope('Mixed_7a'): + net = reduction_b(net) + end_points['Mixed_7a'] = net + + # 5 x Inception-Resnet-C + net = slim.repeat(net, 5, block8, scale=0.20) + end_points['Mixed_8a'] = net + + net = block8(net, activation_fn=None) + end_points['Mixed_8b'] = net + + with tf.variable_scope('Logits'): + end_points['PrePool'] = net + #pylint: disable=no-member + net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', + scope='AvgPool_1a_8x8') + net = slim.flatten(net) + + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='Dropout') + + end_points['PreLogitsFlatten'] = net + + net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, + scope='Bottleneck', reuse=False) + + return net, end_points diff --git a/training/loss.py b/training/loss.py new file mode 100644 index 0000000..478565b --- /dev/null +++ b/training/loss.py @@ -0,0 +1,180 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Loss functions.""" + +import tensorflow as tf +import dnnlib.tflib as tflib +from dnnlib.tflib.autosummary import autosummary + +#---------------------------------------------------------------------------- +# Convenience func that casts all of its arguments to tf.float32. + +def fp32(*values): + if len(values) == 1 and isinstance(values[0], tuple): + values = values[0] + values = tuple(tf.cast(v, tf.float32) for v in values) + return values if len(values) >= 2 else values[0] + +#---------------------------------------------------------------------------- +# WGAN & WGAN-GP loss functions. + +def G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + labels = training_set.get_random_labels_tf(minibatch_size) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + loss = -fake_scores_out + return loss + +def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument + wgan_epsilon = 0.001): # Weight for the epsilon term, \epsilon_{drift}. + + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = fake_scores_out - real_scores_out + + with tf.name_scope('EpsilonPenalty'): + epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) + loss += epsilon_penalty * wgan_epsilon + return loss + +def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument + wgan_lambda = 10.0, # Weight for the gradient penalty term. + wgan_epsilon = 0.001, # Weight for the epsilon term, \epsilon_{drift}. + wgan_target = 1.0): # Target value for gradient magnitudes. + + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = fake_scores_out - real_scores_out + + with tf.name_scope('GradientPenalty'): + mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) + mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) + mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) + mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) + mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) + mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) + mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) + mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) + gradient_penalty = tf.square(mixed_norms - wgan_target) + loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) + + with tf.name_scope('EpsilonPenalty'): + epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) + loss += epsilon_penalty * wgan_epsilon + return loss + +#---------------------------------------------------------------------------- +# Hinge loss functions. (Use G_wgan with these) + +def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) + return loss + +def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument + wgan_lambda = 10.0, # Weight for the gradient penalty term. + wgan_target = 1.0): # Target value for gradient magnitudes. + + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) + + with tf.name_scope('GradientPenalty'): + mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) + mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) + mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) + mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) + mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) + mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) + mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) + mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) + gradient_penalty = tf.square(mixed_norms - wgan_target) + loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) + return loss + + +#---------------------------------------------------------------------------- +# Loss functions advocated by the paper +# "Which Training Methods for GANs do actually Converge?" + +def G_logistic_saturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + labels = training_set.get_random_labels_tf(minibatch_size) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + loss = -tf.nn.softplus(fake_scores_out) # log(1 - logistic(fake_scores_out)) + return loss + +#--------------------------------------------------------------- +# Modified by Deng et al. +def G_logistic_nonsaturating(G, D, latents, opt, training_set, minibatch_size, randomize_noise = True): # pylint: disable=unused-argument + labels = training_set.get_random_labels_tf(minibatch_size) + fake_images_out = G.get_output_for(latents, labels, is_training=True, randomize_noise=randomize_noise) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + loss = tf.nn.softplus(-fake_scores_out) # -log(logistic(fake_scores_out)) + return loss,fake_images_out +#--------------------------------------------------------------- + +def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) + loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type + return loss + +#--------------------------------------------------------------- +# Modified by Deng et al. +def D_logistic_simplegp(G, D,latents, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0,randomize_noise = True): # pylint: disable=unused-argument + fake_images_out = G.get_output_for(latents, labels, is_training=True,randomize_noise=randomize_noise) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) + loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type + + + if r1_gamma != 0.0: + with tf.name_scope('R1Penalty'): + real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out)) + real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0])) + r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3]) + r1_penalty = autosummary('Loss/r1_penalty', r1_penalty) + loss += r1_penalty * (r1_gamma * 0.5) + + if r2_gamma != 0.0: + with tf.name_scope('R2Penalty'): + fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out)) + fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0])) + r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3]) + r2_penalty = autosummary('Loss/r2_penalty', r2_penalty) + loss += r2_penalty * (r2_gamma * 0.5) + return loss +#--------------------------------------------------------------- \ No newline at end of file diff --git a/training/loss_control.py b/training/loss_control.py new file mode 100644 index 0000000..5eabe10 --- /dev/null +++ b/training/loss_control.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Losses for imitative and contrastive learning +import tensorflow as tf +from dnnlib.tflib.autosummary import autosummary +from training.networks_recon import R_Net +from training.networks_id import Perceptual_Net +from training.networks_parser import Parsing +import numpy as np + +#--------------------------------------------------------------------------- + +def gaussian_kernel(size=5,sigma=2): + x_points = np.arange(-(size-1)//2,(size-1)//2+1,1) + y_points = x_points[::-1] + xs,ys = np.meshgrid(x_points,y_points) + kernel = np.exp(-(xs**2+ys**2)/(2*sigma**2))/(2*np.pi*sigma**2) + kernel = kernel/kernel.sum() + kernel = tf.constant(kernel,dtype=tf.float32) + + return kernel + +def gaussian_blur(image,size=5,sigma=2): + kernel = gaussian_kernel(size=size,sigma=sigma) + kernel = tf.tile(tf.reshape(kernel,[tf.shape(kernel)[0],tf.shape(kernel)[1],1,1]),[1,1,3,1]) + blur_image = tf.nn.depthwise_conv2d(image,kernel,strides=[1,1,1,1],padding='SAME',data_format='NHWC') + + return blur_image + +#---------------------------------------------------------------------------- +# Imitative losses + +# L1 loss between rendered image and fake image +def L1_loss(render_img,fake_images,render_mask): + l1_loss = tf.reduce_sum(tf.sqrt(tf.reduce_sum((render_img - fake_images)**2, axis = 1) + 1e-8 )*render_mask)/tf.reduce_sum(render_mask) + l1_loss = autosummary('Loss/l1_loss', l1_loss) + return l1_loss + +# landmark loss and lighting loss between rendered image and fake image +def Reconstruction_loss(fake_image,landmark_label,coeff_label,FaceRender): + landmark_label = landmark_label*224./256. + + fake_image = (fake_image+1)*127.5 + fake_image = tf.clip_by_value(fake_image,0,255) + fake_image = tf.transpose(fake_image,perm=[0,2,3,1]) + fake_image = tf.reverse(fake_image,[3]) #RGBtoBGR + fake_image = tf.image.resize_images(fake_image,size=[224, 224], method=tf.image.ResizeMethod.BILINEAR) + + # input to R_Net should have a shape of [batchsize,224,224,3], color range from 0-255 in BGR order. + with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): + coeff = R_Net(fake_image,is_training=False, reuse=tf.AUTO_REUSE) + landmark_p = FaceRender.Get_landmark(coeff) #224*224 + + landmark_weight = tf.ones([1,68]) + landmark_weight = tf.reshape(landmark_weight,[1,68,1]) + lm_loss = tf.reduce_mean(tf.square((landmark_p-landmark_label)/224)*landmark_weight) + + + fake_gamma = coeff[:,227:254] + render_gamma = coeff_label[:,227:254] + + gamma_loss = tf.reduce_mean(tf.abs(fake_gamma - render_gamma)) + + lm_loss = autosummary('Loss/lm_loss', lm_loss) + gamma_loss = autosummary('Loss/gamma_loss', gamma_loss) + + + return lm_loss,gamma_loss + +# identity similarity loss between rendered image and fake image +def ID_loss(render_image,fake_image,render_mask): + + render_image = (render_image+1)*127.5 + render_image = tf.clip_by_value(render_image,0,255) + render_image = tf.transpose(render_image,perm=[0,2,3,1]) + render_image = tf.image.resize_images(render_image,size=[160,160], method=tf.image.ResizeMethod.BILINEAR) + fake_image = (fake_image+1)*127.5 + fake_image = tf.clip_by_value(fake_image,0,255) + fake_image = tf.transpose(fake_image,perm=[0,2,3,1]) + fake_image = fake_image*tf.expand_dims(render_mask,3) + fake_image = tf.image.resize_images(fake_image,size=[160,160], method=tf.image.ResizeMethod.BILINEAR) + + render_image = tf.reshape(render_image,[-1,160,160,3]) + + # input to face recognition network should have a shape of [batchsize,160,160,3], color range from 0-255 in RGB order. + id_fake = Perceptual_Net(fake_image) + id_render = Perceptual_Net(render_image) + + id_fake = tf.nn.l2_normalize(id_fake, dim = 1) + id_render = tf.nn.l2_normalize(id_render, dim = 1) + # cosine similarity + sim = tf.reduce_sum(id_fake*id_render,1) + loss = tf.reduce_mean(tf.maximum(0.3,1.0 - sim)) # need clip! IMPORTANT + + loss = autosummary('Loss/id_loss', loss) + + return loss + +# average skin color loss between rendered image and fake image +def Skin_color_loss(fake,render,mask): + mask = tf.expand_dims(mask,1) + mean_fake = tf.reduce_sum(fake*mask,[2,3])/tf.reduce_sum(mask,[2,3]) + mean_render = tf.reduce_sum(render*mask,[2,3])/tf.reduce_sum(mask,[2,3]) + + loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum((mean_fake - mean_render)**2, axis = 1) + 1e-8 )) + loss = autosummary('Loss/skin_loss', loss) + + return loss + +#---------------------------------------------------------------------------- +# Contrastive losses + +# loss for expression change +def Exp_warp_loss(fake1,fake2,r_shape1,r_shape2,mask1,mask2,FaceRender): + + pos1_2d = FaceRender.Projection_block(r_shape1) + pos2_2d = FaceRender.Projection_block(r_shape2) + pos_diff = pos1_2d - pos2_2d + pos_diff = tf.stack([-pos_diff[:,:,1],pos_diff[:,:,0]],axis = 2) + pos_diff = tf.concat([pos_diff,tf.zeros([tf.shape(pos_diff)[0],tf.shape(pos_diff)[1],1])], axis = 2) + flow_1to2,_ = FaceRender.Render_block(r_shape2,tf.zeros_like(r_shape2),pos_diff,FaceRender.facemodel,256,1) + flow_1to2 = flow_1to2[:,:,:,:2] + fake_1to2 = tf.contrib.image.dense_image_warp(fake1,-flow_1to2) # IMPORTANT! + loss_mask = tf.cast((mask1 - mask2) <= 0, tf.float32) + fake2 = gaussian_blur(fake2,size=5,sigma=2) + fake_1to2 = gaussian_blur(fake_1to2,size=5,sigma=2) + + loss = tf.reduce_sum(tf.sqrt(tf.reduce_sum((fake2 - fake_1to2)**2,axis = 3) + 1e-8)*loss_mask)/tf.reduce_sum(loss_mask) + loss = autosummary('Loss/Exp_warp_loss', loss) + + return loss + +# loss for lighting change +def Gamma_change_loss(fake1,fake2,FaceRender): + fake_image = tf.concat([fake1,fake2],axis = 0) + fake_image = (fake_image+1)*127.5 + fake_image = tf.clip_by_value(fake_image,0,255) + fake_image = tf.reverse(fake_image,[3]) #RGBtoBGR + fake_image = tf.image.resize_images(fake_image,size=[224, 224], method=tf.image.ResizeMethod.BILINEAR) + with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): + coeff = R_Net(fake_image,is_training=False, reuse=tf.AUTO_REUSE) + landmark_p = FaceRender.Get_landmark(coeff) + landmark_p = landmark_p*256/224. + + lm1 = tf.expand_dims(landmark_p[0],0) + lm2 = tf.expand_dims(landmark_p[1],0) + hair_region_loss = Hair_region_loss(fake1,fake2,lm1,lm2) + id_consistent_loss = ID_consistent_loss(fake1,fake2) + lm_consistent_loss = Lm_consistent_loss(lm1,lm2) + + loss = hair_region_loss + 2*id_consistent_loss + 1000*lm_consistent_loss + return loss + +# hair region consistency between fake image pair +def Hair_region_loss(fake1,fake2,lm1,lm2): + fake1 = (fake1+1)*127.5 + fake1 = tf.clip_by_value(fake1,0,255) + fake2 = (fake2+1)*127.5 + fake2 = tf.clip_by_value(fake2,0,255) + + # input to face parser should have a shape of [batchsize,256,256,3], color range from 0-255 in RGB order. + seg_mask1 = Parsing(fake1,lm1) + seg_mask2 = Parsing(fake2,lm2) + + hair_mask1 = seg_mask1[:,:,:,2] + hair_mask2 = seg_mask2[:,:,:,2] + + loss = tf.reduce_mean((hair_mask1-hair_mask2)**2) + loss = autosummary('Loss/Hair_region_loss', loss) + + return loss + + +# identity consistency between fake image pair +def ID_consistent_loss(fake1,fake2): + fake1 = (fake1+1)*127.5 + fake1 = tf.clip_by_value(fake1,0,255) + fake1 = tf.image.resize_images(fake1,size=[160,160], method=tf.image.ResizeMethod.BILINEAR) + + fake2 = (fake2+1)*127.5 + fake2 = tf.clip_by_value(fake2,0,255) + fake2 = tf.image.resize_images(fake2,size=[160,160], method=tf.image.ResizeMethod.BILINEAR) + + id_fake1 = Perceptual_Net(fake1) + id_fake2 = Perceptual_Net(fake2) + + id_fake1 = tf.nn.l2_normalize(id_fake1, dim = 1) + id_fake2 = tf.nn.l2_normalize(id_fake2, dim = 1) + # cosine similarity + sim = tf.reduce_sum(id_fake1*id_fake2,1) + loss = tf.reduce_mean(1.0 - sim) + loss = autosummary('Loss/ID_consistent_loss', loss) + + return loss + +# landmark consistency between fake image pair +def Lm_consistent_loss(landmark_p1,landmark_p2): + loss = tf.reduce_mean(tf.square((landmark_p1-landmark_p2)/224)) + loss = autosummary('Loss/Lm_consistent_loss', loss) + + return loss \ No newline at end of file diff --git a/training/misc.py b/training/misc.py new file mode 100644 index 0000000..9c5ab74 --- /dev/null +++ b/training/misc.py @@ -0,0 +1,245 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Miscellaneous utility functions.""" + +import os +import glob +import pickle +import re +import numpy as np +from collections import defaultdict +import PIL.Image +import dnnlib + +import config +from training import dataset + +#---------------------------------------------------------------------------- +# Convenience wrappers for pickle that are able to load data produced by +# older versions of the code, and from external URLs. + +def open_file_or_url(file_or_url): + if dnnlib.util.is_url(file_or_url): + return dnnlib.util.open_url(file_or_url, cache_dir=config.cache_dir) + return open(file_or_url, 'rb') + +def load_pkl(file_or_url): + with open_file_or_url(file_or_url) as file: + return pickle.load(file, encoding='latin1') + +def save_pkl(obj, filename): + with open(filename, 'wb') as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) + +#---------------------------------------------------------------------------- +# Image utils. + +def adjust_dynamic_range(data, drange_in, drange_out): + if drange_in != drange_out: + scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) + bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) + data = data * scale + bias + return data + +def create_image_grid(images, grid_size=None): + assert images.ndim == 3 or images.ndim == 4 + num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] + + if grid_size is not None: + grid_w, grid_h = tuple(grid_size) + else: + grid_w = max(int(np.ceil(np.sqrt(num))), 1) + grid_h = max((num - 1) // grid_w + 1, 1) + + grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) + for idx in range(num): + x = (idx % grid_w) * img_w + y = (idx // grid_w) * img_h + grid[..., y : y + img_h, x : x + img_w] = images[idx] + return grid + +def convert_to_pil_image(image, drange=[0,1]): + assert image.ndim == 2 or image.ndim == 3 + if image.ndim == 3: + if image.shape[0] == 1: + image = image[0] # grayscale CHW => HW + else: + image = image.transpose(1, 2, 0) # CHW -> HWC + + image = adjust_dynamic_range(image, drange, [0,255]) + image = np.rint(image).clip(0, 255).astype(np.uint8) + fmt = 'RGB' if image.ndim == 3 else 'L' + return PIL.Image.fromarray(image, fmt) + +def save_image(image, filename, drange=[0,1], quality=95): + img = convert_to_pil_image(image, drange) + if '.jpg' in filename: + img.save(filename,"JPEG", quality=quality, optimize=True) + else: + img.save(filename) + +def save_image_grid(images, filename, drange=[0,1], grid_size=None): + convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) + +#---------------------------------------------------------------------------- +# Locating results. + +def locate_run_dir(run_id_or_run_dir): + if isinstance(run_id_or_run_dir, str): + if os.path.isdir(run_id_or_run_dir): + return run_id_or_run_dir + converted = dnnlib.submission.submit.convert_path(run_id_or_run_dir) + if os.path.isdir(converted): + return converted + + run_dir_pattern = re.compile('^0*%s-' % str(run_id_or_run_dir)) + for search_dir in ['']: + full_search_dir = config.result_dir if search_dir == '' else os.path.normpath(os.path.join(config.result_dir, search_dir)) + run_dir = os.path.join(full_search_dir, str(run_id_or_run_dir)) + if os.path.isdir(run_dir): + return run_dir + run_dirs = sorted(glob.glob(os.path.join(full_search_dir, '*'))) + run_dirs = [run_dir for run_dir in run_dirs if run_dir_pattern.match(os.path.basename(run_dir))] + run_dirs = [run_dir for run_dir in run_dirs if os.path.isdir(run_dir)] + if len(run_dirs) == 1: + return run_dirs[0] + raise IOError('Cannot locate result subdir for run', run_id_or_run_dir) + +def list_network_pkls(run_id_or_run_dir, include_final=True): + run_dir = locate_run_dir(run_id_or_run_dir) + pkls = sorted(glob.glob(os.path.join(run_dir, 'network-*.pkl'))) + if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': + if include_final: + pkls.append(pkls[0]) + del pkls[0] + return pkls + +def locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): + for candidate in [snapshot_or_network_pkl, run_id_or_run_dir_or_network_pkl]: + if isinstance(candidate, str): + if os.path.isfile(candidate): + return candidate + converted = dnnlib.submission.submit.convert_path(candidate) + if os.path.isfile(converted): + return converted + + pkls = list_network_pkls(run_id_or_run_dir_or_network_pkl) + if len(pkls) >= 1 and snapshot_or_network_pkl is None: + return pkls[-1] + + for pkl in pkls: + try: + name = os.path.splitext(os.path.basename(pkl))[0] + number = int(name.split('-')[-1]) + if number == snapshot_or_network_pkl: + return pkl + except ValueError: pass + except IndexError: pass + raise IOError('Cannot locate network pkl for snapshot', snapshot_or_network_pkl) + +def get_id_string_for_network_pkl(network_pkl): + p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') + return '-'.join(p[max(len(p) - 2, 0):]) + +#---------------------------------------------------------------------------- +# Loading data from previous training runs. + +def load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): + return load_pkl(locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl)) + +def parse_config_for_previous_run(run_id): + run_dir = locate_run_dir(run_id) + + # Parse config.txt. + cfg = defaultdict(dict) + with open(os.path.join(run_dir, 'config.txt'), 'rt') as f: + for line in f: + line = re.sub(r"^{?\s*'(\w+)':\s*{(.*)(},|}})$", r"\1 = {\2}", line.strip()) + if line.startswith('dataset =') or line.startswith('train ='): + exec(line, cfg, cfg) # pylint: disable=exec-used + + # Handle legacy options. + if 'file_pattern' in cfg['dataset']: + cfg['dataset']['tfrecord_dir'] = cfg['dataset'].pop('file_pattern').replace('-r??.tfrecords', '') + if 'mirror_augment' in cfg['dataset']: + cfg['train']['mirror_augment'] = cfg['dataset'].pop('mirror_augment') + if 'max_labels' in cfg['dataset']: + v = cfg['dataset'].pop('max_labels') + if v is None: v = 0 + if v == 'all': v = 'full' + cfg['dataset']['max_label_size'] = v + if 'max_images' in cfg['dataset']: + cfg['dataset'].pop('max_images') + return cfg + +def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment + cfg = parse_config_for_previous_run(run_id) + cfg['dataset'].update(kwargs) + dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **cfg['dataset']) + mirror_augment = cfg['train'].get('mirror_augment', False) + return dataset_obj, mirror_augment + +def apply_mirror_augment(minibatch): + mask = np.random.rand(minibatch.shape[0]) < 0.5 + minibatch = np.array(minibatch) + minibatch[mask] = minibatch[mask, :, :, ::-1] + return minibatch + +#---------------------------------------------------------------------------- +# Size and contents of the image snapshot grids that are exported +# periodically during training. + +def setup_snapshot_image_grid(G, training_set, + size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. + layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. + + # Select size. + gw = 1; gh = 1 + if size == '1080p': + gw = np.clip(1920 // G.output_shape[3], 3, 32) + gh = np.clip(1080 // G.output_shape[2], 2, 32) + if size == '4k': + gw = np.clip(3840 // G.output_shape[3], 7, 32) + gh = np.clip(2160 // G.output_shape[2], 4, 32) + + # Initialize data arrays. + reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) + labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) + # latents = np.random.randn(gw * gh, *G.input_shape[1:]) + + # Random layout. + if layout == 'random': + reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) + + # Class-conditional layouts. + class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) + if layout in class_layouts: + bw, bh = class_layouts[layout] + nw = (gw - 1) // bw + 1 + nh = (gh - 1) // bh + 1 + blocks = [[] for _i in range(nw * nh)] + for _iter in range(1000000): + real, label = training_set.get_minibatch_np(1) + idx = np.argmax(label[0]) + while idx < len(blocks) and len(blocks[idx]) >= bw * bh: + idx += training_set.label_size + if idx < len(blocks): + blocks[idx].append((real, label)) + if all(len(block) >= bw * bh for block in blocks): + break + for i, block in enumerate(blocks): + for j, (real, label) in enumerate(block): + x = (i % nw) * bw + j % bw + y = (i // nw) * bh + j // bw + if x < gw and y < gh: + reals[x + y * gw] = real[0] + labels[x + y * gw] = label[0] + + return (gw, gh), reals, labels + +#---------------------------------------------------------------------------- diff --git a/training/networks_id.py b/training/networks_id.py new file mode 100644 index 0000000..4d08e5c --- /dev/null +++ b/training/networks_id.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Face recognition network proposed by Schroff et al. 15, +# https://arxiv.org/abs/1503.03832, +# https://github.com/davidsandberg/facenet + +import tensorflow as tf +from training.inception_resnet_v1 import inception_resnet_v1 +slim = tf.contrib.slim + + +def Perceptual_Net(input_imgs): + #input_imgs: [Batchsize,H,W,C], 0-255, BGR image + #meanface: a mean face RGB image for normalization + + input_imgs = tf.cast(input_imgs,tf.float32) + input_imgs = tf.clip_by_value(input_imgs,0,255) + input_imgs = (input_imgs - 127.5)/128.0 + + #standard face-net backbone + batch_norm_params = { + 'decay': 0.995, + 'epsilon': 0.001, + 'updates_collections': None} + + + with slim.arg_scope([slim.conv2d, slim.fully_connected],weights_initializer=slim.initializers.xavier_initializer(), + weights_regularizer=slim.l2_regularizer(0.0), + normalizer_fn=slim.batch_norm, + normalizer_params=batch_norm_params): + feature_128,_ = inception_resnet_v1(input_imgs, bottleneck_layer_size=128, is_training=False, reuse=tf.AUTO_REUSE) + + #output the last FC layer feature(before classification) as identity feature + return feature_128 \ No newline at end of file diff --git a/training/networks_parser.py b/training/networks_parser.py new file mode 100644 index 0000000..33d123c --- /dev/null +++ b/training/networks_parser.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# A tensorflow implementation of face parsing network +# proposed by Lin et al. 19, +# https://arxiv.org/abs/1906.01342. +#-------------------------------------------------------------- +import tensorflow as tf +from scipy.io import loadmat,savemat +import os +import numpy as np +from training.parser_utils import * +from training.resnet_block import * + +def fpn(c1,c2,c3,c4,data_format = 'channels_first'): + with tf.variable_scope('c4'): + h = tf.shape(c4)[2] + w = tf.shape(c4)[3] + f4 = conv2d_fixed_padding(c4,256, 1, 1, data_format,use_bias = True) + f4 = tf.transpose(f4,perm=[0,2,3,1]) + f4 = tf.image.resize_images(f4,[2*h,2*w],align_corners = True) + f4 = tf.transpose(f4,perm=[0,3,1,2]) + + + with tf.variable_scope('c3'): + h = tf.shape(c3)[2] + w = tf.shape(c3)[3] + f3 = conv2d_fixed_padding(c3,256, 1, 1, data_format,use_bias = True) + f3 += f4 + f3 = tf.transpose(f3,perm=[0,2,3,1]) + f3 = tf.image.resize_images(f3,[2*h,2*w],align_corners = True) + f3 = tf.transpose(f3,perm=[0,3,1,2]) + + with tf.variable_scope('c2'): + h = tf.shape(c2)[2] + w = tf.shape(c2)[3] + f2 = conv2d_fixed_padding(c2,256, 1, 1, data_format,use_bias = True) + f2 += f3 + f2 = tf.transpose(f2,perm=[0,2,3,1]) + f2 = tf.image.resize_images(f2,[2*h,2*w],align_corners = True) + f2 = tf.transpose(f2,perm=[0,3,1,2]) + + with tf.variable_scope('c1'): + h = tf.shape(c1)[2] + w = tf.shape(c1)[3] + f1 = conv2d_fixed_padding(c1,256, 1, 1, data_format,use_bias = True) + f1 += f2 + + with tf.variable_scope('convlast'): + x = conv2d_fixed_padding(f1,256, 3, 1, data_format,use_bias = True) + + + return x + +def MaskNet(x,is_training = False,data_format = 'channels_first'): + with tf.variable_scope('neck'): + x = conv2d_fixed_padding(x,256, 3, 1, data_format,use_bias = True) + x = batch_norm_relu(x, is_training, data_format) + x = conv2d_fixed_padding(x,256, 3, 1, data_format,use_bias = True) + x = batch_norm_relu(x, is_training, data_format) + + with tf.variable_scope('convlast'): + x = conv2d_fixed_padding(x,3, 1, 1, data_format,use_bias = True) + x = tf.nn.softmax(x,axis = 1) + x = tf.transpose(x,perm=[0,2,3,1]) + x = tf.image.resize_images(x,[512,512],align_corners = True) + x = tf.transpose(x,perm=[0,3,1,2]) + + return x + + + +def FaceParser(inputs, data_format = 'channels_first',is_training = False): + with tf.variable_scope('resnet',reuse = tf.AUTO_REUSE): + with tf.variable_scope('block0'): + inputs = conv2d_fixed_padding( + inputs=inputs, filters=64, kernel_size=7, + strides=2, data_format=data_format) + + inputs = batch_norm_relu(inputs, is_training, data_format) + + inputs = tf.layers.max_pooling2d( + inputs=inputs, pool_size=3, + strides=2, padding='SAME', + data_format=data_format) + + with tf.variable_scope('block1'): + inputs = building_block(inputs, 64, is_training, None, 1, data_format) + c1 = inputs = building_block(inputs, 64, is_training, None, 1, data_format) + + with tf.variable_scope('block2'): + + c2 = inputs = block_layer(inputs, filters = 128, blocks = 2, strides = 2, training = is_training, + data_format = data_format) + + with tf.variable_scope('block3'): + + c3 = inputs = block_layer(inputs, filters = 256, blocks = 2, strides = 2, training = is_training, + data_format = data_format) + + with tf.variable_scope('block4'): + + c4 = inputs = block_layer(inputs, filters = 512, blocks = 2, strides = 2, training = is_training, + data_format = data_format) + + with tf.variable_scope('fpn',reuse = tf.AUTO_REUSE): + + x = fpn(c1,c2,c3,c4) + + with tf.variable_scope('MaskNet',reuse = tf.AUTO_REUSE): + x = MaskNet(x) + + return x + +# Get hair segmentation from input image +def Parsing(inputs,lm): + lm = tf.stack([lm[:,:,0],256 - lm[:,:,1]],axis = 2) + lm5p = transfer_68to5(lm) + lm5p = tf.stop_gradient(lm5p) + + warp_inputs,tinv = preprocess_image_seg(inputs,lm5p) + warp_inputs = normalize_image(warp_inputs) + warp_inputs = tf.transpose(warp_inputs,perm=[0,3,1,2]) + + with tf.variable_scope('FaceParser'): + outputs = FaceParser(warp_inputs) + + outputs = tf.transpose(outputs,[0,2,3,1]) + ori_image = reverse_warp_and_distort(outputs,tinv) + ori_image = tf.transpose(ori_image,perm=[0,2,1,3]) # rotate hair segmentation + return ori_image \ No newline at end of file diff --git a/training/networks_recon.py b/training/networks_recon.py new file mode 100644 index 0000000..8ca34d9 --- /dev/null +++ b/training/networks_recon.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import tensorflow as tf +from tensorflow.contrib.slim.nets import resnet_v1 +slim = tf.contrib.slim + +# 3D face reconstruction network using resnet_v1_50 by Deng et al. 19, +# https://github.com/microsoft/Deep3DFaceReconstruction +#----------------------------------------------------------------------------------------------- + +def R_Net(inputs,is_training=True,reuse = None): + #input: [Batchsize,H,W,C], 0-255, BGR image + inputs = tf.cast(inputs,tf.float32) + # standard ResNet50 backbone (without the last classfication FC layer) + with slim.arg_scope(resnet_v1.resnet_arg_scope()): + net,end_points = resnet_v1.resnet_v1_50(inputs,is_training = is_training ,reuse = reuse) + + # Modified FC layer with 257 channels for reconstruction coefficients + net_id = slim.conv2d(net, 80, [1, 1], + activation_fn=None, + normalizer_fn=None, + weights_initializer = tf.zeros_initializer(), + scope='fc-id') + net_ex = slim.conv2d(net, 64, [1, 1], + activation_fn=None, + normalizer_fn=None, + weights_initializer = tf.zeros_initializer(), + scope='fc-ex') + net_tex = slim.conv2d(net, 80, [1, 1], + activation_fn=None, + normalizer_fn=None, + weights_initializer = tf.zeros_initializer(), + scope='fc-tex') + net_angles = slim.conv2d(net, 3, [1, 1], + activation_fn=None, + normalizer_fn=None, + weights_initializer = tf.zeros_initializer(), + scope='fc-angles') + net_gamma = slim.conv2d(net, 27, [1, 1], + activation_fn=None, + normalizer_fn=None, + weights_initializer = tf.zeros_initializer(), + scope='fc-gamma') + net_t_xy = slim.conv2d(net, 2, [1, 1], + activation_fn=None, + normalizer_fn=None, + weights_initializer = tf.zeros_initializer(), + scope='fc-XY') + net_t_z = slim.conv2d(net, 1, [1, 1], + activation_fn=None, + normalizer_fn=None, + weights_initializer = tf.zeros_initializer(), + scope='fc-Z') + + + net_id = tf.squeeze(net_id, [1,2], name='fc-id/squeezed') + net_ex = tf.squeeze(net_ex, [1,2], name='fc-ex/squeezed') + net_tex = tf.squeeze(net_tex, [1,2],name='fc-tex/squeezed') + net_angles = tf.squeeze(net_angles,[1,2], name='fc-angles/squeezed') + net_gamma = tf.squeeze(net_gamma,[1,2], name='fc-gamma/squeezed') + net_t_xy = tf.squeeze(net_t_xy,[1,2], name='fc-XY/squeezed') + net_t_z = tf.squeeze(net_t_z,[1,2], name='fc-Z/squeezed') + + net_ = tf.concat([net_id,net_tex,net_ex,net_angles,net_gamma,net_t_xy,net_t_z], axis = 1) + + return net_ \ No newline at end of file diff --git a/training/networks_stylegan.py b/training/networks_stylegan.py new file mode 100644 index 0000000..d1a245f --- /dev/null +++ b/training/networks_stylegan.py @@ -0,0 +1,701 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Network architectures used in the StyleGAN paper.""" + +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib + +# NOTE: Do not import any application-specific modules here! +# Specify all network parameters as kwargs. + +#---------------------------------------------------------------------------- +# Primitive ops for manipulating 4D activation tensors. +# The gradients of these are not necessary efficient or even meaningful. + +def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1): + assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:]) + assert isinstance(stride, int) and stride >= 1 + + # Finalize filter kernel. + f = np.array(f, dtype=np.float32) + if f.ndim == 1: + f = f[:, np.newaxis] * f[np.newaxis, :] + assert f.ndim == 2 + if normalize: + f /= np.sum(f) + if flip: + f = f[::-1, ::-1] + f = f[:, :, np.newaxis, np.newaxis] + f = np.tile(f, [1, 1, int(x.shape[1]), 1]) + + # No-op => early exit. + if f.shape == (1, 1) and f[0,0] == 1: + return x + + # Convolve using depthwise_conv2d. + orig_dtype = x.dtype + x = tf.cast(x, tf.float32) # tf.nn.depthwise_conv2d() doesn't support fp16 + f = tf.constant(f, dtype=x.dtype, name='filter') + strides = [1, 1, stride, stride] + x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW') + x = tf.cast(x, orig_dtype) + return x + +def _upscale2d(x, factor=2, gain=1): + assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:]) + assert isinstance(factor, int) and factor >= 1 + + # Apply gain. + if gain != 1: + x *= gain + + # No-op => early exit. + if factor == 1: + return x + + # Upscale using tf.tile(). + s = x.shape + x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) + x = tf.tile(x, [1, 1, 1, factor, 1, factor]) + x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) + return x + +def _downscale2d(x, factor=2, gain=1): + assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:]) + assert isinstance(factor, int) and factor >= 1 + + # 2x2, float32 => downscale using _blur2d(). + if factor == 2 and x.dtype == tf.float32: + f = [np.sqrt(gain) / factor] * factor + return _blur2d(x, f=f, normalize=False, stride=factor) + + # Apply gain. + if gain != 1: + x *= gain + + # No-op => early exit. + if factor == 1: + return x + + # Large factor => downscale using tf.nn.avg_pool(). + # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work. + ksize = [1, 1, factor, factor] + return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') + +#---------------------------------------------------------------------------- +# High-level ops for manipulating 4D activation tensors. +# The gradients of these are meant to be as efficient as possible. + +def blur2d(x, f=[1,2,1], normalize=True): + with tf.variable_scope('Blur2D'): + @tf.custom_gradient + def func(x): + y = _blur2d(x, f, normalize) + @tf.custom_gradient + def grad(dy): + dx = _blur2d(dy, f, normalize, flip=True) + return dx, lambda ddx: _blur2d(ddx, f, normalize) + return y, grad + return func(x) + +def upscale2d(x, factor=2): + with tf.variable_scope('Upscale2D'): + @tf.custom_gradient + def func(x): + y = _upscale2d(x, factor) + @tf.custom_gradient + def grad(dy): + dx = _downscale2d(dy, factor, gain=factor**2) + return dx, lambda ddx: _upscale2d(ddx, factor) + return y, grad + return func(x) + +def downscale2d(x, factor=2): + with tf.variable_scope('Downscale2D'): + @tf.custom_gradient + def func(x): + y = _downscale2d(x, factor) + @tf.custom_gradient + def grad(dy): + dx = _upscale2d(dy, factor, gain=1/factor**2) + return dx, lambda ddx: _downscale2d(ddx, factor) + return y, grad + return func(x) + +#---------------------------------------------------------------------------- +# Get/create weight tensor for a convolutional or fully-connected layer. + +def get_weight(shape, gain=np.sqrt(2), use_wscale=False, lrmul=1): + fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] + he_std = gain / np.sqrt(fan_in) # He init + + # Equalized learning rate and custom learning rate multiplier. + if use_wscale: + init_std = 1.0 / lrmul + runtime_coef = he_std * lrmul + else: + init_std = he_std / lrmul + runtime_coef = lrmul + + # Create variable. + init = tf.initializers.random_normal(0, init_std) + return tf.get_variable('weight', shape=shape, initializer=init) * runtime_coef + +#---------------------------------------------------------------------------- +# Fully-connected layer. + +def dense(x, fmaps, **kwargs): + if len(x.shape) > 2: + x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) + w = get_weight([x.shape[1].value, fmaps], **kwargs) + w = tf.cast(w, x.dtype) + return tf.matmul(x, w) + +#---------------------------------------------------------------------------- +# Convolutional layer. + +def conv2d(x, fmaps, kernel, **kwargs): + assert kernel >= 1 and kernel % 2 == 1 + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs) + w = tf.cast(w, x.dtype) + return tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW') + +#---------------------------------------------------------------------------- +# Fused convolution + scaling. +# Faster and uses less memory than performing the operations separately. + +def upscale2d_conv2d(x, fmaps, kernel, fused_scale='auto', **kwargs): + assert kernel >= 1 and kernel % 2 == 1 + assert fused_scale in [True, False, 'auto'] + if fused_scale == 'auto': + fused_scale = min(x.shape[2:]) * 2 >= 128 + + # Not fused => call the individual ops directly. + if not fused_scale: + return conv2d(upscale2d(x), fmaps, kernel, **kwargs) + + # Fused => perform both ops simultaneously using tf.nn.conv2d_transpose(). + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs) + w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in] + w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) + w = tf.cast(w, x.dtype) + os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2] + return tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW') + +def conv2d_downscale2d(x, fmaps, kernel, fused_scale='auto', **kwargs): + assert kernel >= 1 and kernel % 2 == 1 + assert fused_scale in [True, False, 'auto'] + if fused_scale == 'auto': + fused_scale = min(x.shape[2:]) >= 128 + + # Not fused => call the individual ops directly. + if not fused_scale: + return downscale2d(conv2d(x, fmaps, kernel, **kwargs)) + + # Fused => perform both ops simultaneously using tf.nn.conv2d(). + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs) + w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25 + w = tf.cast(w, x.dtype) + return tf.nn.conv2d(x, w, strides=[1,1,2,2], padding='SAME', data_format='NCHW') + +#---------------------------------------------------------------------------- +# Apply bias to the given activation tensor. + +def apply_bias(x, lrmul=1): + b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros()) * lrmul + b = tf.cast(b, x.dtype) + if len(x.shape) == 2: + return x + b + return x + tf.reshape(b, [1, -1, 1, 1]) + +#---------------------------------------------------------------------------- +# Leaky ReLU activation. More efficient than tf.nn.leaky_relu() and supports FP16. + +def leaky_relu(x, alpha=0.2): + with tf.variable_scope('LeakyReLU'): + alpha = tf.constant(alpha, dtype=x.dtype, name='alpha') + @tf.custom_gradient + def func(x): + y = tf.maximum(x, x * alpha) + @tf.custom_gradient + def grad(dy): + dx = tf.where(y >= 0, dy, dy * alpha) + return dx, lambda ddx: tf.where(y >= 0, ddx, ddx * alpha) + return y, grad + return func(x) + +#---------------------------------------------------------------------------- +# Pixelwise feature vector normalization. + +def pixel_norm(x, epsilon=1e-8): + with tf.variable_scope('PixelNorm'): + epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon') + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon) + +#---------------------------------------------------------------------------- +# Instance normalization. + +def instance_norm(x, epsilon=1e-8): + assert len(x.shape) == 4 # NCHW + with tf.variable_scope('InstanceNorm'): + orig_dtype = x.dtype + x = tf.cast(x, tf.float32) + x -= tf.reduce_mean(x, axis=[2,3], keepdims=True) + epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon') + x *= tf.rsqrt(tf.reduce_mean(tf.square(x), axis=[2,3], keepdims=True) + epsilon) + x = tf.cast(x, orig_dtype) + return x + +#---------------------------------------------------------------------------- +# Positional normalization +def position_norm(x,epsilon=1e-8): + assert len(x.shape) == 4 + with tf.variable_scope('PositionNorm'): + orig_dtype = x.dtype + x = tf.cast(x,tf.float32) + x -= tf.reduce_mean(x, axis=[1], keepdims=True) + epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon') + x *= tf.rsqrt(tf.reduce_mean(tf.square(x), axis=[1], keepdims=True) + epsilon) + x = tf.cast(x, orig_dtype) + return x + +#---------------------------------------------------------------------------- +# Style modulation. + +def style_mod(x, dlatent, **kwargs): + with tf.variable_scope('StyleMod'): + style = apply_bias(dense(dlatent, fmaps=x.shape[1]*2, gain=1, **kwargs)) + style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2)) + return x * (style[:,0] + 1) + style[:,1] + +#---------------------------------------------------------------------------- +# Noise input. + +def apply_noise(x, noise_var=None, randomize_noise=True): + assert len(x.shape) == 4 # NCHW + with tf.variable_scope('Noise'): + if noise_var is None or randomize_noise: + noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) + else: + noise = tf.cast(noise_var, x.dtype) + weight = tf.get_variable('weight', shape=[x.shape[1].value], initializer=tf.initializers.zeros()) + return x + noise * tf.reshape(tf.cast(weight, x.dtype), [1, -1, 1, 1]) + +#---------------------------------------------------------------------------- +# Minibatch standard deviation. + +def minibatch_stddev_layer(x, group_size=4, num_new_features=1): + with tf.variable_scope('MinibatchStddev'): + group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. + s = x.shape # [NCHW] Input shape. + y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c. + y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32. + y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group. + y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group. + y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group. + y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels. + y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups + y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type. + y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels. + return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. + +#---------------------------------------------------------------------------- +# Style-based generator used in the StyleGAN paper. +# Composed of two sub-networks (G_mapping and G_synthesis) that are defined below. + +def G_style( + latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. + labels_in, # Second input: Conditioning labels [minibatch, label_size]. + truncation_psi = 0.7, # Style strength multiplier for the truncation trick. None = disable. + truncation_cutoff = 8, # Number of layers for which to apply the truncation trick. None = disable. + truncation_psi_val = None, # Value for truncation_psi to use during validation. + truncation_cutoff_val = None, # Value for truncation_cutoff to use during validation. + dlatent_avg_beta = 0.995, # Decay for tracking the moving average of W during training. None = disable. + style_mixing_prob = 0.9, # Probability of mixing styles during training. None = disable. + is_training = False, # Network is under training? Enables and disables specific features. + is_validation = False, # Network is under validation? Chooses which value to use for truncation_psi. + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + components = dnnlib.EasyDict(), # Container for sub-networks. Retained between calls. + **kwargs): # Arguments for sub-networks (G_mapping and G_synthesis). + + # Validate arguments. + assert not is_training or not is_validation + assert isinstance(components, dnnlib.EasyDict) + if is_validation: + truncation_psi = truncation_psi_val + truncation_cutoff = truncation_cutoff_val + if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): + truncation_psi = None + if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): + truncation_cutoff = None + # if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): + # dlatent_avg_beta = None + if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): + style_mixing_prob = None + + # Setup components. + if 'synthesis' not in components: + components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs) + num_layers = components.synthesis.input_shape[1] + dlatent_size = components.synthesis.input_shape[2] + if 'mapping' not in components: + components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) + + # Setup variables. + lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) + dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) + + # Evaluate mapping network. + dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) + + # Update moving average of W. + # if dlatent_avg_beta is not None: + if not is_validation: + with tf.variable_scope('DlatentAvg'): + batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) + update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) + with tf.control_dependencies([update_op]): + dlatents = tf.identity(dlatents) + + + #--------------------------------------------------------------- + # Modified by Deng et al. + + # Perform style mixing regularization. + # if style_mixing_prob is not None: + # with tf.name_scope('StyleMix'): + # latents2 = tf.random_normal(tf.shape(latents_in)) + # dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs) + # layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] + # cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 + # mixing_cutoff = tf.cond( + # tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, + # lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), + # lambda: cur_layers) + # dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) + + # Apply truncation trick. + # if truncation_psi is not None and truncation_cutoff is not None: + # with tf.variable_scope('Truncation'): + # layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] + # ones = np.ones(layer_idx.shape, dtype=np.float32) + # coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) + # dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) + + #--------------------------------------------------------------- + + # Evaluate synthesis network. + with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]): + images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs) + return tf.identity(images_out, name='images_out') + +#---------------------------------------------------------------------------- +# Mapping network used in the StyleGAN paper. + +def G_mapping( + latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. + labels_in, # Second input: Conditioning labels [minibatch, label_size]. + latent_size = 254+32, # Latent vector (Z) dimensionality. + label_size = 0, # Label dimensionality, 0 if no labels. + dlatent_size = 512, # Disentangled latent (W) dimensionality. + dlatent_broadcast = None, # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size]. + mapping_layers = 8, # Number of mapping layers. + mapping_fmaps = 512, # Number of activations in the mapping layers. + mapping_lrmul = 0.01, # Learning rate multiplier for the mapping layers. + mapping_nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu'. + use_wscale = True, # Enable equalized learning rate? + normalize_latents = False, # Normalize latent vectors (Z) before feeding them to the mapping layers? + dtype = 'float32', # Data type to use for activations and outputs. + **_kwargs): # Ignore unrecognized keyword args. + + act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[mapping_nonlinearity] + + # Inputs. + latents_in.set_shape([None, latent_size]) + labels_in.set_shape([None, label_size]) + latents_in = tf.cast(latents_in, dtype) + labels_in = tf.cast(labels_in, dtype) + x = latents_in + + # Embed labels and concatenate them with latents. + if label_size: + with tf.variable_scope('LabelConcat'): + w = tf.get_variable('weight', shape=[label_size, latent_size], initializer=tf.initializers.random_normal()) + y = tf.matmul(labels_in, tf.cast(w, dtype)) + x = tf.concat([x, y], axis=1) + + #--------------------------------------------------------------- + # Modified by Deng et al. + + # Normalize latents. + # if normalize_latents: + # x = pixel_norm(x) + + #--------------------------------------------------------------- + + # Mapping layers. + for layer_idx in range(mapping_layers): + with tf.variable_scope('Dense%d' % layer_idx): + fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps + x = dense(x, fmaps=fmaps, gain=gain, use_wscale=use_wscale, lrmul=mapping_lrmul) + x = apply_bias(x, lrmul=mapping_lrmul) + x = act(x) + + # Broadcast. + if dlatent_broadcast is not None: + with tf.variable_scope('Broadcast'): + x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1]) + + # Output. + assert x.dtype == tf.as_dtype(dtype) + return tf.identity(x, name='dlatents_out') + +#---------------------------------------------------------------------------- +# Synthesis network used in the StyleGAN paper. + +def G_synthesis( + dlatents_in, # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size]. + dlatent_size = 512, # Disentangled latent (W) dimensionality. + num_channels = 3, # Number of output color channels. + resolution = 1024, # Output resolution. + fmap_base = 8192, # Overall multiplier for the number of feature maps. + fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. + fmap_max = 512, # Maximum number of feature maps in any layer. + use_styles = True, # Enable style inputs? + const_input_layer = True, # First layer is a learned constant? + use_noise = True, # Enable noise inputs? + randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. + nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu' + use_wscale = True, # Enable equalized learning rate? + use_pixel_norm = False, # Enable pixelwise feature vector normalization? + use_instance_norm = True, # Enable instance normalization? + dtype = 'float32', # Data type to use for activations and outputs. + fused_scale = 'auto', # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically. + blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering. + structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically. + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + force_clean_graph = False, # True = construct a clean graph that looks nice in TensorBoard, False = default behavior. + **_kwargs): # Ignore unrecognized keyword args. + + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2**resolution_log2 and resolution >= 4 + def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + def blur(x): return blur2d(x, blur_filter) if blur_filter else x + if is_template_graph: force_clean_graph = True + if force_clean_graph: randomize_noise = False + if structure == 'auto': structure = 'linear' if force_clean_graph else 'recursive' + act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity] + num_layers = resolution_log2 * 2 - 2 + num_styles = num_layers if use_styles else 1 + images_out = None + + # Primary inputs. + dlatents_in.set_shape([None, num_styles, dlatent_size]) + dlatents_in = tf.cast(dlatents_in, dtype) + lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype) + + # Noise inputs. + noise_inputs = [] + if use_noise: + for layer_idx in range(num_layers): + res = layer_idx // 2 + 2 + shape = [1, use_noise, 2**res, 2**res] + noise_inputs.append(tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False)) + + # Things to do at the end of each layer. + def layer_epilogue(x, layer_idx): + if use_noise: + x = apply_noise(x, noise_inputs[layer_idx], randomize_noise=randomize_noise) + x = apply_bias(x) + x = act(x) + if use_pixel_norm: + x = pixel_norm(x) + if use_instance_norm: + x = instance_norm(x) + if use_styles: + x = style_mod(x, dlatents_in[:, layer_idx], use_wscale=use_wscale) + return x + + # Early layers. + with tf.variable_scope('4x4'): + if const_input_layer: + with tf.variable_scope('Const'): + x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.ones()) + x = layer_epilogue(tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1]), 0) + else: + with tf.variable_scope('Dense'): + x = dense(dlatents_in[:, 0], fmaps=nf(1)*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN + x = layer_epilogue(tf.reshape(x, [-1, nf(1), 4, 4]), 0) + with tf.variable_scope('Conv'): + x = layer_epilogue(conv2d(x, fmaps=nf(1), kernel=3, gain=gain, use_wscale=use_wscale), 1) + + # Building blocks for remaining layers. + def block(res, x): # res = 3..resolution_log2 + with tf.variable_scope('%dx%d' % (2**res, 2**res)): + with tf.variable_scope('Conv0_up'): + x = layer_epilogue(blur(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale)), res*2-4) + with tf.variable_scope('Conv1'): + x = layer_epilogue(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale), res*2-3) + return x + def torgb(res, x): # res = 2..resolution_log2 + lod = resolution_log2 - res + with tf.variable_scope('ToRGB_lod%d' % lod): + return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale)) + + # Fixed structure: simple and efficient, but does not support progressive growing. + if structure == 'fixed': + for res in range(3, resolution_log2 + 1): + x = block(res, x) + images_out = torgb(resolution_log2, x) + + # Linear structure: simple but inefficient. + if structure == 'linear': + images_out = torgb(2, x) + for res in range(3, resolution_log2 + 1): + lod = resolution_log2 - res + x = block(res, x) + img = torgb(res, x) + images_out = upscale2d(images_out) + with tf.variable_scope('Grow_lod%d' % lod): + images_out = tflib.lerp_clip(img, images_out, lod_in - lod) + + # Recursive structure: complex but efficient. + if structure == 'recursive': + def cset(cur_lambda, new_cond, new_lambda): + return lambda: tf.cond(new_cond, new_lambda, cur_lambda) + def grow(x, res, lod): + y = block(res, x) + img = lambda: upscale2d(torgb(res, y), 2**lod) + img = cset(img, (lod_in > lod), lambda: upscale2d(tflib.lerp(torgb(res, y), upscale2d(torgb(res - 1, x)), lod_in - lod), 2**lod)) + if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1)) + return img() + images_out = grow(x, 3, resolution_log2 - 3) + + assert images_out.dtype == tf.as_dtype(dtype) + return tf.identity(images_out, name='images_out') + +#---------------------------------------------------------------------------- +# Discriminator used in the StyleGAN paper. + +def D_basic( + images_in, # First input: Images [minibatch, channel, height, width]. + labels_in, # Second input: Labels [minibatch, label_size]. + num_channels = 1, # Number of input color channels. Overridden based on dataset. + resolution = 32, # Input resolution. Overridden based on dataset. + label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. + fmap_base = 8192, # Overall multiplier for the number of feature maps. + fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. + fmap_max = 512, # Maximum number of feature maps in any layer. + nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', + use_wscale = True, # Enable equalized learning rate? + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable. + mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer. + dtype = 'float32', # Data type to use for activations and outputs. + fused_scale = 'auto', # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically. + blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering. + structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically. + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + **_kwargs): # Ignore unrecognized keyword args. + + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2**resolution_log2 and resolution >= 4 + def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + def blur(x): return blur2d(x, blur_filter) if blur_filter else x + if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive' + act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity] + + images_in.set_shape([None, num_channels, resolution, resolution]) + labels_in.set_shape([None, label_size]) + images_in = tf.cast(images_in, dtype) + labels_in = tf.cast(labels_in, dtype) + lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) + scores_out = None + + # Building blocks. + def fromrgb(x, res): # res = 2..resolution_log2 + with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)): + return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, gain=gain, use_wscale=use_wscale))) + def block(x, res): # res = 2..resolution_log2 + with tf.variable_scope('%dx%d' % (2**res, 2**res)): + if res >= 3: # 8x8 and up + with tf.variable_scope('Conv0'): + x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale))) + with tf.variable_scope('Conv1_down'): + x = act(apply_bias(conv2d_downscale2d(blur(x), fmaps=nf(res-2), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale))) + else: # 4x4 + if mbstd_group_size > 1: + x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) + with tf.variable_scope('Conv'): + x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale))) + with tf.variable_scope('Dense0'): + x = act(apply_bias(dense(x, fmaps=nf(res-2), gain=gain, use_wscale=use_wscale))) + with tf.variable_scope('Dense1'): + x = apply_bias(dense(x, fmaps=max(label_size, 1), gain=1, use_wscale=use_wscale)) + return x + + # Fixed structure: simple and efficient, but does not support progressive growing. + if structure == 'fixed': + x = fromrgb(images_in, resolution_log2) + for res in range(resolution_log2, 2, -1): + x = block(x, res) + scores_out = block(x, 2) + + # Linear structure: simple but inefficient. + if structure == 'linear': + img = images_in + x = fromrgb(img, resolution_log2) + for res in range(resolution_log2, 2, -1): + lod = resolution_log2 - res + x = block(x, res) + img = downscale2d(img) + y = fromrgb(img, res - 1) + with tf.variable_scope('Grow_lod%d' % lod): + x = tflib.lerp_clip(x, y, lod_in - lod) + scores_out = block(x, 2) + + # Recursive structure: complex but efficient. + if structure == 'recursive': + def cset(cur_lambda, new_cond, new_lambda): + return lambda: tf.cond(new_cond, new_lambda, cur_lambda) + def grow(res, lod): + x = lambda: fromrgb(downscale2d(images_in, 2**lod), res) + if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1)) + x = block(x(), res); y = lambda: x + if res > 2: y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod)) + return y() + scores_out = grow(2, resolution_log2 - 2) + + # Label conditioning from "Which Training Methods for GANs do actually Converge?" + if label_size: + with tf.variable_scope('LabelSwitch'): + scores_out = tf.reduce_sum(scores_out * labels_in, axis=1, keepdims=True) + + assert scores_out.dtype == tf.as_dtype(dtype) + scores_out = tf.identity(scores_out, name='scores_out') + return scores_out + +#---------------------------------------------------------------------------- +# Modified by Deng et al. + +# Mapping z space variable to lambda space variable +def CoeffDecoder(z,ch_depth = 3, ch_dim = 512, coeff_length = 128): + with tf.variable_scope('stage1'): + with tf.variable_scope('decoder'): + y = z + for i in range(ch_depth): + y = tf.layers.dense(y, ch_dim, tf.nn.relu, name='fc'+str(i)) + + x_hat = tf.layers.dense(y, coeff_length, name='x_hat') + x_hat = tf.stop_gradient(x_hat) + + return x_hat +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/training/parser_utils.py b/training/parser_utils.py new file mode 100644 index 0000000..d266dbc --- /dev/null +++ b/training/parser_utils.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Face parsing network proposed by Lin et al. 19, +# https://arxiv.org/abs/1906.01342, +# transfered to tensorflow version. +import tensorflow as tf +from scipy.io import loadmat +import cv2 +import os +import numpy as np + +def transfer_68to5(points): + # print(points) + p1 = tf.reduce_mean(points[:,36:42,:],1) + p2 = tf.reduce_mean(points[:,42:48,:],1) + p3 = points[:,30,:] + p4 = points[:,48,:] + p5 = points[:,54,:] + + p = tf.stack([p1,p2,p3,p4,p5],axis = 1) + + return p + +def standard_face_pts_512(): + pts = tf.constant([ + 196.0, 226.0, + 316.0, 226.0, + 256.0, 286.0, + 220.0, 360.4, + 292.0, 360.4]) + + pts = tf.reshape(pts,[5,2]) + + return pts + +def normalize_image(image): + _mean = tf.constant([[[[0.485, 0.456, 0.406]]]]) # rgb + _std = tf.constant([[[[0.229, 0.224, 0.225]]]]) + return (image / 255.0 - _mean) / _std + +def affine_transform(points,std_points,batchsize): + + # batchsize = points.shape[0] + p_num = points.shape[1] + x = std_points[:,:,0] + y = std_points[:,:,1] + + u = points[:,:,0] + v = points[:,:,1] + + X1 = tf.stack([x,y,tf.ones([batchsize,p_num]),tf.zeros([batchsize,p_num])],axis = 2) + X2 = tf.stack([y,-x,tf.zeros([batchsize,p_num]),tf.ones([batchsize,p_num])],axis = 2) + X = tf.concat([X1,X2],axis = 1) + + U = tf.expand_dims(tf.concat([u,v],axis = 1),2) + + r = tf.squeeze(tf.matrix_solve_ls(X,U),[2]) + sc = r[:,0] + ss = r[:,1] + tx = r[:,2] + ty = r[:,3] + + transform = tf.stack([sc,ss,tx,-ss,sc,ty,tf.zeros([batchsize]),tf.zeros([batchsize])],axis = 1) + t = tf.stack([sc,-ss,tf.zeros([batchsize]),ss,sc,tf.zeros([batchsize]),tx,ty,tf.ones([batchsize])],axis = 1) + t = tf.reshape(t,[-1,3,3]) + t = t + tf.reshape(tf.eye(3),[-1,3,3])*1e-5 + tinv = tf.matrix_inverse(t) + + return t,tinv + +# similarity transformation for images +def similarity_transform(points,batchsize): + + std_points = standard_face_pts_512() + std_points = tf.tile(tf.expand_dims(tf.reshape(std_points,[5,2]),0),[batchsize,1,1]) + + t,tinv = affine_transform(points,std_points,batchsize) + + return t,tinv + +def warp_and_distort(image,transform_matrix_inv,batchsize): + yy = loadmat('./training/pretrained_weights/parsing_net/yy.mat')['grid'] + xx = loadmat('./training/pretrained_weights/parsing_net/xx.mat')['grid'] + yy = tf.constant(yy) + xx = tf.constant(xx) + yy = tf.tile(tf.expand_dims(yy,0),[batchsize,1,1]) + xx = tf.tile(tf.expand_dims(xx,0),[batchsize,1,1]) + + yy = tf.reshape(yy,[-1,512*512]) + xx = tf.reshape(xx,[-1,512*512]) + xxyy_one = tf.stack([xx,yy,tf.ones_like(xx)], axis = 1) #batchx3x(h*w) + transform_matrix_inv = tf.transpose(transform_matrix_inv,perm=[0,2,1]) + xxyy_one = tf.matmul(transform_matrix_inv,xxyy_one) + + xx = tf.reshape(xxyy_one[:,0,:]/xxyy_one[:,2,:], [-1,512,512]) + yy = tf.reshape(xxyy_one[:,1,:]/xxyy_one[:,2,:], [-1,512,512]) + + + warp_image = tf.contrib.resampler.resampler(image,tf.stack([xx,yy],axis = 3)) + + return warp_image + +def preprocess_image_seg(image,lm5p): + batchsize = 1 + t,tinv = similarity_transform(lm5p,batchsize) + warp_image = warp_and_distort(image,t,batchsize) + return warp_image,tinv + +def _meshgrid(h,w): + yy, xx = tf.meshgrid(np.arange(0,h,dtype = np.float32),np.arange(0,w, dtype = np.float32)) + return yy,xx + +def _safe_arctanh(x): + x = tf.clip_by_value(x,-0.999,0.999) + x = tf.math.atanh(x) + return x + +def _distort(yy,xx,h,w,src_h,src_w,rescale = 1.0,distort_lambda = 1.0): + + def _non_linear(a): + nl_part1 = tf.cast(a > (1.0 - distort_lambda),tf.float32) + nl_part2 = tf.cast(a < (-1.0 + distort_lambda),tf.float32) + nl_part3 = tf.cast(a == (-1.0 + distort_lambda),tf.float32) + + a_part1 = _safe_arctanh((a - 1.0 + distort_lambda)/distort_lambda)*distort_lambda + 1.0 - distort_lambda + a_part2 = _safe_arctanh((a + 1.0 - distort_lambda)/distort_lambda)*distort_lambda - 1.0 + distort_lambda + + a = a_part1*nl_part1 + a_part2*nl_part2 + a*nl_part3 + return a + + yy = (yy / (h/2.0) - 1.0)*rescale + yy = (_non_linear(yy) + 1.0) * src_h / 2.0 + xx = (xx / (w/2.0) -1.0)*rescale + xx = (_non_linear(xx) + 1.0) * src_w / 2.0 + + return yy,xx + +def _undistort(yy,xx,h,w,src_h,src_w,rescale = 1.0,distort_lambda = 1.0): + + def _non_linear(a): + nl_part1 = tf.cast(a > (1.0 - distort_lambda),tf.float32) + nl_part2 = tf.cast(a < (-1.0 + distort_lambda),tf.float32) + nl_part3 = tf.cast(a == (-1.0 + distort_lambda),tf.float32) + + a_part1 = tf.math.tanh((a - 1.0 + distort_lambda)/distort_lambda)*distort_lambda + 1.0 - distort_lambda + a_part2 = tf.math.tanh((a + 1.0 - distort_lambda)/distort_lambda)*distort_lambda - 1.0 + distort_lambda + + a = a_part1*nl_part1 + a_part2*nl_part2 + a*nl_part3 + return a + + yy = _non_linear(yy / (h/2.0) -1.0) + yy = (yy / rescale + 1.0) *src_h / 2.0 + xx = _non_linear(xx / (w/2.0) -1.0) + xx = (xx / rescale + 1.0) *src_w / 2.0 + + return yy,xx + +def reverse_warp_and_distort(image,transform_matrix): + batchsize = 1 + + yy,xx = _meshgrid(256,256) + + + yy = tf.tile(tf.expand_dims(yy,0),[batchsize,1,1]) + xx = tf.tile(tf.expand_dims(xx,0),[batchsize,1,1]) + yy = tf.reshape(yy,[-1,256*256]) + xx = tf.reshape(xx,[-1,256*256]) + + xxyy_one = tf.stack([xx,yy,tf.ones_like(xx)], axis = 1) #batchx3x(h*w) + transform_matrix = tf.transpose(transform_matrix,perm=[0,2,1]) + xxyy_one = tf.matmul(transform_matrix,xxyy_one) + + xx = tf.reshape(xxyy_one[:,0,:]/xxyy_one[:,2,:], [-1,256,256]) + yy = tf.reshape(xxyy_one[:,1,:]/xxyy_one[:,2,:], [-1,256,256]) + + yy, xx = _undistort(yy,xx,512,512,512,512) + + warp_image = tf.contrib.resampler.resampler(image,tf.stack([xx,yy],axis = 3)) + + return warp_image \ No newline at end of file diff --git a/training/pretrained_weights/parsing_net/xx.mat b/training/pretrained_weights/parsing_net/xx.mat new file mode 100644 index 0000000..68a1391 Binary files /dev/null and b/training/pretrained_weights/parsing_net/xx.mat differ diff --git a/training/pretrained_weights/parsing_net/yy.mat b/training/pretrained_weights/parsing_net/yy.mat new file mode 100644 index 0000000..91df168 Binary files /dev/null and b/training/pretrained_weights/parsing_net/yy.mat differ diff --git a/training/resnet_block.py b/training/resnet_block.py new file mode 100644 index 0000000..d785813 --- /dev/null +++ b/training/resnet_block.py @@ -0,0 +1,166 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow as tf + +_BATCH_NORM_DECAY = 0.997 +_BATCH_NORM_EPSILON = 1e-5 + +def batch_norm_relu(inputs, training, data_format): + """Performs a batch normalization followed by a ReLU.""" + # We set fused=True for a significant performance boost. See + # https://www.tensorflow.org/performance/performance_guide#common_fused_ops + inputs = tf.layers.batch_normalization( + inputs=inputs, axis=1 if data_format == 'channels_first' else 3, + momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True, + scale=True, training=training, fused=True) + inputs = tf.nn.relu(inputs) + return inputs + + +def batch_norm(inputs, training, data_format): + """Performs a batch normalization followed by a ReLU.""" + # We set fused=True for a significant performance boost. See + # https://www.tensorflow.org/performance/performance_guide#common_fused_ops + inputs = tf.layers.batch_normalization( + inputs=inputs, axis=1 if data_format == 'channels_first' else 3, + momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True, + scale=True, training=training, fused=True) + return inputs + + +def fixed_padding(inputs, kernel_size, data_format): + """Pads the input along the spatial dimensions independently of input size. + + Args: + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + kernel_size: The kernel to be used in the conv2d or max_pool2d operation. + Should be a positive integer. + data_format: The input format ('channels_last' or 'channels_first'). + + Returns: + A tensor with the same format as the input with the data either intact + (if kernel_size == 1) or padded (if kernel_size > 1). + """ + pad_total = kernel_size - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + + if data_format == 'channels_first': + padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], + [pad_beg, pad_end], [pad_beg, pad_end]]) + else: + padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], + [pad_beg, pad_end], [0, 0]]) + return padded_inputs + + +def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format,use_bias = False): + """Strided 2-D convolution with explicit padding.""" + # The padding is consistent and is based only on `kernel_size`, not on the + # dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). + if strides > 1: + inputs = fixed_padding(inputs, kernel_size, data_format) + + return tf.layers.conv2d( + inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, + padding=('SAME' if strides == 1 else 'VALID'), use_bias=use_bias, + kernel_initializer=tf.variance_scaling_initializer(), + data_format=data_format) + + + +def building_block(inputs, filters, training, projection_shortcut, strides, + data_format): + """Standard building block for residual networks with BN before convolutions. + + Args: + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + filters: The number of filters for the convolutions. + training: A Boolean for whether the model is in training or inference + mode. Needed for batch normalization. + projection_shortcut: The function to use for projection shortcuts + (typically a 1x1 convolution when downsampling the input). + strides: The block's stride. If greater than 1, this block will ultimately + downsample the input. + data_format: The input format ('channels_last' or 'channels_first'). + + Returns: + The output tensor of the block. + """ + shortcut = inputs + # inputs = batch_norm_relu(inputs, training, data_format) + + # The projection shortcut should come after the first batch norm and ReLU + # since it performs a 1x1 convolution. + if projection_shortcut is not None: + shortcut = projection_shortcut(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, filters=filters, kernel_size=3, strides=strides, + data_format=data_format) + + inputs = batch_norm_relu(inputs, training, data_format) + inputs = conv2d_fixed_padding( + inputs=inputs, filters=filters, kernel_size=3, strides=1, + data_format=data_format) + + inputs = batch_norm(inputs, training, data_format) + + return tf.nn.relu(inputs + shortcut) + + +def block_layer(inputs, filters, blocks, strides, training, + data_format): + """Creates one layer of blocks for the ResNet model. + + Args: + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + filters: The number of filters for the first convolution of the layer. + block_fn: The block to use within the model, either `building_block` or + `bottleneck_block`. + blocks: The number of blocks contained in the layer. + strides: The stride to use for the first convolution of the layer. If + greater than 1, this layer will ultimately downsample the input. + training: Either True or False, whether we are currently training the + model. Needed for batch norm. + name: A string name for the tensor output of the block layer. + data_format: The input format ('channels_last' or 'channels_first'). + + Returns: + The output tensor of the block layer. + """ + # Bottleneck blocks end with 4x the number of filters as they start with + filters_out = filters + + def projection_shortcut(inputs): + with tf.variable_scope('downsample'): + inputs = conv2d_fixed_padding( + inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, + data_format=data_format) + inputs = batch_norm(inputs,training = training, data_format = data_format) + + return inputs + + # Only the first block per block_layer uses projection_shortcut and strides + inputs = building_block(inputs, filters, training, projection_shortcut, strides, + data_format) + + for _ in range(1, blocks): + inputs = building_block(inputs, filters, training, None, 1, data_format) + + return inputs \ No newline at end of file diff --git a/training/training_loop.py b/training/training_loop.py new file mode 100644 index 0000000..0e38fcf --- /dev/null +++ b/training/training_loop.py @@ -0,0 +1,301 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Main training script.""" + +import os +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib +import cv2 +from dnnlib.tflib.autosummary import autosummary + +import config +import train +from training import dataset +from training import misc +from metrics import metric_base +#--------------------------------------------------------------- +# Modified by Deng et al. +from renderer.face_decoder import Face3D +from training.training_utils import * +#--------------------------------------------------------------- + + +#---------------------------------------------------------------------------- +# Evaluate time-varying training parameters. + +def training_schedule( + cur_nimg, + training_set, + num_gpus, + lod_initial_resolution = 4, # Image resolution used at the beginning. + lod_training_kimg = 600, # Thousands of real images to show before doubling the resolution. + lod_transition_kimg = 600, # Thousands of real images to show when fading in new layers. + minibatch_base = 16, # Maximum minibatch size, divided evenly among GPUs. + minibatch_dict = {}, # Resolution-specific overrides. + max_minibatch_per_gpu = {}, # Resolution-specific maximum minibatch size per GPU. + G_lrate_base = 0.001, # Learning rate for the generator. + G_lrate_dict = {}, # Resolution-specific overrides. + D_lrate_base = 0.001, # Learning rate for the discriminator. + D_lrate_dict = {}, # Resolution-specific overrides. + lrate_rampup_kimg = 0, # Duration of learning rate ramp-up. + tick_kimg_base = 160, # Default interval of progress snapshots. + tick_kimg_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:30, 1024:20}): # Resolution-specific overrides. + + # Initialize result dict. + s = dnnlib.EasyDict() + s.kimg = cur_nimg / 1000.0 + + # Training phase. + phase_dur = lod_training_kimg + lod_transition_kimg + phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0 + phase_kimg = s.kimg - phase_idx * phase_dur + + # Level-of-detail and resolution. + s.lod = training_set.resolution_log2 + s.lod -= np.floor(np.log2(lod_initial_resolution)) + s.lod -= phase_idx + if lod_transition_kimg > 0: + s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg + s.lod = max(s.lod, 0.0) + s.resolution = 2 ** (training_set.resolution_log2 - int(np.floor(s.lod))) + + # Minibatch size. + s.minibatch = minibatch_dict.get(s.resolution, minibatch_base) + s.minibatch -= s.minibatch % num_gpus + if s.resolution in max_minibatch_per_gpu: + s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus) + + # Learning rate. + s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base) + s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base) + if lrate_rampup_kimg > 0: + rampup = min(s.kimg / lrate_rampup_kimg, 1.0) + s.G_lrate *= rampup + s.D_lrate *= rampup + + # Other parameters. + s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base) + return s + +#---------------------------------------------------------------------------- +# Main training script. + +def training_loop( + submit_config, + #--------------------------------------------------------------- + # Modified by Deng et al. + noise_dim = 32, + weight_args = {}, + train_stage_args = {}, + #--------------------------------------------------------------- + G_args = {}, # Options for generator network. + D_args = {}, # Options for discriminator network. + G_opt_args = {}, # Options for generator optimizer. + D_opt_args = {}, # Options for discriminator optimizer. + G_loss_args = {}, # Options for generator loss. + D_loss_args = {}, # Options for discriminator loss. + dataset_args = {}, # Options for dataset.load_dataset(). + sched_args = {}, # Options for train.TrainingSchedule. + grid_args = {}, # Options for train.setup_snapshot_image_grid(). + metric_arg_list = [], # Options for MetricGroup. + tf_config = {}, # Options for tflib.init_tf(). + G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. + D_repeats = 1, # How many times the discriminator is trained per G iteration. + minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. + reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? + total_kimg = 15000, # Total length of the training, measured in thousands of real images. + mirror_augment = True, # Enable mirror augment? + drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. + image_snapshot_ticks = 1, # How often to export image snapshots? + network_snapshot_ticks = 10, # How often to export network snapshots? + save_tf_graph = True, # Include full TensorFlow computation graph in the tfevents file? + save_weight_histograms = False, # Include weight histograms in the tfevents file? + resume_run_id = 87, # Run ID or network pkl to resume training from, None = start from scratch. + resume_snapshot = 2364, # Snapshot index to resume training from, None = autodetect. + resume_kimg = 2364, # Assumed training progress at the beginning. Affects reporting and training schedule. + resume_time = 0.0, + **_kwargs): # Assumed wallclock time at the beginning. Affects reporting. + + # Initialize dnnlib and TensorFlow. + PI = 3.1415927 + ctx = dnnlib.RunContext(submit_config, train) + tflib.init_tf(tf_config) + + # Load training set. + training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) + # Create 3d face reconstruction block + FaceRender = Face3D() + + # Construct networks. + with tf.device('/gpu:0'): + if resume_run_id is not None: + network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) + print('Loading networks from "%s"...' % network_pkl) + G, D, Gs = misc.load_pkl(network_pkl) + else: + print('Constructing networks...') + #--------------------------------------------------------------- + # Modified by Deng et al. + G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, latent_size = 254+noise_dim, **G_args) + #--------------------------------------------------------------- + D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) + Gs = G.clone('Gs') + G.print_layers(); D.print_layers() + + print('Building TensorFlow graph...') + with tf.name_scope('Inputs'), tf.device('/cpu:0'): + lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) + resolution = tf.placeholder(tf.float32, name='resolution', shape=[]) + lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) + minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) + minibatch_split = minibatch_in // submit_config.num_gpus + Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 + + G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) + D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) + for gpu in range(submit_config.num_gpus): + with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)): + G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') + D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') + lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] + reals, labels = training_set.get_minibatch_tf() + reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) + + #--------------------------------------------------------------- + # Modified by Deng et al. + G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\ + G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\ + lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\ + drange_net=drange_net,lod_in=lod_in,**train_stage_args) + #--------------------------------------------------------------- + + G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) + D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) + G_train_op = G_opt.apply_updates() + D_train_op = D_opt.apply_updates() + + Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) + with tf.device('/gpu:0'): + try: + peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() + except tf.errors.NotFoundError: + peak_gpu_mem_op = tf.constant(0) + + #--------------------------------------------------------------- + # Modified by Deng et al. + restore_weights_and_initialize(train_stage_args) + + print('Setting up snapshot image grid...') + sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) + + grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(G, training_set, **grid_args) + grid_latents = tf.random_normal([np.prod(grid_size),128+32+16+3]) + grid_INPUTcoeff = z_to_lambda_mapping(grid_latents) + grid_INPUTcoeff_w_t = tf.concat([grid_INPUTcoeff,tf.zeros([np.prod(grid_size),3])], axis = 1) + with tf.name_scope('FaceRender'): + grid_render_img,_,_,_ = FaceRender.Reconstruction_Block(grid_INPUTcoeff_w_t,256,np.prod(grid_size),progressive=False) + grid_render_img = tf.transpose(grid_render_img,perm=[0,3,1,2]) + grid_render_img = process_reals(grid_render_img, lod_in, False, training_set.dynamic_range, drange_net) + + grid_INPUTcoeff_,grid_renders = tflib.run([grid_INPUTcoeff,grid_render_img],{lod_in:sched.lod}) + grid_noise = np.random.randn(np.prod(grid_size),32) + grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_,grid_noise],axis = 1) + + grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) + grid_fakes = np.concatenate([grid_fakes,grid_renders],axis = 3) + misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) + misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) + #--------------------------------------------------------------- + + summary_log = tf.summary.FileWriter(submit_config.run_dir) + if save_tf_graph: + summary_log.add_graph(tf.get_default_graph()) + if save_weight_histograms: + G.setup_weight_histograms(); D.setup_weight_histograms() + metrics = metric_base.MetricGroup(metric_arg_list) + + + print('Training...\n') + ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) + maintenance_time = ctx.get_last_update_interval() + cur_nimg = int(resume_kimg * 1000) + cur_tick = 0 + tick_start_nimg = cur_nimg + prev_lod = -1.0 + + while cur_nimg < total_kimg * 1000: + if ctx.should_stop(): break + + # Choose training parameters and configure training ops. + sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) + training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) + if reset_opt_for_new_lod: + if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): + G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() + prev_lod = sched.lod + + # Run training ops. + for _mb_repeat in range(minibatch_repeats): + for _D_repeat in range(D_repeats): + tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution}) + cur_nimg += sched.minibatch + tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution}) + + # print('iter') + # Perform maintenance tasks once per tick. + done = (cur_nimg >= total_kimg * 1000) + if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: + cur_tick += 1 + tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 + tick_start_nimg = cur_nimg + tick_time = ctx.get_time_since_last_update() + total_time = ctx.get_time_since_start() + resume_time + + # Report progress. + print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( + autosummary('Progress/tick', cur_tick), + autosummary('Progress/kimg', cur_nimg / 1000.0), + autosummary('Progress/lod', sched.lod), + autosummary('Progress/minibatch', sched.minibatch), + dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), + autosummary('Timing/sec_per_tick', tick_time), + autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), + autosummary('Timing/maintenance_sec', maintenance_time), + autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) + autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) + autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) + + # Save snapshots. + if cur_tick % image_snapshot_ticks == 0 or done: + #--------------------------------------------------------------- + # Modified by Deng et al. + grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) + grid_fakes = np.concatenate([grid_fakes,grid_renders],axis = 3) + misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) + #--------------------------------------------------------------- + + if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: + pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) + misc.save_pkl((G, D, Gs), pkl) + metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) + + # Update summaries and RunContext. + metrics.update_autosummaries() + tflib.autosummary.save_summaries(summary_log, cur_nimg) + ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) + maintenance_time = ctx.get_last_update_interval() - tick_time + + # Write final results. + misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) + summary_log.close() + + ctx.close() + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/training/training_utils.py b/training/training_utils.py new file mode 100644 index 0000000..cc49566 --- /dev/null +++ b/training/training_utils.py @@ -0,0 +1,297 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Imitative-contrastive training scheme +import os +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib +from training import misc +from training.loss import * +from training.loss_control import * +from training.networks_stylegan import CoeffDecoder + +def process_reals(x, lod, mirror_augment, drange_data, drange_net): + with tf.name_scope('ProcessReals'): + with tf.name_scope('DynamicRange'): + x = tf.cast(x, tf.float32) + x = misc.adjust_dynamic_range(x, drange_data, drange_net) + if mirror_augment: + with tf.name_scope('MirrorAugment'): + s = tf.shape(x) + mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) + mask = tf.tile(mask, [1, s[1], s[2], s[3]]) + x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) + with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. + s = tf.shape(x) + y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) + y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) + y = tf.tile(y, [1, 1, 1, 2, 1, 2]) + y = tf.reshape(y, [-1, s[1], s[2], s[3]]) + x = tflib.lerp(x, y, lod - tf.floor(lod)) + with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks. + s = tf.shape(x) + factor = tf.cast(2 ** tf.floor(lod), tf.int32) + x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) + x = tf.tile(x, [1, 1, 1, factor, 1, factor]) + x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) + return x + +def restore_weights_and_initialize(train_stage_args): + var_list = tf.trainable_variables() + g_list = tf.global_variables() + + # add batch normalization params into trainable variables + bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] + bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] + var_list +=bn_moving_vars + + var_id_list = [v for v in var_list if 'id' in v.name and 'stage1' in v.name] + var_exp_list = [v for v in var_list if 'exp' in v.name and 'stage1' in v.name] + var_gamma_list = [v for v in var_list if 'gamma' in v.name and 'stage1' in v.name] + var_rot_list = [v for v in var_list if 'rot' in v.name and 'stage1' in v.name] + + resnet_vars = [v for v in var_list if 'resnet_v1_50' in v.name] + res_fc = [v for v in var_list if 'fc-id' in v.name or 'fc-ex' in v.name or 'fc-tex' in v.name or 'fc-angles' in v.name or 'fc-gamma' in v.name or 'fc-XY' in v.name or 'fc-Z' in v.name] + resnet_vars += res_fc + + facerec_vars = [v for v in var_list if 'InceptionResnetV1' in v.name] + + + saver_resnet = tf.train.Saver(var_list = resnet_vars) + saver_facerec = tf.train.Saver(var_list = facerec_vars) + saver_id = tf.train.Saver(var_list = var_id_list,max_to_keep = 100) + saver_exp = tf.train.Saver(var_list = var_exp_list,max_to_keep = 100) + saver_gamma = tf.train.Saver(var_list = var_gamma_list,max_to_keep = 100) + saver_rot = tf.train.Saver(var_list = var_rot_list,max_to_keep = 100) + + + saver_resnet.restore(tf.get_default_session(),os.path.join('./training/pretrained_weights/recon_net','FaceReconModel.ckpt')) + saver_facerec.restore(tf.get_default_session(),'./training/pretrained_weights/id_net/model-20170512-110547.ckpt-250000') + saver_id.restore(tf.get_default_session(),'./vae/weights/id/stage1_epoch_395.ckpt') + saver_exp.restore(tf.get_default_session(),'./vae/weights/exp/stage1_epoch_395.ckpt') + saver_gamma.restore(tf.get_default_session(),'./vae/weights/gamma/stage1_epoch_395.ckpt') + saver_rot.restore(tf.get_default_session(),'./vae/weights/rot/stage1_epoch_395.ckpt') + + if train_stage_args.func_name == 'training.training_utils.training_stage2': + parser_vars = [v for v in var_list if 'FaceParser' in v.name] + saver_parser = tf.train.Saver(var_list = parser_vars) + saver_parser.restore(tf.get_default_session(),os.path.join('./training/pretrained_weights/parsing_net','faceparser_public')) + + +#---------------------------------------------------------------------------- +# stage 1: train with imitative losses +def training_stage1( + FaceRender, + noise_dim, + weight_args, + G_gpu, + D_gpu, + G_opt, + D_opt, + training_set, + G_loss_args, + D_loss_args, + lod_assign_ops, + reals, + labels, + minibatch_split, + resolution, + drange_net, + lod_in): + + print('Stage1: Imitative learning...\n') + G_loss,D_loss = imitative_learning(FaceRender,noise_dim,weight_args,G_gpu,D_gpu,G_opt,D_opt,training_set, G_loss_args,\ + D_loss_args,lod_assign_ops,reals,labels,minibatch_split,resolution,drange_net,lod_in) + + return G_loss,D_loss + +# stage 2: train with imitative losses and contrastive losses +def training_stage2( + FaceRender, + noise_dim, + weight_args, + G_gpu, + D_gpu, + G_opt, + D_opt, + training_set, + G_loss_args, + D_loss_args, + lod_assign_ops, + reals, + labels, + minibatch_split, + resolution, + drange_net, + lod_in): + + + print('Stage2: Imitative learning and contrastive learning...\n') + G_loss1,D_loss1 = contrastive_learning(FaceRender,noise_dim,weight_args,G_gpu,D_gpu,G_opt,D_opt,training_set, G_loss_args,\ + D_loss_args,lod_assign_ops,reals,labels,minibatch_split,resolution,drange_net,lod_in) + + G_loss2,D_loss2 = imitative_learning(FaceRender,noise_dim,weight_args,G_gpu,D_gpu,G_opt,D_opt,training_set, G_loss_args,\ + D_loss_args,lod_assign_ops,reals,labels,minibatch_split,resolution,drange_net,lod_in) + + + G_loss = G_loss1 + G_loss2 + D_loss = D_loss1 + D_loss2 + + return G_loss,D_loss + + +# Mapping z sampled from normal distribution to lambda space variables with physical meanings +def z_to_lambda_mapping(latents): + with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): + with tf.variable_scope('id'): + IDcoeff = CoeffDecoder(z = latents[:,:128],coeff_length = 160,ch_dim = 512, ch_depth = 3) + with tf.variable_scope('exp'): + EXPcoeff = CoeffDecoder(z = latents[:,128:128+32],coeff_length = 64,ch_dim = 256, ch_depth = 3) + with tf.variable_scope('gamma'): + GAMMAcoeff = CoeffDecoder(z = latents[:,128+32:128+32+16],coeff_length = 27,ch_dim = 128, ch_depth = 3) + with tf.variable_scope('rot'): + Rotcoeff = CoeffDecoder(z = latents[:,128+32+16:128+32+16+3],coeff_length = 3,ch_dim = 32, ch_depth = 3) + + INPUTcoeff = tf.concat([IDcoeff,EXPcoeff,Rotcoeff,GAMMAcoeff], axis = 1) + + return INPUTcoeff + + +def imitative_learning( + FaceRender, + noise_dim, + weight_args, + G_gpu, + D_gpu, + G_opt, + D_opt, + training_set, + G_loss_args, + D_loss_args, + lod_assign_ops, + reals, + labels, + minibatch_split, + resolution, + drange_net, + lod_in): + + latents = tf.random_normal([minibatch_split,128+32+16+3]) + INPUTcoeff = z_to_lambda_mapping(latents) + + noise_coeff = tf.random_normal([minibatch_split,noise_dim]) + + INPUTcoeff_w_noise = tf.concat([INPUTcoeff,noise_coeff], axis = 1) + INPUTcoeff_w_t = tf.concat([INPUTcoeff,tf.zeros([minibatch_split,3])], axis = 1) + + with tf.name_scope('FaceRender'): + render_img,render_mask,render_landmark,_ = FaceRender.Reconstruction_Block(INPUTcoeff_w_t,resolution,minibatch_split,progressive=True) + render_img = tf.transpose(render_img,perm=[0,3,1,2]) + render_mask = tf.transpose(render_mask,perm=[0,3,1,2]) + render_img = process_reals(render_img, lod_in, False, training_set.dynamic_range, drange_net) + render_mask = process_reals(render_mask, lod_in, False, drange_net, drange_net) + + render_mask = tf.squeeze(render_mask,axis = 1) + + + with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): + G_loss,fake_images = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, latents = INPUTcoeff_w_noise, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) + l1_loss = L1_loss(render_img,fake_images,render_mask) + skin_color_loss = Skin_color_loss(fake_images,render_img,render_mask) + lm_loss, gamma_loss = Reconstruction_loss(fake_images,render_landmark,INPUTcoeff,FaceRender) + id_loss = ID_loss(render_img,fake_images,render_mask) + + add_loss = tf.cond(resolution<=32, lambda:l1_loss*20., + lambda:lm_loss*weight_args.w_lm + gamma_loss*weight_args.w_gamma + id_loss*weight_args.w_id + skin_color_loss*weight_args.w_skin) + + G_loss += add_loss + with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): + D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, latents = INPUTcoeff_w_noise, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) + + return G_loss,D_loss + + +def contrastive_learning( + FaceRender, + noise_dim, + weight_args, + G_gpu, + D_gpu, + G_opt, + D_opt, + training_set, + G_loss_args, + D_loss_args, + lod_assign_ops, + reals, + labels, + minibatch_split, + resolution, + drange_net, + lod_in): + + # expression change pair + latents_id = tf.tile(tf.random_normal([1,128]),[2,1]) + latents_exp = tf.random_normal([2,32]) + latents_gamma = tf.tile(tf.random_normal([1,16]),[2,1]) + latents_rot = tf.tile(tf.random_normal([1,3]),[2,1]) + latents_exp_pair = tf.concat([latents_id,latents_exp,latents_gamma,latents_rot], axis = 1) + + # lighting change pair + latents_id = tf.tile(tf.random_normal([1,128]),[2,1]) + latents_exp = tf.tile(tf.random_normal([1,32]),[2,1]) + latents_gamma = tf.random_normal([2,16]) + latents_rot = tf.tile(tf.random_normal([1,3]),[2,1]) + latents_gamma_pair = tf.concat([latents_id,latents_exp,latents_gamma,latents_rot], axis = 1) + + latents = tf.concat([latents_exp_pair,latents_gamma_pair],axis = 0) + INPUTcoeff = z_to_lambda_mapping(latents) + + noise_coeff = tf.random_normal([1,noise_dim]) + noise_coeff1 = tf.tile(noise_coeff,[2,1]) + noise_coeff = tf.random_normal([1,noise_dim]) + noise_coeff2 = tf.tile(noise_coeff,[2,1]) + noise_coeff_ = tf.concat([noise_coeff1,noise_coeff2],axis = 0) + + INPUTcoeff_w_noise = tf.concat([INPUTcoeff,noise_coeff_], axis = 1) + INPUTcoeff_w_t = tf.concat([INPUTcoeff,tf.zeros([4,3])], axis = 1) + + with tf.name_scope('FaceRender'): + render_img,render_mask,render_landmark,render_shape = FaceRender.Reconstruction_Block(INPUTcoeff_w_t,res=256,batchsize=4,progressive=False) + render_img = tf.transpose(render_img,perm=[0,3,1,2]) + render_mask = tf.transpose(render_mask,perm=[0,3,1,2]) + render_img = process_reals(render_img, lod_in, False, training_set.dynamic_range, drange_net) + render_mask = process_reals(render_mask, lod_in, False, drange_net, drange_net) + render_mask = tf.squeeze(render_mask,axis = 1) + + shape1 = tf.expand_dims(render_shape[0],0) + shape2 = tf.expand_dims(render_shape[1],0) + mask1 = tf.expand_dims(render_mask[0],0) + mask2 = tf.expand_dims(render_mask[1],0) + + with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): + G_loss,fake_images = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, latents = INPUTcoeff_w_noise, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) + + fake1 = tf.expand_dims(fake_images[0],0) + fake1 = tf.transpose(fake1,perm=[0,2,3,1]) + fake2 = tf.expand_dims(fake_images[1],0) + fake2 = tf.transpose(fake2,perm=[0,2,3,1]) + exp_warp_loss = Exp_warp_loss(fake1,fake2,shape1,shape2,mask1,mask2,FaceRender) + + fake3 = tf.expand_dims(fake_images[2],0) + fake3 = tf.transpose(fake3,perm=[0,2,3,1]) + fake4 = tf.expand_dims(fake_images[3],0) + fake4 = tf.transpose(fake4,perm=[0,2,3,1]) + gamma_change_loss = Gamma_change_loss(fake3,fake4,FaceRender) + + add_loss = weight_args.w_exp_warp*exp_warp_loss + weight_args.w_gamma_change*gamma_change_loss + G_loss += add_loss + + with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): + D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, latents = INPUTcoeff_w_noise, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) + + return G_loss, D_loss + + diff --git a/vae/data_loader.py b/vae/data_loader.py new file mode 100644 index 0000000..da02661 --- /dev/null +++ b/vae/data_loader.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from PIL import Image +import threading +import queue as queue # for python 3.x +import random +import numpy as np +import glob +import time +import os +import cv2 +from scipy.io import loadmat + +class DataFetchWorker: + def __init__(self, factor,path, batch_size, shuffle=True): + self.factor = factor + self.path = path + self.batch_size = batch_size + self.shuffle = shuffle + + self.thread_train = [] + self.thread_val = [] + self.queue_train = queue.Queue(20) + self.queue_val = queue.Queue(20) + self.current_idx_train = 0 + self.current_idx_val = 0 + self.stopped = True + + self.list_train = [] + self.list_train = glob.glob(os.path.join(self.path,'*.mat')) + + + self.total_subj_train = len(self.list_train) + + if self.shuffle: + random.shuffle(self.list_train) + + + def run(self): + self.stopped = False + + self.thread_train = threading.Thread(target=self.fill_train_batch) + self.thread_train.setDaemon(True) + self.thread_train.start() + + def process_data(self,coeff): + if self.factor == 'id': + input_coeff = coeff[0,:160].astype(np.float32) + elif self.factor == 'exp': + input_coeff = coeff[0,160:224].astype(np.float32) + elif self.factor == 'rot': + input_coeff = coeff[0,224:227].astype(np.float32) + elif self.factor == 'gamma': + input_coeff = coeff[0,227:254].astype(np.float32) + else: + raise Exception('invalid factor') + + return input_coeff + + def get_sets(self, subjs): + coeff_sets = [] + for subj in subjs: + + data = loadmat(subj) + coeff = data['coeff'] + coeff_ = self.process_data(coeff) + coeff_sets.append(coeff_) + + return coeff_sets + + def fill_train_batch(self): + while not self.stopped: + indices = np.array(range(self.current_idx_train, self.current_idx_train+self.batch_size)) % self.total_subj_train + subjs = [self.list_train[subj_idx] for subj_idx in indices] + coeff_sets = self.get_sets(subjs) + self.queue_train.put(np.asarray(coeff_sets),\ + block=True, timeout=None) + self.current_idx_train = (self.current_idx_train + self.batch_size) % self.total_subj_train + + + def fetch_train_batch(self): + coeff_sets = self.queue_train.get(block=True, timeout=None) + return coeff_sets + + def stop(self): + self.stopped = True + + while not self.queue_train.empty(): + self.queue_train.get(block=False) + time.sleep(0.1) + while not self.queue_train.empty(): + self.queue_train.get(block=False) \ No newline at end of file diff --git a/vae/demo.py b/vae/demo.py new file mode 100644 index 0000000..74c2c34 --- /dev/null +++ b/vae/demo.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# https://github.com/daib13/TwoStageVAE +import argparse +import os +from two_stage_vae_model import * +import numpy as np +import tensorflow as tf +import math +import time +from data_loader import * + + +def main(): + tf.reset_default_graph() + exp_folder = os.path.join(args.output_path, args.factor) + if not os.path.exists(exp_folder): + os.makedirs(exp_folder) + model_path = os.path.join(exp_folder, 'from_scratch') + if not os.path.exists(model_path): + os.makedirs(model_path) + + # train VAE for different factors + if args.factor == 'id': + coeff_dim = 160 + latent_dim = 128 + ch_dim = 512 + ch_depth = 3 + elif args.factor == 'exp': + coeff_dim = 64 + latent_dim = 32 + ch_dim = 256 + ch_depth = 3 + elif args.factor == 'gamma': + coeff_dim = 27 + latent_dim = 16 + ch_dim = 128 + ch_depth = 3 + else: + coeff_dim = 3 + latent_dim = 3 + ch_dim = 32 + ch_depth = 3 + + # input + input_x = tf.placeholder(tf.float32, [args.batch_size, coeff_dim], 'x') + data_worker = DataFetchWorker(factor=args.factor,path=args.datapath, batch_size=args.batch_size, shuffle=True) + num_sample = data_worker.total_subj_train + print(num_sample) + + # model + with tf.variable_scope(args.factor): + model = MLP(input_x, latent_dim, ch_dim, ch_depth, args.cross_entropy_loss) + + sess = tf.InteractiveSession() + sess.run(tf.global_variables_initializer()) + writer = tf.summary.FileWriter(exp_folder, sess.graph) + saver = tf.train.Saver() + + + + # train model + iteration_per_epoch = np.ceil(num_sample / args.batch_size).astype(np.int32) + if not args.val: + # first stage + data_worker.run() + + for epoch in range(args.epochs): + lr = args.lr if args.lr_epochs <= 0 else args.lr * math.pow(args.lr_fac, math.floor(float(epoch) / float(args.lr_epochs))) + epoch_loss = 0 + for j in range(iteration_per_epoch): + loss = model.step(data_worker, lr, sess, writer, args.write_iteration) + epoch_loss += loss + epoch_loss /= iteration_per_epoch + + print('Date: {date}\t' + 'Epoch: [Stage 1][{0}/{1}]\t' + 'Loss: {2:.4f}.'.format(epoch, args.epochs, epoch_loss, date=time.strftime('%Y-%m-%d %H:%M:%S'))) + + if epoch%5 == 0: + saver.save(sess, os.path.join(model_path, 'stage1_'+'epoch_%d.ckpt'%epoch)) + + data_worker.stop() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--root-folder', type=str, default='.') + parser.add_argument('--output-path', type=str, default='./weights') + parser.add_argument('--factor', type=str, default='rot') + + parser.add_argument('--datapath', type=str, default='../FFHQ_data/coeff') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--write-iteration', type=int, default=600) + + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--epochs', type=int, default=400) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--lr-epochs', type=int, default=150) + parser.add_argument('--lr-fac', type=float, default=0.5) + + parser.add_argument('--cross-entropy-loss', default=False, action='store_true') + parser.add_argument('--val', default=False, action='store_true') + + args = parser.parse_args() + print(args) + + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + main() \ No newline at end of file diff --git a/vae/two_stage_vae_model.py b/vae/two_stage_vae_model.py new file mode 100644 index 0000000..067abd6 --- /dev/null +++ b/vae/two_stage_vae_model.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# https://github.com/daib13/TwoStageVAE +import tensorflow as tf +import math +import numpy as np +from tensorflow.python.training.moving_averages import assign_moving_average + +class TwoStageVaeModel(object): + def __init__(self, x, latent_dim=128,ch_dim = 512,ch_depth = 3, cross_entropy_loss=False): + self.x = x + self.batch_size = x.get_shape().as_list()[0] + self.latent_dim = latent_dim + self.ch_dim = ch_dim + self.ch_depth = ch_depth + self.cross_entropy_loss = cross_entropy_loss + + self.is_training = tf.placeholder(tf.bool, [], 'is_training') + + self.__build_network() + self.__build_loss() + self.__build_summary() + self.__build_optimizer() + + def __build_network(self): + with tf.variable_scope('stage1'): + self.build_encoder1() + self.build_decoder1() + + def __build_loss(self): + HALF_LOG_TWO_PI = 0.91893 + + self.kl_loss1 = tf.reduce_sum(tf.square(self.mu_z) + tf.square(self.sd_z) - 2 * self.logsd_z - 1) / 2.0 / float(self.batch_size) + if not self.cross_entropy_loss: + self.gen_loss1 = tf.reduce_sum(tf.square((self.x - self.x_hat) / self.gamma_x) / 2.0 + self.loggamma_x + HALF_LOG_TWO_PI) / float(self.batch_size) + else: + self.gen_loss1 = -tf.reduce_sum(self.x * tf.log(tf.maximum(self.x_hat, 1e-8)) + (1-self.x) * tf.log(tf.maximum(1-self.x_hat, 1e-8))) / float(self.batch_size) + self.loss1 = self.kl_loss1 + self.gen_loss1 + + + def __build_summary(self): + with tf.name_scope('stage1_summary'): + self.summary1 = [] + self.summary1.append(tf.summary.scalar('kl_loss', self.kl_loss1)) + self.summary1.append(tf.summary.scalar('gen_loss', self.gen_loss1)) + self.summary1.append(tf.summary.scalar('loss', self.loss1)) + self.summary1.append(tf.summary.scalar('gamma', self.gamma_x)) + self.summary1 = tf.summary.merge(self.summary1) + + def __build_optimizer(self): + all_variables = tf.global_variables() + variables1 = [var for var in all_variables if 'stage1' in var.name] + self.lr = tf.placeholder(tf.float32, [], 'lr') + self.global_step = tf.get_variable('global_step', [], tf.int32, tf.zeros_initializer(), trainable=False) + self.opt1 = tf.train.AdamOptimizer(self.lr).minimize(self.loss1, self.global_step, var_list=variables1) + + def step(self, data_worker, lr, sess, writer=None, write_iteration=600): + input_batch = data_worker.fetch_train_batch() + loss, summary, _ = sess.run([self.loss1, self.summary1, self.opt1], feed_dict={self.x: input_batch, self.lr: lr, self.is_training: True}) + + global_step = self.global_step.eval(sess) + if global_step % write_iteration == 0 and writer is not None: + writer.add_summary(summary, global_step) + return loss + + def generate(self, sess, num_sample): + num_iter = math.ceil(float(num_sample) / float(self.batch_size)) + gen_samples = [] + for i in range(num_iter): + z = np.random.normal(0, 1, [self.batch_size, self.latent_dim]) + # x = f_1(z) + x = sess.run(self.x_hat, feed_dict={self.z: z, self.is_training: False}) + gen_samples.append(x) + gen_samples = np.concatenate(gen_samples, 0) + return gen_samples[0:num_sample] + + +class MLP(TwoStageVaeModel): + def __init__(self, x, latent_dim=128,ch_dim = 512, ch_depth = 3, cross_entropy_loss=False): + super(MLP, self).__init__(x, latent_dim, ch_dim, ch_depth, cross_entropy_loss) + + def build_encoder1(self): + with tf.variable_scope('encoder'): + y = self.x + for i in range(self.ch_depth): + y = tf.layers.dense(y, self.ch_dim, tf.nn.relu, name='fc'+str(i)) + + self.mu_z = tf.layers.dense(y, self.latent_dim) + self.logsd_z = tf.layers.dense(y, self.latent_dim) + self.sd_z = tf.exp(self.logsd_z) + self.z = self.mu_z + tf.random_normal([self.batch_size, self.latent_dim]) * self.sd_z + + def build_decoder1(self): + with tf.variable_scope('decoder'): + y = self.z + self.final_side_length = self.x.get_shape().as_list()[1] + for i in range(self.ch_depth): + y = tf.layers.dense(y, self.ch_dim, tf.nn.relu, name='fc'+str(i)) + + self.x_hat = tf.layers.dense(y, self.final_side_length, name='x_hat') + self.loggamma_x = tf.get_variable('loggamma_x', [], tf.float32, tf.zeros_initializer()) + self.gamma_x = tf.exp(self.loggamma_x) \ No newline at end of file diff --git a/vae/util.py b/vae/util.py new file mode 100644 index 0000000..49f998d --- /dev/null +++ b/vae/util.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# https://github.com/daib13/TwoStageVAE +import tensorflow as tf +from tensorflow.contrib import layers +import math +import numpy as np +from tensorflow.python.training.moving_averages import assign_moving_average + + +def spectral_norm(input_): + """Performs Spectral Normalization on a weight tensor.""" + if len(input_.shape) < 2: + raise ValueError("Spectral norm can only be applied to multi-dimensional tensors") + + # The paper says to flatten convnet kernel weights from (C_out, C_in, KH, KW) + # to (C_out, C_in * KH * KW). But Sonnet's and Compare_gan's Conv2D kernel + # weight shape is (KH, KW, C_in, C_out), so it should be reshaped to + # (KH * KW * C_in, C_out), and similarly for other layers that put output + # channels as last dimension. + # n.b. this means that w here is equivalent to w.T in the paper. + w = tf.reshape(input_, [-1, input_.get_shape().as_list()[-1]]) + + # Persisted approximation of first left singular vector of matrix `w`. + + u_var = tf.get_variable( + input_.name.replace(":", "") + "/u_var", + shape=(w.shape[0], 1), + dtype=w.dtype, + initializer=tf.random_normal_initializer(), + trainable=False) + u = u_var + + # Use power iteration method to approximate spectral norm. + # The authors suggest that "one round of power iteration was sufficient in the + # actual experiment to achieve satisfactory performance". According to + # observation, the spectral norm become very accurate after ~20 steps. + + power_iteration_rounds = 1 + for _ in range(power_iteration_rounds): + # `v` approximates the first right singular vector of matrix `w`. + v = tf.nn.l2_normalize(tf.matmul(tf.transpose(w), u), dim=None, epsilon=1e-12) + u = tf.nn.l2_normalize(tf.matmul(w, v), dim=None, epsilon=1e-12) + + # Update persisted approximation. + with tf.control_dependencies([tf.assign(u_var, u, name="update_u")]): + u = tf.identity(u) + + # The authors of SN-GAN chose to stop gradient propagating through u and v. + # In johnme@'s experiments it wasn't clear that this helps, but it doesn't + # seem to hinder either so it's kept in order to be a faithful implementation. + u = tf.stop_gradient(u) + v = tf.stop_gradient(v) + + # Largest singular value of `w`. + norm_value = tf.matmul(tf.matmul(tf.transpose(u), w), v) + norm_value.shape.assert_is_fully_defined() + norm_value.shape.assert_is_compatible_with([1, 1]) + + w_normalized = w / norm_value + + # Unflatten normalized weights to match the unnormalized tensor. + w_tensor_normalized = tf.reshape(w_normalized, input_.shape) + return w_tensor_normalized + + +def conv2d(input_, output_dim, k_h, k_w, d_h, d_w, stddev=0.02, name="conv2d", + initializer=tf.truncated_normal_initializer, use_sn=False): + with tf.variable_scope(name): + w = tf.get_variable( + "w", [k_h, k_w, input_.get_shape()[-1], output_dim], + initializer=initializer(stddev=stddev)) + if use_sn: + conv = tf.nn.conv2d(input_, spectral_norm(w), strides=[1, d_h, d_w, 1], padding="SAME") + else: + conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding="SAME") + biases = tf.get_variable( + "biases", [output_dim], initializer=tf.constant_initializer(0.0)) + return tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + + +def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, use_sn=False): + shape = input_.get_shape().as_list() + + with tf.variable_scope(scope or "Linear"): + matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, tf.random_normal_initializer(stddev=stddev)) + bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(bias_start)) + if use_sn: + return tf.matmul(input_, spectral_norm(matrix)) + bias + else: + return tf.matmul(input_, matrix) + bias + + +def lrelu(input_, leak=0.2, name="lrelu"): + return tf.maximum(input_, leak * input_, name=name) + + +def batch_norm(x, is_training, scope, eps=1e-5, decay=0.999, affine=True): + def mean_var_with_update(moving_mean, moving_variance): + if len(x.get_shape().as_list()) == 4: + statistics_axis = [0, 1, 2] + else: + statistics_axis = [0] + mean, variance = tf.nn.moments(x, statistics_axis, name='moments') + with tf.control_dependencies([assign_moving_average(moving_mean, mean, decay), assign_moving_average(moving_variance, variance, decay)]): + return tf.identity(mean), tf.identity(variance) + + with tf.name_scope(scope): + with tf.variable_scope(scope + '_w'): + params_shape = x.get_shape().as_list()[-1:] + moving_mean = tf.get_variable('mean', params_shape, initializer=tf.zeros_initializer(), trainable=False) + moving_variance = tf.get_variable('variance', params_shape, initializer=tf.ones_initializer, trainable=False) + + mean, variance = tf.cond(is_training, lambda: mean_var_with_update(moving_mean, moving_variance), lambda: (moving_mean, moving_variance)) + if affine: + beta = tf.get_variable('beta', params_shape, initializer=tf.zeros_initializer()) + gamma = tf.get_variable('gamma', params_shape, initializer=tf.ones_initializer) + return tf.nn.batch_normalization(x, mean, variance, beta, gamma, eps) + else: + return tf.nn.batch_normalization(x, mean, variance, None, None, eps) + + +def deconv2d(input_, output_shape, k_h, k_w, d_h, d_w, stddev=0.02, name="deconv2d"): + with tf.variable_scope(name): + w = tf.get_variable("w", [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], initializer=tf.random_normal_initializer(stddev=stddev)) + deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) + biases = tf.get_variable("biases", [output_shape[-1]], initializer=tf.constant_initializer(0.0)) + return tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) + + +def downsample(x, out_dim, kernel_size, name): + with tf.variable_scope(name): + input_shape = x.get_shape().as_list() + assert(len(input_shape) == 4) + return tf.layers.conv2d(x, out_dim, kernel_size, 2, 'same') + + +def upsample(x, out_dim, kernel_size, name): + with tf.variable_scope(name): + input_shape = x.get_shape().as_list() + assert(len(input_shape) == 4) + return tf.layers.conv2d_transpose(x, out_dim, kernel_size, 2, 'same') + + +def res_block(x, out_dim, is_training, name, depth=2, kernel_size=3): + with tf.variable_scope(name): + y = x + for i in range(depth): + y = tf.nn.relu(batch_norm(y, is_training, 'bn'+str(i))) + y = tf.layers.conv2d(y, out_dim, kernel_size, padding='same', name='layer'+str(i)) + s = tf.layers.conv2d(x, out_dim, kernel_size, padding='same', name='shortcut') + return y + s + + +def res_fc_block(x, out_dim, name, depth=2): + with tf.variable_scope(name): + y = x + for i in range(depth): + y = tf.layers.dense(tf.nn.relu(y), out_dim, name='layer'+str(i)) + s = tf.layers.dense(x, out_dim, name='shortcut') + return y + s + + +def scale_block(x, out_dim, is_training, name, block_per_scale=1, depth_per_block=2, kernel_size=3): + with tf.variable_scope(name): + y = x + for i in range(block_per_scale): + y = res_block(y, out_dim, is_training, 'block'+str(i), depth_per_block, kernel_size) + return y + + +def scale_fc_block(x, out_dim, name, block_per_scale=1, depth_per_block=2): + with tf.variable_scope(name): + y = x + for i in range(block_per_scale): + y = res_fc_block(y, out_dim, 'block'+str(i), depth_per_block) + return y \ No newline at end of file diff --git a/vae/weights/exp/stage1_epoch_395.ckpt.data-00000-of-00001 b/vae/weights/exp/stage1_epoch_395.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..ca3f443 Binary files /dev/null and b/vae/weights/exp/stage1_epoch_395.ckpt.data-00000-of-00001 differ diff --git a/vae/weights/exp/stage1_epoch_395.ckpt.index b/vae/weights/exp/stage1_epoch_395.ckpt.index new file mode 100644 index 0000000..fc70e91 Binary files /dev/null and b/vae/weights/exp/stage1_epoch_395.ckpt.index differ diff --git a/vae/weights/exp/stage1_epoch_395.ckpt.meta b/vae/weights/exp/stage1_epoch_395.ckpt.meta new file mode 100644 index 0000000..5e45294 Binary files /dev/null and b/vae/weights/exp/stage1_epoch_395.ckpt.meta differ diff --git a/vae/weights/gamma/stage1_epoch_395.ckpt.data-00000-of-00001 b/vae/weights/gamma/stage1_epoch_395.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..ae3220d Binary files /dev/null and b/vae/weights/gamma/stage1_epoch_395.ckpt.data-00000-of-00001 differ diff --git a/vae/weights/gamma/stage1_epoch_395.ckpt.index b/vae/weights/gamma/stage1_epoch_395.ckpt.index new file mode 100644 index 0000000..c76e31c Binary files /dev/null and b/vae/weights/gamma/stage1_epoch_395.ckpt.index differ diff --git a/vae/weights/gamma/stage1_epoch_395.ckpt.meta b/vae/weights/gamma/stage1_epoch_395.ckpt.meta new file mode 100644 index 0000000..688cabf Binary files /dev/null and b/vae/weights/gamma/stage1_epoch_395.ckpt.meta differ diff --git a/vae/weights/id/stage1_epoch_395.ckpt.data-00000-of-00001 b/vae/weights/id/stage1_epoch_395.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..ed75555 Binary files /dev/null and b/vae/weights/id/stage1_epoch_395.ckpt.data-00000-of-00001 differ diff --git a/vae/weights/id/stage1_epoch_395.ckpt.index b/vae/weights/id/stage1_epoch_395.ckpt.index new file mode 100644 index 0000000..57b0b44 Binary files /dev/null and b/vae/weights/id/stage1_epoch_395.ckpt.index differ diff --git a/vae/weights/id/stage1_epoch_395.ckpt.meta b/vae/weights/id/stage1_epoch_395.ckpt.meta new file mode 100644 index 0000000..ca61b15 Binary files /dev/null and b/vae/weights/id/stage1_epoch_395.ckpt.meta differ diff --git a/vae/weights/rot/stage1_epoch_395.ckpt.data-00000-of-00001 b/vae/weights/rot/stage1_epoch_395.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..a0a191f Binary files /dev/null and b/vae/weights/rot/stage1_epoch_395.ckpt.data-00000-of-00001 differ diff --git a/vae/weights/rot/stage1_epoch_395.ckpt.index b/vae/weights/rot/stage1_epoch_395.ckpt.index new file mode 100644 index 0000000..e898dc9 Binary files /dev/null and b/vae/weights/rot/stage1_epoch_395.ckpt.index differ diff --git a/vae/weights/rot/stage1_epoch_395.ckpt.meta b/vae/weights/rot/stage1_epoch_395.ckpt.meta new file mode 100644 index 0000000..8d16eed Binary files /dev/null and b/vae/weights/rot/stage1_epoch_395.ckpt.meta differ