In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from copy import deepcopy
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimage

os.environ["CUDA_VISIBLE_DEVICES"]=""

import sys
sys.path.insert(0, '../')
from gantools import utils
from gantools import plot
from gantools.gansystem import GANsystem
from gantools.data import Dataset, Dataset_parameters

from cosmotools.model import CosmoWGAN
from cosmotools.metric import evaluation, stats
from cosmotools.data import toy_dataset_generator

from gantools.model import ConditionalParamWGAN
from gantools.gansystem import GANsystem


In [None]:
# Note: some of the parameters don't make sense for the fake dataset
ns = 32 # Resolution of the image
try_resume = True # Try to resume previous simulation

# Function to be applied at the end of the generator
def non_lin(x):
    return tf.nn.sigmoid(x)

# Data handling

Load the data

In [None]:
nsamples = 5000
sigma_int = [0.001, 0.01]
N_int = [5, 20]
image_shape = [ns, ns]
normalise = True

In [None]:
# Generate toy images
images, parameters = toy_dataset_generator.generate_fake_dataset(nsamples=nsamples, sigma_int=sigma_int, N_int=N_int, image_shape=image_shape, normalise=normalise)

In [None]:
print(images.shape, parameters.shape)

In [None]:
# Convert to gantools dataset
dataset = Dataset_parameters(images, parameters)

In [None]:
# Get all the data
X, params = dataset.get_all_data()
vmin = np.min(X)
vmax = np.max(X)
X = X.flatten()

Display the histogram of the pixel densities after the forward map

In [None]:
plt.hist(X, 100)
print('min: {}'.format(np.min(X)))
print('max: {}'.format(np.max(X)))
plt.yscale('log')

In [None]:
# to free some memory
del X

Let us plot 16 images

In [None]:
imgs, params = dataset.get_samples(N=16)

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        col.imshow(imgs[idx], vmin=vmin, vmax=vmax)
        col.set_title("$\sigma$: " + str(params[idx, 0])[0:7] + ", $N$: " + str(int(params[idx, 1]) + 1), fontsize=14)
        col.axis('off')
        idx = idx + 1

# Define parameters for the WGAN

In [None]:
time_str = '2D_mac'
global_path = '../saved_results/Fake Dataset/'

name = 'Simple_WGAN_two_params_sigmoid_' + str(ns) + '_' + time_str

## Parameters

In [None]:
bn = False

# Parameters for the generator
params_generator = dict()
params_generator['latent_dim'] = 128
params_generator['stride'] = [2, 2, 1]
params_generator['nfilter'] = [16, 32, 1]
params_generator['shape'] = [[5, 5], [5, 5], [5, 5]]
params_generator['batch_norm'] = [bn, bn]
params_generator['full'] = [256, 512, 16 * 16 * 8]
params_generator['summary'] = True
params_generator['non_lin'] = non_lin
params_generator['in_conv_shape'] = [8, 8]

# Parameters for the discriminator
params_discriminator = dict()
params_discriminator['stride'] = [1, 2, 1]
params_discriminator['nfilter'] = [32, 16, 8]
params_discriminator['shape'] = [[5, 5], [5, 5], [5, 5]]
params_discriminator['batch_norm'] = [bn, bn, bn]
params_discriminator['full'] = [512, 256, 128]
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True

# Optimization parameters
d_opt = dict()
d_opt['optimizer'] = "rmsprop"
d_opt['learning_rate'] = 3e-5
params_optimization = dict()
params_optimization['discriminator'] = deepcopy(d_opt)
params_optimization['generator'] = deepcopy(d_opt)
params_optimization['n_critic'] = 5
params_optimization['batch_size'] = 32
params_optimization['epoch'] = 100

# all parameters
params = dict()
params['net'] = dict() # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['discriminator'] = params_discriminator
params['net']['shape'] = [ns, ns, 1] # Shape of the image
params['net']['gamma_gp'] = 10 # Gradient penalty

# Conditional params
params['net']['prior_normalization'] = False
params['net']['cond_params'] = 2
params['net']['init_range'] = [sigma_int, N_int]
params['net']['prior_distribution'] = "gaussian_length"
params['net']['final_range'] = [0.1*np.sqrt(params_generator['latent_dim']), 1*np.sqrt(params_generator['latent_dim'])]

