In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# import sys
# sys.path.insert(0, '../')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"


import numpy as np
import tensorflow as tf

from gantools import data
from gantools import utils
from gantools import plot
from gantools.model import WGAN, CosmoWGAN
from gantools.gansystem import GANsystem, PaulinaGANsystem
from gantools.data import fmap
from gantools import evaluation
import functools
import matplotlib.pyplot as plt
from copy import deepcopy

# Parameters

In [None]:
ns = 32 # Resolution of the image
try_resume = False # Try to resume previous simulation
Mpch = 70 # Type of dataset (select 70 or 350)

# Do not change these for now
shift = 3
c = 40000
forward = functools.partial(fmap.stat_forward, shift=shift, c=c)
backward = functools.partial(fmap.stat_backward, shift=shift, c=c)
def non_lin(x):
    return tf.nn.relu(x)

# Data handling

Load the data

In [None]:
dataset = data.load.load_dataset(nsamples=10, spix=ns, Mpch=Mpch, forward_map=forward)

In [None]:
# The dataset can return an iterator.
it = dataset.iter(10)
print(next(it).shape)
del it

In [None]:
# Get all the data
X = dataset.get_all_data().flatten()

In [None]:
# Check that the backward maps invert the forward map.
assert(np.sum(np.abs(forward(backward(X))-X))< 1)
# # For debugging
# np.sum(np.abs(forward(backward(X))-X))
# forward(backward(X))-X
# x = np.arange(1e4)
# plt.plot(x, backward(forward(x))-x)

Display the histogram of the pixel densities after the forward map

In [None]:
plt.hist(X, 100)
print('min: {}'.format(np.min(X)))
print('max: {}'.format(np.max(X)))
plt.yscale('log')

In [None]:
# to free some memory
del X

Let us plot 16 images

In [None]:
plt.figure(figsize=(15,15))
plot.draw_images(dataset.get_samples(N=16),nx=4,ny=4);
plt.title("Real samples")

# Define parameters for the WGAN

In [None]:
time_str = '2D_paulina_tradgan4'
global_path = 'saved_results'

name = 'WGAN{}'.format(ns) + '_' + time_str

# Here I use>
# self.disc_loss_calc2 = tf.reduce_mean(self.plc_float_r - self.plc_float)

## Parameters

In [None]:
bn = False

# Parameters for the discriminator
params_discriminator = dict()
params_discriminator['stride'] = [2, 2, 2, 1]
params_discriminator['nfilter'] = [16, 64, 256, 32]
params_discriminator['shape'] = [[5, 5],[5, 5], [5, 5], [3, 3]]
params_discriminator['batch_norm'] = [bn, bn, bn, bn]
params_discriminator['full'] = [32]
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True
params_discriminator['is_3d'] = False

# Parameters for the generator
params_generator = dict()
params_generator['stride'] = [1, 1, 2, 1, 1]
params_generator['latent_dim'] = 16*16*32
params_generator['nfilter'] = [32, 64, 256, 32, 1]
params_generator['shape'] = [[5, 5], [5, 5],[5, 5], [5, 5], [5, 5]]
params_generator['batch_norm'] = [bn, bn, bn, bn]
params_generator['full'] = []
params_generator['summary'] = True
params_generator['non_lin'] = non_lin
params_generator['is_3d'] = False

# Optimization parameters
d_opt = dict()
d_opt['optimizer'] = "rmsprop"
d_opt['learning_rate'] = 3e-5
d_opt['kwargs'] = dict()
params_optimization = dict()
params_optimization['batch_size'] = 16
params_optimization['epoch'] = 10
params_optimization['discriminator'] = deepcopy(d_opt)
params_optimization['generator'] = deepcopy(d_opt)
params_optimization['n_critic'] = 5


# Cosmology parameters
params_cosmology = dict()
params_cosmology['forward_map'] = forward
params_cosmology['backward_map'] = backward

# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['cosmology'] = params_cosmology # Parameters for the cosmological summaries
params['net']['prior_distribution'] = 'gaussian'
params['net']['shape'] = [ns, ns, 1] # Shape of the image
params['net']['is_3d'] = False
params['net']['gamma_gp'] = 10 # Gradient penalty

params['optimization'] = params_optimization
params['summary_every'] = 200 # Tensorboard summaries every ** iterations
params['print_every'] = 50 # Console summaries every ** iterations
params['save_every'] = 1000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = 2000



