In [None]:
import os
import h5py
import shutil
from astropy.nddata.utils import Cutout2D
from astropy.io import fits
from astropy import table
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
import tensorflow as tf
from tensorflow import keras

In [None]:
dest = os.path.expandvars("$SLURM_TMPDIR") + "/"
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"

In [None]:
# Only use tiles with all five channels

tile_list = open(image_dir + "tiles.list", "r")

u_images = []
u_weights = []
g_images = []
g_weights = []
r_images = []
r_weights = []
i_images= []
i_weights = []
z_images = []
z_weights = []
cats = []

for tile in tile_list:
    tile = tile[:-1] # Remove new line character
    channels = tile.split(" ")
    if len(channels) == 5: # Order is u,g,r,i,z
        u_images.append(image_dir + channels[0] + ".fits")
        u_weights.append(image_dir + channels[0] + ".weight.fits.fz")
        g_images.append(image_dir + channels[1] + ".fits")
        g_weights.append(image_dir + channels[1] + ".wt.fits")
        r_images.append(image_dir + channels[2] + ".fits")
        r_weights.append(image_dir + channels[2] + ".weight.fits.fz")
        i_images.append(image_dir + channels[3] + ".fits")
        i_weights.append(image_dir + channels[3] + ".wt.fits")
        z_images.append(image_dir + channels[4] + ".fits")
        z_weights.append(image_dir + channels[4] + ".wt.fits")
        cats.append(image_dir + channels[2] + ".cat")
tile_list.close()

In [None]:
print(len(u_images))
print(len(u_weights))
print(len(g_images))
print(len(g_weights))
print(len(r_images))
print(len(r_weights))
print(len(i_images))
print(len(i_weights))
print(len(z_images))
print(len(z_weights))
print(len(cats))

In [None]:
# Copy first 5 tiles to $SLURM_TMPDIR
n_tiles = 5
for n in range(n_tiles):
    shutil.copy2(u_images[n], dest)
    shutil.copy2(u_weights[n], dest)
    shutil.copy2(g_images[n], dest)
    shutil.copy2(g_weights[n], dest)
    shutil.copy2(r_images[n], dest)
    shutil.copy2(r_weights[n], dest)
    shutil.copy2(i_images[n], dest)
    shutil.copy2(i_weights[n], dest)
    shutil.copy2(z_images[n], dest)
    shutil.copy2(z_weights[n], dest)
    shutil.copy2(cats[n], dest)
    
    u_images[n] = os.path.abspath(dest + os.path.basename(u_images[n]))
    u_weights[n] = os.path.abspath(dest + os.path.basename(u_weights[n]))
    g_images[n] = os.path.abspath(dest + os.path.basename(g_images[n]))
    g_weights[n] = os.path.abspath(dest + os.path.basename(g_weights[n]))
    r_images[n] = os.path.abspath(dest + os.path.basename(r_images[n]))
    r_weights[n] = os.path.abspath(dest + os.path.basename(r_weights[n]))
    i_images[n] = os.path.abspath(dest + os.path.basename(i_images[n]))
    i_weights[n] = os.path.abspath(dest + os.path.basename(i_weights[n]))
    z_images[n] = os.path.abspath(dest + os.path.basename(z_images[n]))
    z_weights[n] = os.path.abspath(dest + os.path.basename(z_weights[n]))    
    cats[n] = os.path.abspath(dest + os.path.basename(cats[n])) 

In [None]:
batch_size = 128
cutout_size = 64
weights = np.zeros((batch_size, cutout_size, cutout_size, 5))

