# Wasserstein super-resolution emulator

### Import relevant packages

In [1]:
import numpy as np
import tensorflow as tf
import tqdm
import scipy as sp



### Load low and high-resolution density fields, and the initial conditions for the latter

In [2]:
hr_resolution = 512
lr_resolution = 256

hr_final = np.load("high_res/df_0_z=0.npy")
hr_final = hr_final.astype(np.float32).reshape((hr_resolution, hr_resolution, hr_resolution))

hr_initial = np.load("high_res/df_0_z=127.npy")
hr_initial = hr_initial.astype(np.float32).reshape((hr_resolution, hr_resolution, hr_resolution))

lr_final = np.load("low_res/df_0_z=0.npy")
lr_final = lr_final.astype(np.float32).reshape((lr_resolution, lr_resolution, lr_resolution))

### Normalize final density fields such that mean density guaranteed to be unity

In [3]:
hr_final /= hr_final.sum()/hr_resolution**3
lr_final /= lr_final.sum()/lr_resolution**3

## Set up the TensorFlow graph

### Data augmentation
We are going to perform data augmentation on the fly to save time. In particular, we will perform right angle rotations, with a total of 24 possibilities, to randomly selected 3D slices of training set. Moreover, we also apply a random flip along the three axes.

In [4]:
def rotations(y):

    x = y[0]
    rand = y[1]
    inv = y[2]
    x = tf.case(
        {tf.equal(rand, 1): lambda: x[:, ::-1, ::-1, :, :],
         tf.equal(rand, 2): lambda: x[:, ::-1, :, ::-1, :],
         tf.equal(rand, 3): lambda: x[:, :, ::-1, ::-1, :],
         tf.equal(rand, 4): lambda: tf.transpose(x, (0, 2, 1, 3, 4))[:, ::-1, :, :, :],
         tf.equal(rand, 5): lambda: tf.transpose(x, (0, 2, 1, 3, 4))[:, ::-1, :, ::-1, :],
         tf.equal(rand, 6): lambda: tf.transpose(x, (0, 2, 1, 3, 4))[:, :, ::-1, :, :],
         tf.equal(rand, 7): lambda: tf.transpose(x, (0, 2, 1, 3, 4))[:, :, ::-1, ::-1, :],
         tf.equal(rand, 8): lambda: tf.transpose(x, (0, 3, 2, 1, 4))[:, ::-1, :, :, :],
         tf.equal(rand, 9): lambda: tf.transpose(x, (0, 3, 2, 1, 4))[:, ::-1, ::-1, :, :],
         tf.equal(rand, 10): lambda: tf.transpose(x, (0, 3, 2, 1, 4))[:, :, :, ::-1, :],
         tf.equal(rand, 11): lambda: tf.transpose(x, (0, 3, 2, 1, 4))[:, :, ::-1, ::-1, :],
         tf.equal(rand, 12): lambda: tf.transpose(x, (0, 1, 3, 2, 4))[:, :, ::-1, :, :],
         tf.equal(rand, 13): lambda: tf.transpose(x, (0, 1, 3, 2, 4))[:, ::-1, ::-1, :, :],
         tf.equal(rand, 14): lambda: tf.transpose(x, (0, 1, 3, 2, 4))[:, :, :, ::-1, :],
         tf.equal(rand, 15): lambda: tf.transpose(x, (0, 1, 3, 2, 4))[:, ::-1, :, ::-1, :],
         tf.equal(rand, 16): lambda: tf.transpose(x, (0, 2, 3, 1, 4))[:, ::-1, ::-1, :, :],
         tf.equal(rand, 17): lambda: tf.transpose(x, (0, 2, 3, 1, 4))[:, :, ::-1, ::-1, :],
         tf.equal(rand, 18): lambda: tf.transpose(x, (0, 2, 3, 1, 4))[:, ::-1, :, ::-1, :],
         tf.equal(rand, 19): lambda: tf.transpose(x, (0, 2, 3, 1, 4))[:, ::-1, ::-1, ::-1, :], 
         tf.equal(rand, 20): lambda: tf.transpose(x, (0, 3, 1, 2, 4))[:, ::-1, ::-1, :, :],
         tf.equal(rand, 21): lambda: tf.transpose(x, (0, 3, 1, 2, 4))[:, ::-1, :, ::-1, :],
         tf.equal(rand, 22): lambda: tf.transpose(x, (0, 3, 1, 2, 4))[:, :, ::-1, ::-1, :],
         tf.equal(rand, 23): lambda: tf.transpose(x, (0, 3, 1, 2, 4))[:, ::-1, ::-1, ::-1, :],
         tf.equal(rand, 24): lambda: x}, default = lambda: x, exclusive = True)
    
    return tf.case({tf.equal(inv, 1): lambda: x[:, ::-1, ::-1, ::-1, :]}, default = lambda: x, exclusive = True)

