In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
sys.path.insert(0, '../')

In [None]:
from matplotlib import pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import numpy as np
import tensorflow as tf

from cosmotools.data import load
from gantools import utils
from gantools.model import WGAN
from gantools.gansystem import GANsystem
from cosmotools.data import fmap
import functools
from copy import deepcopy

from cosmotools.metric import evaluation
from cosmotools.model import CosmoWGAN

In [None]:
# Setting the forward and backward transform

shift = 1
c = 20000
from functools import partial
forward = partial(fmap.stat_forward, shift=shift, c=c)
backward = partial(fmap.stat_backward, shift=shift, c=c)

In [None]:
def map_to_plot(img):
    return fmap.log_forward(backward(img))

def plot_cubes_paper(images, nx=2, ny=2, **kwargs):
    if len(images)<nx*ny:
        raise ValueError("Not enough samples.")
    fig, ax = plt.subplots(nx, ny, sharey=True,figsize=(11/2*nx,10.5/2*ny))
    sn = 0
    for i in range(nx):
        for j in range(ny):
            if nx==1 and ny==1:
                tax = ax
            elif nx==1:
                tax = ax[j]
            elif ny==1:
                tax = ax[i]
            else:
                tax = ax[i,j]
            tax.imshow(map_to_plot(images[sn]), interpolation='none', **kwargs)
            tax.axis('off')
            sn += 1
    plt.tight_layout()

    return fig

In [None]:
pathcheckpoints = '../saved_results/nbody-2d'

In [None]:
pathfig = 'figures_2d/'
os.makedirs(pathfig, exist_ok=True)

In [None]:
# Select here the size of the image
ns = 256 # 32, 64, 128, 256

# number of sample generated
num_samples = 256*256*256//(ns*ns)

In [None]:
name = 'WGAN{}test_full_2D_checkpoints'.format(ns)

# Number to be changed if you retrain the network
if ns==32:
    checkpoint = 210000
elif ns==64:
    checkpoint = 148000
elif ns==128:
    checkpoint = 216000
elif ns==256:
    checkpoint = 96000
else:
    checkpoint = None


pathmodel = os.path.join(pathcheckpoints, name)
params = utils.load_params(pathmodel)

In [None]:
wgan = GANsystem(CosmoWGAN,params)

In [None]:
gen_samples = wgan.generate(N=num_samples, checkpoint=checkpoint)
gen_samples = np.squeeze(gen_samples)

In [None]:
dataset = load.load_nbody_dataset(ncubes=30, spix=ns, forward_map=forward)
real_samples = np.squeeze(dataset.get_all_data())
del dataset

# Plot samples

In [None]:
cmin = 0
cmax = 4
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples, cmap=plt.cm.plasma, clim=clim);
fig.suptitle('Real ${}^2$'.format(ns), y=1.04, fontsize=36 )
plt.savefig(pathfig+"2d-real{}.png".format(ns), bbox_inches='tight', format='png')

fig2 = plot_cubes_paper(gen_samples, cmap=plt.cm.plasma, clim=clim);
fig2.suptitle('Fake ${}^2$'.format(ns), y=1.04, fontsize=36 )
plt.savefig(pathfig+"2d-fake{}.png".format(ns), bbox_inches='tight', format='png')

# Invert the transform


In [None]:
gen_samples_raw = backward(gen_samples)
real_samples_raw = backward(real_samples)

# Compute stats and scores

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_mass_hist(real_samples_raw[:num_samples], gen_samples_raw[:num_samples], confidence='std', lim=None)
plt.savefig(pathfig+"{}_hist.pdf".format(ns), bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_peak_count(real_samples_raw[:num_samples], gen_samples_raw[:num_samples], confidence='std', lim=None)
plt.savefig(pathfig+"{}_peak.pdf".format(ns), bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_psd(real_samples_raw[:num_samples], gen_samples_raw[:num_samples], confidence='std')
plt.savefig(pathfig+"{}_psd.pdf".format(ns), bbox_inches='tight', format='pdf')

In [None]:
from cosmotools.metric.score import score_histogram, score_peak_histogram, score_psd
print('PSD score: {}'.format(score_psd(real_samples_raw[:num_samples], gen_samples_raw[:num_samples])))
print('Mass histogram score: {}'.format(score_histogram(real_samples_raw[:num_samples], gen_samples_raw[:num_samples])))
print('Peak histogram score: {}'.format(score_peak_histogram(real_samples_raw[:num_samples], gen_samples_raw[:num_samples])))

# Obtained results

### Size: 32x32
- PSD score: 9.24
- Mass histogram score: 7.44
- Peak histogram score: 3.25

### Size: 64x64
- PSD score: 5.08
- Mass histogram score: 5.56
- Peak histogram score: 1.09

### Size: 128x128
- PSD score: 5.27
- Mass histogram score: 4.37
- Peak histogram score: 0.89

### Size: 256x256
- PSD score: 3.36
- Mass histogram score: 5.66
- Peak histogram score: 1.22