In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from copy import deepcopy
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimage


import sys
sys.path.insert(0, '../')
from gantools import utils
from gantools import plot
from gantools.gansystem import GANsystem
from gantools.data import Dataset, Dataset_parameters

from cosmotools.model import CosmoWGAN
from cosmotools.metric import evaluation
from cosmotools.data import toy_dataset_generator

from gantools.model import ConditionalParamWGAN
from gantools.gansystem import GANsystem


In [None]:
ns = 32 # Resolution of the image
try_resume = True # Try to resume previous simulation

def non_lin(x):
    return tf.nn.sigmoid(x)

# Data handling

Load the data

In [None]:
nsamples = 5000
sigma_int = [0.001, 0.01]
N_int = [10, 11]
image_shape = [ns, ns]

In [None]:
# Generate toy images
images, parameters = toy_dataset_generator.generate_fake_dataset(nsamples=nsamples, sigma_int=sigma_int, N_int=N_int, image_shape=image_shape)

In [None]:
print(images.shape, parameters.shape)

In [None]:
# Convert to gantools dataset
dataset = Dataset_parameters(images, parameters)

In [None]:
# The dataset can return an iterator.
it = dataset.iter(10)
current = next(it)
print(current[0, 0].shape, current[0, 1].shape)
del it

In [None]:
# Get all the data
X, params = dataset.get_all_data()
X = X.flatten()

Display the histogram of the pixel densities after the forward map

In [None]:
plt.hist(X, 100)
print('min: {}'.format(np.min(X)))
print('max: {}'.format(np.max(X)))
plt.yscale('log')

In [None]:
# to free some memory
del X

Let us plot 16 images

In [None]:
imgs, params = dataset.get_samples(N=16)

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        col.imshow(imgs[idx], vmin=0, vmax=1)
        col.set_title("sigma: " + str(params[idx, 0])[0:7] + ", N: " + str(int(params[idx, 1]) + 1), fontsize=14)
        col.axis('off')
        idx = idx + 1

# Define parameters for the WGAN

In [None]:
time_str = '2D_mac'
global_path = '../saved_results/Fake_Dataset/'

name = 'Simple_WGAN_conditional_' + str(ns) + '_' + time_str

## Parameters

In [None]:
bn = False

# Parameters for the generator
params_generator = dict()
params_generator['latent_dim'] = 128
params_generator['stride'] = [2, 2, 1]
params_generator['nfilter'] = [16, 32, 1]
params_generator['shape'] = [[5, 5], [5, 5], [5, 5]]
params_generator['batch_norm'] = [bn, bn]
params_generator['full'] = [256, 512, 16 * 16 * 8]
params_generator['summary'] = True
params_generator['non_lin'] = non_lin
params_generator['in_conv_shape'] = [8, 8]

# Parameters for the discriminator
params_discriminator = dict()
params_discriminator['stride'] = [1, 2, 1]
params_discriminator['nfilter'] = [32, 16, 8]
params_discriminator['shape'] = [[5, 5], [5, 5], [5, 5]]
params_discriminator['batch_norm'] = [bn, bn, bn]
params_discriminator['full'] = [512, 256, 128]
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True

# Optimization parameters
d_opt = dict()
d_opt['optimizer'] = "rmsprop"
d_opt['learning_rate'] = 3e-5
params_optimization = dict()
params_optimization['discriminator'] = deepcopy(d_opt)
params_optimization['generator'] = deepcopy(d_opt)
params_optimization['n_critic'] = 5
params_optimization['batch_size'] = 32
params_optimization['epoch'] = 100

# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['shape'] = [ns, ns, 1] # Shape of the image
params['net']['gamma_gp'] = 10 # Gradient penalty

# Conditional params
params['net']['prior_normalization'] = False
params['net']['cond_params'] = 1
params['net']['init_range'] = [sigma_int, N_int]
params['net']['prior_distribution'] = "gaussian_length"
params['net']['final_range'] = [0.1*np.sqrt(params_generator['latent_dim']), 1*np.sqrt(params_generator['latent_dim'])]

