In [None]:
import os
import h5py
import shutil
import pandas as pd
from astropy.nddata.utils import Cutout2D
from astropy.io import fits
from astropy import table
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
from astropy.visualization import make_lupton_rgb
from collections import Counter
from tensorflow import keras

# Initialize hdf5 file

In [None]:
os.path.expandvars("$SLURM_TMPDIR")

In [None]:
hf = h5py.File("labelled_cutouts_alt.h5", "r+")
hf.close()

In [None]:
src = os.path.expandvars("$SCRATCH") + "/labelled_cutouts_alt.h5"
dest = os.path.expandvars("$SLURM_TMPDIR") + "/"
shutil.copy2(src, dest)

In [None]:
dest = os.path.expandvars("$SLURM_TMPDIR") + "/"
hf = h5py.File(dest + "labelled_cutouts_alt.h5", "r+")

In [None]:
hf.close()

In [None]:
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"
label_dir = "labels/"

In [None]:
tile_file = open(image_dir + "tiles_cand.list", "r")
tile_list = tile_file.readlines()
for i in range(len(tile_list)):
    tile_list[i] = tile_list[i][:-1] # Remove new line characters
    print(tile_list[i])
tile_file.close()

In [None]:
label_subdirs = ["stronglensdb_confirmed_unige/", "stronglensdb_candidates_unige/", "canameras2020/",
                 "huang2020a_grade_A/", "huang2020a_grade_B/", "huang2020a_grade_C/", 
                 "huang2020b_grade_A/", "huang2020b_grade_B/", "huang2020b_grade_C/"]
filters = ["CFIS u/", "PS1 g/", "CFIS r/", "PS1 i/", "PS1 z/"]
filter_dict = {k:v for v,k in enumerate(filters)}

In [None]:
for label_subdir in label_subdirs:
    for f in filters:
        nlabels = 0
        subdir = label_dir + label_subdir + f
        for t in os.listdir(subdir):
            df = pd.read_csv(subdir + t)
            nlabels += len(df)
        print("Number of labels in {}: {}".format(subdir, nlabels))

In [None]:
!ls labels/stronglensdb_confirmed_unige/CFIS\ u

In [None]:
!ls labels/stronglensdb_confirmed_unige/CFIS\ r

# Save labelled cutouts to hdf5 file