params['optimization'] = params_optimization
params['summary_every'] = 2000 # Tensorboard summaries every ** iterations
params['print_every'] = 500 # Console summaries every ** iterations
params['save_every'] = 10000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = 2000

In [None]:
resume, params = utils.test_resume(try_resume, params)
# If a model is reloaded and some parameters have to be changed, then it should be done here.
# For example, setting the number of epoch to 5 would be:
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')

# Build the model

In [None]:
class CosmoConditionalParamWGAN(ConditionalParamWGAN, CosmoWGAN):
    pass

In [None]:
wgan = GANsystem(ConditionalParamWGAN, params)

# Train the model

In [None]:
wgan.train(dataset, resume=resume)

# Generate new samples


In [None]:
checkpoint = None

In [None]:
inter = 4

# Generate grid
grid = []
for c in range(wgan.net.params['cond_params']):
    if c == 0:
        gen_params = np.linspace(0.002, wgan.net.params['init_range'][c][1], inter)
    else:
        gen_params = np.linspace(wgan.net.params['init_range'][c][0], wgan.net.params['init_range'][c][1], inter)
    grid.append(gen_params)

# Note: assume 2D grid of parameters
gen_params = []
for i in range(inter):
    for j in range(inter):
        gen_params.append([grid[0][i], grid[1][j]])
gen_params = np.array(gen_params)

# Produce images
latent = wgan.net.sample_latent(bs=inter * inter, params=gen_params)
gen_images = wgan.generate(N=inter * inter, **{'z': latent}, checkpoint=checkpoint)

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        col.imshow(gen_images[idx, :, :, 0], vmin=vmin, vmax=vmax)
        col.set_title("$\sigma=$" + str(gen_params[idx][0])[0:7] + ", $N=$" + str(int(gen_params[idx][1])+1))
        col.axis('off')
        idx = idx + 1

Real and fake comparison

In [None]:
grid = [[0.005, 10], [0.01, 10], [0.005, 15], [0.01, 15]]

In [None]:
gen_params = np.array(grid)

# Produce images
latent = wgan.net.sample_latent(bs=len(gen_params), params=gen_params)
gen_images = wgan.generate(N=len(gen_params), **{'z': latent}, checkpoint=checkpoint)

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(15, 7.5))
idx = 0
for row in ax:
    for col in row:
        if idx < 4:
            col.imshow(gen_images[idx, :, :, 0], vmin=0, vmax=1)
        else:
            img = toy_dataset_generator.generate_fake_images(1, sigma=gen_params[idx%4][0], N=int(gen_params[idx%4][1]), image_shape=image_shape)
            col.imshow(img[0], vmin=0, vmax=1)
        col.set_title("$\sigma$=" + str(gen_params[idx%4][0])[0:7] + ", $N$=" + str(int(gen_params[idx%4][1] + 1)), fontsize=14)
        col.axis('off')
        idx = idx + 1

# Generate a single image

In [None]:
latent = wgan.net.sample_latent(params=np.array([[0.005, 4]]))
gen_sample = wgan.generate(N=1, **{'z': latent}, checkpoint=checkpoint)
plt.imshow(gen_sample[0, :, :, 0])
plt.axis('off')

# "Category" Morphing

In [None]:
inter = 4

In [None]:
# Generate grid
grid = []
for c in range(wgan.net.params['cond_params']):
    if c == 0:
        gen_params = np.linspace(0.002, wgan.net.params['init_range'][c][1], inter)
    else:
        gen_params = np.linspace(wgan.net.params['init_range'][c][0], wgan.net.params['init_range'][c][1], inter)
    grid.append(gen_params)

# Note: assume 2D grid of parameters
gen_params = []
for i in range(inter):
    for j in range(inter):
        gen_params.append([grid[0][i], grid[1][j]])
gen_params = np.array(gen_params)

imgs = evaluation.generate_samples_same_seed(wgan, gen_params, checkpoint=checkpoint)

In [None]:
fig, ax = plt.subplots(nrows=inter, ncols=inter, figsize=(15,15))
idx = 0
for row in ax:
    for col in row:
        col.imshow(imgs[idx][0, :, :, 0])
        sigma = gen_params[idx][0]
        N = gen_params[idx][1]
        col.set_title("$\sigma=$" + str(sigma)[0:7] + ", $N=$" + str(int(N)), fontsize=14)
        col.axis('off')
        idx = idx + 1

