In [1]:
from astropy.io import fits
import numpy as np
import tensorflow as tf
from pathlib import Path
tf.enable_eager_execution()

In [21]:
SEPARABLE_CONV = False   # use separable convolutions in the generator
NGF = 64  # number of generator filters in first conv layer

In [33]:
types_ = ("psf", "dirty", "skymodel", "image", "model")
fits_open = lambda x: fits.open(str(x))[0].data.squeeze()
def dataset_generator(path="/home/gijs/Work/spiel/runs/first_kat7_2018-05-25/results"):
    p = Path(path)
    sorted_ = zip(*(sorted(p.glob(f"*-{t}.fits")) for t in types_))
    for i in sorted_:
        yield tuple(fits_open(j)[np.newaxis, np.newaxis] for j in i)

In [34]:
ds = tf.data.Dataset.from_generator(dataset_generator, output_types=(tf.float32,) * 5)
#ds = ds.batch(32)

In [35]:
it = ds.make_one_shot_iterator()
psf, dirty, skymodel, image, model = next(it)

In [6]:
def gen_conv(batch_input, out_channels):
    initializer = tf.random_normal_initializer(0, 0.02)
    if SEPARABLE_CONV:
        return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
                                          depthwise_initializer=initializer, pointwise_initializer=initializer)
    else:
        return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
                                kernel_initializer=initializer)


def gen_deconv(batch_input, out_channels):
    initializer = tf.random_normal_initializer(0, 0.02)
    if SEPARABLE_CONV:
        _b, h, w, _c = batch_input.shape
        resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2],
                                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same",
                                          depthwise_initializer=initializer, pointwise_initializer=initializer)
    else:
        return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
                                          kernel_initializer=initializer)

In [19]:

def lrelu(x, a):
    with tf.name_scope("lrelu"):
        # adding these together creates the leak part and linear part
        # then cancels them out by subtracting/adding an absolute value term
        # leak: a*x/2 - a*abs(x)/2
        # linear: x/2 + abs(x)/2

        # this block looks like it has 2 inputs on the graph unless we do this
        x = tf.identity(x)
        return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
    
def create_generator(generator_inputs, generator_outputs_channels):
    layers = []

    # encoder_1: [batch, 512, 512, in_channels] => [batch, 256, 256, ngf]
    with tf.variable_scope("encoder_1"):
        output = gen_conv(generator_inputs, NGF)
        layers.append(output)

    layer_specs = [
        NGF * 2,  # encoder_2: [batch, 256, 256, ngf] => [batch, 128, 128, ngf * 2]
        NGF * 4,  # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 4]
        NGF * 8,  # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 8]
        NGF * 8,  # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
        NGF * 8,  # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
        NGF * 8,  # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
        NGF * 8,  # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
        NGF * 8,  # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
    ]

    for out_channels in layer_specs:
        with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
            rectified = lrelu(layers[-1], 0.2)
            # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
            convolved = gen_conv(rectified, out_channels)
            output = batchnorm(convolved)
            layers.append(output)

    layer_specs = [
        (NGF * 8, 0.5),  # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
        (NGF * 8, 0.5),  # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
        (NGF * 8, 0.5),  # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
        (NGF * 8, 0.0),  # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
        (NGF * 8, 0.0),  # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
        (NGF * 4, 0.0),  # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
        (NGF * 2, 0.0),  # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
        (NGF * 2, 0.0),  # decoder_2: [batch, 128, 128, ngf * 2 * 2] => [batch, 256, 256, ngf * 2]
    ]

    num_encoder_layers = len(layers)
    for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
        skip_layer = num_encoder_layers - decoder_layer - 1
        with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
            if decoder_layer == 0:
                # first decoder layer doesn't have skip connections
                # since it is directly connected to the skip_layer
                input = layers[-1]
            else:
                input = tf.concat([layers[-1], layers[skip_layer]], axis=3)

            rectified = tf.nn.relu(input)
            # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
            output = gen_deconv(rectified, out_channels)
            output = batchnorm(output)

            if dropout > 0.0:
                output = tf.nn.dropout(output, keep_prob=1 - dropout)

            layers.append(output)

    # decoder_1: [batch, 256, 256, ngf * 2] => [batch, 512, 512, generator_outputs_channels (1)]
    with tf.variable_scope("decoder_1"):
        input = tf.concat([layers[-1], layers[0]], axis=3)
        rectified = tf.nn.relu(input)
        output = gen_deconv(rectified, generator_outputs_channels)
        #output = tf.tanh(output)
        layers.append(output)

    return layers[-1]

def batchnorm(inputs):
    return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True,
                                         gamma_initializer=tf.random_normal_initializer(1.0, 0.02))

In [38]:
with tf.variable_scope("generator"):
    out_channels = int(skymodel.get_shape()[-1])
    outputs = create_generator(dirty, out_channels)

InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [1,2,2,512] vs. shape[1] = [1,1,2,512] [Op:ConcatV2] name: concat