In [None]:
import os
import shutil
import h5py
import numpy as np
from astropy.nddata.utils import Cutout2D
from astropy.io import fits
from astropy.table import Table
import pandas as pd
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
import tensorflow as tf
from tensorflow import keras
from sklearn.utils import shuffle
import umap
from sklearn.preprocessing import StandardScaler

# Create hdf5 files for lenses and non-lenses

In [None]:
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"
label_dir = "/home/eyvorch9/scratch/labels/"
label_subdir = "stronglensdb_confirmed_unige/"

In [None]:
# If the hdf5 files already exist
cutout_dir = os.path.expandvars("$SLURM_TMPDIR") + "/"
hf_pos = h5py.File(cutout_dir + "labelled_cutouts_alt.h5", "r+")
hf_neg = h5py.File(cutout_dir + "random_cutouts_cfis.h5", "r+")

In [None]:
# If the hdf5 files do not exist
src = os.path.expandvars("$SCRATCH") + "/"
hf_neg = h5py.File(src + "random_cutouts_cfis.h5", "w")
hf_neg.close()

In [None]:
src = os.path.expandvars("$SCRATCH") + "/random_cutouts_cfis.h5"
dest = os.path.expandvars("$SLURM_TMPDIR") + "/"
shutil.copy2(src, dest)

In [None]:
hf_neg = h5py.File(dest + "random_cutouts_cfis.h5", "r+")

In [None]:
tile_list = open(image_dir + "tiles.list", "r")
tile_files = tile_list.readlines()
for i in range(len(tile_files)):
    tile_files[i] = tile_files[i][:-1]
    print(tile_files[i])
tile_list.close()

In [None]:
tile_id = "157.275"
shutil.copy2(image_dir + f"CFIS.{tile_id}.u.fits", cutout_dir)
#shutil.copy2(image_dir + f"PS1.{tile_id}.g.fits", cutout_dir)
shutil.copy2(image_dir + f"CFIS.{tile_id}.r.fits", cutout_dir)
#shutil.copy2(image_dir + f"PS1.{tile_id}.i.fits", cutout_dir)
#shutil.copy2(image_dir + f"PS1.{tile_id}.z.fits", cutout_dir)
shutil.copy2(image_dir + f"CFIS.{tile_id}.r.cat", cutout_dir)
u_image = cutout_dir + f"CFIS.{tile_id}.u.fits"
#g_image = cutout_dir + f"PS1.{tile_id}.g.fits"
r_image = cutout_dir + f"CFIS.{tile_id}.r.fits"
#i_image = cutout_dir + f"PS1.{tile_id}.i.fits"
#z_image = cutout_dir + f"PS1.{tile_id}.z.fits"

In [None]:
filters = ["CFIS u/", "PS1 g/", "CFIS r/", "PS1 i/", "PS1 z/"]
filter_dict = {k:v for v,k in enumerate(filters)}

In [None]:
def get_confirmed_cutouts():
    n_cutouts = 0
    for k in list(hf_pos.get(label_subdir).keys()):
        f = list(hf_pos.get(label_subdir + k).keys())[0]
        img_subgroup = hf_pos.get(label_subdir + k + "/" + f + "/IMAGES")
        n_cutouts += len(img_subgroup)
       
    confirmed_cutouts = np.zeros((n_cutouts, cutout_size, cutout_size, 5))
    n_tiles = len(list(hf_pos.get(label_subdir).keys()))
    count = 0
    tile_ids = list(hf_pos.get(label_subdir).keys())
    for n in range(n_tiles):
        tile_id = tile_ids[n]
        f = list(hf_pos.get(label_subdir + tile_id).keys())[0]
        df = pd.read_csv(label_dir + label_subdir + f + "/" + tile_id + "_labels.csv")
        img_subgroup = hf_pos.get(label_subdir + tile_id + "/" + f + "/IMAGES")
        n_labels = len(img_subgroup)
        for i in range(n_labels):
            cutout = np.zeros((cutout_size, cutout_size, 5))
            dataset_name = tile_id + str(i)
            filts = [f + "/" for f in list(hf_pos.get(label_subdir + tile_id).keys())]
            filt_indices = [filter_dict.get(f) for f in filts]
            for (j, ind) in enumerate(filt_indices):
                cutout[:,:,ind] = hf_pos.get(label_subdir + tile_id + "/" + filts[j] + "IMAGES/" + dataset_name)
            confirmed_cutouts[count,:,:,:] = cutout
            count += 1
    return confirmed_cutouts