In [None]:
resume, params = utils.test_resume(try_resume, params)
# If a model is reloaded and some parameters have to be changed, then it should be done here.
# For example, setting the number of epoch to 5 would be:
params['optimization']['epoch'] = 5


# Build the model

In [None]:
wgan = PaulinaGANsystem(CosmoWGAN, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Generate new samples
To have meaningful statistics, be sure to generate enough samples
* 2000 : 32 x 32
* 500 : 64 x 64
* 200 : 128 x 128


In [None]:
N = 2000 # Number of samples
gen_sample = np.squeeze(wgan.generate(N=N))

Display a few fake samples

In [None]:
plt.figure(figsize=(15,15))
plot.draw_images(gen_sample,nx=4,ny=4);
plt.title("Fake samples");

# Evaluation of the sample quality

In [None]:
# Before computing the statistics, we need to invert the mapping
raw_images = backward(dataset.get_samples(dataset.N))
gen_sample_raw = backward(gen_sample)

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_psd(raw_images, gen_sample_raw)

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_peak_cout(raw_images, gen_sample_raw)

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_mass_hist(raw_images, gen_sample_raw)

Compute a single metric number

In [None]:
from gantools.metric import ganlist
single_metric = ganlist.global_score(False)

In [None]:
print("The global metric is {}".format(single_metric(raw_images, gen_sample_raw)))

# Export the curves from the summaries

In [None]:
pathfig = 'figures/'
os.makedirs(pathfig, exist_ok=True)

In [None]:
import os
import numpy as np
import tensorflow as tf

from shutil import copy



def get_event_data(event_files, tag, selec_it=None):
    data = []
    it = []
    if selec_it:
        selec_it = set(selec_it)
    for path_event in event_files:
        try:
            for e in tf.train.summary_iterator(path_event):
                for v in e.summary.value:
                    if tag in v.tag:
                        if selec_it is None or e.step in selec_it:
                            data.append(v.simple_value)
                            it.append(e.step)
        except:
            print('Warning corrupted file')
    return np.array(data), np.array(it)


def get_event_files(summary_dir):
    # Getting the event file
    event_files = []
    for filename in os.listdir(summary_dir):
        if 'events.out.tfevents' in filename:
            event_files.append(os.path.join(summary_dir, filename))
    # if len(event_files)>1:
    #     raise ValueError('Multiple event files')
    if len(event_files) == 0:
        raise ValueError('No event files')
    return event_files


In [None]:

summary_dir = wgan.params['summary_dir']
event_files = get_event_files(summary_dir)
maxit = 30000
selec_it = set(range(maxit+1))

selec = [1,2,3,5,6,7]
measures_list = ['final/global_score_1',
                 'final/mass_histogram_l2_1',
                 'final/peak_histogram_l2_1', 
                 'final/psd_l2log_1', 
                 'cosmology/global_score_1',
                 'cosmology/mass_histogram_l2log_1',
                 'cosmology/peak_histogram_l2log_1',
                 'cosmology/psd_l2log_1',
                 'nomap/mass_histogram_log_l2log_1',
                 'nomap/peak_histogram_log_l2log_1']
measures_list = np.array(measures_list)[selec]

names = ['Sum differences',
         'Mass histogram',
         'Peak histogram',
         'Power spectral density',
         'Sum difference mapped',
         'Mass histogram - Raw',
         'Peak histogram - Raw',
         'Power spectral density - Raw',
         'Mass histogram log - mapped',
         'Peak histogram log - mapped']
names= np.array(names)[selec]

# Getting data
duality_gap, it = get_event_data(event_files, 'duality/gap_1', selec_it) 
duality_minmax, itt = get_event_data(event_files, 'duality/minmax_1', selec_it) 
np.testing.assert_almost_equal(it,itt)
duality_maxmin, itt = get_event_data(event_files, 'duality/maxmin_1', selec_it) 
np.testing.assert_almost_equal(it,itt)

corr_dg = []
corr_minmax = []
corr_maxmin = []


for measures in measures_list:
    values, itt = get_event_data(event_files, measures, selec_it) 
    np.testing.assert_almost_equal(it,itt)
    corr_dg.append(np.corrcoef( values, duality_gap))
    corr_minmax.append(np.corrcoef(values, duality_minmax))
    corr_maxmin.append(np.corrcoef(values, duality_maxmin))

corr_dg = np.array(corr_dg)[:,0,1]
corr_minmax = np.array(corr_minmax)[:,0,1]
corr_maxmin = np.array(corr_maxmin)[:,0,1]
plt.plot(corr_dg,'xb', label='Duality gap value')
plt.plot(corr_minmax,'xr', label='Minmax value')
plt.ylim([0,1])
# plt.plot(corr_maxmin,'xg', label='correlation maxmin')
plt.legend()
plt.title('Correlation measurements')


In [None]:
"""
========
Barchart
========

A bar plot with errorbars and height labels on individual bars
"""
import numpy as np
import matplotlib.pyplot as plt

N = len(corr_dg)

ind = np.arange(N)  # the x locations for the groups
width = 0.35       # the width of the bars

fig, ax = plt.subplots()

plt.gca().invert_yaxis()
rects1 = ax.barh(ind, corr_dg, width, color='r')

rects2 = ax.barh(ind + width, corr_minmax, width, color='y')
# add some text for labels, title and axes ticks
# ax.set_xlabel('Score')
# ax.set_title('Pearson correlation')
ax.set_xlabel('Pearson correlation')

ax.set_yticks(ind + width / 2)
ax.set_yticklabels(names)
ax.set_xlim(0,1)
ax.legend((rects1[0], rects2[0]), ('Duality gap value', 'Minimax value'), loc=3, framealpha=1)

# for i, v in enumerate(y):
#     ax.text(v + 3, i + .25, str(v), color='blue', fontweight='bold')
    
def autolabel(rects):
    """
    Attach a text label above each bar displaying its height
    """
    for rect in rects:
        width = rect.get_width()
        ax.text(1.08*width,rect.get_y()+rect.get_height(),
                '{}'.format((int(100*width)/100)),
                ha='center', va='bottom')

autolabel(rects1)
autolabel(rects2)

# plt.show()
plt.savefig(pathfig+"correlation.pdf", bbox_inches='tight', format='pdf')


In [None]:
N = 150
fontsize=12
num = 0
for measures, name in zip(measures_list[:3], names[:3]):
    values, itt = get_event_data(event_files, measures, selec_it)
    np.testing.assert_almost_equal(it,itt)
    plt.figure(figsize=(10,2))
    plt.plot(it[:N], values[:N]/np.max(values), label='Score')
    plt.plot(it[:N], duality_gap[:N]/np.max(duality_gap), label='Duality Gap')
    plt.title(name)
    plt.legend(loc=1, framealpha=1, fontsize=fontsize-2)
    plt.xlabel('Iterations', fontsize=fontsize )
    plt.ylabel('Normalized metric', fontsize=fontsize)
    plt.savefig(pathfig+"stat"+str(num)+".pdf", bbox_inches='tight', format='pdf')
    num = num+1

In [None]:
duality_gap, it = get_event_data(event_files, 'duality/gap_1')
plt.plot(duality_gap)

In [None]:
def plot_imgs_paper(imgs, nx=2, ny=2,**kwargs):
    if len(imgs)<nx*ny:
        raise ValueError("Not enough samples.")
    fig, ax = plt.subplots(nx, ny, sharey=True, figsize=(11/2*nx,10.5/2*ny))
    sn = 0
    for i in range(nx):
        for j in range(ny):
            if nx==1 and ny==1:
                tax = ax
            elif nx==1:
                tax = ax[j]
            elif ny==1:
                tax = ax[i]
            else:
                tax = ax[i,j]
            tax.imshow(imgs[sn], interpolation='none', **kwargs)
            tax.axis('off')
            sn += 1
    plt.tight_layout()
    return fig

In [None]:
real_sample = dataset.get_samples(N=16)

plt.figure(figsize=(15,15))
cmap = plt.cm.plasma
clim = (0, np.max(real_sample))
fig = plot_imgs_paper(gen_sample,nx=4,ny=4, cmap=cmap, clim=clim);
fig.suptitle('Fake samples $32x32$', y=1.03, fontsize=48 )
plt.savefig(pathfig+"fakecosmo.pdf", bbox_inches='tight', format='pdf')

plt.figure(figsize=(15,15))
fig = plot_imgs_paper(real_sample,nx=4,ny=4,cmap=cmap, clim=clim);
fig.suptitle('Real samples $32x32$', y=1.03, fontsize=48 )
plt.savefig(pathfig+"realcosmo.pdf", bbox_inches='tight', format='pdf')