In [None]:
from astropy.nddata.utils import Cutout2D
from esutil import wcsutil
from astropy.io import fits
from astropy import table
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
from tensorflow import keras

In [None]:
r_image = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.fits", memmap=True)
r_header = r_image[0].header
r_weights = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.weight.fits.fz", memmap=True)

In [None]:
w = wcsutil.WCS(r_header)

In [None]:
image_cat = table.Table.read("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.cat", format="ascii.sextractor")

In [None]:
image_cat

In [None]:
sources = [] # Normalized cutouts
weights = []
for (ra, dec) in zip(image_cat["ALPHA_J2000"], image_cat["DELTA_J2000"]): # Centers of sources
    (x, y) = w.sky2image(ra, dec)
    cutout = Cutout2D(r_image[0].data, (x, y), 64, mode="partial", fill_value=0).data
    weight_cutout = Cutout2D(r_weights[1].data, (x, y), 64, mode="partial", fill_value=0).data
    sources.append((cutout - np.min(cutout)) / (np.max(cutout) - np.min(cutout)))
    weights.append((weight_cutout - np.min(weight_cutout)) / (np.max(cutout) - np.min(cutout)))
sources = np.array(sources)
sources = sources.reshape(*sources.shape, 1)
weights = np.array(weights)
weights = weights.reshape(*weights.shape, 1)
r_image.close()
r_weights.close()

In [None]:
threshold = int(0.8*len(sources))
sources_train = sources[:threshold]
sources_test = sources[threshold:]
weights_train = sources[:threshold]
weights_test = sources[threshold:]

In [None]:
def create_autoencoder1(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dense(128)(x)
    encoded = keras.layers.Dense(1024)(x)
    
    x = keras.layers.Reshape((16,16,4))(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(1, (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
autoencoder1 = create_autoencoder1((sources.shape[1], sources.shape[2], 1))
#opt = keras.optimizers.Adam(learning_rate=0.005)
autoencoder1.compile(optimizer="adam", loss="mse")

In [None]:
autoencoder1.summary()

In [None]:
history1 = autoencoder1.fit(sources_train, sources_train, batch_size=128, epochs=50, validation_split=0.2, sample_weight=np.sqrt(weights_train))

In [None]:
def plot_loss_curves(history):
    plt.plot(history.history["loss"], color="g", label="Training")
    plt.plot(history.history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

In [None]:
plot_loss_curves(history1)

In [None]:
decoded_imgs1 = autoencoder1.predict(sources_test[:100])

In [None]:
def plot_images(images):
    fig, axes = plt.subplots(3,3, figsize=(8,8))
    i = 0
    for row in range(3):
        for col in range(3):
            norm = ImageNormalize(images[i], interval=ZScaleInterval())
            axes[row][col].imshow(images[i], norm=norm)
            i += 1

In [None]:
plot_images(sources_test)

In [None]:
plot_images(decoded_imgs1)

In [None]:
def create_autoencoder2(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(8, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    encoded = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(x)

    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(8, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(1, (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
autoencoder2 = create_autoencoder2((sources.shape[1], sources.shape[2], 1))
#opt = keras.optimizers.Adam(learning_rate=0.005)
autoencoder2.compile(optimizer="adam", loss="mse")

In [None]:
autoencoder2.summary()

In [None]:
history2 = autoencoder2.fit(sources_train, sources_train, batch_size=128, epochs=50, validation_split=0.2, sample_weight=np.sqrt(weights_train))

In [None]:
plot_loss_curves(history2)

In [None]:
decoded_imgs2 = autoencoder2.predict(sources_test[:100])

In [None]:
plot_images(sources_test)

In [None]:
plot_images(decoded_imgs2)