In [None]:
def generate_cutouts(tile_indices, batch_size, cutout_size):
    b = 0 # counter for batch
    sources = np.zeros((batch_size, cutout_size, cutout_size, 5))
    while True:
        for i in tile_indices:
            u_image = fits.open(u_images[i], memmap=True)
            u_weight = fits.open(u_weights[i], memmap=True)
            g_image = fits.open(g_images[i], memmap=True)
            g_weight = fits.open(g_weights[i], memmap=True)
            r_image = fits.open(r_images[i], memmap=True)
            r_weight = fits.open(r_weights[i], memmap=True)
            i_image = fits.open(i_images[i], memmap=True)
            i_weight = fits.open(i_weights[i], memmap=True)
            z_image = fits.open(z_images[i], memmap=True)
            z_weight = fits.open(z_weights[i], memmap=True)
            
            cat = table.Table.read(cats[i], format="ascii.sextractor")
            for j in range(len(cat)):
                if cat["FLAGS"][j] != 0 or cat["MAG_AUTO"][j] >= 99.0 or cat["MAGERR_AUTO"][j] <= 0 or cat["MAGERR_AUTO"][j] >= 1:
                    continue
                x = cat["X_IMAGE"][j]
                y = cat["Y_IMAGE"][j]
                g_cutout = Cutout2D(g_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                if np.count_nonzero(np.isnan(g_cutout)) > 0.05*cutout_size**2:
                    continue
                i_cutout = Cutout2D(i_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                if np.count_nonzero(np.isnan(i_cutout)) > 0.05*cutout_size**2:
                    continue
                z_cutout = Cutout2D(z_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                if np.count_nonzero(np.isnan(z_cutout)) > 0.05*cutout_size**2:
                    continue

                u_cutout = Cutout2D(u_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                r_cutout = Cutout2D(r_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                u_weight_cutout = Cutout2D(u_weight[1].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                g_weight_cutout = Cutout2D(g_weight[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                r_weight_cutout = Cutout2D(r_weight[1].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                i_weight_cutout = Cutout2D(i_weight[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                z_weight_cutout = Cutout2D(z_weight[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                
                g_cutout[np.isnan(g_cutout)] = 0
                i_cutout[np.isnan(i_cutout)] = 0
                z_cutout[np.isnan(z_cutout)] = 0
                g_weight_cutout[np.isnan(g_weight_cutout)] = 0
                i_weight_cutout[np.isnan(i_weight_cutout)] = 0
                z_weight_cutout[np.isnan(z_weight_cutout)] = 0

                u_lower = np.percentile(u_cutout, 1)
                u_upper = np.percentile(u_cutout, 99)
                g_lower = np.percentile(g_cutout, 1)
                g_upper = np.percentile(g_cutout, 99)
                r_lower = np.percentile(r_cutout, 1)
                r_upper = np.percentile(r_cutout, 99)
                i_lower = np.percentile(i_cutout, 1)
                i_upper = np.percentile(i_cutout, 99)
                z_lower = np.percentile(z_cutout, 1)
                z_upper = np.percentile(z_cutout, 99)
                
                if u_upper == u_lower: # Avoid division by 0
                    sources[b,:,:,0] = np.zeros((cutout_size, cutout_size))
                    weights[b,:,:,0] = np.zeros((cutout_size, cutout_size))
                else:
                    sources[b,:,:,0] = (u_cutout - np.min(u_cutout)) / (u_upper - u_lower)
                    weights[b,:,:,0] = (u_weight_cutout - np.min(u_weight_cutout)) / (u_upper - u_lower)
                if g_upper == g_lower:
                    sources[b,:,:,1] = np.zeros((cutout_size, cutout_size))
                    weights[b,:,:,1] = np.zeros((cutout_size, cutout_size))
                else:
                    sources[b,:,:,1] = (g_cutout - np.min(g_cutout)) / (g_upper - g_lower)
                    weights[b,:,:,1] = (g_weight_cutout - np.min(g_weight_cutout)) / (g_upper - g_lower)
                if r_upper == r_lower:
                    sources[b,:,:,2] = np.zeros((cutout_size, cutout_size))
                    weights[b,:,:,2] = np.zeros((cutout_size, cutout_size))
                else:
                    sources[b,:,:,2] = (r_cutout - np.min(r_cutout)) / (r_upper - r_lower)
                    weights[b,:,:,2] = (r_weight_cutout - np.min(r_weight_cutout)) / (r_upper - r_lower)
                if i_upper == i_lower:
                    sources[b,:,:,3] = np.zeros((cutout_size, cutout_size))
                    weights[b,:,:,3] = np.zeros((cutout_size, cutout_size))
                else:
                    sources[b,:,:,3] = (i_cutout - np.min(i_cutout)) / (i_upper - i_lower)
                    weights[b,:,:,3] = (i_weight_cutout - np.min(i_weight_cutout)) / (i_upper - i_lower)
                if z_upper == z_lower:
                    sources[b,:,:,4] = np.zeros((cutout_size, cutout_size))
                    weights[b,:,:,4] = np.zeros((cutout_size, cutout_size))
                else:
                    sources[b,:,:,4] = (z_cutout - np.min(z_cutout)) / (z_upper - z_lower)
                    weights[b,:,:,4] = (z_weight_cutout - np.min(z_weight_cutout)) / (z_upper - z_lower)
                    
                b += 1
                if b == batch_size:
                    b = 0
                    u_image.close()
                    u_weight.close()
                    g_image.close()
                    g_weight.close()
                    r_image.close()
                    r_weight.close()
                    i_image.close()
                    i_weight.close()
                    z_image.close()
                    z_weight.close()
                    yield (sources, sources)
                    u_image = fits.open(u_images[i], memmap=True)
                    u_weight = fits.open(u_weights[i], memmap=True)
                    g_image = fits.open(g_images[i], memmap=True)
                    g_weight = fits.open(g_weights[i], memmap=True)
                    r_image = fits.open(r_images[i], memmap=True)
                    r_weight = fits.open(r_weights[i], memmap=True)
                    i_image = fits.open(i_images[i], memmap=True)
                    i_weight = fits.open(i_weights[i], memmap=True)
                    z_image = fits.open(z_images[i], memmap=True)
                    z_weight = fits.open(z_weights[i], memmap=True)
                    
            u_image.close()
            u_weight.close()
            g_image.close()
            g_weight.close()
            r_image.close()
            r_weight.close()
            i_image.close()
            i_weight.close()
            z_image.close()
            z_weight.close() 

In [None]:
def train_autoencoder(model, train_indices, val_indices, batch_size, cutout_size):
    n_cutouts_train = 0
    for i in train_indices:
        cat = table.Table.read(cats[i], format="ascii.sextractor")
        n_cutouts_train += len(cat)
    n_cutouts_val = 0
    for i in val_indices:
        cat = table.Table.read(cats[i], format="ascii.sextractor")
        n_cutouts_val += len(cat)
    train_steps = n_cutouts_train // batch_size
    val_steps = n_cutouts_val // batch_size
    history = model.fit(generate_cutouts(train_indices, batch_size, cutout_size), 
                        epochs=5, steps_per_epoch=train_steps, 
                        validation_data=generate_cutouts(val_indices, batch_size, cutout_size), 
                        validation_steps=val_steps)
    return model, history

In [None]:
def custom_loss(y_true, y_pred):
    return keras.losses.MSE(y_true*np.sqrt(weights), y_pred*np.sqrt(weights))

In [None]:
train_indices = [1, 2]
val_indices = [0]

In [None]:
def create_autoencoder(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(8, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    encoded = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(x)

    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(8, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(shape[2], (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
autoencoder = create_autoencoder((cutout_size, cutout_size, 5))
autoencoder.compile(optimizer='adam', loss=custom_loss)

In [None]:
autoencoder.summary()

In [None]:
(autoencoder, history) = train_autoencoder(autoencoder, train_indices, val_indices, batch_size, cutout_size)

In [None]:
def plot_loss_curves(history):
    plt.plot(history.history["loss"], color="g", label="Training")
    plt.plot(history.history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

In [None]:
plot_loss_curves(history)

In [None]:
def generate_test_cutouts(index, n_cutouts, cutout_size, start=0):
    n = 0
    sources = np.zeros((n_cutouts, cutout_size, cutout_size, 5))
    u_image = fits.open(u_images[index], memmap=True)
    g_image = fits.open(g_images[index], memmap=True)
    r_image = fits.open(r_images[index], memmap=True)
    i_image = fits.open(i_images[index], memmap=True)
    z_image = fits.open(z_images[index], memmap=True)
            
    cat = table.Table.read(cats[index], format="ascii.sextractor")
    for j in range(len(cat)):
        if cat["FLAGS"][j] != 0 or cat["MAG_AUTO"][j] >= 99.0 or cat["MAGERR_AUTO"][j] <= 0 or cat["MAGERR_AUTO"][j] >= 1:
            continue
        x = cat["X_IMAGE"][j]
        y = cat["Y_IMAGE"][j]
        g_cutout = Cutout2D(g_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
        if len(np.isnan(g_cutout)) > 0.05*cutout_size**2:
            continue
        i_cutout = Cutout2D(i_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
        if len(np.isnan(i_cutout)) > 0.05*cutout_size**2:
            continue
        z_cutout = Cutout2D(z_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
        if len(np.isnan(z_cutout)) > 0.05*cutout_size**2:
            continue

        u_cutout = Cutout2D(u_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
        r_cutout = Cutout2D(r_image[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                
        g_cutout[np.isnan(g_cutout)] = 0
        i_cutout[np.isnan(i_cutout)] = 0
        z_cutout[np.isnan(z_cutout)] = 0

        u_lower = np.percentile(u_cutout, 1)
        u_upper = np.percentile(u_cutout, 99)
        g_lower = np.percentile(g_cutout, 1)
        g_upper = np.percentile(g_cutout, 99)
        r_lower = np.percentile(r_cutout, 1)
        r_upper = np.percentile(r_cutout, 99)
        i_lower = np.percentile(i_cutout, 1)
        i_upper = np.percentile(i_cutout, 99)
        z_lower = np.percentile(z_cutout, 1)
        z_upper = np.percentile(z_cutout, 99)
                
        if u_upper == u_lower: # Avoid division by 0
            sources[n,:,:,0] = np.zeros((cutout_size, cutout_size))
        else:
            sources[n,:,:,0] = (u_cutout - np.min(u_cutout)) / (u_upper - u_lower)
        if g_upper == g_lower:
            sources[n,:,:,1] = np.zeros((cutout_size, cutout_size))
        else:
            sources[n,:,:,1] = (g_cutout - np.min(g_cutout)) / (g_upper - g_lower)
        if r_upper == r_lower:
            sources[n,:,:,2] = np.zeros((cutout_size, cutout_size))
        else:
            sources[n,:,:,2] = (r_cutout - np.min(r_cutout)) / (r_upper - r_lower)
        if i_upper == i_lower:
            sources[n,:,:,3] = np.zeros((cutout_size, cutout_size))
        else:
            sources[n,:,:,3] = (i_cutout - np.min(i_cutout)) / (i_upper - i_lower)
        if z_upper == z_lower:
            sources[n,:,:,4] = np.zeros((cutout_size, cutout_size))
        else:
            sources[n,:,:,4] = (z_cutout - np.min(z_cutout)) / (z_upper - z_lower)

        n += 1
        if n == n_cutouts:
            u_image.close()
            g_image.close()
            r_image.close()
            i_image.close()
            z_image.close()
            return sources

In [None]:
test_index = 4
sources_test = generate_test_cutouts(test_index, 50, cutout_size)

In [None]:
decoded_imgs = autoencoder.predict(sources_test)

In [None]:
def plot_images(images, figname, start=0):
    fig, axes = plt.subplots(1,5, figsize=(14,8))
    channels = ["CFIS u", "PS1 g", "CFIS r", "PS1 i", "PS1 z"]
    for col in range(5):
        norm = ImageNormalize(images[start,:,:,col], interval=ZScaleInterval())
        axes[col].imshow(images[start,:,:,col], norm=norm)
        axes[col].set_title(channels[col])
    plt.savefig("../Plots/" + figname)

In [None]:
plot_images(sources_test, figname="Test cutouts182.271.png", start=40)

In [None]:
plot_images(decoded_imgs, figname="Reconstructed cutouts182.271.png", start=40)

In [None]:
#model.save("../Models")