In [None]:
cutout_size = 128
for label_subdir in label_subdirs[2:]:
    for f in filters:
        subdir = label_dir + label_subdir + f
        for csv in os.listdir(subdir):
            tile_id = csv[:7] # XXX.XXX id
            img_group_name = label_subdir + tile_id + "/" + f + "/IMAGES"
            wt_group_name = label_subdir + tile_id + "/" + f + "/WEIGHTS"
            if img_group_name not in hf:
                img_group = hf.create_group(img_group_name)
                wt_group = hf.create_group(wt_group_name)
            else:
                img_group = hf[img_group_name]
                wt_group = hf[wt_group_name]
            tile_name = f.split(" ")[0] + "." + tile_id + "." + f.split(" ")[1][0]
            if "CFIS" in f:
                wt_name = ".weight.fits.fz"
                wt_index = 1
            else:
                wt_name = ".wt.fits"
                wt_index = 0
            shutil.copy2(image_dir + tile_name + ".fits", dest)
            shutil.copy2(image_dir + tile_name + wt_name, dest)
            img_fits = fits.open(dest + tile_name + ".fits", memmap=True)
            wt_fits = fits.open(dest + tile_name + wt_name, memmap=True)
            
            df = pd.read_csv(subdir + csv)
            nlabels = len(df)
            for n in range(nlabels):
                x = df["x"][n]
                y = df["y"][n]
                img_cutout = Cutout2D(img_fits[0].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                wt_cutout = Cutout2D(wt_fits[wt_index].data, (x, y), cutout_size, mode="partial", fill_value=0).data
                
                img_lower = np.percentile(img_cutout, 1)
                img_upper = np.percentile(img_cutout, 99)
                if img_lower == img_upper:
                    img_norm = np.zeros((cutout_size, cutout_size))
                else:
                    img_norm = (img_cutout - np.min(img_cutout)) / (img_upper - img_lower)

                if (tile_id + str(n)) not in img_group:
                    img_group.create_dataset(tile_id + str(n), data=img_norm)
                    wt_group.create_dataset(tile_id + str(n), data=wt_cutout)

            img_fits.close()
            wt_fits.close()
    print(f"Finished {label_subdir}")

# Look at confirmed cutouts from random CFIS tile

In [None]:
subdir = label_dir + label_subdirs[0]
u_tiles = os.listdir(subdir + filters[0])
r_tiles = os.listdir(subdir + filters[1])
tile_intersection = sorted(list(set(u_tiles) & set(r_tiles)))
print(tile_intersection)
print(len(tile_intersection))

In [None]:
tile = tile_intersection[26][:7]
u_file = fits.open(image_dir + "CFIS.{}.u.fits".format(tile))
r_file = fits.open(image_dir + "CFIS.{}.r.fits".format(tile))
r_head = r_file[0].header
u_head = u_file[0].header
r_data = r_file[0].data
u_data = u_file[0].data

In [None]:
tile

In [None]:
channels = ["u", "r"]
plot_tiles = [u_data, r_data]
fig, axes = plt.subplots(1,len(channels), figsize=(12,8))
for i in range(len(channels)):
    norm = ImageNormalize(plot_tiles[i], interval=ZScaleInterval())
    axes[i].imshow(plot_tiles[i], norm=norm)
    axes[i].set_title(channels[i])
plt.savefig("Plots/{}_tiles.png".format(tile))

In [None]:
r_wcs = WCS(r_head)
u_wcs = WCS(u_head)
r_df = pd.read_csv(subdir + "CFIS r/{}_labels.csv".format(tile))
u_df = pd.read_csv(subdir + "CFIS u/{}_labels.csv".format(tile))

In [None]:
x,y = skycoord_to_pixel(SkyCoord(r_df["ra"], r_df["dec"], unit="deg"), r_wcs)
print((x,y))

In [None]:
r_cutout = Cutout2D(r_data, (x, y), cutout_size, mode="partial", fill_value=0).data
u_cutout = Cutout2D(u_data, (x, y), cutout_size, mode="partial", fill_value=0).data

In [None]:
channels = ["u", "r"]
plot_cutouts = [u_cutout, r_cutout]
fig, axes = plt.subplots(1,len(channels), figsize=(12,8))
for i in range(len(channels)):
    norm = ImageNormalize(plot_cutouts[i], interval=ZScaleInterval())
    axes[i].imshow(plot_cutouts[i], norm=norm)
    axes[i].set_title(channels[i])
plt.savefig("Plots/{}_labelled_cutouts.png".format(tile))

# Plot cutouts from each resource

In [None]:
def get_cutouts(label_subdir, reconstruct=False):
    cutout_size = 128
    n_cutouts = 0
    for k in list(hf.get(label_subdir).keys()):
        f = list(hf.get(label_subdir + k).keys())[0]
        img_subgroup = hf.get(label_subdir + k + "/" + f + "/IMAGES")
        n_cutouts += len(img_subgroup)
       
    plot_cutouts = np.zeros((n_cutouts, cutout_size, cutout_size, 3), dtype=int)
    if "confirmed" in label_subdir:
        types = []
    n_tiles = len(list(hf.get(label_subdir).keys()))
    n_plots = 0
    tile_ids = list(hf.get(label_subdir).keys())
    for n in range(n_tiles):
        tile_id = tile_ids[n]
        f = list(hf.get(label_subdir + tile_id).keys())[0]
        df = pd.read_csv(label_dir + label_subdir + f + "/" + tile_id + "_labels.csv")
        img_subgroup = hf.get(label_subdir + tile_id + "/" + f + "/IMAGES")
        n_labels = len(img_subgroup)
        for i in range(n_labels):
            if "confirmed" in label_subdir:
                types.append(df["type"][i])
            cutout = np.zeros((cutout_size, cutout_size, 5))
            dataset_name = tile_id + str(i)
            filts = [f + "/" for f in list(hf.get(label_subdir + tile_id).keys())]
            filt_indices = [filter_dict.get(f) for f in filts]
            for (j, ind) in enumerate(filt_indices):
                cutout[:,:,ind] = hf.get(label_subdir + tile_id + "/" + filts[j] + "IMAGES/" + dataset_name)
            if reconstruct:
                cutout = autoencoder.predict(np.expand_dims(cutout, axis=0))[0]
            if len(filt_indices) == 1:
                if 2 in filt_indices: # red
                    cutout_r = cutout[:,:,2]
                    cutout_g = np.zeros((cutout_size, cutout_size))
                    cutout_b = np.zeros((cutout_size, cutout_size))
                elif 1 in filt_indices: # green
                    cutout_r = np.zeros((cutout_size, cutout_size))
                    cutout_g = cutout[:,:,1]
                    cutout_b = np.zeros((cutout_size, cutout_size))
                else:
                    cutout_r = np.zeros((cutout_size, cutout_size))
                    cutout_g = np.zeros((cutout_size, cutout_size))
                    cutout_b = cutout[:,:,filt_indices[0]]
                cutout_rgb = make_lupton_rgb(cutout_r, cutout_g, cutout_b, Q=10, stretch=3)
            elif len(filt_indices) == 2:
                if 2 in filt_indices: # red
                    cutout_r = cutout[:,:,2]
                else:
                    cutout_r = cutout[:,:,filt_indices[0]]
                if 1 in filt_indices: # green
                    cutout_g = cutout[:,:,1]
                else:
                    cutout_g = cutout[:,:,filt_indices[1]]
                cutout_b = np.zeros((cutout_size, cutout_size))
                cutout_rgb = make_lupton_rgb(cutout_r, cutout_g, cutout_b, Q=10, stretch=3)
            else:
                if 2 in filt_indices:
                    cutout_r = cutout[:,:,2]
                else:
                    cutout_r = cutout[:,:,filt_indices[0]]
                if 1 in filt_indices:
                    cutout_g = cutout[:,:,1]
                else:
                    cutout_g = cutout[:,:,filt_indices[1]]
                if 3 in filt_indices: # i band
                    cutout_b = cutout[:,:,3]
                else:
                    cutout_b = cutout[:,:,filt_indices[-1]]
                cutout_rgb = make_lupton_rgb(cutout_r, cutout_g, cutout_b, Q=10, stretch=3)
            plot_cutouts[n_plots,:,:,:] = cutout_rgb
            n_plots += 1
    if "confirmed" in label_subdir:
        return (plot_cutouts, types)
    else:
        return (plot_cutouts, _)

In [None]:
confirmed_cutouts, confirmed_types = get_cutouts(label_subdirs[0])

In [None]:
def custom_loss_all(y_true, y_pred):
    return keras.losses.MSE(y_true*np.sqrt(weights_all), y_pred*np.sqrt(weights_all))

In [None]:
autoencoder = keras.models.load_model("Models/autoencoder_128p",
                                 custom_objects={'custom_loss_all': custom_loss_all})

In [None]:
confirmed_reconstructed, _ = get_cutouts(label_subdirs[0], reconstruct=True)

In [None]:
Counter(confirmed_types)

In [None]:
def plot_cutouts(cutouts, figname, ncols=8, types=None):
    n_cutouts = len(cutouts)
    nrows = int(np.ceil(n_cutouts/ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(3.1*ncols,3.4*nrows))
    n_plots = 0
    for n1 in range(nrows):
        for n2 in range(ncols):
            cutout_rgb = cutouts[n_plots]
            norm = ImageNormalize(cutout_rgb, interval=ZScaleInterval())
            axes[n1][n2].imshow(cutout_rgb, norm=norm)
            if types is not None:
                axes[n1][n2].set_title(types[n_plots])
            n_plots += 1     
            if n_plots == n_cutouts:
                break
        if n_plots == n_cutouts:
                break

    # delete empty axes
    for i in range(n_cutouts, nrows*ncols):
        fig.delaxes(axes.flatten()[i])
    plt.savefig(f"Plots/{figname}") 

In [None]:
plot_cutouts(confirmed_cutouts, "confirmed_cutouts_rgb_alt.png", types=confirmed_types)

In [None]:
plot_cutouts(confirmed_reconstructed, "confirmed_reconstructed_rgb.png", types=confirmed_types)

In [None]:
plot_indices = [14, 18, 23, 37, 80, 134, 137, 154, 184, 236]
confirmed_cutouts_report = [confirmed_cutouts[i] for i in plot_indices]
confirmed_types_report = [confirmed_types[i] for i in plot_indices]
plot_cutouts(confirmed_cutouts_report, "confirmed_cutouts_rgb_report.png", ncols=5, types=confirmed_types_report)

In [None]:
candidate_cutouts, _ = get_cutouts(label_subdirs[1])

In [None]:
plot_cutouts(candidate_cutouts, "candidate_cutouts_rgb.png")

In [None]:
plot_indices = [20, 43, 63, 64, 88, 109, 211, 215, 368, 383]
candidate_cutouts_report = [candidate_cutouts[i] for i in plot_indices]
plot_cutouts(candidate_cutouts_report, "candidate_cutouts_rgb_report.png", ncols=5)

In [None]:
canameras_cutouts, _ = get_cutouts(label_subdirs[2])

In [None]:
plot_cutouts(canameras_cutouts, "canameras_cutouts_rgb.png")

In [None]:
for label_subdir in label_subdirs[3:]:
    figname = label_subdir[:-1] + ".png"
    huang_cutouts, _ = get_cutouts(label_subdir)
    plot_cutouts(huang_cutouts, figname)

In [None]:
def count_cutouts(label_subdir):
    n_cutouts = 0
    for k in list(hf.get(label_subdir).keys()):
        f = list(hf.get(label_subdir + k).keys())[0]
        img_subgroup = hf.get(label_subdir + k + "/" + f + "/IMAGES")
        n_cutouts += len(img_subgroup)
    print(label_subdir + ":" + str(n_cutouts))

In [None]:
for label_subdir in label_subdirs[:]:
    count_cutouts(label_subdir)

In [None]:
hf.close()
#src = os.path.expandvars("$SLURM_TMPDIR") + "/labelled_cutouts_alt.h5"
#dest = os.path.expandvars("$SCRATCH") + "/"
#shutil.copy2(src, dest)