In [None]:
from astropy.nddata.utils import Cutout2D
from astropy.io import fits
from astropy import table
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
import tensorflow as tf
from tensorflow import keras

In [None]:
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"

In [None]:
tile_list = open(image_dir + "tiles.list", "r")
for tile in tile_list.readlines():
    print(tile[:-1])
tile_list.close()

In [None]:
# Only use tiles with all five channels

tile_list = open(image_dir + "tiles.list", "r")

u_images = []
u_weights = []
g_images = []
g_weights = []
r_images = []
r_weights = []
i_images= []
i_weights = []
z_images = []
z_weights = []
cats = []

for tile in tile_list:
    tile = tile[:-1] # Remove new line character
    channels = tile.split(" ")
    if len(channels) == 5: # Order is u,g,r,i,z
        u_images.append(image_dir + channels[0] + ".fits")
        u_weights.append(image_dir + channels[0] + ".weight.fits.fz")
        g_images.append(image_dir + channels[1] + ".fits")
        g_weights.append(image_dir + channels[1] + ".wt.fits")
        r_images.append(image_dir + channels[2] + ".fits")
        r_weights.append(image_dir + channels[2] + ".weight.fits.fz")
        i_images.append(image_dir + channels[3] + ".fits")
        i_weights.append(image_dir + channels[3] + ".wt.fits")
        z_images.append(image_dir + channels[4] + ".fits")
        z_weights.append(image_dir + channels[4] + ".wt.fits")
        cats.append(image_dir + channels[2] + ".cat")
tile_list.close()

In [None]:
print(len(u_images))
print(len(u_weights))
print(len(g_images))
print(len(g_weights))
print(len(r_images))
print(len(r_weights))
print(len(i_images))
print(len(i_weights))
print(len(z_images))
print(len(z_weights))
print(len(cats))

