In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
sys.path.insert(0, '../')

In [None]:
import numpy as np
import tensorflow as tf
import functools
from matplotlib import pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

from gantools import utils, plot
from gantools.gansystem import GANsystem
from gantools.model import WGAN, UpscalePatchWGAN

from cosmotools.metric import evaluation
from cosmotools.model import CosmoWGAN
from cosmotools.data import load
from cosmotools.data import fmap
from cosmotools.gansystem import CosmoUpscaleGANsystem as UpscaleGANsystem

In [None]:
# Setting the forward and backward transform

shift = 1
c = 20000
from functools import partial
forward = partial(fmap.stat_forward, shift=shift, c=c)
backward = partial(fmap.stat_backward, shift=shift, c=c)


# number of sample generated
num_samples = 30

pathmodel = '../saved_results/nbody'
exp = ''
pathsample = '../samples/{}nbody'.format(exp)


In [None]:
def map_to_plot(img):
    return fmap.log_forward(backward(img))

def plot_cubes_paper(cubes, slice_num=None, nx=2, ny=2, **kwargs):
    if len(cubes)<nx*ny:
        raise ValueError("Not enough samples.")
    if slice_num is None:
        slice_num = cubes.shape[1]//2+4
    fig, ax = plt.subplots(nx, ny, sharey=True,figsize=(11/2*nx,10.5/2*ny))
    sn = 0
    for i in range(nx):
        for j in range(ny):
            if nx==1 and ny==1:
                tax = ax
            elif nx==1:
                tax = ax[j]
            elif ny==1:
                tax = ax[i]
            else:
                tax = ax[i,j]
            tax.imshow(map_to_plot(cubes[sn,slice_num]), interpolation='none', **kwargs)
            tax.axis('off')
            sn += 1
    plt.tight_layout()

    return fig

# Generate samples - 32->64->256

## 0->32

In [None]:
# name = 'WGAN_{}0_to_32_checkpoints'.format(exp)
# checkpoint32 = 138000 #49000

name = 'WGAN_{}0_to_32_checkpoints'.format(exp)
checkpoint32 = 34000 #None # to be changed if you retrain the network


pathmodel32 = os.path.join(pathmodel, name)
params_32 = utils.load_params(pathmodel32)

In [None]:
wgan_32 = GANsystem(CosmoWGAN,params_32)

In [None]:
gen_samples_32 = wgan_32.generate(N=num_samples, checkpoint=checkpoint32)
gen_samples_32 = np.squeeze(gen_samples_32)

In [None]:
os.makedirs(os.path.join(pathsample,'{}0_32/'.format(exp)), exist_ok=True)
utils.save_hdf5(gen_samples_32, os.path.join(pathsample,'{}0_32/32_samples_ckpt_latest.h5'.format(exp)))

In [None]:
cmin = 0.1
cmax = 2
clim = (cmin, cmax)
plot_cubes_paper(gen_samples_32, cmap=plt.cm.plasma, clim=clim, nx=4, ny=4);

## 32->64

Load fake 32

In [None]:
gen_samples_32 = utils.load_hdf5(os.path.join(pathsample,'{}0_32/32_samples_ckpt_latest.h5'.format(exp)))[:num_samples]
print("gen_samples_32 shape=", gen_samples_32.shape)
gen_samples_32 = np.expand_dims(gen_samples_32, axis=4)
print('downsampled shape=', gen_samples_32.shape)

In [None]:
name = 'WGAN_{}32_to_64_checkpoints'.format(exp)
checkpoint64 = 92000 # to be changed if you retrain the network
class CosmoUpscalePatchWGAN(UpscalePatchWGAN, CosmoWGAN):
    pass

pathmodel64 = os.path.join(pathmodel, name)
params_64 = utils.load_params(pathmodel64)

In [None]:
wgan_64 = UpscaleGANsystem(CosmoUpscalePatchWGAN, params_64)
gen_samples_64 = wgan_64.upscale_image(small=gen_samples_32, checkpoint=checkpoint64)