### Construct the critic network
The critic encodes a series of four convolutional layers, with a gradual reduction in their kernel sizes from 7x7x7 to 1x1x1, activated with $\texttt{leaky ReLU}$. The critic takes as input the high-resolution initial conditions and real density field. The output of the final convolutional layer is flattened and fed into a fully connected layer with linear activation.
Essentially, the critic reduces the input real and generated 3D high-resolution density fields to a compact representation whose difference is an approximation to the Wasserstein distance, conditional on their respective initial conditions.

In [5]:
def W(density, IC, pad_hr, slice_size_hr, reference=False):

    IC = tf.slice(IC, [0, pad_hr, pad_hr, pad_hr, 0], [-1, slice_size_hr-2*pad_hr, slice_size_hr-2*pad_hr, slice_size_hr-2*pad_hr, -1])
    print("IC")
    print(IC)
    
    if reference:
        density = tf.slice(density, [0, pad_hr, pad_hr, pad_hr, 0], [-1, slice_size_hr-2*pad_hr, slice_size_hr-2*pad_hr, slice_size_hr-2*pad_hr, -1])
        print("density")
        print(density)
    
    data = tf.concat((density, IC), axis=4)
    print("data")
    print(data)

    w1 = tf.get_variable("W_w1", 
                         (7, 7, 7, 2, 8), 
                         dtype=tf.float32, 
                         initializer=tf.random_normal_initializer(0, 0.1))
    b1 = tf.get_variable("W_b1", 
                         (8), 
                         dtype=tf.float32, 
                         initializer=tf.constant_initializer(0.1))
    x1 = tf.nn.leaky_relu(tf.nn.conv3d(data, 
                                       w1, 
                                       strides=[1, 2, 2, 2, 1], 
                                       padding='VALID') + b1, 0.1)
    print(x1)

    w2 = tf.get_variable("W_w2",
                         (5, 5, 5, 8, 16),
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(0, 0.1))
    b2 = tf.get_variable("W_b2",
                         (16),
                         dtype=tf.float32,
                         initializer=tf.constant_initializer(0.1))
    x2 = tf.nn.leaky_relu(tf.nn.conv3d(x1, 
                                       w2, 
                                       strides=[1, 1, 1, 1, 1], 
                                       padding='VALID') + b2, 0.1)
    print(x2)    
    w3 = tf.get_variable("W_w3",
                         (3, 3, 3, 16, 32),
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(0, 0.1))
    b3 = tf.get_variable("W_b3",
                         (32),
                         dtype=tf.float32,
                         initializer=tf.constant_initializer(0.1))
    x3 = tf.nn.leaky_relu(tf.nn.conv3d(x2, 
                                       w3, 
                                       strides=[1, 2, 2, 2, 1], 
                                       padding='VALID') + b3, 0.1)
    
    w4 = tf.get_variable("W_w4",
                         (1, 1, 1, 32, 64),
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(0, 0.1))
    b4 = tf.get_variable("W_b4",
                         (64),
                         dtype=tf.float32,
                         initializer=tf.constant_initializer(0.1))
    x4 = tf.nn.leaky_relu(tf.nn.conv3d(x3, 
                                       w4, 
                                       strides=[1, 1, 1, 1, 1], 
                                       padding='VALID') + b4, 0.1)    
        
    x5 = tf.reshape(x4, (-1, np.product(x4.get_shape().as_list()[1:])))
    w5 = tf.get_variable("W_w5", 
                         (x5.get_shape().as_list()[-1], 1), 
                         dtype=tf.float32, 
                         initializer=tf.random_normal_initializer(0, 0.1))
    b5 = tf.get_variable("W_b5", 
                         (1), 
                         dtype=tf.float32, 
                         initializer=tf.constant_initializer(0.1))
    x_out = tf.matmul(x5, w5) + b5
    
    return x_out

### Residual inception module
We employ Inception blocks, encoding residual connections in a modified variant of the originally proposed Inception module, in the architecture of our super-resolution emulator, as illustrated below. The largest convolutional kernel in the Inception module is $7\times7\times7$.

