In [None]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
import tensorflow as tf
from tensorflow import keras

In [None]:
cutout_dir = os.path.expandvars("$SLURM_TMPDIR") + "/"
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"

In [None]:
hf = h5py.File(cutout_dir + "cutouts_filtered.h5", "r")

In [None]:
tile_list = open(image_dir + "tiles.list", "r")
for tile in tile_list.readlines():
    print(tile[:-1])
tile_list.close()

In [None]:
# Only use tiles with all five channels

tile_list = open(image_dir + "tiles.list", "r")
tile_ids = []

for tile in tile_list:
    tile = tile[:-1] # Remove new line character
    channels = tile.split(" ")
    if len(channels) == 5: # Order is u,g,r,i,z
        tile_ids.append(channels[0][5:12]) # XXX.XXX id
tile_list.close()

In [None]:
print(len(tile_ids))
print(tile_ids[0])

In [None]:
batch_size = 128
cutout_size = 64

In [None]:
def get_cutouts(tile_indices, batch_size, cutout_size):
    b = 0 # counter for batch
    sources = np.zeros((batch_size, cutout_size, cutout_size, 10))
    while True:
        for i in tile_indices:
            img_group = hf.get(tile_ids[i] + "/IMAGES")
            wt_group = hf.get(tile_ids[i] + "/WEIGHTS")
            n_cutouts = len(img_group)
            for n in range(n_cutouts):
                sources[b,:,:,:5] = np.array(img_group.get(f"c{n}"))
                sources[b,:,:,5:10] = np.array(wt_group.get(f"c{n}"))
                b += 1
                if b == batch_size:
                    b = 0
                    yield (sources, sources)

In [None]:
def train_autoencoder(model, train_indices, val_indices, batch_size, cutout_size):
    n_cutouts_train = 0
    for i in train_indices:
        img_group = hf.get(tile_ids[i] + "/IMAGES")        
        n_cutouts_train += len(img_group)
    n_cutouts_val = 0
    for i in val_indices:
        img_group = hf.get(tile_ids[i] + "/IMAGES")        
        n_cutouts_val += len(img_group)
    train_steps = n_cutouts_train // batch_size
    val_steps = n_cutouts_val // batch_size
    history = model.fit(get_cutouts(train_indices, batch_size, cutout_size), 
                        epochs=15, steps_per_epoch=train_steps, 
                        validation_data=get_cutouts(val_indices, batch_size, cutout_size), 
                        validation_steps=val_steps)
    return model, history

In [None]:
def create_autoencoder(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(8, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    encoded = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(x)

    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(8, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(shape[2], (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
def custom_loss(y_true, y_pred):
    weights = y_true[:,:,5:10]
    return keras.losses.MSE(tf.math.multiply(y_true[:,:,:5],tf.math.sqrt(weights)), 
                            tf.math.multiply(y_pred[:,:,:5],tf.math.sqrt(weights)))

In [None]:
train_indices = range(5)
val_indices = [5]

In [None]:
autoencoder = create_autoencoder((cutout_size, cutout_size, 10))
autoencoder.compile(optimizer='adam', loss=custom_loss)

In [None]:
autoencoder.summary()

In [None]:
(autoencoder, history) = train_autoencoder(autoencoder, train_indices, val_indices, batch_size, cutout_size)

In [None]:
def plot_loss_curves(history):
    plt.plot(history.history["loss"], color="g", label="Training")
    plt.plot(history.history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

In [None]:
plot_loss_curves(history)

In [None]:
def get_test_cutouts(index, n_cutouts, cutout_size, start=0):
    n = 0
    sources = np.zeros((n_cutouts, cutout_size, cutout_size, 5))
    img_group = hf.get(tile_ids[index] + "/IMAGES")
    wt_group = hf.get(tile_ids[index] + "/WEIGHTS")
    for i in range(start, len(img_group)):
        sources[n,:,:,:5] = np.array(img_group.get(f"c{i}"))
        sources[n,:,:,5:10] = np.array(wt_group.get(f"c{i}"))
        n += 1
        if n == n_cutouts:
            return sources

In [None]:
test_index = 6
sources_test = get_test_cutouts(test_index, 50, cutout_size)

In [None]:
decoded_imgs = autoencoder.predict(sources_test)

In [None]:
def plot_images(images, figname, start=0):
    fig, axes = plt.subplots(1,5, figsize=(14,8))
    channels = ["CFIS u", "PS1 g", "CFIS r", "PS1 i", "PS1 z"]
    for col in range(5):
        norm = ImageNormalize(images[start,:,:,col], interval=ZScaleInterval())
        axes[col].imshow(images[start,:,:,col], norm=norm)
        axes[col].set_title(channels[col])

In [None]:
plot_images(sources_test, start=5)

In [None]:
plot_images(decoded_imgs, start=5)