In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from copy import deepcopy
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimage

os.environ["CUDA_VISIBLE_DEVICES"]=""

import sys
sys.path.insert(0, '../')
from gantools import utils
from gantools import plot
from gantools.gansystem import GANsystem
from gantools.data import Dataset, Dataset_parameters
from gantools.data import transformation

from cosmotools.model import CosmoWGAN
from cosmotools.metric import evaluation
from cosmotools.data import load
from cosmotools.utils import histogram_large

from gantools.model import ConditionalParamWGAN
from gantools.gansystem import GANsystem




# Parameters

In [None]:
ns = 128 # Resolution of the image
try_resume = True # Try to resume previous simulation

def non_lin(x):
    return tf.nn.relu(x)

# Data handling

In [None]:
dataset_train_shuffled_name = 'kids_train_shuffled.h5'
dataset_train_name = 'kids_train.h5'
dataset_test_name = 'kids_test.h5'

Load the data

In [None]:
dataset = load.load_params_dataset(filename=dataset_train_shuffled_name, batch=2000, shape=[ns, ns], transform=transformation.random_transpose_2d)

Display the histogram of the pixel densities after the forward map

In [None]:
# vmin, vmax = utils.find_minmax(dataset)
vmin, vmax = (0.006971482983495468, 1.3119225229919165)

In [None]:
histo, x = histogram_large(dataset, lim=(vmin, vmax))

In [None]:
plot.plot_histogram(x, histo)
print('min: {}'.format(vmin))
print('max: {}'.format(vmax))

Change maximum value for better plots

In [None]:
vmax = 0.125

Let us plot 16 images

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
imgs = dataset.get_samples(N=16)[0]
params = dataset.get_samples(N=16)[1]
for row in ax:
    for col in row:
        plot.plot_img(imgs[idx, :, :, 0], vmin=vmin, vmax=vmax, ax=col)
        col.axis('off')
        col.set_title('$\Omega_M: $' + str(params[idx, 0]) + ', $\sigma_8$: ' + str(params[idx, 1]), fontsize=14)
        idx = idx + 1

# Define parameters for the WGAN

In [None]:
time_str = '2D'
global_path = '../saved_results'

name = 'KidsConditional{}'.format(ns) + '_smart_' + time_str

## Parameters

In [None]:
bn = False

params_discriminator = dict()
params_discriminator['stride'] = [1, 2, 2, 2, 2]
params_discriminator['nfilter'] = [32, 64, 128, 256, 512]
params_discriminator['shape'] = [[7, 7], [5, 5], [5, 5], [5,5], [3,3]]
params_discriminator['batch_norm'] = [bn, bn, bn, bn, bn]
params_discriminator['full'] = [512, 256, 128]
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True
params_discriminator['data_size'] = 2

params_generator = dict()
params_generator['stride'] = [2, 2, 2, 2, 1]
params_generator['latent_dim'] = 64
params_generator['nfilter'] = [256, 128, 64, 32, 1]
params_generator['shape'] = [[3, 3], [5, 5], [5, 5], [5, 5], [7,7]]
params_generator['batch_norm'] = [bn, bn, bn, bn]
params_generator['full'] = [256, 512, 8 * 8 * 512]
params_generator['summary'] = True
params_generator['non_lin'] = non_lin
params_generator['data_size'] = 2

params_optimization = dict()
params_optimization['optimizer'] = 'rmsprop'
params_optimization['batch_size'] = 64
params_optimization['learning_rate'] = 1e-5
params_optimization['epoch'] = 10

# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['shape'] = [ns, ns, 1] # Shape of the image
params['net']['gamma_gp'] = 10 # Gradient penalty

# Conditional params
params['net']['prior_normalization'] = False
params['net']['cond_params'] = 2
params['net']['init_range'] = [[0.101, 0.487], [0.487, 1.331]]
params['net']['prior_distribution'] = "gaussian_length"
params['net']['final_range'] = [0.1*np.sqrt(params_generator['latent_dim']), 1*np.sqrt(params_generator['latent_dim'])]

params['optimization'] = params_optimization
params['optimization']['discriminator'] = deepcopy(params_optimization)
params['optimization']['generator'] = deepcopy(params_optimization)
params['summary_every'] = 5000 # Tensorboard summaries every ** iterations
params['print_every'] = 2500 # Console summaries every ** iterations
params['save_every'] = 25000 # Save the model every ** iterations
params['duality_every'] = 5
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = 2000

