In [1]:
import tensorflow as tf
tf.enable_eager_execution()

In [2]:
from scipy.signal import fftconvolve
from astropy.io import fits
import numpy as np
import math
from pathlib import Path
from io import BytesIO
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
from vacuum.model import create_model
from vacuum.io import load_data, preprocess, deprocess, fits_open

In [52]:
INPUT_DIR = "/scratch/datasets/kat7_2000/"

SEPARABLE_CONV = False   # use separable convolutions in the generator
NGF = 64                 # number of generator filters in first conv layer
NDF = 64                 # number of discriminator filters in first conv laye
BATCH_SIZE = 5           # number of images in batch
CROP_SIZE = 256
EPS = 1e-12
FLIP = False              # flip images horizontally during training
SCALE_SIZE = 256         # scale images to this size before cropping to 256x256
MAX_EPOCHS = 1           # number of training epochs
LR = 0.0002              # initial learning rate for adam
BETA1 = 0.5              # momentum term of adam
L1_WEIGHT = 100.0        # weight on L1 term for generator gradient
GAN_WEIGHT = 1.0         # weight on GAN term for generator gradient
INPUT_MULTIPLY = 1.0     # Use this to scale in the max input fluxes to about 5 Jy 
DATA_START = 1840
DATA_END = 1899

In [53]:
batch, count = load_data(INPUT_DIR, CROP_SIZE, FLIP, SCALE_SIZE, MAX_EPOCHS,
                         BATCH_SIZE, start=DATA_START, end=DATA_END)
steps_per_epoch = int(math.ceil(count / BATCH_SIZE))
iter = batch.make_one_shot_iterator()
index, min_flux, max_flux, psf, dirty, skymodel = iter.get_next()
print("examples count = %d" % count)

examples count = 59


In [54]:
scaled_dirty = preprocess(dirty, min_, max_)
scaled_skymodel = preprocess(skymodel, min_, max_)
scaled_psf = preprocess(psf, min_, max_)

In [55]:
deprocessed_dirty = deprocess(scaled_dirty, min_, max_)
deprocessed_skymodel = deprocess(scaled_skymodel, min_, max_)
deprocessed_psf = deprocess(scaled_psf, min_, max_)

In [63]:
tf.reduce_max(dirty, axis=(1, 2, 3)),  tf.reduce_max(deprocessed_dirty, axis=(1, 2, 3))

(<tf.Tensor: id=644, shape=(5,), dtype=float32, numpy=
 array([2.3067167, 1.6453992, 1.3210348, 1.47137  , 1.7186923],
       dtype=float32)>, <tf.Tensor: id=646, shape=(5,), dtype=float32, numpy=
 array([2.3067167, 1.6453991, 1.3210349, 1.47137  , 1.7186924],
       dtype=float32)>)

In [35]:
min_ = tf.reduce_min(dirty, axis=(1, 2, 3))
max_ = tf.reduce_max(dirty, axis=(1, 2, 3))

In [37]:
max_flux, max_

(<tf.Tensor: id=335, shape=(5,), dtype=float32, numpy=
 array([2.3067167, 1.6453992, 1.3210348, 1.47137  , 1.7186923],
       dtype=float32)>, <tf.Tensor: id=353, shape=(5,), dtype=float32, numpy=
 array([2.3067167, 1.6453992, 1.3210348, 1.47137  , 1.7186923],
       dtype=float32)>)

In [25]:
max_

<tf.Tensor: id=167, shape=(5,), dtype=float32, numpy=array([2.34, 1.61, 1.17, 1.4 , 1.69], dtype=float32)>

In [38]:
result = dirty / max_[:, None, None, None]

In [44]:
result.numpy()[4].max()

1.0

In [10]:
def render(a, imgdata, title):
    i = a.pcolor(imgdata, cmap='cubehelix')
    f.colorbar(i, ax=a)
    a.set_title(title)

In [11]:
def shift(i, x=0, y=0):
    return tf.image.pad_to_bounding_box(
        i,
        max(0, y),
        max(0, x),
        i.shape.as_list()[1] + abs(y),
        i.shape.as_list()[2] + abs(x))

In [12]:
def shifted_convolve(skymodel, psf, y, x):
    shifted = shift(psf, y=y, x=x)
    #filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
    convolved = tf.nn.conv2d(skymodel, shifted, [1, 1, 1, 1], "SAME")
    residual = deprocessed_dirty - convolved
    return convolved, residual

In [14]:
convolved, residual = shifted_convolve(deprocessed_skymodel, deprocessed_psf, y=-1, x=-1)
f, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, figsize=(16,14))
render(a1, descaled.numpy().squeeze(), 'descaled')
render(a2, deprocessed_dirty.numpy().squeeze(), 'deprocessed_dirty')
render(a3, convolved.numpy().squeeze(), 'convolved')
#render(a4, psf.numpy().squeeze(), 'psf')
render(a4, residual.numpy().squeeze(), 'residual')


InvalidArgumentError: input depth must be evenly divisible by filter depth: 5 vs 257 [Op:Conv2D]

# making sure the shift is sound

In [None]:
_, residual1 = shifted_convolve(deprocessed_skymodel, deprocessed_psf, y=-1, x=-1)
_, residual2 = shifted_convolve(deprocessed_skymodel, deprocessed_psf, y=1, x=1)
_, residual3 = shifted_convolve(deprocessed_skymodel, deprocessed_psf, y=0, x=0)
_, residual4 = shifted_convolve(deprocessed_skymodel, deprocessed_psf, y=-1, x=1)
_, residual5 = shifted_convolve(deprocessed_skymodel, deprocessed_psf, y=1, x=-1)
f, ((a1, a2), (a3, a4), (a5, a6)) = plt.subplots(3, 2, figsize=(16,14))
render(a1, residual1.numpy().squeeze(), 'residual1')
render(a2, residual2.numpy().squeeze(), 'residual2')
render(a3, residual3.numpy().squeeze(), 'residual3')
render(a4, residual4.numpy().squeeze(), 'residual4')
render(a5, residual5.numpy().squeeze(), 'residual5')