In [None]:
from astropy.nddata.utils import Cutout2D
from esutil import wcsutil
from astropy.io import fits
from astropy import table
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
from tensorflow import keras

# Store filepaths

In [None]:
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"

In [None]:
tile_list = open(image_dir + "tiles.list", "r")
for tile in tile_list.readlines():
    print(tile[:-1])
tile_list.close()

In [None]:
tile_list = open(image_dir + "tiles.list", "r")

r_images = []
r_weights = []
cats = []

for tile in tile_list:
    tile = tile[:-1] # Remove new line character
    channels = tile.split(" ")
    for c in channels:
        if len(c) == 0: # Line is blank space
            continue
        if c[-1] == "r": # Tile has red channel
            r_images.append(image_dir + c + ".fits")
            r_weights.append(image_dir + c + ".weight.fits.fz")
            cats.append(image_dir + c + ".cat")
            break
tile_list.close()

In [None]:
print(len(r_images))
print(len(r_weights))
print(len(cats))

# Create and train autoencoder

In [None]:
def generate_cutouts(file_indices, batch_size, cutout_size):
    n = 0
    sources = np.zeros((batch_size, cutout_size, cutout_size, 1))
    while True:
        for i in file_indices:
            r_image = fits.open(r_images[i], memmap=True)
            r_weight = fits.open(r_weights[i], memmap=True)

            cat = table.Table.read(cats[i], format="ascii.sextractor")
            for (x, y) in zip(cat["X_IMAGE"], cat["Y_IMAGE"]):
                r_cutout = Cutout2D(r_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                r_weight_cutout = Cutout2D(r_weight[1].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                
                r_lower = np.percentile(r_cutout, 1)
                r_upper = np.percentile(r_cutout, 99)
                sources[n,:,:,0] = (r_cutout - np.min(cutout)) / (r_upper - r_lower)
                weights[n,:,:,0] = (r_weight_cutout - np.min(r_weight_cutout)) / (r_upper - r_lower)
                
                n += 1
                if n == batch_size:
                    r_image.close()
                    r_weight.close()
                    
                    n = 0
                    yield (sources, sources)
                    
                    r_image = fits.open(r_images[i], memmap=True)
                    r_weight = fits.open(r_weights[i], memmap=True)

            r_image.close()
            r_weight.close()

In [None]:
def train_autoencoder(model, cat_train_indices, cat_val_indices, batch_size, cutout_size):
    n_cutouts_train = 0
    for i in cat_train_indices:
        image_cat = table.Table.read(cats[i], format="ascii.sextractor")
        n_cutouts_train += len(image_cat)
    n_cutouts_val = 0
    for i in cat_val_indices:
        image_cat = table.Table.read(cats[i], format="ascii.sextractor")
        n_cutouts_val += len(image_cat)

    train_steps = n_cutouts_train // batch_size
    val_steps = n_cutouts_val // batch_size
    history = model.fit(generate_cutouts(cat_train_indices, batch_size, cutout_size), 
                        epochs=5, steps_per_epoch=train_steps, 
                        validation_data=generate_cutouts(cat_val_indices, batch_size, cutout_size), 
                        validation_steps=val_steps)
    return (model, history)

## Autoencoder with pooling layers

In [None]:
def create_autoencoder1(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dense(128)(x)
    encoded = keras.layers.Dense(1024)(x)
    
    x = keras.layers.Reshape((16,16,4))(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(1, (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
def custom_loss(y_true, y_pred):
    return keras.losses.MSE(y_true*np.sqrt(weights), y_pred*np.sqrt(weights))

In [None]:
cutout_size = 64
autoencoder1 = create_autoencoder1((cutout_size, cutout_size, 1))
#opt = keras.optimizers.Adam(learning_rate=0.005)
autoencoder1.compile(optimizer="adam", loss=custom_loss)

In [None]:
autoencoder1.summary()

In [None]:
cat_train_indices = [100]
cat_val_indices = [0]
batch_size = 32
weights = np.zeros((batch_size, cutout_size, cutout_size, 1))
(autoencoder1, history1) = train_autoencoder(autoencoder1, cat_train_indices, cat_val_indices, batch_size, cutout_size)

In [None]:
def plot_loss_curves(history):
    plt.plot(history.history["loss"], color="g", label="Training")
    plt.plot(history.history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

In [None]:
plot_loss_curves(history1)

In [None]:
def generate_test_cutouts(index, n_cutouts, cutout_size):
    n = 0
    sources = np.zeros((n_cutouts, cutout_size, cutout_size, 1))
    r_image = fits.open(r_images[index], memmap=True)
    cat = table.Table.read(cats[index], format="ascii.sextractor")
    for (x, y) in zip(cat["X_IMAGE"], cat["Y_IMAGE"]):
        r_cutout = Cutout2D(r_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
        r_lower = np.percentile(r_cutout, 1)
        r_upper = np.percentile(r_cutout, 99)
        sources[n,:,:,0] = (r_cutout - np.min(r_cutout)) / (r_upper - r_lower)
        n += 1
        if n == n_cutouts:
            r_image.close()
            return sources

In [None]:
test_index = 4
sources_test = generate_test_cutouts(test_index, 100, cutout_size)

In [None]:
decoded_imgs1 = autoencoder1.predict(sources_test)

In [None]:
def plot_images(images):
    fig, axes = plt.subplots(3,3, figsize=(8,8))
    i = 0
    for row in range(3):
        for col in range(3):
            norm = ImageNormalize(images[i], interval=ZScaleInterval())
            axes[row][col].imshow(images[i], norm=norm)
            i += 1

In [None]:
plot_images(sources_test)

In [None]:
plot_images(decoded_imgs1)

In [None]:
mse = keras.losses.MeanSquaredError()
mse(sources_test, decoded_imgs1).numpy()

## Fully convolutional autoencoder

In [None]:
def create_autoencoder2(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(8, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    encoded = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(x)

    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(8, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(1, (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
autoencoder2 = create_autoencoder2((cutout_size, cutout_size, 1))
#opt = keras.optimizers.Adam(learning_rate=0.005)
autoencoder2.compile(optimizer="adam", loss=custom_loss)

In [None]:
autoencoder2.summary()

In [None]:
(autoencoder2, history2) = train_autoencoder(autoencoder2, cat_train_indices, cat_val_indices, batch_size, cutout_size)

In [None]:
plot_loss_curves(history2)

In [None]:
decoded_imgs2 = autoencoder2.predict(sources_test)

In [None]:
plot_images(sources_test)

In [None]:
plot_images(decoded_imgs2)

In [None]:
mse(sources_test, decoded_imgs2).numpy()