In [None]:
resume, params = utils.test_resume(try_resume, params)
# If a model is reloaded and some parameters have to be changed, then it should be done here.
# For example, setting the number of epoch to 5 would be:
# params['optimization']['epoch'] = 5
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')

# Build the model

In [None]:
class CosmoConditionalParamWGAN(ConditionalParamWGAN, CosmoWGAN):
    pass

In [None]:
wgan = GANsystem(CosmoConditionalParamWGAN, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Generate new samples

In [None]:
checkpoint = 349163


In [None]:
inter = 4

# Generate grid
# Note: pay attention that parameters should be inside the grid or at least colse to boundaries
grid = []
for c in range(wgan.net.params['cond_params']):
    if c == 0:
        gen_params = np.linspace(0.15, 0.4, inter)
    if c == 1:
        gen_params = np.linspace(0.6, 1.0, inter)
    grid.append(gen_params)

# Note: assume 2D grid of parameters
gen_params = []
for i in range(inter):
    for j in range(inter):
        gen_params.append([grid[0][i], grid[1][j]])
gen_params = np.array(gen_params)

# Produce images
latent = wgan.net.sample_latent(bs=inter * inter, params=gen_params)
gen_images = wgan.generate(N=inter * inter, **{'z': latent}, checkpoint=checkpoint)

Display a few fake samples

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        plot.plot_img(gen_images[idx], vmin=vmin, vmax=vmax, ax=col)
        col.set_title("$\Omega_M$: " + str(gen_params[idx][0])[0:7] + ", $\sigma_8$: " + str(gen_params[idx][1])[0:7], fontsize=14)
        idx = idx + 1

In [None]:
# Compare real and fake
grid = np.array([[0.137, 1.23], [0.199, 0.87], [0.311, 0.842], [0.487, 0.643]])

In [None]:
dataset = load.load_params_dataset(filename=dataset_test_name, batch=10, sorted=True, shape=[ns, ns])

In [None]:
# Produce images
latent = wgan.net.sample_latent(bs=len(grid), params=grid)
gen_images = wgan.generate(N=len(grid), **{'z': latent}, checkpoint=checkpoint)

# Get real images
real_images = []
for p in grid:
    real_images.append(dataset.get_data_for_params(p, N=10)[0])

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(15, 7.5))
idx = 0
for row in ax:
    for col in row:
        if idx < 4:
            plot.plot_img(gen_images[idx], vmin=vmin, vmax=vmax, ax=col)
        else:
            plot.plot_img(real_images[idx % 4][np.random.randint(10)], vmin=vmin, vmax=vmax, ax=col)
        col.set_title("$\Omega_M$: " + str(grid[idx%4][0]) + ", $\sigma_8$: " + str(grid[idx%4][1]), fontsize=14)
        idx = idx + 1

# Generate single image

In [None]:
latent = wgan.net.sample_latent(params=np.array([[0.171, 0.976]]))
gen_sample = wgan.generate(N=1, **{'z': latent}, checkpoint=checkpoint)
plot.plot_img(gen_sample[0], vmin=vmin, vmax=vmax)

# Same seed

In [None]:
inter = 4

In [None]:
# Generate grid
grid = []
for c in range(wgan.net.params['cond_params']):
    if c == 0:
        gen_params = np.linspace(0.2, 0.3, inter)
    if c == 1:
        gen_params = np.linspace(0.60, 0.9, inter)
    grid.append(gen_params)

# Note: assume 2D grid of parameters
gen_params = []
for i in range(inter):
    for j in range(inter):
        gen_params.append([grid[0][i], grid[1][j]])
gen_params = np.array(gen_params)

imgs = evaluation.generate_samples_same_seed(wgan, gen_params, checkpoint=checkpoint)

