In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# import sys
# sys.path.insert(0, '../')

import numpy as np
import tensorflow as tf
import data
import utils
import plot as plots
import evaluation
from model import WGanModel
from gan import CosmoGAN
from data import fmap
import functools
import os
import matplotlib.pyplot as plt
# os.environ["CUDA_VISIBLE_DEVICES"]="0"


# Parameters

In [3]:
ns = 32
try_resume = False
Mpch = 350
shift = 3
c = 40000
res = 256
forward = functools.partial(fmap.stat_forward, shift=shift, c=c)
backward = functools.partial(fmap.stat_backward, shift=shift, c=c)
def non_lin(x):
    return tf.nn.relu(x)

# Data handling

Load the data

In [4]:
dataset = data.load.load_dataset(spix=ns, resolution=res,Mpch=Mpch, forward_map=forward)

In [None]:
X = dataset.get_all_data().flatten()

In [None]:
assert(np.sum(np.abs(forward(backward(X))-X))< 5)
# # For debugging
# np.sum(np.abs(forward(backward(X))-X))
# forward(backward(X))-X
# x = np.arange(1e4)
# plt.plot(x, backward(forward(x))-x)

In [None]:
plt.hist(X, 100)
print('min: {}'.format(np.min(X)))
print('max: {}'.format(np.max(X)))
plt.yscale('log')

Let us plot 16 images

In [None]:
del X

In [None]:
plt.figure(figsize=(15,15))
plots.draw_images(dataset.get_samples(N=16),nx=4,ny=4)

# A) The WGAN

In [None]:
time_str = 'stat_c_{}_shift_{}_laplacian_Mpch_{}_res_{}'.format(c, shift, Mpch, res)
global_path = '/scratch/snx3000/nperraud/saved_result/'

name = 'WGAN{}'.format(ns)

## Parameters

In [None]:
bn = False

params_discriminator = dict()
params_discriminator['stride'] = [2, 2, 2, 1]
params_discriminator['nfilter'] = [16, 64, 256, 32]
params_discriminator['shape'] = [[5, 5],[5, 5], [5, 5], [3, 3]]
params_discriminator['batch_norm'] = [bn, bn, bn, bn]
params_discriminator['full'] = [32]
params_discriminator['minibatch_reg'] = False
params_discriminator['summary'] = True

params_generator = dict()
params_generator['stride'] = [1, 1, 2, 1, 1]
params_generator['latent_dim'] = 16*16*32
params_generator['nfilter'] = [32, 64, 256, 32, 1]
params_generator['shape'] = [[5, 5], [5, 5],[5, 5], [5, 5], [5, 5]]
params_generator['batch_norm'] = [bn, bn, bn, bn]
params_generator['full'] = []
params_generator['summary'] = True
params_generator['non_lin'] = non_lin

params_optimization = dict()
params_optimization['gamma_gp'] = 10
params_optimization['batch_size'] = 16
params_optimization['gen_optimizer'] = 'adam' # rmsprop / adam / sgd
params_optimization['disc_optimizer'] = 'adam' # rmsprop / adam /sgd
params_optimization['disc_learning_rate'] = 1e-5
params_optimization['gen_learning_rate'] = 1e-5
params_optimization['beta1'] = 0.5
params_optimization['beta2'] = 0.99
params_optimization['epsilon'] = 1e-8
params_optimization['epoch'] = 5


params_cosmology = dict()
params_cosmology['forward_map'] = forward
params_cosmology['backward_map'] = backward
params_cosmology['Nstats'] = 5000


params = dict()
params['generator'] = params_generator
params['discriminator'] = params_discriminator
params['optimization'] = params_optimization
params['cosmology'] = params_cosmology

params['normalize'] = False
params['image_size'] = [ns, ns]
params['prior_distribution'] = 'laplacian'
params['sum_every'] = 500
params['viz_every'] = 500
params['print_every'] = 100
params['save_every'] = 2000
params['name'] = name
params['summary_dir'] = os.path.join(global_path, params['name'] + '_' + time_str +'_summary/')
params['save_dir'] = os.path.join(global_path,params['name'] + '_' + time_str + '_checkpoints/')



In [None]:
resume, params = utils.test_resume(try_resume, params)
# params['optimization']['disc_learning_rate'] = 3e-6
# params['optimization']['gen_learning_rate'] = 3e-6
# params['optimization']['epoch'] = 10


## Build the model

In [None]:
wgan = CosmoGAN(params, WGanModel)

## Train the model
Note that the input is the processed data... Maybe, we should change that

In [None]:
wgan.train(dataset, resume=resume)

In [None]:
gen_sample, gen_sample_raw = wgan.generate(N=5000, checkpoint=70000)
gen_sample = np.squeeze(gen_sample)
gen_sample_raw = np.squeeze(gen_sample_raw)

In [None]:
plt.figure(figsize=(15,15))
plots.draw_images(gen_sample,nx=4,ny=4)

In [None]:
raw_images = backward(dataset.get_samples(dataset.N))

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_psd(raw_images, gen_sample_raw)

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_peak_cout(raw_images, gen_sample_raw)

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_mass_hist(raw_images, gen_sample_raw)