In [6]:
def inception(x, filters_1x1x1_7x7x7, filters_7x7x7, filters_1x1x1_5x5x5, filters_5x5x5, filters_1x1x1_3x3x3, filters_3x3x3, filters_1x1x1, name):
    input_filters = x.get_shape().as_list()[-1]

    w_1x1x1_7x7x7 = tf.get_variable(name + "_w_1x1x1_7x7x7", (1, 1, 1, input_filters, filters_1x1x1_7x7x7), dtype = tf.float32, initializer = tf.random_normal_initializer(0, 0.1))
    b_1x1x1_7x7x7 = tf.get_variable(name + "_b_1x1x1_7x7x7", (filters_1x1x1_7x7x7), dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    x_1x1x1_7x7x7 = tf.nn.conv3d(x, w_1x1x1_7x7x7, strides = [1, 1, 1, 1, 1], padding = 'VALID') + b_1x1x1_7x7x7
    w_7x7x7 = tf.get_variable(name + "_w_7x7x7", (7, 7, 7, filters_1x1x1_7x7x7, filters_7x7x7), dtype = tf.float32, initializer = tf.random_normal_initializer(0, 0.1))
    b_7x7x7 = tf.get_variable(name + "_b_7x7x7", (filters_7x7x7), dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    x_7x7x7 = tf.nn.conv3d(x_1x1x1_7x7x7, w_7x7x7, strides = [1, 1, 1, 1, 1], padding = 'VALID') + b_7x7x7    
    
    w_1x1x1_5x5x5 = tf.get_variable(name + "_w_1x1x1_5x5x5", (1, 1, 1, input_filters, filters_1x1x1_5x5x5), dtype = tf.float32, initializer = tf.random_normal_initializer(0, 0.1))
    b_1x1x1_5x5x5 = tf.get_variable(name + "_b_1x1x1_5x5x5", (filters_1x1x1_5x5x5), dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    x_1x1x1_5x5x5 = tf.nn.conv3d(x, w_1x1x1_5x5x5, strides = [1, 1, 1, 1, 1], padding = 'VALID') + b_1x1x1_5x5x5
    w_5x5x5 = tf.get_variable(name + "_w_5x5x5", (5, 5, 5, filters_1x1x1_5x5x5, filters_5x5x5), dtype = tf.float32, initializer = tf.random_normal_initializer(0, 0.1))
    b_5x5x5 = tf.get_variable(name + "_b_5x5x5", (filters_5x5x5), dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    x_5x5x5 = tf.nn.conv3d(x_1x1x1_5x5x5, w_5x5x5, strides = [1, 1, 1, 1, 1], padding = 'VALID') + b_5x5x5
    x_5x5x5 = tf.slice(x_5x5x5, [0, 1, 1, 1, 0], [-1, x_7x7x7.get_shape().as_list()[1], x_7x7x7.get_shape().as_list()[2], x_7x7x7.get_shape().as_list()[3], -1])
    
    w_1x1x1_3x3x3 = tf.get_variable(name + "_w_1x1x1_3x3x3", (1, 1, 1, input_filters, filters_1x1x1_3x3x3), dtype = tf.float32, initializer = tf.random_normal_initializer(0, 0.1))
    b_1x1x1_3x3x3 = tf.get_variable(name + "_b_1x1x1_3x3x3", (filters_1x1x1_3x3x3), dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    x_1x1x1_3x3x3 = tf.nn.conv3d(x, w_1x1x1_3x3x3, strides = [1, 1, 1, 1, 1], padding = 'VALID') + b_1x1x1_3x3x3
    w_3x3x3 = tf.get_variable(name + "_w_3x3x3", (3, 3, 3, filters_1x1x1_3x3x3, filters_3x3x3), dtype = tf.float32, initializer = tf.random_normal_initializer(0, 0.1))
    b_3x3x3 = tf.get_variable(name + "_b_3x3x3", (filters_3x3x3), dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    x_3x3x3 = tf.nn.conv3d(x_1x1x1_3x3x3, w_3x3x3, strides = [1, 1, 1, 1, 1], padding = 'VALID') + b_3x3x3
    x_3x3x3 = tf.slice(x_3x3x3, [0, 2, 2, 2, 0], [-1, x_7x7x7.get_shape().as_list()[1], x_7x7x7.get_shape().as_list()[2], x_7x7x7.get_shape().as_list()[3], -1])
    
    w_1x1x1 = tf.get_variable(name + "_w_1x1x1", (1, 1, 1, input_filters, filters_1x1x1), dtype = tf.float32, initializer = tf.random_normal_initializer(0, 0.1))
    b_1x1x1 = tf.get_variable(name + "_b_1x1x1", (filters_1x1x1), dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    x_1x1x1 = tf.nn.conv3d(x, w_1x1x1, strides = [1, 1, 1, 1, 1], padding = 'SAME') + b_1x1x1
    x_1x1x1 = tf.slice(x_1x1x1, [0, 3, 3, 3, 0], [-1, x_7x7x7.get_shape().as_list()[1], x_7x7x7.get_shape().as_list()[2], x_7x7x7.get_shape().as_list()[3], -1])

    x_out = tf.concat((x_7x7x7, x_5x5x5, x_3x3x3, x_1x1x1), axis = 4)
    output_filters = x_out.get_shape().as_list()[-1]
    x_out_ = tf.slice(tf.transpose(tf.stack([x[:, :, :, :, 0] for i in range(output_filters)]), perm = [1, 2, 3, 4, 0]), [0, 2, 2, 2, 0], [-1, x_5x5x5.get_shape().as_list()[1], x_5x5x5.get_shape().as_list()[2], x_5x5x5.get_shape().as_list()[3], -1])
    
    return tf.add(x_out, x_out_)

### 1x1x1 convolution

To be used as the output layer in the super-resolution emulator.

In [7]:
def kernel_1x1x1_convolutions(x_in, filters_1x1x1, name):
    input_filters = x_in.get_shape().as_list()[-1]
        
    w_1x1x1 = tf.get_variable(name + "_w_1x1x1",
                              (1, 1, 1, input_filters, filters_1x1x1),
                              dtype=tf.float32,
                              initializer=tf.random_normal_initializer(0, 0.1))
    b_1x1x1 = tf.get_variable(name + "_b_1x1x1",
                              (filters_1x1x1),
                              dtype=tf.float32,
                              initializer=tf.constant_initializer(0.1))
    x_1x1x1 = tf.nn.conv3d(x_in, w_1x1x1, strides = [1, 1, 1, 1, 1], padding='SAME') + b_1x1x1
    
    return x_1x1x1

### Construct the super-resolution emulator
Our emulator takes as input a low-resolution density field cube and augments the features to yield the corresponding high-resolution density field.

The residual Inception blocks propagates the information from fairly local regions of the dark matter field, and since we use residual connections in the Inception blocks, we combine structure from distant patches, whilst still retaining a close relation to the density field itself. The subsequent strided 1x1x1 convolution, with no residual connection, leads to the desired larger array size. Since there is an enormous complexity in the low-res density field, we use many filters in each layer to learn the wide variety of possible features. 

The architecture of the emulator is driven by simple physical principles. It is designed to perform a physical mapping that encodes some fundamental symmetries (rotational and translational invariance) and as such, the cosmological principle, and the maximum extent of the receptive field (i.e. size of the largest convolutional kernel in the Inception module) is motivated by causal transport arguments to ensure that the non-local information is captured by the network. The non-linearity involved in this super-resolution procedure is provided by the $\texttt{leaky ReLU}$ activation with a leaky parameter of $\alpha=0.1$. To ensure the positivity of the final high-resolution field, we use $\texttt{ReLU}$ at the output layer.

In [8]:
def G(δ_lr, δ_ic, num_convs, slice_size_hr=40, n_filters=6):
    
    print(δ_lr)
    δ_lr = tf.keras.layers.UpSampling3D(size=(2, 2, 2), data_format="channels_last")(δ_lr)
    print(δ_lr)
    
    x = tf.nn.leaky_relu(inception(δ_lr, n_filters, n_filters, n_filters, n_filters, n_filters, n_filters, n_filters, 'layer_1'), 0.1)
    print(x)

    y = tf.nn.leaky_relu(inception(δ_ic, n_filters, n_filters, n_filters, n_filters, n_filters, n_filters, n_filters, 'layer_3'), 0.1)
    print(y)
        
    x_out = tf.concat((x, y), axis=4)
    print(x_out)
    x_out = tf.nn.relu(inception(x_out, n_filters, n_filters, n_filters, n_filters, n_filters, n_filters, n_filters, 'layer_5'))
    print(x_out)
    
    x_out = kernel_1x1x1_convolutions(x_out, 1, 'layer_out')
    print(x_out)
    
    return x_out

### Training methodology

We choose the input to the emulator to be conveniently larger to eliminate the need for padding. We therefore compute the corresponding sizes of the tensors, which depends on the number of convolutional layers and the desired size of the prediction. If the network has two Inception modules, where the largest convolutional kernel is 7x7x7, then our input must be larger by $2\times(7-1) = 12$ voxels on each side.

In [9]:
num_convs = 2
pad = 0
input_patch = 20
slice_size_lr = int(input_patch + 2*pad)
slice_size_hr = input_patch*2

pad_hr = int((num_convs * 6)/2)
pad_lr = int(pad_hr/2)

During training, we load both the entire low- and high-resolution density fields and initial conditions from the simulation into the TensorFlow graph and select by index sub-volume elements of size $40^3$ and $20^3$, respectively, which massively reduces computation time compared with passing the 3D slices of data at each weight update.

In [10]:
δ_lr_init = tf.placeholder(dtype=tf.float32,
                        shape=(lr_resolution, lr_resolution, lr_resolution),
                        name="initialise_low_res_density")
δ_lr = tf.Variable(δ_lr_init, trainable=False, name="low_res_density")
δ_lr_assign = tf.assign(δ_lr, δ_lr_init)

δ_ic_init = tf.placeholder(dtype=tf.float32,
                        shape=(hr_resolution, hr_resolution, hr_resolution),
                        name="initialise_high_res_IC")
δ_ic = tf.Variable(δ_ic_init, trainable=False, name="high_res_IC")
δ_ic_assign = tf.assign(δ_ic, δ_ic_init)

δ_hr_init = tf.placeholder(dtype=tf.float32,
                        shape=(hr_resolution, hr_resolution, hr_resolution),
                        name="initialise_high_res_density")
δ_hr = tf.Variable(δ_hr_init, trainable=False, name="high_res_density")
δ_hr_assign = tf.assign(δ_hr, δ_hr_init)

We are going to feed into the network a random number for the rotation of the box, a set of x, y and z indices to grab the first element of the slice from the field. We need two sets of indices, for the low- and high-resolution slices, respectively. We also need a random number $\epsilon=[0, 1]$ for the probability of the density field during the gradient penalty, and finally a number for the strength of the coupling of the gradient penalty. We will start with an initial batch size `m` of unity but will progressively double the batch size for every 100k weight updates.

In [11]:
def process_indices(ind, x_size, y_size, z_size):
    size = x_size*y_size*z_size
    all_indices = np.zeros((size*ind.shape[0], 3)).astype(np.int32)
    
    for i in range(ind.shape[0]):
        counter = 0
        for kk in range(z_size):
            for jj in range(y_size):
                for ii in range(x_size):
                    all_indices[counter + i*size] = [ind[i, 0] + ii, ind[i, 1] + jj, ind[i, 2] + kk]
                    counter += 1

    return all_indices

In [12]:
rotation = tf.placeholder(dtype=tf.int32, shape=(None,), name="rotation")
inv = tf.placeholder(dtype=tf.int32, shape=(None,), name="inversion")

process_indices_lr = tf.expand_dims(tf.Variable(process_indices(np.array([[0,0,0]]), slice_size_lr, slice_size_lr, slice_size_lr), trainable=False), 0)
process_indices_hr = tf.expand_dims(tf.Variable(process_indices(np.array([[0,0,0]]), slice_size_hr, slice_size_hr, slice_size_hr), trainable=False), 0)

ind_lr = tf.placeholder(dtype=tf.int32, shape=(None, 3), name="ind_lr")
ind_hr = tf.placeholder(dtype=tf.int32, shape=(None, 3), name="ind_hr")

final_ind_lr = tf.reshape(tf.add(tf.expand_dims(ind_lr,1), process_indices_lr), (-1, 3), name="lr_ind_reshape")
final_ind_hr = tf.reshape(tf.add(tf.expand_dims(ind_hr,1), process_indices_hr), (-1, 3), name="hr_ind_reshape")

ϵ = tf.placeholder(dtype=tf.float32, shape=(None), name="epsilon")
ϵ_ = tf.expand_dims(tf.expand_dims(tf.expand_dims(tf.expand_dims(ϵ, 1), 1), 1), 1)
λ = tf.placeholder(dtype=tf.float32, shape=(), name="lambda")

To be able to pass a cube to the graph, we will define two placeholders; one for passing low-resolution density field and another for the high-resolution initial conditions.

In [13]:
single_δ_lr = tf.placeholder(dtype=tf.float32, 
                          shape=(1, slice_size_lr, slice_size_lr, slice_size_lr, 1),
                          name="single_delta_lr")

single_δ_ic = tf.placeholder(dtype=tf.float32, 
                          shape=(1, slice_size_hr, slice_size_hr, slice_size_hr, 1),
                          name="single_delta_ic")

While we train on smaller boxes, we eventually verify the network's performance on an unseen simulation (*test set*), where we predict simulations larger than the ones used to train the network.

In [14]:
big_size_lr = 128
big_slice_size_lr = int(big_size_lr+2*pad)

big_δ_lr = tf.placeholder(dtype=tf.float32, 
                       shape=(1, big_slice_size_lr, big_slice_size_lr, big_slice_size_lr, 1),
                       name="big_delta_lr")

big_slice_size_ic = big_size_lr*2

big_δ_ic = tf.placeholder(dtype=tf.float32, 
                       shape=(1, big_slice_size_ic, big_slice_size_ic, big_slice_size_ic, 1),
                       name="big_delta_ic")

The low-resolution densities, and corresponding initial conditions, stored in the graph are then sliced using the fed indices and `m` batches are concatenated if we wanted to increase the batch size. These tensors are then rotated and/or mirrored (the same way for both the density and the initial conditions) as data augmentation.

In [15]:
gather_lr = tf.gather_nd(δ_lr, final_ind_lr, name="gather_lr")
gather_ic = tf.gather_nd(δ_ic, final_ind_hr, name="gather_ic")
gather_hr = tf.gather_nd(δ_hr, final_ind_hr, name="gather_hr")

real_δ_lr = tf.reshape(gather_lr, (-1, slice_size_lr, slice_size_lr, slice_size_lr, 1), name="lr_reshape")
real_δ_ic = tf.reshape(gather_ic, (-1, slice_size_hr, slice_size_hr, slice_size_hr, 1), name="ic_reshape")
real_δ_hr = tf.reshape(gather_hr, (-1, slice_size_hr, slice_size_hr, slice_size_hr, 1), name="hr_reshape")

δ_lr_rotated = tf.squeeze(tf.map_fn(rotations, [tf.expand_dims(real_δ_lr, 1), rotation, inv], dtype=tf.float32), 1)
δ_ic_rotated = tf.squeeze(tf.map_fn(rotations, [tf.expand_dims(real_δ_ic, 1), rotation, inv], dtype=tf.float32), 1)
δ_hr_rotated = tf.squeeze(tf.map_fn(rotations, [tf.expand_dims(real_δ_hr, 1), rotation, inv], dtype=tf.float32), 1)

### Computing the approximate Wasserstein distance

The output of the critic is a single scalar which is used to compute the approximately learned Wasserstein distance between the predicted and true high-resolution density field given a particular generative network. This output can therefore be used to compute the loss function which is minimized to train the emulator.

First we need the critic result of the real high-resolution density field.

In [16]:
with tf.variable_scope("W") as scope:
    W_real = tf.identity(W(δ_hr_rotated, δ_ic_rotated, pad_hr, slice_size_hr, reference=True), name="W_real")

IC
Tensor("W/Slice:0", shape=(?, 28, 28, 28, 1), dtype=float32)
density
Tensor("W/Slice_1:0", shape=(?, 28, 28, 28, 1), dtype=float32)
data
Tensor("W/concat:0", shape=(?, 28, 28, 28, 2), dtype=float32)
Tensor("W/LeakyRelu:0", shape=(?, 11, 11, 11, 8), dtype=float32)
Tensor("W/LeakyRelu_1:0", shape=(?, 7, 7, 7, 16), dtype=float32)


And then we want the output of the emulator (and we will also create another input to the graph provided by the fed density). We also need to grab the result of the critic from this generated output from the emulator.

In [17]:
with tf.variable_scope("G_com") as scope:
    generated_δ_hr = tf.identity(G(δ_lr_rotated, δ_ic_rotated, num_convs), name="generated_delta_hr")
    scope.reuse_variables()
    output = tf.identity(G(single_δ_lr, single_δ_ic, num_convs, slice_size_hr), name="output")
    scope.reuse_variables()
    big_output = tf.identity(G(big_δ_lr, big_δ_ic, num_convs, big_slice_size_ic), name="big_output")
    
with tf.variable_scope("W") as scope:
    scope.reuse_variables()
    W_gen = tf.identity(W(generated_δ_hr, δ_ic_rotated, pad_hr, slice_size_hr), name="W_gen")

Tensor("Squeeze:0", shape=(?, 20, 20, 20, 1), dtype=float32)
Tensor("G_com/up_sampling3d/concat_2:0", shape=(?, 40, 40, 40, 1), dtype=float32)
Tensor("G_com/LeakyRelu:0", shape=(?, 34, 34, 34, 24), dtype=float32)
Tensor("G_com/LeakyRelu_1:0", shape=(?, 34, 34, 34, 48), dtype=float32)
Tensor("G_com/concat_2:0", shape=(?, 34, 34, 34, 72), dtype=float32)
Tensor("G_com/Relu:0", shape=(?, 28, 28, 28, 48), dtype=float32)
Tensor("G_com/add_24:0", shape=(?, 28, 28, 28, 1), dtype=float32)
Tensor("single_delta_lr:0", shape=(1, 20, 20, 20, 1), dtype=float32)
Tensor("G_com/up_sampling3d_1/concat_2:0", shape=(1, 40, 40, 40, 1), dtype=float32)
Tensor("G_com/LeakyRelu_2:0", shape=(1, 34, 34, 34, 24), dtype=float32)
Tensor("G_com/LeakyRelu_3:0", shape=(1, 34, 34, 34, 48), dtype=float32)
Tensor("G_com/concat_6:0", shape=(1, 34, 34, 34, 72), dtype=float32)
Tensor("G_com/Relu_1:0", shape=(1, 28, 28, 28, 48), dtype=float32)
Tensor("G_com/add_49:0", shape=(1, 28, 28, 28, 1), dtype=float32)
Tensor("big_delt

For improved training performance, we implement the gradient penalty method via the addition of a penalty term in the critic loss, as an alternative to the standard weight clipping, to enforce the Lipschitz-1 constraint on the critic. This is a requirement for computing the approximate Wasserstein distance. The Lipschitz constraint is enforced by penalizing the gradient norm for random samples $\hat{\mathbf{x}} \sim \mathbb{P}_{\hat{\mathbf{x}}}$, where $\hat{\mathbf{x}} = \epsilon \mathbf{x} + (1 - \epsilon) \tilde{\mathbf{x}}$ and $\epsilon$ is sampled randomly and uniformly, $\epsilon \in [0,1]$.

We calculate this set of random samples, pass it through the critic and then calculate the outputs gradient with respect to the input.

In [18]:
δ_hr_rotated = tf.slice(δ_hr_rotated, [0, pad_hr, pad_hr, pad_hr, 0], [-1, slice_size_hr-2*pad_hr, slice_size_hr-2*pad_hr, slice_size_hr-2*pad_hr, -1])

hat_δ_hr = ϵ_ * δ_hr_rotated + (1. - ϵ_) * generated_δ_hr

with tf.variable_scope("W") as scope:
    scope.reuse_variables()
    W_hat = tf.identity(W(hat_δ_hr, δ_ic_rotated, pad_hr, slice_size_hr), name="W_hat")
    
W_grad = tf.gradients(W_hat, hat_δ_hr, name="W_grad")

IC
Tensor("W_2/Slice:0", shape=(?, 28, 28, 28, 1), dtype=float32)
data
Tensor("W_2/concat:0", shape=(?, 28, 28, 28, ?), dtype=float32)
Tensor("W_2/LeakyRelu:0", shape=(?, 11, 11, 11, 8), dtype=float32)
Tensor("W_2/LeakyRelu_1:0", shape=(?, 7, 7, 7, 16), dtype=float32)


The loss function for the critic and the emulator can now be defined. We define a critic loss with the gradient penalty to be optimised, and a pure measure of the Wasserstein distance for tracking the distance. The emulator loss does not need the gradient penalty.

In [19]:
W_loss = W_gen - W_real
W_loss_g = tf.reduce_mean(W_loss + λ * (tf.norm(W_grad) - 1.)**2.)
W_loss_n = tf.identity(-tf.reduce_mean(W_loss), name="W_loss_n")
G_loss = tf.reduce_mean(- W_loss, name="G_loss")

### Use the Adam optimizer to train the super-resolution emulator and critic networks
Standard choice of hyperparameters for the Adam optimizer.

Since we want to only update the weights of the critic for the critic update and only update the weights of the emulator for the emulator update, we cycle through the gradient calculations and remove the superfluous weight update operations from the list of operations to be given to the optimizer.

In [20]:
α = 1e-4
β1 = 0.5
β2 = 0.999
W_opt = tf.train.AdamOptimizer(α, β1, β2)
W_grad = W_opt.compute_gradients(W_loss_g)
W_grad_ = []
for i in W_grad:
    if "W" in i[1].name:
        W_grad_.append(i)
W_train = W_opt.apply_gradients(W_grad_, name="W_train")
        
G_opt = tf.train.AdamOptimizer(α, β1, β2)
G_grad = G_opt.compute_gradients(G_loss)
G_grad_ = []
for i in G_grad:
    if "G" in i[1].name:
        G_grad_.append(i)
G_train = G_opt.apply_gradients(G_grad_, name="G_train")

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


#### Create a session to launch the graph 

We now feed in the low- and high-resolution density fields and initial conditions to be stored in the graph so that they do not have to be passed during training.

In [21]:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer(), 
         feed_dict={"initialise_low_res_density:0": lr_final,
                    "initialise_high_res_density:0": hr_final,
                    "initialise_high_res_IC:0": hr_initial})

### Pre-training preparations
We choose a standard value for the arbitrary penalty coefficient, which has been shown to work well for a range of network architectures and data sets, and update the weights of the critic five times per single emulator update.

In [22]:
λ_value = 10.
n_critic = 10
epochs = 500000
W_loss_training = []
G_loss_training = []
W_loss_validation = []
G_loss_validation = []
m = 1
i_epoch = 1
saver = tf.train.Saver()

### Training routine
The training proceeds for a set number of weight updates, until an overall convergence of the emulator is achieved. The training rationale is to reduce the Wasserstein distance between the true and generated high-resolution density fields, conditional on the initial conditions, such that the emulator learns the correct mapping from low- to high-resolution density fields.

The training steps are as follows:
- The input to the emulator is randomly chosen and the corresponding true high-resolution density volume is selected;
- To encode some further symmetries through our training set, we also perform a rotation of the selected patches, thereby extracting the input 3D slice from a randomly oriented region, and/or randomly mirrored along the three axes;
- The initial training step involves the optimization of the weights of the critic network to minimize the augmented loss function (including the gradient penalty), while concurrently freezing the parameters of the emulator;
- The weights of the critic must be updated $n_{\mathrm{critic}}$ times, where $n_\mathrm{critic}$ is sufficient for the critic to converge;
- In the subsequent step, the critic weights are temporarily anchored, and the emulator parameters are adjusted;
- The emulator employs the gradient of the Wasserstein loss function w.r.t its parameters for training;
- The training routine then proceeds in iterative fashion, until an overall convergence of the emulator is achieved.

We use a $512^3$ simulation box for training, where we use a large portion of the box for training and the remaining section for validation, such that we utilize non-mutual parts of the box for the training and validation set.

In [None]:
tq = tqdm.trange(epochs, leave = True, desc = "Epochs")
for e in tq:
    for t in range(n_critic):
        i = np.random.randint(0, 25, m)
        z = np.random.randint(0, 2, m)
        ind_value_1 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
        ind_value_2 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
        ind_value_3 = np.random.randint(pad, lr_resolution - slice_size_lr - (slice_size_lr - pad), m)
        ϵ_value = np.random.uniform(0, 1, m)
        sess.run("W_train",
                 feed_dict={"ind_lr:0": np.array([ind_value_1, ind_value_2, ind_value_3]).T,
                            "ind_hr:0": np.array([2*ind_value_1, 2*ind_value_2, 2*ind_value_3]).T,
                            "epsilon:0": ϵ_value,
                            "lambda:0": λ_value,
                            "rotation:0": i,
                            "inversion:0": z})

    # W training loss
    i = np.random.randint(0, 25, m)
    z = np.random.randint(0, 2, m)
    ind_value_1 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_2 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_3 = np.random.randint(pad, lr_resolution - slice_size_lr - (slice_size_lr - pad), m)
    ϵ_value = np.random.uniform(0, 1, m)
    W_loss_training.append(sess.run("W_loss_n:0", 
                                    feed_dict={"ind_lr:0": np.array([ind_value_1, ind_value_2, ind_value_3]).T,
                                               "ind_hr:0": np.array([2*ind_value_1, 2*ind_value_2, 2*ind_value_3]).T,
                                               "epsilon:0": ϵ_value,
                                               "lambda:0": λ_value,
                                               "rotation:0": i,
                                               "inversion:0": z}))
   
    # W validation loss
    i = np.random.randint(0, 25, m)
    z = np.random.randint(0, 2, m)
    ind_value_1 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_2 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_3 = np.array([lr_resolution - (slice_size_lr - pad) for i in range(m)])
    ϵ_value = np.random.uniform(0, 1, m)
    W_loss_validation.append(sess.run("W_loss_n:0", 
                                      feed_dict={"ind_lr:0": np.array([ind_value_1, ind_value_2, ind_value_3]).T,
                                                 "ind_hr:0": np.array([2*ind_value_1, 2*ind_value_2, 2*ind_value_3]).T,
                                                 "epsilon:0": ϵ_value,
                                                 "lambda:0": λ_value,
                                                 "rotation:0": i,
                                                 "inversion:0": z}))

    # G training loss
    i = np.random.randint(0, 25, m)
    z = np.random.randint(0, 2, m)
    ind_value_1 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_2 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_3 = np.random.randint(pad, lr_resolution - slice_size_lr - (slice_size_lr - pad), m)
    _, G_loss_temp = sess.run(["G_train", "G_loss:0"],
                              feed_dict={"ind_lr:0": np.array([ind_value_1, ind_value_2, ind_value_3]).T,
                                         "ind_hr:0": np.array([2*ind_value_1, 2*ind_value_2, 2*ind_value_3]).T,
                                         "rotation:0": i,
                                         "inversion:0": z})
    G_loss_training.append(G_loss_temp)
   
    # G validation loss
    i = np.random.randint(0, 25, m)
    z = np.random.randint(0, 2, m)
    ind_value_1 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_2 = np.random.randint(pad, lr_resolution - (slice_size_lr - pad), m)
    ind_value_3 = np.array([lr_resolution - (slice_size_lr - pad) for i in range(m)])
    G_loss_validation.append(sess.run("G_loss:0", 
                                      feed_dict={"ind_lr:0": np.array([ind_value_1, ind_value_2, ind_value_3]).T,
                                                 "ind_hr:0": np.array([2*ind_value_1, 2*ind_value_2, 2*ind_value_3]).T, 
                                                 "rotation:0": i,
                                                 "inversion:0": z}))
    
    if i_epoch%100000 == 0:
        m *= 2
    i_epoch += 1
    
    tq.set_postfix(W_training_loss = W_loss_training[-1], 
                   G_training_loss = G_loss_training[-1], 
                   W_validation_loss = W_loss_validation[-1], 
                   G_validation_loss = G_loss_validation[-1])

### Save losses

In [None]:
np.savez("predictions/training_validation_loss.npz", W_loss_training=W_loss_training, G_loss_training=G_loss_training, W_loss_validation=W_loss_validation, G_loss_validation=G_loss_validation)