In [None]:
os.makedirs(os.path.join(pathsample,'{}32_64/'.format(exp)), exist_ok=True)
utils.save_hdf5(gen_samples_64, os.path.join(pathsample,'{}32_64/64_samples_ckpt_latest.h5'.format(exp)))

In [None]:
cmin = 0
cmax = 2.5
clim = (cmin, cmax)
plot_cubes_paper(gen_samples_64, cmap=plt.cm.plasma, clim=clim, nx=4, ny=4);


## 64->256

In [None]:
gen_samples_64 = utils.load_hdf5(os.path.join(pathsample,'{}32_64/64_samples_ckpt_latest.h5'.format(exp)))

print('gen_samples_64 shape=', gen_samples_64.shape)
gen_samples_64 = np.expand_dims(gen_samples_64, axis=4)[:num_samples]
print('downsampled shape=', gen_samples_64.shape)

In [None]:
name = 'WGAN_{}64_to_256_checkpoints'.format(exp)
checkpoint256 = 76000 # to be changed if you retrain the network
pathmodel256 = os.path.join(pathmodel, name)

params_256 = utils.load_params(pathmodel256)


In [None]:
wgan_256 = UpscaleGANsystem(CosmoUpscalePatchWGAN, params_256)
gen_samples_256 = wgan_256.upscale_image(small=gen_samples_64, checkpoint=checkpoint256)

In [None]:
cmin = 0
cmax = 3.5
clim = (cmin, cmax)
plot_cubes_paper(gen_samples_256, cmap=plt.cm.plasma, clim=clim, nx=2, ny=2);

In [None]:
# fps=16
# clim = (0,4.5)
# animation = plot.cubes_to_animation(map_to_plot(gen_samples_256), cmap=plt.cm.plasma, clim=clim)
# # animation = plot.cubes_to_animation(real_samples_256, cmap=plt.cm.plasma, clim=clim)
# animation.ipython_display(fps=16, loop=True, autoplay=True)

In [None]:
os.makedirs(os.path.join(pathsample,'{}64_256/'.format(exp)), exist_ok=True)
utils.save_hdf5(gen_samples_256, os.path.join(pathsample,'{}64_256/256_samples_ckpt_latest.h5'.format(exp)))

# Uniscale model
This model is not working.

In [None]:
name = 'WGAN_{}uniscale_checkpoints'.format(exp)
checkpointuniscale = 38000 # to be changed if you retrain the network

pathmodel_uniscale = os.path.join(pathmodel, name)
params_uniscale = utils.load_params(pathmodel_uniscale)
class CosmoUpscalePatchWGAN(UpscalePatchWGAN, CosmoWGAN):
    pass

In [None]:
wgan_uniscale = UpscaleGANsystem(CosmoUpscalePatchWGAN, params_uniscale)
offset_u = 32

In [None]:
gen_samples_uniscale = wgan_uniscale.upscale_image(N=num_samples, resolution=256+offset_u, checkpoint=checkpointuniscale)[:,offset_u:,offset_u:,offset_u:]


In [None]:
os.makedirs(os.path.join(pathsample,'{}uniscale/'.format(exp)), exist_ok=True)
utils.save_hdf5(gen_samples_uniscale, os.path.join(pathsample,'{}uniscale/256_samples_ckpt_latest.h5'.format(exp)))

In [None]:
cmin = 0
cmax = 3.5
clim = (cmin, cmax)
plot_cubes_paper(gen_samples_uniscale, cmap=plt.cm.plasma, clim=clim, nx=2, ny=2);

# Make the figures for the paper

In [None]:
pathfig = 'figures/'
os.makedirs(pathfig, exist_ok=True)

In [None]:
offset = 0

#### Load all samples