In [None]:
def create_cutout(fits_file, x, y):
    cutout = Cutout2D(fits_file[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
    if np.count_nonzero(np.isnan(cutout)) >= 0.05*cutout_size**2 or np.count_nonzero(cutout) == 0: # Don't use this cutout
        return None
    cutout[np.isnan(cutout)] = 0
    lower = np.percentile(cutout, 1)
    upper = np.percentile(cutout, 99)
    if lower == upper:
        cutout_norm = np.zeros((cutout_size, cutout_size))
    else:
        cutout_norm = (cutout - np.min(cutout)) / (upper - lower)
    return cutout_norm

In [None]:
def get_negative_cutouts():
    n_negative = 6370
    u_fits = fits.open(u_image, memmap=True)
    #g_fits = fits.open(g_image, memmap=True)
    r_fits = fits.open(r_image, memmap=True)
    #i_fits = fits.open(i_image, memmap=True)
    #z_fits = fits.open(z_image, memmap=True)
    cat = Table.read(dest + f"CFIS.{tile_id}.r.cat", format="ascii.sextractor")
    n = 0
    for i in range(len(cat)):
        cutout = np.zeros((cutout_size, cutout_size, 5))
        if cat["FLAGS"][i] != 0 or cat["MAG_AUTO"][i] >= 99.0 or cat["MAGERR_AUTO"][i] <= 0 or cat["MAGERR_AUTO"][i] >= 1:
            continue
        x = cat["X_IMAGE"][i]
        y = cat["Y_IMAGE"][i]
        
        u_cutout = create_cutout(u_fits, x, y)
        if u_cutout is None:
            continue
        #g_cutout = create_cutout(g_fits, x, y)
        #if g_cutout is None:
        #    continue
        r_cutout = create_cutout(r_fits, x, y)
        if r_cutout is None:
            continue
        #i_cutout = create_cutout(i_fits, x, y)
        #if i_cutout is None:
        #    continue
        #z_cutout = create_cutout(z_fits, x, y)
        #if z_cutout is None:
        #    continue
        cutout[:,:,0] = u_cutout
        #cutout[:,:,1] = g_cutout
        cutout[:,:,2] = r_cutout
        #cutout[:,:,3] = i_cutout
        #cutout[:,:,4] = z_cutout
        hf_neg.create_dataset(f"cutout{n}", data=cutout)
        n += 1
        if n == n_negative:
            u_fits.close()
            #g_fits.close()
            r_fits.close()
            #i_fits.close()
            #z_fits.close()
            return

In [None]:
cutout_size = 128
confirmed_cutouts = get_confirmed_cutouts()

In [None]:
get_negative_cutouts()

In [None]:
hf_neg.close()
src = os.path.expandvars("$SLURM_TMPDIR") + "/random_cutouts_cfis.h5"
dest = os.path.expandvars("$SCRATCH") + "/"
shutil.copy2(src, dest)

In [None]:
hf_neg = h5py.File(cutout_dir + "random_cutouts_cfis.h5", "r+")

# Create and train classifier

In [None]:
def get_cutouts(pos_start, pos_end, neg_start, neg_end, batch_size):
    ratio = (neg_end - neg_start) // (pos_end - pos_start)
    b = 0 # counter for batch
    cutouts = np.zeros((batch_size, cutout_size, cutout_size, 5))
    labels = np.zeros(batch_size)
    pos_index = pos_start
    neg_index = neg_start
    count = 0
    while True:
        if count > 0 and count % ratio == 0:
            cutouts[b,:,:,:] = confirmed_cutouts[pos_index]
            labels[b] = 1
            pos_index += 1
            if pos_index == pos_end:
                pos_index = pos_start
            b += 1
            if b == batch_size:
                b = 0
                new_shape = cutout_size*cutout_size*5
                cutouts_scaled = StandardScaler().fit_transform(cutouts.reshape(batch_size, new_shape))
                yield (cutouts_scaled.reshape(cutouts.shape), labels)
        else:
            cutouts[b,:,:,:] = np.array(hf_neg.get(f"cutout{neg_index}"))
            labels[b] = 0
            neg_index += 1
            if neg_index == neg_end:
                neg_index = neg_start
            b += 1
            if b == batch_size:
                b = 0
                new_shape = cutout_size*cutout_size*5
                cutouts_scaled = StandardScaler().fit_transform(cutouts.reshape(batch_size, new_shape))
                yield (cutouts_scaled.reshape(cutouts.shape), labels)
        count += 1

In [None]:
def train_classifier(model, n_epochs, batch_size):
    num_cutouts_train_neg = int(0.7*len(hf_neg))
    neg_start_train = 0
    neg_end_train = num_cutouts_train_neg
    neg_start_val = num_cutouts_train_neg
    neg_end_val = int(0.9*len(hf_neg))

    num_cutouts_train_pos = int(0.7*len(confirmed_cutouts))
    pos_start_train = 0
    pos_end_train = num_cutouts_train_pos
    pos_start_val = num_cutouts_train_pos
    pos_end_val = int(0.9*len(confirmed_cutouts))

    train_steps = (neg_end_train + pos_end_train) // batch_size
    val_steps = ((neg_end_val - neg_start_val) + (pos_end_val - pos_start_val)) // batch_size
    neg_weight = (num_cutouts_train_neg + num_cutouts_train_pos) / num_cutouts_train_neg
    pos_weight = (num_cutouts_train_neg + num_cutouts_train_pos) / num_cutouts_train_pos
    class_weight = {0: neg_weight, 1: pos_weight}
    history = model.fit(get_cutouts(pos_start_train, pos_end_train, neg_start_train, neg_end_train, batch_size), 
                        epochs=n_epochs, steps_per_epoch=train_steps, 
                        validation_data=get_cutouts(pos_start_val, pos_end_val, neg_start_val, neg_end_val, batch_size), 
                        validation_steps=val_steps, callbacks=[callback], class_weight=class_weight)
    return model, history

In [None]:
len(hf_neg)

In [None]:
def create_classifier(encoder):
    model = keras.Sequential(encoder)
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(128))
    model.add(keras.layers.Dense(64))
    model.add(keras.layers.Dense(1, activation="sigmoid"))   
    return model

In [None]:
def custom_loss_all(y_true, y_pred):
    return keras.losses.MSE(y_true*np.sqrt(weights_all), y_pred*np.sqrt(weights_all))

In [None]:
autoencoder = keras.models.load_model("../Models/autoencoder_128p",
                                 custom_objects={'custom_loss_all': custom_loss_all})
encoder = keras.Model(autoencoder.input, autoencoder.layers[7].output)

In [None]:
for i in range(len(encoder.layers)):
    encoder.layers[i].trainable = False

In [None]:
def scheduler(epoch, lr):
    if epoch < 10:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

In [None]:
cutout_size = 128
classifier = create_classifier(encoder)
optimizer = keras.optimizers.Adam(learning_rate=0.0001)
callback = keras.callbacks.LearningRateScheduler(scheduler)
classifier.compile(optimizer=optimizer, loss="binary_crossentropy", metrics="accuracy")

In [None]:
#classifier = keras.models.load_model("../Models/binary_classifier")

In [None]:
classifier.summary()

In [None]:
keras.utils.plot_model(classifier, to_file="../Models/binary_classifier.png", show_shapes=True, show_layer_names=True)

In [None]:
n_epochs = 100
batch_size = 32
(classifier, history) = train_classifier(classifier, n_epochs, batch_size)
classifier.save("../Models/binary_classifier_alt")
hist_df = pd.DataFrame(history.history) 

hist_csv_file = '../Histories/history_binary_classifier_alt.csv'
with open(hist_csv_file, mode='a') as f:
    hist_df.to_csv(f)

In [None]:
def plot_loss_curves(history, figname):
    plt.plot(history["loss"], color="g", label="Training")
    plt.plot(history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig("../Loss Curves/" + figname)

In [None]:
plot_loss_curves(history.history, figname="binary_classifier_alt.png")

In [None]:
num_negative = int(0.1*len(hf_neg))
test_negative = np.zeros((num_negative, cutout_size, cutout_size, 5))
test_start = int(0.9*len(hf_neg))
test_end = len(hf_neg)
i = 0
for n in range(test_start, test_end):
    test_negative[i] = np.array(hf_neg.get(f"cutout{n}"))
    i += 1

In [None]:
test_positive = confirmed_cutouts[int(0.9*len(confirmed_cutouts)):]

In [None]:
test_cutouts = np.array(list(test_negative) + list(test_positive))
test_labels = np.array(list(np.zeros(len(test_negative), dtype=int)) + list(np.ones(len(test_positive), dtype=int)))
(test_cutouts, test_labels) = shuffle(test_cutouts, test_labels)

In [None]:
def evaluate_model(model, x_test, y_test):
    test_loss, test_acc = model.evaluate(x_test, y_test)
    y_predict = model.predict(x_test)
    plt.hist(y_predict)
    print("Lowest 10 scores:")
    print(sorted(y_predict)[:10])
    print()
    print("Highest 10 scores:")
    print()
    print(sorted(y_predict)[-10:])
    conf = tf.math.confusion_matrix(y_test, y_predict)
    print(f"Confusion Matrix:\n {conf}")
    print("Test loss: %.3f" % test_loss)
    print("Test accuracy: %3f" % test_acc)

In [None]:
evaluate_model(classifier, test_cutouts, test_labels)

In [None]:
hf_pos.close()
hf_neg.close()