In [None]:
from astropy.nddata.utils import Cutout2D
from astropy.io import fits
import fitsio
from astropy import table
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
import tensorflow as tf
from tensorflow import keras

In [None]:
r_image = fitsio.read("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.fits")

In [None]:
image_cat = table.Table.read("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.cat", format="ascii.sextractor")

In [None]:
sources = []
for (x, y) in zip(image_cat["X_IMAGE"], image_cat["Y_IMAGE"]): # Centers of sources
    sources.append(Cutout2D(r_image, (x, y), 32, mode="partial", fill_value=0).data)
sources = np.array(sources)
sources = sources.reshape(*sources.shape, 1)

In [None]:
np.shape(sources)

In [None]:
sources_norm = (sources - np.min(sources)) / (np.max(sources) - np.min(sources))

In [None]:
def create_autoencoder(shape):
    num_out = 1
    num_z = 128
    input_img = keras.Input(shape)
    #input_layer = keras.layers.Input(shape)
    #base_model = keras.applications.ResNet50(include_top=False, weights=None, input_shape=shape)
    #base_model.trainable = True
    #model_out = base_model(input_layer, training=True)
    model_out = keras.layers.GlobalAveragePooling2D()(input_img)
        
    x = keras.layers.Dense(512,activation=keras.layers.LeakyReLU(alpha=0.1))(model_out)
    x = keras.layers.Dense(256,activation=keras.layers.LeakyReLU(alpha=0.1))(x)

    #can probably do this in one layer
    x_out = keras.layers.Dense(num_out)(x)
    z_out = keras.layers.Dense(num_z)(x)
    encoded = tf.concat([x_out, z_out], axis=1) 

    #TODO this decoder was made in a rush and will be changed in future

    y = keras.layers.Dense(num_z + num_out)(encoded)
    y = keras.layers.Dense(256,activation=keras.layers.LeakyReLU(alpha=0.1))(y)
    y = keras.layers.Dense(512,activation=keras.layers.LeakyReLU(alpha=0.1))(y)
    y = keras.layers.Dense(8192)(y)
    y = keras.layers.Reshape([2,2,2048])(y)

    y = keras.layers.Conv2DTranspose(512,3)(y)
    y = keras.layers.Conv2DTranspose(128,5)(y)
    y = keras.layers.Conv2DTranspose(64,9)(y)
    decoded = keras.layers.Conv2DTranspose(1,17)(y)

    return keras.Model(input_img, decoded)

In [None]:
autoencoder = create_autoencoder((sources.shape[1], sources.shape[2], 1))
autoencoder.compile(optimizer='adam', loss="mse")

In [None]:
autoencoder.summary()

In [None]:
history_01 = autoencoder.fit(sources_norm, sources_norm, epochs=10, validation_split=0.2)