In [None]:
# Define path of params
path = [[0.01, 10],
        [0.002, 10],
        [0.002, 5],
        [0.005, 5],
        [0.005, 12],
        [0.005, 20]]
for p in path:
    p.append(False)

In [None]:
path = evaluation.interpolate_between(path, 5)

In [None]:
# Generate frames
frames = evaluation.generate_samples_same_seed(wgan, path)

In [None]:
fig, ax = plt.subplots()
def make_frame(t):
    t = int(t)
    ax.clear()
    ax.imshow(frames[t][0, :, :, 0])
    ax.axis('off')
    ax.set_title("$\sigma=$" + str(path[t][0])[0:7] + ", $N=$" + str(int(path[t][1]) + 1))
    return mplfig_to_npimage(fig)

animation = VideoClip(make_frame, duration=len(path))
plt.close()
animation.ipython_display(fps=20, loop=True, autoplay=True)

# Evaluation of the sample quality

In [None]:
inter = 4

# Generate grid
grid = []
for c in range(wgan.net.params['cond_params']):
    if c == 0:
        gen_params = np.linspace(0.002, 0.008, inter)
    if c == 1:
        gen_params = np.linspace(5, 15, inter)
    grid.append(gen_params)

# Note: assume 2D grid of parameters
gen_params = []
for i in range(inter):
    for j in range(inter):
        gen_params.append([grid[0][i], grid[1][j]])
gen_params = np.array(gen_params)

In [None]:
def generate_images_with_params(params, n):
    gen_params = np.ones((n, wgan.net.params['cond_params'])) * params
    latent = wgan.net.sample_latent(bs=n, params=gen_params)
    return wgan.generate(N=n, **{'z': latent}, checkpoint=checkpoint)

In [None]:
nsamples = 2000
real_images = []
fake_images = []
for i in range(len(gen_params)):
    
    # Generate real images
    raw_images = toy_dataset_generator.generate_fake_images(nsamples=nsamples, sigma=gen_params[i, 0], N=int(gen_params[i, 1]), image_shape=[ns, ns])
    
    # Generate fake images
    gen_images = generate_images_with_params(gen_params[i], nsamples)
    
    real_images.append(raw_images)
    fake_images.append(gen_images)

In [None]:
param_titles = (lambda x: "$\sigma=$" + str(x)[:7], lambda x: "$N=$" + str(int(x)))

In [None]:
# Compute the plots for a set of parameter (assumed arranged in 2D fashion) and a given statistic function
# Produces n_params * n_params plots, where every cell represents the statistic of a couple of parameters
def compute_plot_for_params_2d(params, real, fake, func, **kwargs):
    inter = int(np.sqrt(len(params)))
    fig, ax = plt.subplots(nrows=inter, ncols=inter, figsize=(20, 20))
    idx = 0
    row_nr = 0
    title = ""
    scores = []
    for row in ax:
        for col in row:
            s = func(real[idx], fake[idx], ax=col, display=False, **kwargs)
            title = col.title.get_text()
            col.set_title("")
            if idx % inter != 0:
                col.set_ylabel("")
            if row_nr != inter - 1:
                col.set_xlabel("") 
            idx = idx + 1
            scores.append(s)
        row_nr = row_nr + 1

    for a, param in zip(ax[0], params[:, 1]):
        a.set_title(param_titles[1](param), fontsize=16)
    for a, param in zip(ax[:, -1], params[range(0, len(params), inter), 0]):
        ar = a.twinx()
        ar.set_ylabel(param_titles[0](param), labelpad=50 if 'fractional_difference' in kwargs and kwargs['fractional_difference'] else 10, fontsize=16)
        ar.set_yticks([])
    fig.suptitle(title, fontsize=20)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    return fig, np.array(scores)

PSD

In [None]:
lenstools = True
bin_k = 15
box_l = (5*np.pi/180)
cut = [50, 1000]
ylim= [(5e-5, 5e-1), (0, 0.5)] if lenstools else [(1e-3, 1e2), (0, 0.5)]
_, psd_s = compute_plot_for_params_2d(gen_params, real_images, fake_images, evaluation.compute_and_plot_psd, confidence='std', bin_k=bin_k, box_l=box_l, cut=cut, multiply=True, lenstools=lenstools, param_titles=param_titles, fractional_difference=True, loc=1, ylim=ylim)

