In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
sys.path.insert(0, '../')

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"


import numpy as np
import tensorflow as tf
import functools
import matplotlib.pyplot as plt
from copy import deepcopy

from gantools import utils
from gantools import plot
from gantools.model import WGAN
from gantools.gansystem import GANsystem


from cosmotools.metric import evaluation
from cosmotools.model import CosmoWGAN
from cosmotools.data import load
from cosmotools.data import fmap

# Parameters

In [None]:
ns = 64 # Resolution of the image
try_resume = False # Try to resume previous training step
Mpch = 350 # Type of dataset (select 70 or 350)


forward = fmap.stat_forward
backward = fmap.stat_backward
def non_lin(x):
    return tf.nn.relu(x)

# Data handling

Load the data

In [None]:
dataset = load.load_nbody_dataset(ncubes=10, spix=ns, Mpch=Mpch, forward_map=forward)

In [None]:
# The dataset can return an iterator.
it = dataset.iter(10)
print(next(it).shape)
del it

In [None]:
# Get all the data
X = dataset.get_all_data().flatten()

In [None]:
# Check that the backward maps invert the forward map.
assert(np.sum(np.abs(forward(np.round(backward(X))))-X)==0)

Display the histogram of the pixel densities after the forward map

In [None]:
plt.hist(X, 100)
print('min: {}'.format(np.min(X)))
print('max: {}'.format(np.max(X)))
plt.yscale('log')

In [None]:
# to free some memory
del X

Let us plot 16 images

In [None]:
plt.figure(figsize=(15,15))
plot.draw_images(dataset.get_samples(N=16),nx=4,ny=4);
plt.title("Real samples")

# Define parameters for the WGAN

In [None]:
global_path = 'saved_results'
name = 'WGAN{}'.format(ns) + '_' + '2D_simple'

## Parameters

In [None]:
bn = False

md=32

params_discriminator = dict()
params_discriminator['stride'] = [1, 2, 2, 2, 1]
params_discriminator['nfilter'] = [md, 2*md, 4*md, 2*md, md]
params_discriminator['shape'] = [[4, 4],[4, 4],[4, 4], [4, 4], [4, 4]]
params_discriminator['batch_norm'] = [bn, bn, bn, bn, bn ]
params_discriminator['full'] = []
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True
params_discriminator['data_size'] = 2
params_discriminator['inception'] = False
params_discriminator['spectral_norm'] = False
params_discriminator['fft_features'] = False
params_discriminator['psd_features'] = False

params_generator = dict()
params_generator['stride'] = [1, 2, 2, 2, 1]
params_generator['latent_dim'] = ns*2
params_generator['in_conv_shape'] =[ns//8,ns//8]
params_generator['nfilter'] = [md, 2*md, 4*md, 2*md, 1]
params_generator['shape'] = [[4, 4],[4, 4], [4, 4],[4, 4],[4, 4]]
params_generator['batch_norm'] = [bn, bn, bn,bn ]
params_generator['full'] = [(ns//8)**2 *8]
params_generator['summary'] = True
params_generator['non_lin'] = None
params_generator['data_size'] = 2
params_generator['inception'] = False
params_generator['spectral_norm'] = False


params_optimization = dict()
params_optimization['batch_size'] = 32
params_optimization['epoch'] = (ns**2)//64
params_optimization['n_critic'] = 5
# params_optimization['generator'] = dict()
# params_optimization['generator']['optimizer'] = 'adam'
# params_optimization['generator']['kwargs'] = {'beta1':0, 'beta2':0.9}
# params_optimization['generator']['learning_rate'] = 0.0004
# params_optimization['discriminator'] = dict()
# params_optimization['discriminator']['optimizer'] = 'adam'
# params_optimization['discriminator']['kwargs'] = {'beta1':0, 'beta2':0.9}
# params_optimization['discriminator']['learning_rate'] = 0.0001

# Cosmology parameters
params_cosmology = dict()
params_cosmology['forward_map'] = forward
params_cosmology['backward_map'] = backward


# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['cosmology'] = params_cosmology # Parameters for the cosmological summaries
params['net']['prior_distribution'] = 'gaussian'
params['net']['shape'] = [ns, ns, 1] # Shape of the image
params['net']['loss_type'] = 'wasserstein' # loss ('hinge' or 'wasserstein')
params['net']['gamma_gp'] = 10 # Gradient penalty

params['optimization'] = params_optimization
params['summary_every'] = 500 # Tensorboard summaries every ** iterations
params['print_every'] = 50 # Console summaries every ** iterations
params['save_every'] = 2000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = (64*32*32)//ns



In [None]:
resume, params = utils.test_resume(try_resume, params)
# If a model is reloaded and some parameters have to be changed, then it should be done here.
# For example, setting the number of epoch to 10 would be:
params['optimization']['epoch'] = 40


# Build the model

In [None]:
wgan = GANsystem(CosmoWGAN, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Generate new samples
To have meaningful statistics, be sure to generate enough samples
* 2000 : 32 x 32
* 500 : 64 x 64
* 200 : 128 x 128


In [None]:
N = 2000 # Number of samples
gen_sample = np.squeeze(wgan.generate(N=N))

Display a few fake samples

In [None]:
plt.figure(figsize=(15,15))
plot.draw_images(gen_sample,nx=4,ny=4);
plt.title("Fake samples");

# Evaluation of the sample quality

In [None]:
# Before computing the statistics, we need to invert the mapping
raw_images = backward(dataset.get_samples(2*N))
gen_sample_raw = backward(gen_sample)

In [None]:
_ = evaluation.compute_and_plot_psd(raw_images[:N], gen_sample_raw, confidence='std')

In [None]:
_ = evaluation.compute_and_plot_peak_count(raw_images[:N], gen_sample_raw, confidence='std')

In [None]:
_ = evaluation.compute_and_plot_mass_hist(raw_images[:N], gen_sample_raw, confidence='std')

# Compute the scores

In [None]:
from cosmotools.metric.score import score_histogram, score_peak_histogram, score_psd
print('PSD score: {}'.format(score_psd(raw_images[:N],gen_sample_raw)))
print('Histogram score: {}'.format(score_histogram(raw_images[:N],gen_sample_raw)))
print('Peak histogram score: {}'.format(score_peak_histogram(raw_images[:N],gen_sample_raw)))

#### For comparizon, the score obtained with real data

In [None]:
print('PSD score: {}'.format(score_psd(raw_images[:N],raw_images[N:2*N])))
print('Histogram score: {}'.format(score_histogram(raw_images[:N],raw_images[N:2*N])))
print('Peak histogram score: {}'.format(score_peak_histogram(raw_images[:N],raw_images[N:2*N])))