params['optimization'] = params_optimization
params['summary_every'] = 50 # Tensorboard summaries every ** iterations
params['print_every'] = 50 # Console summaries every ** iterations
params['save_every'] = 10000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = 2000

In [None]:
resume, params = utils.test_resume(try_resume, params)
# If a model is reloaded and some parameters have to be changed, then it should be done here.
# For example, setting the number of epoch to 5 would be:
# params['optimization']['epoch'] = 5
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')

# Build the model

In [None]:
class CosmoConditionalParamWGAN(ConditionalParamWGAN, CosmoWGAN):
    pass

In [None]:
wgan = GANsystem(CosmoConditionalParamWGAN, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Generate new samples


In [None]:
checkpoint = None

In [None]:
gen_params = np.atleast_2d(np.linspace(0.002, sigma_int[1], 4)).T
gen_params = np.concatenate((gen_params, np.ones((4, 1)) * N_int[0]), axis=1)
latent = wgan.net.sample_latent(bs=4, params=gen_params)
gen_images = wgan.generate(N=4, **{'z': latent}, checkpoint=checkpoint)

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(15, 7.5))
idx = 0
for row in ax:
    for col in row:
        if idx < 4:
            col.imshow(gen_images[idx, :, :, 0], vmin=0, vmax=1)
        else:
            img = toy_dataset_generator.generate_fake_images(1, sigma=gen_params[idx%4][0], N=int(gen_params[idx%4][1]), image_shape=image_shape)
            col.imshow(img[0], vmin=0, vmax=1)
        col.set_title("$\sigma$=" + str(gen_params[idx%4][0])[0:7] + ", $N$=" + str(int(gen_params[idx%4][1] + 1)), fontsize=14)
        col.axis('off')
        idx = idx + 1

# Generate a single image

In [None]:
latent = wgan.net.sample_latent(params=np.array([[0.002, 10]]))
gen_sample = wgan.generate(N=1, **{'z': latent}, checkpoint=checkpoint)
plt.imshow(gen_sample[0, :, :, 0], vmin=0, vmax=1)
plt.axis('off')

# "Category" Morphing

In [None]:
# Sample a latent vector
latent_0 = wgan.net.sample_latent(bs=4, params=np.array([[0.001, 10]]))

# Draw an unnormalised distribution
z = utils.sample_latent(1, wgan.net.params['generator']['latent_dim'], prior="gaussian")

# Normalise the distribution to the final range
# gen_params = np.linspace(wgan.net.params['final_range'][0], wgan.net.params['final_range'][1], 4)
gen_params = np.linspace(0.002, wgan.net.params['init_range'][0][1], 4)
for i in range(4):
    scaled_p = utils.scale2range(gen_params[i], wgan.net.params['init_range'][0], wgan.net.params['final_range'])
    z_r = (z.T * np.sqrt((scaled_p * scaled_p) / np.sum(z * z, axis=1))).T
    latent_0[i, :] = z_r[0, :]

# Generate images
imgs = wgan.generate(N=4, **{'z': latent_0}, checkpoint=checkpoint)[:, :, :, 0]

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(15,15))
idx = 0
for col in ax:
    col.imshow(imgs[idx], vmin=0, vmax=1)
    col.set_title("$\sigma=$" + str(gen_params[idx])[0:7] + ", $N=$" + str(int(N_int[1])), fontsize=14)
    col.axis('off')
    idx = idx + 1

In [None]:
# Parameters
gen_params = []
for p in np.linspace(0.001, 0.01, 20):
    gen_params.append([p])
gen_params = np.array(gen_params)
frames = evaluation.generate_samples_same_seed(wgan, gen_params)