In [None]:
# Only use 25000 cutouts to reduce tensor size
lim = 25000
size = 32
sources = np.zeros((lim, size, size, 5)) # Order of channels is u,g,r,i,z. Normalized data.
#weights = np.zeros((lim, size, size, 5))
n = 0
for i in range(len(cats)):
    u_image = fits.open(u_images[i], memmap=True)
    #u_weight = fits.open(u_weights[i], memmap=True)
    g_image = fits.open(g_images[i], memmap=True)
    #g_weight = fits.open(g_weights[i], memmap=True)
    r_image = fits.open(r_images[i], memmap=True)
    #r_weight = fits.open(r_weights[i], memmap=True)
    i_image = fits.open(i_images[i], memmap=True)
    #i_weight = fits.open(i_weights[i], memmap=True)
    z_image = fits.open(z_images[i], memmap=True)
    #z_weight = fits.open(z_weights[i], memmap=True)
    
    cat = table.Table.read(cats[i], format="ascii.sextractor")
    for (x, y) in zip(cat["X_IMAGE"], cat["Y_IMAGE"]):
        g_cutout = Cutout2D(g_image[0].data, (x, y), size, mode="partial", fill_value=0).data
        if len(g_cutout[np.isnan(g_cutout)]) > 0.05*len(g_cutout)**2:
            continue
        i_cutout = Cutout2D(i_image[0].data, (x, y), size, mode="partial", fill_value=0).data
        if len(i_cutout[np.isnan(i_cutout)]) > 0.05*len(i_cutout)**2:
            continue
        z_cutout = Cutout2D(z_image[0].data, (x, y), size, mode="partial", fill_value=0).data
        if len(z_cutout[np.isnan(z_cutout)]) > 0.05*len(z_cutout)**2:
            continue
        
        r_cutout = Cutout2D(r_image[0].data, (x, y), size, mode="partial", fill_value=0).data
        #r_weight_cutout = Cutout2D(r_weight[1].data, (x, y), size, mode="partial", fill_value=0).data
        u_cutout = Cutout2D(u_image[0].data, (x, y), size, mode="partial", fill_value=0).data
        #u_weight_cutout = Cutout2D(u_weight[1].data, (x, y), size, mode="partial", fill_value=0).data
        #g_weight_cutout = Cutout2D(g_weight[0].data, (x, y), size, mode="partial", fill_value=0).data
        #i_weight_cutout = Cutout2D(i_weight[0].data, (x, y), size, mode="partial", fill_value=0).data
        #z_weight_cutout = Cutout2D(z_weight[0].data, (x, y), size, mode="partial", fill_value=0).data

        g_cutout[np.isnan(g_cutout)] = 0
        #g_weight_cutout[np.isnan(g_weight_cutout)] = 0
        i_cutout[np.isnan(i_cutout)] = 0
        #i_weight_cutout[np.isnan(i_weight_cutout)] = 0
        z_cutout[np.isnan(z_cutout)] = 0
        #z_weight_cutout[np.isnan(z_weight_cutout)] = 0
        
        #u_lower = np.percentile(u_cutout, 1)
        #u_upper = np.percentile(u_cutout, 99)
        g_lower = np.percentile(g_cutout, 1)
        g_upper = np.percentile(g_cutout, 99)
        i_lower = np.percentile(i_cutout, 1)
        i_upper = np.percentile(i_cutout, 99)
        r_lower = np.percentile(r_cutout, 1)
        r_upper = np.percentile(r_cutout, 99)
        z_lower = np.percentile(z_cutout, 1)
        z_upper = np.percentile(z_cutout, 99)
                
        #sources[n,:,:,0] = (u_cutout - u_lower) / (u_upper - u_lower)
        #weights[n,:,:,0] = (u_weight_cutout - np.min(u_weight_cutout)) / (np.max(u_cutout) - np.min(u_cutout))
        sources[n,:,:,0] = 0
        sources[n,:,:,1] = (g_cutout - g_lower) / (g_upper - g_lower)
        #weights[n,:,:,1] = (g_weight_cutout - np.min(g_weight_cutout)) / (np.max(g_cutout) - np.min(g_cutout))
        sources[n,:,:,2] = (r_cutout - r_lower) / (r_upper - r_lower)
        #weights[n,:,:,2] = (r_weight_cutout - np.min(r_weight_cutout)) / (np.max(r_cutout) - np.min(r_cutout))
        sources[n,:,:,3] = (i_cutout - i_lower) / (i_upper - i_lower)
        #weights[n,:,:,3] = (i_weight_cutout - np.min(i_weight_cutout)) / (np.max(i_cutout) - np.min(i_cutout))
        sources[n,:,:,4] = (z_cutout - z_lower) / (z_upper - z_lower)
        #weights[n,:,:,4] = (z_weight_cutout - np.min(z_weight_cutout)) / (np.max(z_cutout) - np.min(z_cutout))
        
        #sources[np.isnan(sources)] = 0 # Due to division by 0
        #weights[np.isnan(weights)] = 0
        n += 1
        if n == lim:
            break
    u_image.close()
    #u_weight.close()
    g_image.close()
    #g_weight.close()
    r_image.close()
    #r_weight.close()
    i_image.close()
    #i_weight.close()
    z_image.close()
    #z_weight.close()
    if n == lim:
        break

In [None]:
def plot_images(images, start=0):
    fig, axes = plt.subplots(3,5, figsize=(12,8))
    for row in range(3):
        i = 0
        for col in range(5):
            norm = ImageNormalize(images[start+row,:,:,i])
            axes[row][col].imshow(images[start+row,:,:,i], norm=norm)
            i += 1

In [None]:
plot_images(sources)

In [None]:
threshold = int(0.8*len(sources))
sources_train = sources[:threshold]
sources_test = sources[threshold:]
#weights_train = weights[:threshold]
#weights_test = weights[threshold:]

