In [None]:
import os
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
import tensorflow as tf
from tensorflow import keras

In [None]:
cutout_dir = os.path.expandvars("$SLURM_TMPDIR") + "/"
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"

In [None]:
os.listdir(cutout_dir)

In [None]:
hf = h5py.File(cutout_dir + "cutouts_filtered.h5", "r")

In [None]:
n_cutouts = 0
for i in range(10):
    n_cutouts += len(hf.get(tile_ids[i] + "/IMAGES"))
print(n_cutouts)

In [None]:
len(hf.get(tile_ids[12] + "/IMAGES"))

# Look at sample cutouts

In [None]:
group = hf.get("180.272/IMAGES")
plot_cutouts = np.array(group.get("c0"))
print(plot_cutouts.shape)

In [None]:
group_weight = hf.get("180.272/WEIGHTS")
plot_weights = np.array(group_weight.get("c0"))
print(plot_weights.shape)

In [None]:
channels = ["u", "g", "r", "i", "z"]
fig, axes = plt.subplots(1,5, figsize=(12,8))
for i in range(5):
    norm = ImageNormalize(plot_cutouts[:,:,i], interval=ZScaleInterval())
    axes[i].imshow(plot_cutouts[:,:,i], norm=norm)
    axes[i].set_title(channels[i])

In [None]:
channels = ["u", "g", "r", "i", "z"]
fig, axes = plt.subplots(1,5, figsize=(12,8))
for i in range(5):
    norm = ImageNormalize(plot_weights[:,:,i], interval=ZScaleInterval())
    axes[i].imshow(plot_weights[:,:,i], norm=norm)
    axes[i].set_title(channels[i])

In [None]:
plot_cutouts[:,:,1]

# Get tile ids

In [None]:
tile_list = open(image_dir + "tiles.list", "r")
for tile in tile_list.readlines():
    print(tile[:-1])
tile_list.close()

In [None]:
# Only use tiles with all five channels

tile_list = open(image_dir + "tiles.list", "r")
tile_ids = []

for tile in tile_list:
    tile = tile[:-1] # Remove new line character
    channels = tile.split(" ")
    if len(channels) == 5: # Order is u,g,r,i,z
        tile_ids.append(channels[0][5:12]) # XXX.XXX id
tile_list.close()

In [None]:
print(len(tile_ids))
print(tile_ids[0])

# Prepare Training

In [None]:
batch_size = 128
cutout_size = 64
n_epochs = 12
weights_cfis = np.zeros((batch_size, cutout_size, cutout_size, 2))
weights_ps1 = np.zeros((batch_size, cutout_size, cutout_size, 3))
weights_all = np.zeros((batch_size, cutout_size, cutout_size, 5))

In [None]:
def get_cutouts(tile_indices, batch_size, cutout_size, bands="all"):
    b = 0 # counter for batch
    if bands == "all":
        sources = np.zeros((batch_size, cutout_size, cutout_size, 5))
        band_indices = [0, 1, 2, 3, 4]
        weights = weights_all
    elif bands == "cfis":
        sources = np.zeros((batch_size, cutout_size, cutout_size, 2))
        band_indices = [0, 2]
        weights = weights_cfis
    else: # PS1
        sources = np.zeros((batch_size, cutout_size, cutout_size, 3))
        band_indices = [1, 3, 4]
        weights = weights_ps1
    while True:
        for i in tile_indices:
            img_group = hf.get(tile_ids[i] + "/IMAGES")
            wt_group = hf.get(tile_ids[i] + "/WEIGHTS")
            n_cutouts = len(img_group)
            for n in range(n_cutouts):
                sources[b,:,:,:] = np.array(img_group.get(f"c{n}"))[:,:,band_indices]
                weights[b,:,:,:] = np.array(wt_group.get(f"c{n}"))[:,:,band_indices]
                b += 1
                if b == batch_size:
                    b = 0
                    yield (sources, sources)

In [None]:
def train_autoencoder(model, train_indices, val_indices, n_epochs, batch_size, cutout_size, bands="all"):
    n_cutouts_train = 0
    for i in train_indices:
        img_group = hf.get(tile_ids[i] + "/IMAGES")        
        n_cutouts_train += len(img_group)
    n_cutouts_val = 0
    for i in val_indices:
        img_group = hf.get(tile_ids[i] + "/IMAGES")        
        n_cutouts_val += len(img_group)
    train_steps = n_cutouts_train // batch_size
    val_steps = n_cutouts_val // batch_size
    history = model.fit(get_cutouts(train_indices, batch_size, cutout_size, bands), 
                        epochs=n_epochs, steps_per_epoch=train_steps, 
                        validation_data=get_cutouts(val_indices, batch_size, cutout_size, bands), 
                        validation_steps=val_steps)
    return model, history

