In [1]:
import tensorflow as tf
tf.enable_eager_execution()

In [2]:
from scipy.signal import fftconvolve
from astropy.io import fits
import numpy as np
import math
from pathlib import Path
from io import BytesIO
import matplotlib.pyplot as plt
%matplotlib inline

In [13]:
from vacuum.model import create_model
from vacuum.io_ import load_data, preprocess, deprocess, fits_open
from vacuum.util import shift

In [23]:
INPUT_DIR = "/scratch/gijs/datasets/meerkat16"

SEPARABLE_CONV = False   # use separable convolutions in the generator
NGF = 64                 # number of generator filters in first conv layer
NDF = 64                 # number of discriminator filters in first conv laye
BATCH_SIZE = 1           # number of images in batch
CROP_SIZE = 256
EPS = 1e-12
FLIP = False              # flip images horizontally during training
SCALE_SIZE = 256         # scale images to this size before cropping to 256x256
MAX_EPOCHS = 1           # number of training epochs
LR = 0.0002              # initial learning rate for adam
BETA1 = 0.5              # momentum term of adam
L1_WEIGHT = 100.0        # weight on L1 term for generator gradient
GAN_WEIGHT = 1.0         # weight on GAN term for generator gradient
INPUT_MULTIPLY = 1.0     # Use this to scale in the max input fluxes to about 5 Jy 
DATA_START = 1840
DATA_END = 1899

In [24]:
batch, count = load_data(INPUT_DIR, CROP_SIZE, FLIP, SCALE_SIZE, MAX_EPOCHS,
                         BATCH_SIZE, start=DATA_START, end=DATA_END)
steps_per_epoch = int(math.ceil(count / BATCH_SIZE))
iter = batch.make_one_shot_iterator()
index, min_flux, max_flux, psf, dirty, skymodel = iter.get_next()
print("examples count = %d" % count)

examples count = 59


In [25]:
scaled_dirty = preprocess(dirty, min_flux, max_flux)
scaled_skymodel = preprocess(skymodel, min_flux, max_flux)
scaled_psf = preprocess(psf, min_flux, max_flux)

In [33]:
deprocessed_output = deprocess(skymodel, min_flux, max_flux)
deprocessed_dirty = deprocess(dirty, min_flux, max_flux)

In [34]:
shifted = shift(psf, y=0, x=-1)
shifted.shape

TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)])

In [35]:
filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
filter_.shape

TensorShape([Dimension(256), Dimension(256), Dimension(1), Dimension(1)])

In [37]:
convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME")
convolved.shape

TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)])

In [39]:
gen_loss_RES = tf.tensordot(deprocessed_output, (-2 * deprocessed_dirty + convolved ), (1, 2))
gen_loss_RES.shape

TensorShape([Dimension(1), Dimension(256), Dimension(1), Dimension(1), Dimension(256), Dimension(1)])

In [44]:
tf.reduce_sum(tf.multiply(deprocessed_output, convolved -2 * deprocessed_dirty))

<tf.Tensor: id=700, shape=(), dtype=float32, numpy=109013.24>