In [None]:
gen_samples_32 = utils.load_hdf5(os.path.join(pathsample,'{}0_32/32_samples_ckpt_latest.h5'.format(exp)))
gen_samples_64 = utils.load_hdf5(os.path.join(pathsample,'{}32_64/64_samples_ckpt_latest.h5'.format(exp)))
gen_samples_256 = utils.load_hdf5(os.path.join(pathsample,'{}64_256/256_samples_ckpt_latest.h5'.format(exp)))[:,offset:,offset:,offset:]
gen_samples_uniscale = utils.load_hdf5(os.path.join(pathsample,'{}uniscale/256_samples_ckpt_latest.h5'.format(exp)))[:,offset:,offset:,offset:]


#### Backward transform

In [None]:
gen_samples_32_raw = backward(gen_samples_32)
gen_samples_64_raw = backward(gen_samples_64)
gen_samples_256_raw = backward(gen_samples_256)


In [None]:
gen_samples_uniscale_raw = backward(gen_samples_uniscale)


#### Load real data

In [None]:
dataset256 = load.load_nbody_dataset(resolution=256,Mpch=350,shuffle=False,forward_map=forward,spix=256,patch=False,is_3d=True,augmentation=False)
real_samples_256 = dataset256.get_all_data()[:,offset:,offset:,offset:]
del dataset256


In [None]:
real_samples_256_raw = backward(real_samples_256)

In [None]:
dataset32 = load.load_nbody_dataset(resolution=256,Mpch=350,shuffle=False,forward_map=forward,spix=32,patch=False,is_3d=True,augmentation=False, scaling=8)
real_samples_32 = dataset32.get_all_data()
del dataset32

In [None]:
real_samples_32_raw = backward(real_samples_32)

In [None]:
dataset64 = load.load_nbody_dataset(resolution=256,Mpch=350,shuffle=False,forward_map=forward,spix=64,patch=False,is_3d=True,augmentation=False, scaling=4)
real_samples_64 = dataset64.get_all_data()
del dataset64

In [None]:
real_samples_64_raw = backward(real_samples_64)

