In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from copy import deepcopy

import sys
sys.path.insert(0, '../')
from gantools import data
from gantools import utils
from gantools import plot
from gantools.model import CosmoWGAN
from gantools.gansystem import GANsystem
from gantools import evaluation

In [None]:
# Note: some of the parameters don't make sense for the fake dataset
ns = 32 # Resolution of the image
try_resume = True # Try to resume previous simulation

def non_lin(x):
    return tf.nn.sigmoid(x)

# Data handling

Load the data

In [None]:
# Create fake images
nsamples = 5000
sigma = 0.005
N = 10
image_shape = [ns, ns]
images = data.toy_dataset_generator.generate_fake_images(nsamples=nsamples, sigma=sigma, N=N, image_shape=image_shape)

In [None]:
# Convert to gantools dataset
dataset = data.Dataset.Dataset(images)

In [None]:
# The dataset can return an iterator.
it = dataset.iter(10)
print(next(it).shape)
del it

In [None]:
# Get all the data
X = dataset.get_all_data().flatten()

Display the histogram of the pixel densities after the forward map

In [None]:
plt.hist(X, 100)
print('min: {}'.format(np.min(X)))
print('max: {}'.format(np.max(X)))
plt.yscale('log')

In [None]:
# to free some memory
del X

Let us plot 16 images

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
imgs = dataset.get_samples(N=16)
for row in ax:
    for col in row:
        col.imshow(imgs[idx], vmin=0, vmax=1)
        col.axis('off')
        idx = idx + 1

# Define parameters for the WGAN

In [None]:
time_str = '2D_mac'
global_path = 'saved_results/Fake Dataset/'

name = 'Simple_WGAN_fake_' + str(ns) + '_' + time_str

## Parameters

In [None]:
bn = False

# Parameters for the generator
params_generator = dict()
params_generator['latent_dim'] = 128
params_generator['stride'] = [1, 2, 1]
params_generator['nfilter'] = [16, 32, 1]
params_generator['shape'] = [[5, 5], [5, 5], [5, 5]]
params_generator['batch_norm'] = [bn, bn]
params_generator['full'] = [16 * 16 * 8]
params_generator['summary'] = True
params_generator['non_lin'] = non_lin
params_generator['in_conv_shape'] = [16, 16]

# Parameters for the discriminator
params_discriminator = dict()
params_discriminator['stride'] = [1, 2, 1]
params_discriminator['nfilter'] = [32, 16, 8]
params_discriminator['shape'] = [[5, 5], [5, 5], [5, 5]]
params_discriminator['batch_norm'] = [bn, bn, bn]
params_discriminator['full'] = []
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True

# Optimization parameters
d_opt = dict()
d_opt['optimizer'] = "rmsprop"
d_opt['learning_rate'] = 3e-5
params_optimization = dict()
params_optimization['discriminator'] = deepcopy(d_opt)
params_optimization['generator'] = deepcopy(d_opt)
params_optimization['n_critic'] = 5
params_optimization['batch_size'] = 32
params_optimization['epoch'] = 75

# Cosmology parameters
params_cosmology = dict()
params_cosmology['forward_map'] = None
params_cosmology['backward_map'] = None

# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['cosmology'] = params_cosmology # Parameters for the cosmological summaries
params['net']['shape'] = [ns, ns, 1] # Shape of the image
params['net']['gamma_gp'] = 10 # Gradient penalty

params['optimization'] = params_optimization
params['summary_every'] = 1000 # Tensorboard summaries every ** iterations
params['print_every'] = 500 # Console summaries every ** iterations
params['save_every'] = 10000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = 2000

In [None]:
resume, params = utils.test_resume(try_resume, params)
# If a model is reloaded and some parameters have to be changed, then it should be done here.
# For example, setting the number of epoch to 5 would be:
params['optimization']['epoch'] = 25
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')

# Build the model

In [None]:
wgan = GANsystem(CosmoWGAN, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Evaluation of the sample quality

In [None]:
N = 2000 # Number of samples
gen_sample = np.squeeze(wgan.generate(N=N))

In [None]:
# Before computing the statistics, we need to invert the mapping
raw_images = dataset.get_samples(N)
gen_sample_raw = gen_sample

Display a few fake samples

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        col.imshow(gen_sample_raw[idx], vmin=0, vmax=1)
        col.axis('off')
        idx = idx + 1

Display real and fake

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(20,10))
idx = 0
real_imgs = dataset.get_samples(4)
for row in ax:
    for col in row:
        col.imshow(gen_sample_raw[idx] if idx < 4 else real_imgs[idx % 4], vmin=0, vmax=1)
        col.axis('off')
        idx = idx + 1

In [None]:
lenstools = True
bin_k = 15
box_l = (5*np.pi/180)
cut = [50, 1000]

In [None]:
logel2, l2, logel1, l1, fd = evaluation.compute_and_plot_psd(raw_images, gen_sample_raw, multiply=True, confidence='std', fractional_difference=True, bin_k=bin_k, box_l=box_l, cut=cut, lenstools=lenstools, loc=1)

In [None]:
logel2, l2, logel1, l1, _ = evaluation.compute_and_plot_peak_count(raw_images, gen_sample_raw, log=False, neighborhood_size=2, threshold=0.01, confidence='std', fractional_difference=True, loc=3)

In [None]:
logel2, l2, logel1, l1, _ = evaluation.compute_and_plot_mass_hist(raw_images, gen_sample_raw, log=False, confidence='std', lim=(0,1), fractional_difference=True)

In [None]:
if lenstools:
    ylims = [[(1e-5, 1e-1), (0, 0.1)], [(1e-2, 3e2), (0, 0.35)], [(1e-1, 1e3), (0, 0.25)]]
else:
    ylims = [[(1e-3, 1e1), (0, 0.1)], [(1e-2, 3e2), (0, 0.35)], [(1e-1, 1e3), (0, 0.25)]]
locs = [1, 1, 1]
fractional_difference=[True, True, True]
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
_ = evaluation.plot_stats(ax, gen_sample_raw, raw_images, log=False, lim=(0,1), neighborhood_size=2, threshold=0.01, confidence='std', multiply=True, bin_k=bin_k, box_l=box_l, cut=cut, lenstools=lenstools, fractional_difference=fractional_difference, locs=locs, ylims=ylims)
fig.tight_layout()

In [None]:
corr_r, corr_f, k = evaluation.compute_plot_correlation(raw_images, gen_sample_raw, bin_k=bin_k, box_l=box_l, cut=cut, lenstools=lenstools)

In [None]:
print("Correlation l2 loss:", np.linalg.norm(corr_r - corr_f))

MS-SSIM score

In [None]:
s_fake, s_real = evaluation.compute_ssim_score([gen_sample_raw], [raw_images])

In [None]:
print(s_fake[0], s_real[0])
print(np.abs(s_fake[0] - s_real[0]))