In [None]:
fig, ax = plt.subplots(nrows=inter, ncols=inter, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        plot.plot_img(imgs[idx][0], vmin=vmin, vmax=vmax, ax=col)
        col.set_title("$\Omega_M$: " + str(gen_params[idx][0])[0:7] + ", $\sigma_8$: " + str(gen_params[idx][1])[0:7], fontsize=14)
        idx = idx + 1

Define path of parameters

In [None]:
# Long traversal (back and forth)
path = [[0.189, 0.659],
        [0.212, 0.727],
        [0.233, 0.791],
        [0.254, 0.852],
        [0.273, 0.91 ],
        [0.292, 0.966],
        [0.33,  0.898],
        [0.311, 0.842],
        [0.291, 0.783],
        [0.271, 0.723],
        [0.25, 0.658],
        [0.227, 0.591],
       ]

In [None]:
for p in path:
    p.append(True)

path = evaluation.interpolate_between(path, 7)

In [None]:
# Generate frames
frames = evaluation.generate_samples_same_seed(wgan, path[:len(path)//2], checkpoint=checkpoint)

In [None]:
fig, ax = plt.subplots(figsize=(20, 20))
def make_frame(t):
    t = int(t)
    ax.clear()
    plot.plot_img(frames[t][0], vmin=vmin, vmax=vmax, ax=ax)
    ax.axis('off')
    ax.set_title("$\Omega_M$: " + str(path[t][0])[0:7] + ", $\sigma_8$: " + str(path[t][1])[0:7], fontsize=28)
    return mplfig_to_npimage(fig)

animation = VideoClip(make_frame, duration=len(path)//2)
plt.close()
animation.ipython_display(fps=20, loop=True, autoplay=True)

# Evaluation of the sample quality

In [None]:
N = 2000

Accuracy on training set

In [None]:
dataset = load.load_params_dataset(filename=dataset_train_name, batch=N, sorted=True, shape=[ns, ns])

In [None]:
params = dataset.get_different_params()

In [None]:
print(params)

In [None]:
# Define getter functions for every parameter set
# Note this is needed to save memory, as in this way every subset is loaded only when needed
real_imgs = []
fake_imgs = []
for p in params:
    real_imgs.append(lambda p1=p: dataset.get_data_for_params(p1, N=N)[0])
    fake_imgs.append(lambda p1=p: evaluation.generate_samples_params(wgan, p1, nsamples=N, checkpoint=checkpoint))

In [None]:
# Parameters used for plotting
lenstools = True
def title_func(params):
    return "$\Omega_M$: " + str(params[0])[0:7] + ", $\sigma_8$: " + str(params[1])[0:7]

if lenstools:
    ylims = [[(1e-7, 1e-3), (0, 0.5)], [(1e-2, 1e3), (0, 0.5)], [(1e-2, 1e5), (0, 0.5)]]
else:
    ylims = [[(1e-4, 2e0), (0, 0.5)], [(1e-2, 1e3), (0, 0.5)], [(1e-2, 1e5), (0, 0.5)]]
fractional_difference = [True, True, True]

box_l = (5*np.pi/180)
bin_k = 50
cut = [200, 6000]
cut_corr = [0, 6000]

In [None]:
fig, score = evaluation.compute_plots_for_params(params, real_imgs, fake_imgs, param_str=title_func, log=False, lim=(0, 0.4), ylims=ylims, confidence='std', fractional_difference=fractional_difference, cut=cut, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

In [None]:
# Score has shape n_params, n_stats, losses
print("PSD score:", score[:, 0, 0])
print("Peak score:", score[:, 1, 0])
print("Mass score:", score[:, 2, 0])
print("PSD diff:", score[:, 0, 1])

print("PSD total:", np.mean(score[:, 0, 0]), " +/- ", np.std(score[:, 0, 0]))
print("Peak total:", np.mean(score[:, 1, 0]), " +/- ", np.std(score[:, 1, 0]))
print("Mass total:", np.mean(score[:, 2, 0]), " +/- ", np.std(score[:, 2, 0]))
print("PSD diff total:", np.mean(score[:, 2, 1]), " +/- ", np.std(score[:, 2, 1]))

In [None]:
# Save scores for heatmap
train_scores = score
train_params = params

Correlation

In [None]:
corr, k = evaluation.compute_correlations(real_imgs, fake_imgs, params, cut=cut_corr, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

In [None]:
score_c = evaluation.plot_correlations(corr, k, params, tick_every=10, param_str=title_func)

In [None]:
print("Correlation losses:", score_c.flatten())
print("Total correlation loss:", np.mean(score_c), " +/- ", np.std(score_c))

MS-SSIM score

In [None]:
s_fake, s_real = evaluation.compute_ssim_score(fake_imgs, real_imgs)

In [None]:
print(s_fake)
print(s_real)
print(np.mean(s_fake), " +/- ", np.std(s_fake))
print(np.mean(s_real), " +/- ", np.std(s_real))
print(np.mean(np.abs(s_fake - s_real)), " +/- ", np.std(np.abs(s_fake - s_real)))

In [None]:
dataset = load.load_params_dataset(filename=dataset_train_shuffled_name, batch=N, shape=[ns, ns])

In [None]:
real_imgs, params = dataset.get_samples(N)

In [None]:
fake_imgs = evaluation.generate_samples_params(wgan, params, nsamples=N, checkpoint=checkpoint)

In [None]:
s_fake, s_real = evaluation.compute_ssim_score([fake_imgs], [real_imgs])

In [None]:
print(s_fake[0])
print(s_real[0])
print(np.abs(s_fake[0] - s_real[0]))

Accuracy on test set

In [None]:
dataset = load.load_params_dataset(filename=dataset_test_name, batch=N, sorted=True, shape=[ns, ns])

Interpolations

In [None]:
params_inter = [[0.137, 1.23],
               [0.25, 0.658],
               [0.311, 0.842],
               [0.199, 0.87],
               [0.254, 0.852],
               [0.312, 0.664],
               [0.356, 0.614],
               [0.421, 0.628]]
params_inter = np.array(params_inter)

In [None]:
# Define getter functions for every parameter set
# Note this is needed to save memory, as in this way every subset is loaded only when needed
real_imgs = []
fake_imgs = []
for p in params_inter:
    real_imgs.append(lambda p1=p: dataset.get_data_for_params(p1, N=N)[0])
    fake_imgs.append(lambda p1=p: evaluation.generate_samples_params(wgan, p1, nsamples=N, checkpoint=checkpoint))

In [None]:
fig, score_inter = evaluation.compute_plots_for_params(params_inter, real_imgs, fake_imgs, param_str=title_func, ylims=ylims, log=False, confidence='std', lim=(0, 0.4), fractional_difference=fractional_difference, lenstools=lenstools, cut=cut, box_l=box_l, bin_k=bin_k)

In [None]:
print("PSD score:", score_inter[:, 0, 0])
print("Peak score:", score_inter[:, 1, 0])
print("Mass score:", score_inter[:, 2, 0])
print("PSD diff:", score_inter[:, 0, 1])

print("PSD total:", np.mean(score_inter[:, 0, 0]), " +/- ", np.std(score_inter[:, 0, 0]))
print("Peak total:", np.mean(score_inter[:, 1, 0]), " +/- ", np.std(score_inter[:, 1, 0]))
print("Mass total:", np.mean(score_inter[:, 2, 0]), " +/- ", np.std(score_inter[:, 2, 0]))
print("PSD diff total:", np.mean(score_inter[:, 2, 1]), " +/- ", np.std(score_inter[:, 2, 1]))

In [None]:
corr, k = evaluation.compute_correlations(real_imgs, fake_imgs, params_inter, cut=cut_corr, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

In [None]:
score_c_inter = evaluation.plot_correlations(corr, k, params_inter, tick_every=10, param_str=title_func)

In [None]:
print("Correlation losses:", score_c_inter.flatten())
print("Total correlation loss:", np.mean(score_c_inter), " +/- ", np.std(score_c_inter))

Extrapolations

In [None]:
params_extra = [[0.196, 1.225],
                [0.127, 0.836],
                [0.487, 0.643]]
params_extra = np.array(params_extra)

In [None]:
# Define getter functions for every parameter set
# Note this is needed to save memory, as in this way every subset is loaded only when needed
real_imgs = []
fake_imgs = []
for p in params_extra:
    real_imgs.append(lambda p1=p: dataset.get_data_for_params(p1, N=N)[0])
    fake_imgs.append(lambda p1=p: evaluation.generate_samples_params(wgan, p1, nsamples=N, checkpoint=checkpoint))

In [None]:
fig, score_extra = evaluation.compute_plots_for_params(params_extra, real_imgs, fake_imgs, param_str=title_func, ylims=ylims, log=False, confidence='std', lim=(0, 0.4), fractional_difference=fractional_difference, lenstools=lenstools, cut=cut, box_l=box_l, bin_k=bin_k)

In [None]:
print("PSD score:", score_inter[:, 0, 0])
print("Peak score:", score_inter[:, 1, 0])
print("Mass score:", score_inter[:, 2, 0])
print("PSD diff:", score_inter[:, 0, 1])

print("PSD total:", np.mean(score_inter[:, 0, 0]), " +/- ", np.std(score_inter[:, 0, 0]))
print("Peak total:", np.mean(score_inter[:, 1, 0]), " +/- ", np.std(score_inter[:, 1, 0]))
print("Mass total:", np.mean(score_inter[:, 2, 0]), " +/- ", np.std(score_inter[:, 2, 0]))
print("PSD diff total:", np.mean(score_inter[:, 2, 1]), " +/- ", np.std(score_inter[:, 2, 1]))

In [None]:
corr, k = evaluation.compute_correlations(real_imgs, fake_imgs, params_extra, cut=cut_corr, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

In [None]:
score_c_extra = evaluation.plot_correlations(corr, k, params_extra, tick_every=10, param_str=title_func)

In [None]:
print("Correlation losses:", score_c_extra.flatten())
print("Total correlation loss:", np.mean(score_c_extra), " +/- ", np.std(score_c_extra))

Accuracy heat map

In [None]:
test_params = np.vstack([params_inter, params_extra])
test_scores = np.vstack([score_inter, score_extra])

In [None]:
plot.plot_heatmap(test_scores[:, 0, 1], test_params, train_scores[:, 0, 1], train_params)

In [None]:
thresholds = [0.025, 0.05, 0.10, 0.15]
plot.plot_heatmap(test_scores[:, 0, 1], test_params, train_scores[:, 0, 1], train_params, thresholds=thresholds)

In [None]:
test_params = np.vstack([params_inter, params_extra])
test_c_scores = np.vstack([score_c_inter, score_c_extra])

In [None]:
plot.plot_heatmap(test_c_scores[:, 0], test_params, score_c[:, 0], train_params, vmax=15)

MS-SSIM

In [None]:
dataset = load.load_params_dataset(filename=dataset_test_name, batch=N, sorted=True, shape=[ns, ns])

In [None]:
# Define getter functions for every parameter set
# Note this is needed to save memory, as in this way every subset is loaded only when needed
real_imgs = []
fake_imgs = []
for p in dataset.get_different_params():
    real_imgs.append(lambda p1=p: dataset.get_data_for_params(p1, N=N)[0])
    fake_imgs.append(lambda p1=p: evaluation.generate_samples_params(wgan, p1, nsamples=N, checkpoint=checkpoint))

In [None]:
s_fake, s_real = evaluation.compute_ssim_score(fake_imgs, real_imgs)

In [None]:
print(s_fake)
print(s_real)
print(np.mean(s_fake), " +/- ", np.std(s_fake))
print(np.mean(s_real), " +/- ", np.std(s_real))
print(np.mean(np.abs(s_fake - s_real)), " +/- ", np.std(np.abs(s_fake - s_real)))

In [None]:
real_imgs, params = dataset.get_random_data(N)

In [None]:
fake_imgs = evaluation.generate_samples_params(wgan, params, nsamples=N, checkpoint=checkpoint)

In [None]:
s_fake, s_real = evaluation.compute_ssim_score([fake_imgs], [real_imgs])

In [None]:
print(s_fake[0])
print(s_real[0])
print(np.abs(s_fake[0] - s_real[0]))

# Video

In [None]:
N = 2000

In [None]:
dataset_train = load.load_params_dataset(filename=dataset_train_name, batch=N, sorted=True)
dataset_test = load.load_params_dataset(filename=dataset_test_name, batch=N, sorted=True)

In [None]:
def load_real_data(params):
    try:
        data = dataset_train.get_data_for_params(params, N)[0].reshape((N, ns, ns))
    except:
        data = dataset_test.get_data_for_params(params, N)[0].reshape((N, ns, ns))
    return data

In [None]:
# Create list of dictionaries for video generation
X = []
for i in range(len(path)):
    X.append({})
    X[i]['params'] = np.array([path[i][0], path[i][1]])
    X[i]['real'] = None
    X[i]['fake'] = lambda p=X[i]['params']: evaluation.generate_samples_params(wgan, p, nsamples=N, checkpoint=checkpoint)[:, :, :, 0]
    if path[i][2]:
        X[i]['real'] = lambda p=X[i]['params']: load_real_data(p)

In [None]:
param_grid = dataset_train.get_different_params()

In [None]:
# Generate frames
frames = evaluation.make_frames(X, title_func=title_func, log=False, confidence='std', lim=(0, 0.4), vmin=vmin, vmax=vmax, params_grid=param_grid, fractional_difference=fractional_difference, ylims=ylims, cut=cut, lenstools=lenstools, save_frames_dir='frames')

In [None]:
# Generate video
d_frame = 0.75
duration = len(X) * d_frame
animation = VideoClip(evaluation.make_frame_func(X, 'frames', duration, frames_stat=3), duration=duration)
animation.ipython_display(fps=10, loop=True, autoplay=True, width=900)

# Sanity  checks

In [None]:
nsamples = 100

In [None]:
dataset = load.load_params_dataset(filename=dataset_test_name, batch=nsamples, shape=[ns, ns])

Real images and real parameters

In [None]:
images, parameters = dataset.get_samples(nsamples)
dat = Dataset_parameters(images, parameters)
it = dat.iter(nsamples)
batch = next(it)

In [None]:
disc_out = wgan.get_values_at(batch, '_D_real', checkpoint=checkpoint)
print(np.mean(disc_out), np.std(disc_out))

Random images with real parameters

In [None]:
images, parameters = dataset.get_samples(nsamples)
images = np.random.rand(nsamples, ns, ns)
dat = Dataset_parameters(images, parameters)
it = dat.iter(nsamples)
batch = next(it)

In [None]:
disc_out = wgan.get_values_at(batch, '_D_real', checkpoint=checkpoint)
print(np.mean(disc_out), np.std(disc_out))

Real images with fake parameters

In [None]:
images, parameters = dataset.get_samples(nsamples)
for c in range(wgan.net.params['cond_params']):
    parameters[:, c] = utils.scale2range(np.random.rand(nsamples), [0, 1], wgan.net.params['init_range'][c])
dat = Dataset_parameters(images, parameters)
it = dat.iter(nsamples)
batch = next(it)

In [None]:
disc_out = wgan.get_values_at(batch, '_D_real', checkpoint=checkpoint)
print(np.mean(disc_out), np.std(disc_out))

Random images with fake parameters

In [None]:
images, parameters = dataset.get_samples(nsamples)
images = np.random.rand(nsamples, ns, ns)
for c in range(wgan.net.params['cond_params']):
    parameters[:, c] = utils.scale2range(np.random.rand(nsamples), [0, 1], wgan.net.params['init_range'][c])
dat = Dataset_parameters(images, parameters)
it = dat.iter(nsamples)
batch = next(it)

In [None]:
disc_out = wgan.get_values_at(batch, '_D_real', checkpoint=checkpoint)
print(np.mean(disc_out), np.std(disc_out))

# Analyse weights

In [None]:
from cosmotools.metric import feature_analysis

In [None]:
nsamples = 100

In [None]:
def title_func(params):
    return "$\Omega_M$: " + str(params[0])[0:7] + ", $\sigma_8$: " + str(params[1])[0:7]

In [None]:
dataset_train = load.load_params_dataset(filename=dataset_train_name, batch=N, sorted=True, shape=[ns, ns])
dataset_test = load.load_params_dataset(filename=dataset_test_name, batch=N, sorted=True, shape=[ns, ns])

def batch_loader(params):
    try:
        dat = dataset_train.get_data_for_params(params, nsamples)
    except:
        dat = dataset_test.get_data_for_params(params, nsamples)
    images, parameters = dat
    dataset = Dataset_parameters(images.reshape((nsamples, ns, ns)), parameters)
    it = dataset.iter(nsamples)
    batch = next(it)
    return batch

In [None]:
# Long traversal (back and forth)
path = [[0.189, 0.659],
        [0.212, 0.727],
        [0.233, 0.791],
        [0.254, 0.852],
        [0.273, 0.91 ],
        [0.292, 0.966],
        [0.33,  0.898],
        [0.311, 0.842],
        [0.291, 0.783],
        [0.271, 0.723],
        [0.25, 0.658],
        [0.227, 0.591],

In [None]:
# Define params grid
diff_params = np.array(path)
d_frame = 1
duration = len(diff_params) * d_frame

In [None]:
# Generator
generator = False
make_frames_feat, make_frames_weig = feature_analysis.make_features_videos(wgan, batch_loader, diff_params, duration, title_func=title_func, checkpoint=checkpoint, generator=generator)

In [None]:
animation = VideoClip(make_frames_feat, duration=duration)
animation.ipython_display(fps=10, loop=True, autoplay=True, width=900)

In [None]:
animation = VideoClip(make_frames_weig, duration=duration)
animation.ipython_display(fps=10, loop=True, autoplay=True, width=900)

In [None]:
# Generator
generator = True
make_frames_feat, make_frames_weig = feature_analysis.make_features_videos(wgan, batch_loader, diff_params, duration, title_func=title_func, checkpoint=checkpoint, generator=generator)

In [None]:
animation = VideoClip(make_frames_feat, duration=duration)
animation.ipython_display(fps=10, loop=True, autoplay=True, width=900)

In [None]:
animation = VideoClip(make_frames_weig, duration=duration)
animation.ipython_display(fps=10, loop=True, autoplay=True, width=900)

# Results for one cosmology

In [None]:
N = 2000

In [None]:
p = np.array([0.254, 0.852])

In [None]:
dataset = load.load_params_dataset(filename=dataset_test_name, batch=N, sorted=True, shape=[ns, ns])
raw_images = dataset.get_data_for_params(p, N=N)[0]

In [None]:
gen_sample_raw = evaluation.generate_samples_params(wgan, p, nsamples=N, checkpoint=checkpoint)

In [None]:
evaluation.compute_and_plot_psd(raw_images, gen_sample_raw, multiply=True, confidence='std', fractional_difference=True, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
_ = evaluation.plot_stats(ax, gen_sample_raw, raw_images, log=False, lim=(0,0.4), confidence='std', multiply=True, fractional_difference=[True, True, True], lenstools=lenstools, box_l=box_l, bin_k=bin_k)
fig.tight_layout()

In [None]:
c_r, c_f, _ = evaluation.compute_plot_correlation(raw_images, gen_sample_raw, cut=cut_corr, tick_every=10, lenstools=lenstools, box_l=box_l, bin_k=bin_k)
print(np.linalg.norm(c_r-c_f))

In [None]:
evaluation.compute_plot_psd_mode_hists(raw_images, gen_sample_raw, modes=3, hist_batch=4, confidence='std', lenstools=lenstools, box_l=box_l, bin_k=bin_k)

# Extrapolation

In [None]:
path = np.array([[0.332, 0.724], [0.37, 0.838], [0.425, 1], [0.487, 1.331]])

In [None]:
dataset = load.load_params_dataset(filename=dataset_train_name, batch=N, sorted=True, shape=[ns, ns])

In [None]:
# Define getter functions for every parameter set
# Note this is needed to save memory, as in this way every subset is loaded only when needed
fake_imgs = []
real_imgs = []
for p in path:
    if dataset.has_params(p):
        real_imgs.append(lambda p1=p: dataset.get_data_for_params(p1, N=N)[0])
    else:
        real_imgs.append(None)
    fake_imgs.append(lambda p1=p: evaluation.generate_samples_params(wgan, p1, nsamples=N, checkpoint=checkpoint))

In [None]:
if lenstools:
    ylims = [[(1e-7, 5e-3), (0, 0.5)], [(1e-2, 1e3), (0, 0.5)], [(1e-2, 1e5), (0, 0.5)]]
else:
    ylims = [[(1e-4, 2e0), (0, 0.5)], [(1e-2, 1e3), (0, 0.5)], [(1e-2, 1e5), (0, 0.5)]]
fig, score = evaluation.compute_plots_for_params(path, real_imgs, fake_imgs, param_str=title_func, log=False, lim=(0, 0.8), ylims=ylims, confidence='std', fractional_difference=fractional_difference, cut=cut, lenstools=True)

In [None]:
fig

# Correlation experiments

In [None]:
N = 2000

In [None]:
dataset = load.load_params_dataset(filename=dataset_train_name, batch=N, sorted=True, shape=[ns, ns])

Worst correlation score

In [None]:
real_imgs = dataset.get_data_for_params(np.array([0.469, 0.589]), N=N)[0]
fake_imgs = evaluation.generate_samples_params(wgan, np.array([0.469, 0.589]), nsamples=N, checkpoint=checkpoint)

In [None]:
c_r, c_f, _ = evaluation.compute_plot_correlation(real_imgs, fake_imgs, cut=cut_corr, tick_every=10, lenstools=lenstools, box_l=box_l, bin_k=bin_k)
print(np.linalg.norm(c_r - c_f))

In [None]:
evaluation.compute_and_plot_psd(real_imgs, fake_imgs, multiply=True, confidence='std', fractional_difference=True, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

In [None]:
evaluation.compute_plot_psd_mode_hists(real_imgs, fake_imgs, modes=3, lenstools=lenstools, hist_batch=4, confidence='std', box_l=box_l, bin_k=bin_k)

Best correlation score

In [None]:
real_imgs = dataset.get_data_for_params(np.array([0.148, 0.9]), N=N)[0]
fake_imgs = evaluation.generate_samples_params(wgan, np.array([0.148, 0.9]), nsamples=N, checkpoint=checkpoint)

In [None]:
c_r, c_f, _ = evaluation.compute_plot_correlation(real_imgs, fake_imgs, cut=cut_corr, tick_every=10, lenstools=lenstools, box_l=box_l, bin_k=bin_k)
print(np.linalg.norm(c_r - c_f))

In [None]:
evaluation.compute_and_plot_psd(real_imgs, fake_imgs, multiply=True, confidence='std', fractional_difference=True, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

In [None]:
evaluation.compute_plot_psd_mode_hists(real_imgs, fake_imgs, modes=3, lenstools=lenstools, hist_batch=4, confidence='std', box_l=box_l, bin_k=bin_k)

Correlation of a batch

In [None]:
N = 64

In [None]:
dataset = load.load_params_dataset(filename=dataset_train_shuffled_name, batch=N, shape=[ns, ns])

In [None]:
real_imgs, params = dataset.get_samples(N)

In [None]:
fake_imgs = evaluation.generate_samples_params(wgan, params, nsamples=N, checkpoint=checkpoint)

In [None]:
c_r, c_f, _ = evaluation.compute_plot_correlation(real_imgs, fake_imgs, cut=cut_corr, tick_every=10, lenstools=lenstools, box_l=box_l, bin_k=bin_k)
print(np.linalg.norm(c_r - c_f))

In [None]:
evaluation.compute_and_plot_psd(real_imgs, fake_imgs, multiply=True, confidence='std', fractional_difference=True, lenstools=lenstools, box_l=box_l, bin_k=bin_k)

Check smoothness of correlation

In [None]:
from cosmotools.metric import stats

In [None]:
N = 2000

In [None]:
dataset = load.load_params_dataset(filename=dataset_test_name, batch=N, sorted=True, shape=[ns, ns])

In [None]:
params_test = dataset.get_different_params()

In [None]:
corr_test = []
for p in params_test:
    if lenstools:
        c, _ = stats.psd_correlation_lenstools(dataset.get_data_for_params(p, N)[0], box_l=box_l, bin_k=bin_k, cut=cut_corr)
    else:
        c, _ = stats.psd_correlation(dataset.get_data_for_params(p, N)[0], box_l=box_l, bin_k=bin_k, log_sampling=False, cut=cut_corr)
    corr_test.append(np.linalg.norm(c))
corr_test = np.array(corr_test)

In [None]:
dataset = load.load_params_dataset(filename=dataset_train_name, batch=N, sorted=True, shape=[ns, ns])
params_train = dataset.get_different_params()

In [None]:
corr_train = []
for p in params_train:
    if lenstools:
        c, _ = stats.psd_correlation_lenstools(dataset.get_data_for_params(p, N)[0], cut=cut_corr, box_l=box_l, bin_k=bin_k)
    else:
        c, _ = stats.psd_correlation(dataset.get_data_for_params(p, N)[0], box_l=box_l, bin_k=bin_k, log_sampling=False, cut=cut_corr)
    corr_train.append(np.linalg.norm(c))
corr_train = np.array(corr_train)

In [None]:
params = np.vstack([params_train, params_test])
corr = np.hstack([corr_train, corr_test])

In [None]:
plot.plot_heatmap(corr, params, vmax=corr.max())

In [None]:
print(corr.mean())

# Fréchet Inception distance

In [None]:
N = 2000

In [None]:
dataset = load.load_params_dataset(filename=dataset_test_name, batch=N, sorted=True, shape=[ns, ns])

In [None]:
# Define parameters
# params = dataset.get_different_params()
params = [[0.137, 1.23],
          [0.25, 0.658],
          [0.311, 0.842],
          [0.199, 0.87],
          [0.254, 0.852],
          [0.312, 0.664],
          [0.356, 0.614],
          [0.421, 0.628]]
params = np.array(params)

In [None]:
# Generate images
real_imgs = []
fake_imgs = []
for p in params:
    real_imgs.append(dataset.get_data_for_params(p, N=N)[0])
    fake_imgs.append(evaluation.generate_samples_params(wgan, p, nsamples=N, checkpoint=checkpoint))

In [None]:
# Load regressor
regressor_path = '../saved_results/Regressor/Kids_Regressor_128_smart_2D_mac_checkpoints/'

In [None]:
fids, fig = evaluation.compute_plot_fid(real_imgs, fake_imgs, params, regressor_path, batch_size=250, checkpoint=140000, lims=[[0.05, 0.5], [0.4, 1.4]], alpha=0.025)

In [None]:
print(fids)
print(np.mean(fids), "+/-", np.std(fids))

In [None]:
fig