In [None]:
# cmin = np.min(cubes)
# cmax = np.max(cubes)/1.5
cmin = 0
cmax = 3.5
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_256, cmap=plt.cm.plasma, clim=clim);
fig.suptitle('Real $256^3$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"real256.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_256, cmap=plt.cm.plasma, clim=clim);
fig2.suptitle('Fake $256^3$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"fake256.pdf", bbox_inches='tight', format='pdf')



In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_mass_hist(real_samples_256_raw, gen_samples_256_raw, confidence='std' )
plt.savefig(pathfig+"256full_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_peak_count(real_samples_256_raw, gen_samples_256_raw, confidence='std')
plt.savefig(pathfig+"256full_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_psd(real_samples_256_raw, gen_samples_256_raw, confidence='std')
plt.savefig(pathfig+"256full_psd.pdf", bbox_inches='tight', format='pdf')

In [None]:
from cosmotools.metric.score import score_histogram, score_peak_histogram, score_psd
print('PSD score: {}'.format(score_psd(real_samples_256_raw, gen_samples_256_raw)))
print('Mass histogram score: {}'.format(score_histogram(real_samples_256_raw, gen_samples_256_raw)))
print('Peak histogram score: {}'.format(score_peak_histogram(real_samples_256_raw, gen_samples_256_raw)))

In [None]:
from cosmotools.metric.score import score_histogram, score_peak_histogram, score_psd
print('PSD score: {}'.format(score_psd(real_samples_256_raw[15:], real_samples_256_raw[:15])))
print('Mass histogram score: {}'.format(score_histogram(real_samples_256_raw[15:], real_samples_256_raw[:15])))
print('Peak histogram score: {}'.format(score_peak_histogram(real_samples_256_raw[15:], real_samples_256_raw[:15])))

# Scale by scale analysis

## 32 cubes

In [None]:
gen_samples_32 = utils.load_hdf5(os.path.join(pathsample,'{}0_32/32_samples_ckpt_latest.h5'.format(exp)))
gen_samples_32_raw = backward(gen_samples_32)

In [None]:
# cmin = np.min(cubes_32)
# # cmin = 0
# cmax = np.max(cubes_32)
cmin = 0.1
cmax = 2
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_32, cmap=plt.cm.plasma, clim=clim, nx=4, ny=4);
fig.suptitle('Real $32^3$', y=1.03, fontsize=48 )
plt.savefig(pathfig+"real32.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_32, cmap=plt.cm.plasma, clim=clim, nx=4, ny=4);
fig2.suptitle('Fake $32^3$', y=1.03, fontsize=48 )
plt.savefig(pathfig+"fake32.pdf", bbox_inches='tight', format='pdf')



In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_mass_hist(real_samples_32_raw, gen_samples_32_raw, confidence='std', lim=None)
plt.savefig(pathfig+"32_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_peak_count(real_samples_32_raw, gen_samples_32_raw, confidence='std', lim=None)
plt.savefig(pathfig+"32_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_psd(real_samples_32_raw, gen_samples_32_raw, confidence='std')
plt.savefig(pathfig+"32_psd.pdf", bbox_inches='tight', format='pdf')

## 64 Cubes

In [None]:
params_64 = utils.load_params(pathmodel64)
wgan_64 = UpscaleGANsystem(CosmoUpscalePatchWGAN, params_64)
gen_samples_64_single = wgan_64.upscale_image(small=np.reshape(real_samples_32, [*real_samples_32.shape,1]), checkpoint=checkpoint64)

In [None]:
gen_samples_64_single_raw = backward(gen_samples_64_single)

In [None]:
# cmin = np.min(cubes_64)
# cmax = np.max(cubes_64)
cmin = 0
cmax = 2.5
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_64, cmap=plt.cm.plasma, clim=clim);
fig.suptitle('Real $64^3$', y=1.04, fontsize=36 )
plt.savefig(pathfig+"up_real64.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_64_single, cmap=plt.cm.plasma, clim=clim);
fig2.suptitle('Fake $64^3$', y=1.04, fontsize=36 )
plt.savefig(pathfig+"up_fake64.pdf", bbox_inches='tight', format='pdf')

fig = plot_cubes_paper(real_samples_32, cmap=plt.cm.plasma, clim=clim);
fig.suptitle('Real downsampled $32^3$', y=1.04, fontsize=36 )
plt.savefig(pathfig+"up_down32.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_mass_hist(real_samples_64_raw, gen_samples_64_single_raw, confidence='std')
plt.savefig(pathfig+"up_64_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_peak_count(real_samples_64_raw, gen_samples_64_single_raw, confidence='std')
plt.savefig(pathfig+"up_64_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_psd(real_samples_64_raw, gen_samples_64_single_raw, confidence='std')
plt.savefig(pathfig+"up_64_psd.pdf", bbox_inches='tight', format='pdf')

## 256 Cubes

In [None]:
params_256 = utils.load_params(pathmodel256)
# checkpoint256 = None

wgan_256 = UpscaleGANsystem(CosmoUpscalePatchWGAN, params_256)
gen_samples_single_256 = wgan_256.upscale_image(small=np.reshape(real_samples_64, [*real_samples_64.shape, 1]), checkpoint=checkpoint256)[:,offset:,offset:,offset:]

In [None]:
# cmin = np.min(cubes_256)
# cmax = np.max(cubes_256)/2.5
cmin = 0
cmax = 3
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_256, cmap=plt.cm.plasma, clim=clim, ny=1);
fig.suptitle('Real $256^3$', y=1.03, fontsize=12 )
plt.savefig(pathfig+"up_real256.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_single_256, cmap=plt.cm.plasma, clim=clim, ny=1);
fig2.suptitle('Fake $256^3$', y=1.03, fontsize=12 )
plt.savefig(pathfig+"up_fake256.pdf", bbox_inches='tight', format='pdf')

fig = plot_cubes_paper(real_samples_64, cmap=plt.cm.plasma, clim=clim, ny=1);
fig.suptitle('Real downsampled $64^3$', y=1.03, fontsize=12 )
plt.savefig(pathfig+"up_down64.pdf", bbox_inches='tight', format='pdf')

In [None]:
gen_samples_single_256_raw = backward(gen_samples_single_256)

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_mass_hist(real_samples_256_raw, gen_samples_single_256_raw, confidence='std')
plt.savefig(pathfig+"up_256_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_peak_count(real_samples_256_raw, gen_samples_single_256_raw, confidence='std')
plt.savefig(pathfig+"up_256_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_psd(real_samples_256_raw, gen_samples_single_256_raw, confidence='std')
plt.savefig(pathfig+"up_256_psd.pdf", bbox_inches='tight', format='pdf')

In [None]:
fig = plot_cubes_paper(real_samples_256, cmap=plt.cm.plasma, clim=clim, ny=1, nx=1);
fig.suptitle('Real $256^3$', y=1.03, fontsize=12 )

fig2 = plot_cubes_paper(gen_samples_single_256, cmap=plt.cm.plasma, clim=clim, ny=1, nx=1);
fig2.suptitle('Fake $256^3$', y=1.03, fontsize=12 )


In [None]:
# fps=16
# clim = (0,3.5)
# animation = plot.cubes_to_animation(gen_samples_single_256, cmap=plt.cm.plasma, clim=clim)
# animation.ipython_display(fps=16, loop=True, autoplay=True)

## Uniscale

In [None]:
cmin = 0
cmax = 3
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_256, cmap=plt.cm.plasma, clim=clim);
fig.suptitle('Real $256^3$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"uniscalereal256.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_uniscale, cmap=plt.cm.plasma, clim=clim);
fig2.suptitle('Fake $256^3 - uniscale$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"uniscalefake256.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_mass_hist(real_samples_256_raw, gen_samples_uniscale_raw, confidence='std')
plt.savefig(pathfig+"256uniscale_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_peak_count(real_samples_256_raw, gen_samples_uniscale_raw, confidence='std')
plt.savefig(pathfig+"256uniscale_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
plt.figure(figsize=(6,3))
_ = evaluation.compute_and_plot_psd(real_samples_256_raw, gen_samples_uniscale_raw, confidence='std')
plt.savefig(pathfig+"256uniscale_psd.pdf", bbox_inches='tight', format='pdf')

# Make videos
All video are saved in HD.

In [None]:
cmin = 0
cmax = 3
clim = (cmin, cmax)

In [None]:
plot.save_animation(real_samples_32[0:16], gen_samples_32[0:16], figsize=(20, 11.25), fontsize=40, fps=8, format='mp4', output_file_name=pathfig+'cubes32.mp4', clim=clim, names=['Real', 'Fake'])


In [None]:
plot.save_animation(real_samples_64[0:16], gen_samples_64[0:16], figsize=(20, 11.25), fontsize=40, fps=8, format='mp4', output_file_name=pathfig+'cubes64.mp4', clim=clim, names=['Real', 'Fake'])


In [None]:
plot.save_animation(real_samples_256[0:16], gen_samples_256[0:16], figsize=(20, 11.25), fontsize=40, fps=16, format='mp4', output_file_name=pathfig+'cubes256.mp4', clim=clim, names=['Real', 'Fake'])


In [None]:
plot.save_animation(real_samples_256[0:16], gen_samples_single_256[0:16], real_downsampled=real_samples_64[0:16], figsize=(20, 11.25), fontsize=40, fps=16, format='mp4', output_file_name=pathfig+'single_scale_256.mp4', clim=clim, names=['Real', 'Downsampled', 'Fake'])


In [None]:
plot.save_animation(real_samples_64[0:16], gen_samples_64_single[0:16], real_downsampled=real_samples_32[0:16], figsize=(20, 11.25), fontsize=40, fps=8, format='mp4', output_file_name=pathfig+'single_scale_64.mp4', clim=clim, names=['Real', 'Downsampled', 'Fake'])


In [None]:
plot.save_animation(real_samples_256[0:16], gen_samples_uniscale[0:16], figsize=(20, 11.25), fontsize=40, fps=16, format='mp4', output_file_name=pathfig+'uniscale256.mp4', clim=clim, names=['Real', 'Fake'])