In [None]:
fig, ax = plt.subplots()
def make_frame(t):
    t = int(t)
    ax.clear()
    ax.imshow(frames[t][0, :, :, 0], vmin=0, vmax=1)
    ax.axis('off')
    ax.set_title("$\sigma=$" + str(gen_params[t][0])[0:7])
    return mplfig_to_npimage(fig)

animation = VideoClip(make_frame, duration=len(gen_params))
plt.close()
animation.ipython_display(fps=20, loop=True, autoplay=True)

# Evaluation of the sample quality

In [None]:
diff_params = 4
gen_params = np.atleast_2d(np.linspace(0.002, 0.008, diff_params)).T
gen_params = np.concatenate((gen_params, np.ones((diff_params, 1)) * N_int[0]), axis=1)

In [None]:
def generate_images_with_params(params, n):
    gen_params = np.ones((n, 1)) * params[0]
    gen_params = np.concatenate((gen_params, np.ones((n, 1)) * params[1]), axis=1)
    latent = wgan.net.sample_latent(bs=n, params=gen_params)
    return wgan.generate(N=n, **{'z': latent}, checkpoint=checkpoint)

In [None]:
N = 2000
real_images = []
fake_images = []
for i in range(len(gen_params)):
    
    # Generate real images
    raw_images = toy_dataset_generator.generate_fake_images(nsamples=N, sigma=gen_params[i, 0], N=int(gen_params[i, 1]), image_shape=[ns, ns])
    
    # Generate fake images
    gen_images = generate_images_with_params(gen_params[i], N)
    
    real_images.append(raw_images)
    fake_images.append(gen_images[:, :, :, 0])

In [None]:
lenstools = True
bin_k = 15
box_l = (5*np.pi/180)
cut = [50, 1000]
if lenstools:
    ylims = [[(5e-5, 2e0), (0, 0.5)], [(1e-1, 1e2), (0, 0.5)], [(4e-1, 1e3), (0, 0.2)]]
else:
    ylims = [[(1e-3, 2e2), (0, 0.5)], [(1e-1, 1e2), (0, 0.5)], [(4e-1, 1e3), (0, 0.2)]]
fractional_difference = [True, True, True]
locs = [2, 1, 1]
def param_str(par):
    return "$\sigma=$" + str(par[0])[0:7]

In [None]:
_, score = evaluation.compute_plots_for_params(gen_params, real_images, fake_images, log=False, lim=(0,1), neighborhood_size=2, threshold=0.01, confidence='std', multiply=True, ylims=ylims, param_str=param_str, fractional_difference=fractional_difference, bin_k=bin_k, box_l=box_l, cut=cut, locs=locs, lenstools=lenstools)

In [None]:
# Score has shape n_params, n_stats, losses (log_l2, l2, log_l1, l1)
print("PSD log-L1 losses:", score[:, 0, 2])
print("PSD frac diffs:", score[:, 0, 4])
print("Peak log-L1 losses:", score[:, 1, 2])
print("Mass log-L1 losses:", score[:, 2, 2])
print("PSD log-L1 total:", np.mean(score[:, 0, 2]), " +/- ", np.std(score[:, 0, 2]))
print("Peak log-L1 total:", np.mean(score[:, 1, 2]), " +/- ", np.std(score[:, 1, 2]))
print("Mass log-L1 total:", np.mean(score[:, 2, 2]), " +/- ", np.std(score[:, 2, 2]))
print("PSD frac diff:", np.mean(score[:, 0, 4]), " +/- ", np.std(score[:, 0, 4]))

Heat map

To have nice plots set gen_params = np.atleast_2d(np.linspace(0.0001, 0.002, 20)).T

In [None]:
# Represent heat-map of accuracy
plt.figure(figsize=(10, 2))
plt.scatter(gen_params[:, 0], gen_params[:, 1] + 1, c=score[:, 0, 4], vmin=0, vmax=1, cmap=plt.cm.RdYlGn_r, edgecolor='k')
plt.xlabel('$\sigma$')
plt.ylabel('$N$')
plt.xlim([-0.001, 0.021])
plt.ylim([10, 12])
plt.plot(np.array([0.0009, 0.0009, 0.01, 0.01, 0.0009]), np.array([10.5, 11.5, 11.5, 10.5, 10.5]), c='k')
plt.colorbar()