In [None]:
print("PSD scores:", psd_s[:, 0])
print("PSD average frac diff:", np.mean(psd_s[:, 1]), " +/- ", np.std(psd_s[:, 1]))

Heat map

In [None]:
# Represent heat-map of accuracy
plt.figure()
plt.scatter(gen_params[:, 0], gen_params[:, 1] + 1, c=psd_s[:, 1], vmin=0, vmax=1, cmap=plt.cm.RdYlGn_r, edgecolor='k')
plt.xlabel('$\sigma$')
plt.ylabel('$N$')
plt.xlim([-0.001, 0.021])
plt.plot(np.array([0.001, 0.001, 0.01, 0.01, 0.001]), np.array([6, 22, 22, 6, 6]), c='k')
plt.colorbar()

In [None]:
thresholds =[0.05, 0.10, 0.15, 0.20]
fig, ax = plt.subplots(nrows=1, ncols=len(thresholds), figsize=(len(thresholds) * 5, 5))
for j in range(len(thresholds)):
    for i in range(len(gen_params)):
        ax[j].scatter(gen_params[i, 0], gen_params[i, 1] + 1, c='g' if psd_s[i, 1] <= thresholds[j] else 'r')
    ax[j].set_xlabel('$\sigma$')
    ax[j].set_ylabel('$N$')
    ax[j].set_xlim([-0.001, 0.021])
    ax[j].plot(np.array([0.001, 0.001, 0.01, 0.01, 0.001]), np.array([6, 22, 22, 6, 6]), c='k')
    ax[j].set_title(thresholds[j])

Mass density histogram and peak density histogram

In [None]:
_, mas_s = compute_plot_for_params_2d(gen_params, real_images, fake_images, evaluation.compute_and_plot_mass_hist, log=False, lim=(0,1), confidence='std', param_titles=param_titles, ylim=[(1e-1, 1e3), (0, 0.5)], fractional_difference=True)

In [None]:
print("Mass scores:", mas_s[:, 0])
print("Mass mean score:", np.mean(mas_s[:, 0]), " +/- ", np.std(mas_s[:, 0]))

In [None]:
_, pea_s = compute_plot_for_params_2d(gen_params, real_images, fake_images, evaluation.compute_and_plot_peak_count, log=False, neighborhood_size=2, threshold=0.01, confidence='std', param_titles=param_titles, ylim=[(3e-1, 6e1), (0, 0.5)], fractional_difference=True)

In [None]:
print("Peak scores:", pea_s[:, 0])
print("Peak mean scores:", np.mean(pea_s[:, 0]), " +/- ", np.std(pea_s[:, 0]))

Plot correlations

In [None]:
def plot_correlations(params, X, title, tick_every=3):
    inter = int(np.sqrt(len(params)))
    fig, ax = plt.subplots(nrows=inter, ncols=inter, figsize=(16, 15))
    idx = 0
    for row in ax:
        for col in row:
            
            # Compute
            if lenstools:
                corr, k = stats.psd_correlation_lenstools(X[idx], bin_k=bin_k, box_l=box_l, cut=cut)
            else:
                corr, k = stats.psd_correlation(X[idx], bin_k=bin_k, cut=cut, box_l=box_l, log_sampling=False)
            
            # Show
            col.imshow(corr, vmin=0, vmax=1, cmap=plt.cm.plasma)
            
            # Define axes
            ticklabels = [0]
            for i in range(len(k)):
                if i % tick_every == 0:
                    ticklabels.append(str(round(k[i], 2)))
            ticks = np.linspace(0, len(k) - (len(k) % tick_every), len(ticklabels))
            col.set_xticks(ticks)
            col.set_xticklabels(ticklabels)
            col.set_yticks(ticks)
            col.set_yticklabels(ticklabels)
            idx = idx + 1
    for a, param in zip(ax[0], params[:, 1]):
        a.set_title(param_titles[1](param), fontsize=14)
    for a, param in zip(ax[:, -1], params[range(0, len(params), inter), 0]):
        ar = a.twinx()
        ar.set_ylabel(param_titles[0](param), labelpad=10, fontsize=14)
        ar.set_yticks([])
    fig.tight_layout()

