In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from copy import deepcopy

os.environ["CUDA_VISIBLE_DEVICES"]=""

import sys
sys.path.insert(0, '../')

from gantools import utils
from gantools import plot
from gantools.gansystem import GANsystem
from gantools.data import Dataset
from gantools.data import transformation

from cosmotools.model import CosmoWGAN
from cosmotools.metric import evaluation
from cosmotools.data import load

# from gantools import data


# Parameters

In [None]:
ns = 128 # Resolution of the image
try_resume = True # Try to resume previous simulation

def non_lin(x):
    return tf.nn.relu(x)

# Data handling

Dataset corresponds to set of parameters with Omega_m = 0.254 and sigma_8 = 0.852

In [None]:
dataset = load.load_params_dataset(filename='kids_test.h5', batch=12000, shape=[ns, ns], sorted=True)

In [None]:
# Get all the data
X, _ = dataset.get_data_for_params(np.array([0.254, 0.852]))
vmin = np.min(X)
vmax = np.max(X)
print(X.shape)

Display the histogram of the pixel densities after the forward map

In [None]:
plt.hist(X.flatten(), 100)
print('min: {}'.format(vmin))
print('max: {}'.format(vmax))
plt.yscale('log')

Rescale dataset

In [None]:
final_inter = [0, 1]
init_inter = [vmin, vmax]

def rescale(x):
    return transformation.rescale(x, init_inter, final_inter)

In [None]:
dataset = Dataset(X, transform=rescale)
X = dataset.get_all_data()
vmin = np.min(X)
vmax = np.max(X)

In [None]:
plt.hist(X.flatten(), 100)
print('min: {}'.format(vmin))
print('max: {}'.format(vmax))
plt.yscale('log')

Augment dataset

In [None]:
dataset = Dataset(X, transform=transformation.random_transpose_2d)

In [None]:
# to free some memory
del X

Let us plot 16 images

In [None]:
vmax = 0.25

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
imgs = dataset.get_samples(N=16)
for row in ax:
    for col in row:
        plot.plot_img(imgs[idx], vmin=vmin, vmax=vmax, ax=col)
        idx = idx + 1
fig.tight_layout()

# Define parameters for the WGAN

In [None]:
time_str = '2D'
global_path = '../saved_results/'

name = 'Kids{}'.format(ns) + '_more_conv_' + time_str

## Parameters

In [None]:
bn = False

params_discriminator = dict()
params_discriminator['stride'] = [1, 2, 2, 2, 2]
params_discriminator['nfilter'] = [32, 64, 128, 256, 512]
params_discriminator['shape'] = [[7, 7], [5, 5], [5, 5], [5,5], [3,3]]
params_discriminator['batch_norm'] = [bn, bn, bn, bn, bn]
params_discriminator['full'] = []
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True
params_discriminator['data_size'] = 2

params_generator = dict()
params_generator['stride'] = [2, 2, 2, 2, 1]
params_generator['latent_dim'] = 64
params_generator['nfilter'] = [256, 128, 64, 32, 1]
params_generator['shape'] = [[3, 3], [5, 5], [5, 5], [5, 5], [7,7]]
params_generator['batch_norm'] = [bn, bn, bn, bn]
params_generator['full'] = [8 * 8 * 512]
params_generator['summary'] = True
params_generator['non_lin'] = non_lin
params_generator['data_size'] = 2

params_optimization = dict()
params_optimization['optimizer'] = 'rmsprop'
params_optimization['batch_size'] = 32
params_optimization['learning_rate'] = 5e-5
params_optimization['epoch'] = 100

# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['prior_distribution'] = 'gaussian'
params['net']['shape'] = [ns, ns, 1] # Shape of the image
params['net']['gamma_gp'] = 10 # Gradient penalty

params['optimization'] = params_optimization
params['summary_every'] = 2000 # Tensorboard summaries every ** iterations
params['print_every'] = 1000 # Console summaries every ** iterations
params['save_every'] = 10000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = 2000

In [None]:
resume, params = utils.test_resume(try_resume, params)
# If a model is reloaded and some parameters have to be changed, then it should be done here.
# For example, setting the number of epoch to 5 would be:
params['optimization']['epoch'] = 1
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')

# Build the model

In [None]:
wgan = GANsystem(CosmoWGAN, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Generate new samples
To have meaningful statistics, be sure to generate enough samples
* 2000 : 32 x 32
* 500 : 64 x 64
* 200 : 128 x 128


In [None]:
checkpoint = None

In [None]:
N = 2000 # Number of samples
gen_sample = np.squeeze(wgan.generate(N=N, checkpoint=checkpoint))

Display histogram of generated images

In [None]:
plt.hist(gen_sample.flatten(), 100)
print('min: {}'.format(np.min(gen_sample)))
print('max: {}'.format(np.max(gen_sample)))
plt.yscale('log')

Display a few fake samples

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        plot.plot_img(gen_sample[idx], vmin=vmin, vmax=vmax, ax=col)
        idx = idx + 1
fig.tight_layout()

In [None]:
# Compare real and fake
real = dataset.get_samples(N=4)
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(20,10))
idx = 0
for row in ax:
    for col in row:
        plot.plot_img(gen_sample[idx] if idx < 4 else real[idx%4], vmin=vmin, vmax=vmax, ax=col)
        idx = idx + 1
fig.tight_layout()

# Evaluation of the sample quality

In [None]:
# Before computing the statistics, we need to invert the mapping
raw_images = dataset.get_samples(N)
gen_sample_raw = gen_sample[:N]

In [None]:
lenstools = True
cut = [200, 6000]
box_l = (5*np.pi/180)
bin_k = 50

In [None]:
evaluation.compute_and_plot_psd(raw_images, gen_sample_raw, multiply=True, confidence='std', fractional_difference=True, cut=cut, bin_k=bin_k, box_l=box_l, lenstools=lenstools)

In [None]:
evaluation.compute_plot_psd_mode_hists(raw_images, gen_sample_raw, modes=3, cut=cut, hist_batch=4, confidence='std', bin_k=bin_k, box_l=box_l, lenstools=lenstools)

In [None]:
evaluation.compute_and_plot_peak_count(raw_images, gen_sample_raw, log=False, lim=(0, 0.8), confidence='std', fractional_difference=True)

In [None]:
evaluation.compute_and_plot_mass_hist(raw_images, gen_sample_raw, log=False, lim=(0, 0.8), confidence='std', fractional_difference=True)

In [None]:
c_r, c_f, _ = evaluation.compute_plot_correlation(raw_images, gen_sample_raw, cut=[0, 6000], tick_every=10, lenstools=lenstools, bin_k=bin_k, box_l=box_l)
print(np.linalg.norm(c_r-c_f))

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
evaluation.plot_stats(ax, gen_sample_raw, raw_images, log=False, lim=(0,0.8), confidence='std', multiply=True, fractional_difference=[True, True, True], cut=cut, lenstools=lenstools, bin_k=bin_k, box_l=box_l)
fig.tight_layout()

MS-SSIM score

In [None]:
s_fake, s_real = evaluation.compute_ssim_score([gen_sample_raw], [raw_images])

In [None]:
print(s_fake[0], s_real[0])