In [None]:
thresholds =[0.08, 0.13, 0.18, 0.23]
fig, ax = plt.subplots(nrows=len(thresholds), ncols=1, figsize=(7, len(thresholds) * 2))
for j in range(len(thresholds)):
    for i in range(len(gen_params)):
        ax[j].scatter(gen_params[i, 0], gen_params[i, 1] + 1, c='g' if score[i, 0, 4] <= thresholds[j] else 'r')
    ax[j].set_xlabel('$\sigma$')
    ax[j].set_ylabel('$N$')
    ax[j].set_xlim([-0.001, 0.021])
    ax[j].set_ylim([10, 12])
    ax[j].plot(np.array([0.0009, 0.0009, 0.01, 0.01, 0.0009]), np.array([10.5, 11.5, 11.5, 10.5, 10.5]), c='k')
    ax[j].set_title("Threshold: " + str(thresholds[j]))
fig.tight_layout()

Correlations

In [None]:
corr, k = evaluation.compute_correlations(real_images, fake_images, gen_params, bin_k=bin_k, box_l=box_l, cut=cut, lenstools=lenstools)

In [None]:
score_c = evaluation.plot_correlations(corr, k, gen_params, param_str=param_str, tick_every=3)

In [None]:
print("Correlation losses:", score_c)
print("Total correlation loss:", np.mean(score_c), " +/- ", np.std(score_c))

MS-SSIM score

In [None]:
latent = wgan.net.sample_latent(bs=len(parameters), params=parameters)
gen_images = wgan.generate(N=len(parameters), **{'z': latent}, checkpoint=checkpoint)

In [None]:
s_fake, s_real = evaluation.compute_ssim_score([gen_images], [images])

In [None]:
print(s_fake[0])
print(s_real[0])
print(np.abs(s_fake[0] - s_real[0]))

In [None]:
s_fake, s_real = evaluation.compute_ssim_score(fake_images, real_images)

In [None]:
print(s_fake)
print(s_real)
print(np.mean(s_fake), " +/- ", np.std(s_fake))
print(np.mean(s_real), " +/- ", np.std(s_real))
diff = np.abs(s_fake - s_real)
print(np.mean(diff), " +/- ", np.std(diff))

# Extrapolation

In [None]:
gen_params = np.array([[0.005, N_int[0]], [0.01, N_int[0]], [0.02, N_int[0]], [0.03, N_int[0]]])

In [None]:
real_images = []
fake_images = []
for i in range(len(gen_params)):
    
    # Generate real images
    raw_images = data.toy_dataset_generator.generate_fake_images(nsamples=N, sigma=gen_params[i, 0], N=int(gen_params[i, 1]), image_shape=[ns, ns])
    
    # Generate fake images
    gen_images = generate_images_with_params(gen_params[i], N)
    
    real_images.append(raw_images)
    fake_images.append(gen_images[:, :, :, 0])

In [None]:
if lenstools:
    ylims = [[(1e-4, 5e-1), (0, 1)], [(2e-1, 2e2), (0, 0.5)], [(4e-1, 2e3), (0, 0.5)]]
else:
    ylims = [[(1e-3, 2e1), (0, 1)], [(1e-1, 60), (0, 0.5)], [(4e-1, 1e3), (0, 0.2)]]
fractional_difference = [True, True, True]
locs = [1, 1, 1]

In [None]:
_, score = evaluation.compute_plots_for_params(gen_params, real_images, fake_images, log=False, lim=(0,1), neighborhood_size=2, threshold=0.01, confidence='std', multiply=True, bin_k=bin_k, cut=cut, box_l=box_l, ylims=ylims, param_str=param_str, fractional_difference=fractional_difference, locs=locs, lenstools=lenstools)