## Autoencoder with pooling layers

In [None]:
def create_autoencoder1(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dense(128)(x)
    encoded = keras.layers.Dense(1024)(x)
    
    x = keras.layers.Reshape((16,16,4))(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(shape[2], (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
def custom_loss_cfis(y_true, y_pred):
    return keras.losses.MSE(y_true*np.sqrt(weights_cfis), y_pred*np.sqrt(weights_cfis))

In [None]:
def custom_loss_ps1(y_true, y_pred):
    return keras.losses.MSE(y_true*np.sqrt(weights_ps1), y_pred*np.sqrt(weights_ps1))

In [None]:
def custom_loss_all(y_true, y_pred):
    return keras.losses.MSE(y_true*np.sqrt(weights_all), y_pred*np.sqrt(weights_all))

In [None]:
train_indices = range(10)
val_indices = [12]

In [None]:
autoencoder1 = create_autoencoder1((cutout_size, cutout_size, 5))
#opt = keras.optimizers.Adam(learning_rate=0.005)
autoencoder1.compile(optimizer="adam", loss=custom_loss)

In [None]:
autoencoder1.summary()

In [None]:
(autoencoder1, history1) = train_autoencoder(autoencoder1, train_indices, val_indices, batch_size, cutout_size)

In [None]:
def plot_loss_curves(history, figname):
    plt.plot(history["loss"], color="g", label="Training")
    plt.plot(history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig("../Loss Curves/" + figname)

In [None]:
plot_loss_curves(history1)

In [None]:
def get_test_cutouts(index, n_cutouts, cutout_size, bands="all", start=0):
    n = 0
    if bands == "all":
        sources = np.zeros((n_cutouts, cutout_size, cutout_size, 5))
        weights = np.zeros((n_cutouts, cutout_size, cutout_size, 5))
        band_indices = [0, 1, 2, 3, 4]
    elif bands == "cfis":
        sources = np.zeros((n_cutouts, cutout_size, cutout_size, 2))
        weights = np.zeros((n_cutouts, cutout_size, cutout_size, 2))
        band_indices = [0, 2]
    else: # PS1
        sources = np.zeros((n_cutouts, cutout_size, cutout_size, 3))
        weights = np.zeros((n_cutouts, cutout_size, cutout_size, 3))
        band_indices = [1, 3, 4]
    img_group = hf.get(tile_ids[index] + "/IMAGES")
    wt_group = hf.get(tile_ids[index] + "/WEIGHTS")
    for i in range(start, len(img_group)):
        sources[n,:,:,:] = np.array(img_group.get(f"c{i}"))[:,:,band_indices]
        weights[n,:,:,:] = np.array(wt_group.get(f"c{i}"))[:,:,band_indices]
        n += 1
        if n == n_cutouts:
            return (sources, weights)

In [None]:
test_index = 13
#sources_test_cfis = get_test_cutouts(test_index, 50, cutout_size, "cfis")
#sources_test_ps1 = get_test_cutouts(test_index, 50, cutout_size, "ps1")
(sources_test_all, weights_test_all) = get_test_cutouts(test_index, 50, cutout_size)

In [None]:
sources_test_all.shape

In [None]:
decoded_imgs1 = autoencoder1.predict(sources_test)

In [None]:
def plot_images(images, figname, bands, start=0):
    fig, axes = plt.subplots(images.shape[0],len(bands), figsize=(14,8))
    fig.subplots_adjust(left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.45)
    for row in range(images.shape[0]):
        for col in range(len(bands)):
            norm = ImageNormalize(images[row, start,:,:,col], interval=ZScaleInterval())
            im = axes[row][col].imshow(images[row, start,:,:,col], norm=norm)
            fig.colorbar(im, fraction=0.045, ax=axes[row][col])
            if row == 0:
                axes[row][col].set_title(bands[col])
    plt.savefig("../Plots/" + figname)

In [None]:
plot_images(sources_test, start=5)

In [None]:
plot_images(decoded_imgs1, start=5)

## Fully convolutional autoencoder

In [None]:
def create_autoencoder2(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.BatchNormalization()(x)

    y = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(input_img)
    y = keras.layers.BatchNormalization()(y)
    encoded = keras.layers.Add()([x,y])
    
    x = keras.layers.Conv2DTranspose(32, kernel_size=4, activation='relu', padding='same')(encoded)
    x = keras.layers.Conv2DTranspose(16, kernel_size=4, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(shape[2], kernel_size=3, activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

# Train on all filters

In [None]:
autoencoder_all = create_autoencoder2((cutout_size, cutout_size, 5))
autoencoder_all.compile(optimizer='adam', loss=custom_loss_all)

In [None]:
autoencoder_all.summary()

In [None]:
autoencoder_all = keras.models.load_model("../Models/autoencoder_64p", 
                                          custom_objects={'custom_loss_all': custom_loss_all})

In [None]:
keras.utils.plot_model(autoencoder_all, to_file='../Models/autoencoder_64p_plot.png', show_shapes=True, show_layer_names=True)

In [None]:
(autoencoder_all, history_all) = train_autoencoder(autoencoder_all, train_indices, val_indices, n_epochs,
                                                   batch_size, cutout_size)
autoencoder_all.save("../Models/autoencoder_64p")
hist_df = pd.DataFrame(history_all.history) 

hist_csv_file = '../Histories/history_64p.csv'
with open(hist_csv_file, mode='a') as f:
    hist_df.to_csv(f)

In [None]:
history_all = pd.read_csv(hist_csv_file)
plot_loss_curves(history_all, figname="TenTrainingTiles64p.png")

In [None]:
decoded_imgs_all = autoencoder_all.predict(sources_test_all)
residuals_all = sources_test_all - decoded_imgs_all

In [None]:
imgs_all = np.zeros((3, *sources_test_all.shape))
imgs_all[0] = sources_test_all
imgs_all[1] = decoded_imgs_all
imgs_all[2] = residuals_all
imgs_all.shape

In [None]:
tile_ids[test_index]

In [None]:
bands=["CFIS u", "PS1 g", "CFIS r", "PS1 i", "PS1 z"]

In [None]:
plot_images(imgs_all, "Cutouts 185.270 c12 64p colorbar.png", bands=bands, start=15)

In [None]:
fig, axes = plt.subplots(1,5, figsize=(14,8))
fig.subplots_adjust(left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.6)
start = 14
for i in range(5):
    norm = ImageNormalize(weights_test_all[start,:,:,i], interval=ZScaleInterval())
    im = axes[i].imshow(weights_test_all[start,:,:,i], norm=norm)
    fig.colorbar(im, fraction=0.045, ax=axes[i])
    axes[i].set_title(bands[i])
plt.savefig(f"../Plots/Weights 185.270 c{start} 64p.png")

In [None]:
weights_test_all[6,:,:,0]

In [None]:
plot_images(decoded_imgs_all, "Reconstructed Cutouts 185.270 128p.png", start=5)

In [None]:
plot_images(residuals_all, "Residual Cutouts 185.270 128p.png", start=5)

In [None]:
def min_max_pixels(images, bands, start=0):
    for i in range(len(bands)):
        print(bands[i])
        print("Min pixel value: " + str(np.min(images[start,:,:,i])))
        print("Max pixel value: " + str(np.max(images[start,:,:,i])))
        print()

In [None]:
min_max_pixels(sources_test_all, bands=bands, start=5)

In [None]:
min_max_pixels(decoded_imgs_all, bands=bands, start=5)

In [None]:
min_max_pixels(residuals_all, bands=bands, start=5)

In [None]:
def plot_hist(images, wts, figname, bands, start=0):
    fig, axes = plt.subplots(images.shape[0],len(bands), figsize=(20,8))
    for row in range(images.shape[0]):
        for col in range(len(bands)):
            mean = np.mean(images[row,start,:,:,col])
            std = np.std(images[row,start,:,:,col])
            if row == 2:
                x = images[0,start,:,:,col]
                xr = images[1,start,:,:,col]
                axes[row][col].hist((np.sqrt(wts[start,:,:,col])*(x-xr)).ravel())
            else:
                axes[row][col].hist(images[row,start,:,:,col].ravel())
            axes[row][col].set_ylim(top=4000)
            #xlim = axes[row][col].get_xlim()[1]
            #ylim = axes[row][col].get_ylim()[1]
            #axes[row][col].annotate(r"$\mu={:.4f}$".format(mean), (0.7*xlim, 0.7*ylim))
            #axes[row][col].annotate(r"$\sigma={:.4f}$".format(std), (0.7*xlim, 0.6*ylim))
            if row == 0:
                axes[row][col].set_title(bands[col])
    plt.savefig("../Histograms/" + figname)

In [None]:
plot_hist(imgs_all, weights_test_all, "Cutouts 185.270 c46 64p weight_res.png", bands=bands, start=46)

In [None]:
plot_hist(decoded_imgs_all, "Reconstructed Cutouts 185.270 128p.png", start=5)

In [None]:
plot_hist(residuals_all, "Residual Cutouts 185.270 128p.png", start=5)

# Train on CFIS filters

In [None]:
autoencoder_cfis = create_autoencoder2((cutout_size, cutout_size, 2))
autoencoder_cfis.compile(optimizer="adam", loss=custom_loss_cfis)

In [None]:
autoencoder_cfis.summary()

In [None]:
(autoencoder_cfis, history_cfis) = train_autoencoder(autoencoder_cfis, train_indices, 
                                                     val_indices, batch_size, cutout_size, bands="cfis")

In [None]:
plot_loss_curves(history_cfis, figname="TenTrainingTilesCFIS.png")

In [None]:
decoded_imgs_cfis = autoencoder_cfis.predict(sources_test_cfis)
residuals_cfis = sources_test_cfis - decoded_imgs_cfis

In [None]:
plot_images(sources_test_cfis, "Test Cutouts 185.270 CFIS.png", bands=["CFIS u", "CFIS r"], start=2)

In [None]:
plot_images(decoded_imgs_cfis, "Reconstructed Cutouts 185.270 CFIS.png", bands=["CFIS u", "CFIS r"], start=2)

In [None]:
plot_images(residuals_cfis, "Residual Cutouts 185.270 CFIS.png", bands=["CFIS u", "CFIS r"], start=2)

In [None]:
min_max_pixels(sources_test_cfis, bands=["CFIS u", "CFIS r"], start=2)

In [None]:
min_max_pixels(decoded_imgs_cfis, bands=["CFIS u", "CFIS r"], start=2)

In [None]:
min_max_pixels(residuals_cfis, bands=["CFIS u", "CFIS r"], start=2)

In [None]:
plot_hist(sources_test_cfis, "Test Cutouts 185.270 CFIS.png", bands=["CFIS u", "CFIS r"], start=2)

In [None]:
plot_hist(decoded_imgs_cfis, "Reconstructed Cutouts 185.270 CFIS.png", bands=["CFIS u", "PS1 r"], start=2)

In [None]:
plot_hist(residuals_cfis, "Residual Cutouts 185.270 CFIS.png", bands=["CFIS u", "CFIS r"], start=2)

# Train on PS1 filters

In [None]:
autoencoder_ps1 = create_autoencoder2((cutout_size, cutout_size, 3))
autoencoder_ps1.compile(optimizer="adam", loss=custom_loss_ps1)

In [None]:
autoencoder_ps1.summary()

In [None]:
(autoencoder_ps1, history_ps1) = train_autoencoder(autoencoder_ps1, train_indices, 
                                                     val_indices, batch_size, cutout_size, bands="ps1")

In [None]:
plot_loss_curves(history_ps1, figname="TenTrainingTilesPS1.png")

In [None]:
decoded_imgs_ps1 = autoencoder_ps1.predict(sources_test_ps1)
residuals_ps1 = sources_test_ps1 - decoded_imgs_ps1

In [None]:
plot_images(sources_test_ps1, "Test Cutouts 185.270 PS1.png", bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
plot_images(decoded_imgs_ps1, "Reconstructed Cutouts 185.270 PS1.png", bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
plot_images(residuals_ps1, "Residual Cutouts 185.270 PS1.png", bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
min_max_pixels(sources_test_ps1, bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
min_max_pixels(decoded_imgs_ps1, bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
min_max_pixels(residuals_ps1, bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
plot_hist(sources_test_ps1, "Test Cutouts 185.270 PS1.png", bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
plot_hist(decoded_imgs_ps1, "Reconstructed Cutouts 185.270 PS1.png", bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
plot_hist(residuals_ps1, "Residual Cutouts 185.270 PS1.png", bands=["PS1 g", "PS1 i", "PS1 z"], start=5)

In [None]:
hf.close()