In [None]:
# Plot real
plot_correlations(gen_params, real_images, "PSD correlation of real images")

In [None]:
# Plot fake
plot_correlations(gen_params, fake_images, "PSD correlation of generated images")

MS-SSIM score across subsets

In [None]:
s_fake, s_real = evaluation.compute_ssim_score(fake_images, real_images)

In [None]:
print(s_fake)
print(s_real)
print(np.mean(s_fake), " +/- ", np.std(s_fake))
print(np.mean(s_real), " +/- ", np.std(s_real))

# Generate video

In [None]:
# Define path of params
path = [[0.002, 10],
        [0.002, 5],
        [0.005, 5],
        [0.005, 12],
        [0.005, 20]]
for p in path:
    p.append(True)

In [None]:
path = evaluation.interpolate_between(path, 5, True)

In [None]:
# Generate fake images
fake_imgs = evaluation.generate_samples_same_seed(wgan, path, nsamples=nsamples, checkpoint=checkpoint)

In [None]:
# Convert to list of dictionaries
X = []
for i in range(len(path)):
    X.append({})
    X[i]['params'] = np.array([path[i][0], path[i][1]])
    X[i]['real'] = None
    X[i]['fake'] = fake_imgs[i][:, :, :, 0]
    if path[i][2]:
        X[i]['real'] = toy_dataset_generator.generate_fake_images(nsamples=nsamples, sigma=path[i][0], N=int(path[i][1]), image_shape=[ns, ns])

In [None]:
def title_func(params):
    return "$\sigma=$" + str(params[0])[0:7] + ", $N=$" + str(int(params[1]) + 1)

In [None]:
# Generate frames
ylims = [[(5e-5, 1e-2), (0, 0.5)], [(3e-1, 6e1), (0, 0.5)], [(1e-1, 1e3), (0, 0.5)]] if lenstools else [[(1e-3, 2e1), (0, 0.5)], [(3e-1, 6e1), (0, 0.5)], [(1e-1, 1e3), (0, 0.5)]]
frames = evaluation.make_frames(X, title_func=title_func, log=False, confidence='std', neighborhood_size=2, threshold=0.01, lim=(0,1), ylims=ylims, multiply=True, bin_k=bin_k, box_l=box_l, cut=cut, fractional_difference=[True, True, True], lenstools=lenstools)

In [None]:
# Make video
d_frame = 0.5
duration = len(X) * d_frame
animation = VideoClip(evaluation.make_frame_func(X, frames, duration), duration=duration)
animation.ipython_display(fps=10, loop=True, autoplay=True, width=900)

# Extrapolation

In [None]:
path = [[0.005, 20], [0.01, 20], [0.02, 20], [0.02, 25]]
for p in path:
    p.append(True)

In [None]:
path = evaluation.interpolate_between(path, 5, True)

In [None]:
# Generate fake images
fake_imgs = evaluation.generate_samples_same_seed(wgan, path, nsamples=nsamples, checkpoint=checkpoint)

In [None]:
# Convert to list of dictionaries
X = []
for i in range(len(path)):
    X.append({})
    X[i]['params'] = np.array([path[i][0], path[i][1]])
    X[i]['real'] = None
    X[i]['fake'] = fake_imgs[i][:, :, :, 0]
    if path[i][2]:
        X[i]['real'] = toy_dataset_generator.generate_fake_images(nsamples=nsamples, sigma=path[i][0], N=int(path[i][1]), image_shape=[ns, ns])

In [None]:
frames = evaluation.make_frames(X, title_func=title_func, log=False, confidence='std', neighborhood_size=2, threshold=0.01, lim=(0,1), ylims=ylims, multiply=True, bin_k=bin_k, box_l=box_l, cut=cut, fractional_difference=[True, True, True], lenstools=lenstools, locs=[1, 1, 1])

In [None]:
# Make video
d_frame = 0.5
duration = len(X) * d_frame
animation = VideoClip(evaluation.make_frame_func(X, frames, duration), duration=duration)
animation.ipython_display(fps=10, loop=True, autoplay=True, width=900)