In [None]:
def create_autoencoder1(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dense(128)(x)
    encoded = keras.layers.Dense(256)(x)
    
    x = keras.layers.Reshape((8,8,4))(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(shape[2], (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
autoencoder1 = create_autoencoder1((sources.shape[1], sources.shape[2], sources.shape[3]))
#opt = keras.optimizers.Adam(learning_rate=0.005)
autoencoder1.compile(optimizer="adam", loss="mse")

In [None]:
autoencoder1.summary()

In [None]:
history1 = autoencoder1.fit(sources_train, sources_train, batch_size=128, epochs=100, validation_split=0.2)

In [None]:
def plot_loss_curves(history):
    plt.plot(history.history["loss"], color="g", label="Training")
    plt.plot(history.history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

In [None]:
plot_loss_curves(history1)

In [None]:
decoded_imgs1 = autoencoder1.predict(sources_test[:100])

In [None]:
plot_images(sources_test)

In [None]:
plot_images(decoded_imgs1)

In [None]:
def create_autoencoder2(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(8, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    encoded = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(x)

    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(8, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(shape[2], (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
autoencoder2 = create_autoencoder2((sources.shape[1], sources.shape[2], sources.shape[3]))
autoencoder2.compile(optimizer='adam', loss="mse")

In [None]:
autoencoder2.summary()

In [None]:
history2 = autoencoder2.fit(sources_train, sources_train, batch_size=128, epochs=100, validation_split=0.2)

In [None]:
plot_loss_curves(history2)

In [None]:
decoded_imgs2 = autoencoder2.predict(sources_test[:100])

In [None]:
plot_images(sources_test)

In [None]:
plot_images(decoded_imgs2)

In [None]:
class AutoEnc(keras.Model):
    def __init__(self, input_shape, noise=False):
        super().__init__()
        
        self.batch_size = 64
        self.input_shape = input_shape
        self.num_out = 1
        self.num_z = 128
        self.checkpoint = 1
        
        self.enc_lr = 1e-6
        self.dec_lr = 1e-6
        enc_optimizer = keras.optimizers.Adam(lr=self.enc_lr)            
        dec_optimizer = keras.optimizers.Adam(lr=self.dec_lr)
        
        self.inp = keras.layers.Input(input_shape, name='ae_input')
        self.Enc = keras.models.Model(self.inp, \
            self.encoder(self.inp)[0], name='encoder')
        self.Dec = keras.models.Model(self.inp, \
            self.decoder(tf.concat(self.encoder(self.inp),axis=1)), name='decoder')
        
        self.Enc.compile(loss=keras.losses.MSE,\
            optimizer=enc_optimizer)
        self.Dec.compile(loss=keras.losses.MSE,\
            optimizer=dec_optimizer)

    def encoder(self,y):
        base_model = keras.applications.ResNet50(include_top=False, weights=None,\
            input_shape=self.input_shape)
        base_model.trainable = True
        model_out = base_model(y, training=True)
        model_out = keras.layers.GlobalAveragePooling2D()(model_out)
        
        x = keras.layers.Dense(512,activation=keras.layers.LeakyReLU(alpha=0.1))(model_out)
        x = keras.layers.Dense(256,activation=keras.layers.LeakyReLU(alpha=0.1))(x)

        #can probably do this in one layer
        x_out = keras.layers.Dense(self.num_out)(x)
        z_out = keras.layers.Dense(self.num_z)(x)
        
        return x_out,z_out 

    def decoder(self,z):
        #TODO this decoder was made in a rush and will be changed in future
        #These layers assume a shape (32x32x5)

        y = keras.layers.Dense(self.num_z + self.num_out)(z)
        y = keras.layers.Dense(256,activation=keras.layers.LeakyReLU(alpha=0.1))(y)
        y = keras.layers.Dense(512,activation=keras.layers.LeakyReLU(alpha=0.1))(y)
        y = keras.layers.Dense(8192)(y)
        y = keras.layers.Reshape([2,2,2048])(y)

        y = keras.layers.Conv2DTranspose(512,3)(y)
        y = keras.layers.Conv2DTranspose(128,5)(y)
        y = keras.layers.Conv2DTranspose(64,9)(y)
        y = keras.layers.Conv2DTranspose(4,17)(y)

        return y


    def train(self, x_train, y_train, x_test, y_test, epochs):
        
        N = int(len(x_train)/self.batch_size)

        reglosses = np.zeros(self.checkpoint)
        reconlosses = np.zeros(self.checkpoint)

        it = 0               
        while it < epochs:    
            x_train, y_train = shuffle(x_train,y_train)
            
            for j in range(N):
                x_true = x_train[j*self.batch_size: (j+1)*self.batch_size]
                y_true = y_train[j*self.batch_size: (j+1)*self.batch_size]
       
                #Im working on another version where both are trained simulateously
                enc_loss = self.Enc.train_on_batch(x_true, y_true)
                dec_loss = self.Dec.train_on_batch(x_true, x_true)
            
                reglosses[it%self.checkpoint] = enc_loss[0]
                reconlosses[it%self.checkpoint] = dec_loss

            if it % self.checkpoint == 0:
                ##TODO add correct saving and validation tests to checkpoints
                self.curr_epoch += self.checkpoint
                
                print('Iterations %d' % self.curr_epoch)
                print('Regression Loss %f' % np.mean(reglosses))
                print('Reconstruction Loss %f' % np.mean(reconlosses))

                

            it += 1