#### Import neccessary libraries and set paths

In [None]:
import tensorflow as tf

config_tf = tf.ConfigProto()
config_tf.gpu_options.allow_growth = True
sess = tf.Session(config=config_tf)

import json
import os
import sys
from importlib import reload
from pathlib import Path

import imageio
import matplotlib.animation as animation
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from astropy.visualization import AsinhStretch, LogStretch, make_lupton_rgb
from astropy.visualization.mpl_normalize import ImageNormalize
from keras.models import model_from_json
from keras.utils import multi_gpu_model
from pygifsicle import optimize
from sklearn.ensemble import RandomForestRegressor

In [None]:
params = {
    "legend.fontsize": "x-large",
    "axes.labelsize": "x-large",
    "axes.titlesize": "x-large",
    "xtick.labelsize": "x-large",
    "ytick.labelsize": "x-large",
    "figure.facecolor": "w",
    "xtick.top": True,
    "ytick.right": True,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "font.family": "serif",
    "mathtext.fontset": "dejavuserif",
}
plt.rcParams.update(params)

In [None]:
# Path where your software library is saved
# Clone the latest version of morphCaps branch from github
path_photoz = "/home/bid13/code/photozCapsNet"

sys.path.insert(1, path_photoz)
path_photoz = Path(path_photoz)

#### Import custom modules

In [None]:
from encapzulate.data_loader.data_loader import load_data
from encapzulate.utils import metrics
from encapzulate.utils.fileio import load_config, load_model
from encapzulate.utils.metrics import Metrics, bins_to_redshifts, probs_to_redshifts
from encapzulate.utils.utils import import_model

reload(metrics)

#### Specify the results to be explored

In [None]:
# Parameters for the exploration
run_name = "paper1_regression_80perc_0"
checkpoint_eval = 100

In [None]:
# Create and set different paths
# path_output = "/data/bid13/photoZ/results"
path_output = "/home/bid13/code/photozCapsNet/results"
path_output = Path(path_output)
path_results = path_output / run_name.split("_")[0] / run_name / "results"
path_config = path_results / "config.yml"

#### Load Config, Model and Data

In [None]:
config = load_config(path_config)
scale = config["image_scale"]

In [None]:
log = pd.read_csv(path_results / "logs" / "log.csv")

In [None]:
max_acc = log[log.val_decoder_model_loss == log.val_decoder_model_loss.min()]
max_acc

In [None]:
# with tf.device('/cpu:0'):
model = load_model(
    path_results / "eval_model.json",
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5",
)
# model = multi_gpu_model(model,gpus=2)
model.summary()

In [None]:
(
    (x_train, y_train, vals_train, z_spec_train, cat_train),
    (x_dev, y_dev, vals_dev, z_spec_dev, cat_dev),
    (x_test, y_test, vals_test, z_spec_test, cat_test),
) = load_data(load_cat=True, **config)

#### Run Predictions

In [None]:
y_caps_test, y_caps_all_test, y_prob_test, x_recon_test, z_phot_test = model.predict(
    x_test, batch_size=1024
)

In [None]:
del x_train
# del x_test
del x_dev
# del x_recon_test
# del x_recon_dev

In [None]:
# def logistic_trans(x, xmin=0, xmax=0.4):
#     return np.log((x - xmin) / (xmax - x))


# def logistic_trans_inv(x, xmin=0, xmax=0.4):
#     return (np.exp(x) * xmax + xmin) / (np.exp(x) + 1)

In [None]:
# z_spec_test = logistic_trans_inv(z_spec_test)
# z_phot_test = np.squeeze(logistic_trans_inv(z_phot_test))

data = np.load("z_pred.npz")
test_id = data["test_id"]
z_spec_test = data["z_spec"]
z_phot_test = data["z_phot"]

In [None]:
import umap

embedder = umap.UMAP(
    random_state=42, n_components=2, n_neighbors=30, #set_op_mix_ratio=1,
    densmap=True, dens_lambda=1
)
embedding = embedder.fit_transform(y_caps_test)

In [None]:
def compute_nn_redshift_z_loss(embedding, redshift, k=50):
    from scipy.spatial import cKDTree
    tree = cKDTree(embedding)
    dd, ii =tree.query(embedding,k=k, n_jobs=-1)
    dd = dd[:,1:]
    ii = ii[:,1:]
#     centroid = np.median(redshift[ii], axis=-1)
    centroid = np.sum(redshift[ii]/dd, axis=-1)/np.sum(1/dd, axis=-1)
    
    return np.mean((redshift-centroid)**2)

In [None]:
# h_grid = [0.01, 0.05, 0.1, 0.5, 1, 2, 5]
# loss =[]
# for h in h_grid:
#     embedding = umap.UMAP(
#         random_state=42,
#         n_components=2,
#         n_neighbors=30,
#         set_op_mix_ratio=1,
#         densmap=True,
#         dens_lambda=h,
#     ).fit_transform(y_caps_test)
#     l = compute_nn_redshift_z_loss(embedding, z_spec_test)
#     print(f"val:{h}    loss:{l}")
#     loss.append(l)


In [None]:
fig, ax = plt.subplots(2, 2, figsize=(27, 20), sharex=True, sharey=True)

# # Define new cmap viridis_white
# cm_init = sns.color_palette("flare", 256, as_cmap=True)
# newcolors = cm_init(np.linspace(0, 1, 256))[50:]
# cm = colors.ListedColormap(newcolors, name="flare_short")

# cm = sns.color_palette("flare", as_cmap=True)
import colorcet as cc

cm = cc.cm.rainbow

sp = ax[0, 0].scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=z_phot_test,
    cmap=cm,
    vmin=0,
    vmax=0.3,
    marker=".",
    rasterized=True,
)
cbar = fig.colorbar(
    sp, ax=ax[0, 0], boundaries=np.linspace(0, 0.4, 200), ticks=np.linspace(0, 0.4, 9)
)
cbar.ax.tick_params(labelsize=20)
cbar.set_label(r"$z_{\mathrm{phot}}$", fontsize=50)
ax[0, 0].tick_params(axis="both", which="major", labelsize=25)
ax[0, 0].tick_params(axis="both", which="minor", labelsize=25)

sp = ax[0, 1].scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=z_spec_test,
    cmap=cm,
    vmin=0,
    vmax=0.3,
    marker=".",
    #     norm = colors.PowerNorm(0.8)
    rasterized=True,
)
ax[0, 1].tick_params(axis="both", which="major", labelsize=25)
ax[0, 1].tick_params(axis="both", which="minor", labelsize=25)


cbar = fig.colorbar(
    sp, ax=ax[0, 1], boundaries=np.linspace(0, 0.4, 200), ticks=np.linspace(0, 0.4, 9)
)
cbar.ax.tick_params(labelsize=20)
cbar.set_label(r"$z_{\mathrm{spec}}$", fontsize=50)

from scipy.spatial import cKDTree

morpho = np.argmax(y_test, axis =-1).astype("bool")
tree = cKDTree(embedding)
dd, ii =tree.query(embedding,k=80, n_jobs=-1)
spir_frac = np.mean(morpho[ii], axis=-1)

cm = plt.cm.get_cmap("RdYlBu")
sp = ax[1,0].scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=spir_frac,
    marker=".",
    cmap=cm,
    rasterized=True,
)
ax[1, 0].tick_params(axis="both", which="major", labelsize=25)
ax[1, 0].tick_params(axis="both", which="minor", labelsize=25)

cbar = fig.colorbar(
    sp,
    ax=ax[1,0], 
)
cbar.ax.tick_params(labelsize=20)
cbar.set_label("Neighbourhood Spiral Fraction", fontsize=40)


cm = cc.cm.rainbow
err = np.abs(z_spec_test - z_phot_test) / (1 + z_spec_test)
sp = ax[1, 1].scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=err,
    cmap=cm,
#     vmin=0,
    vmax=0.03,
    marker=".",
    rasterized=True,
    norm=colors.PowerNorm(0.75)
)

cbar = fig.colorbar(
    sp,
    ax=ax[1, 1], #boundaries=np.linspace(0, 0.05, 200), ticks=np.linspace(0, 0.05, 9)
)
cbar.ax.tick_params(labelsize=20)
cbar.set_label(r"$\mid \frac{\Delta z}{1+z_{\mathrm{spec}}}\mid$", fontsize=50)


mask = err > 0.05
ax[1, 1].scatter(
    embedding[:, 0][mask],
    embedding[:, 1][mask],
    facecolor="k",
    edgecolor="white",
    marker="o",
    label="Outliers",
    rasterized=True,
    s=150,
    
)
ax[1, 1].legend(
    loc="upper left",
    prop={"size": 25},
    markerscale=1,
    frameon=False,
    handletextpad=0.00001,
)
ax[1, 1].tick_params(axis="both", which="major", labelsize=25)
ax[1, 1].tick_params(axis="both", which="minor", labelsize=25)





# sp = ax[1, 1].scatter(
#     embedding[:, 0][morpho],
#     embedding[:, 1][morpho],
#     c="C0",
#     marker=".",
#     label="Spirals",
#     alpha=0.4,
#     rasterized=True,
# )


# sp = ax[1, 1].scatter(
#     embedding[:, 0][~morpho],
#     embedding[:, 1][~morpho],
#     c="C1",
#     marker=".",
#     label="Ellipticals",
#     alpha=0.4,
#     rasterized=True,
# )
# import matplotlib.lines as mlines

# blue_dot = mlines.Line2D(
#     [], [], color="C0", marker="o", alpha=0.8, label="Spirals", ls=""
# )
# orange_dot = mlines.Line2D(
#     [], [], color="C1", marker="o", alpha=0.8, label="Ellipticals", ls=""
# )
# ax[1, 1].legend(
#     loc="upper left",
#     handles=[blue_dot, orange_dot],
#     ncol=1,
#     prop={"size": 25},
#     frameon=False,
#     handletextpad=0.00001,
#     markerscale=3,
# )


# # ax[1, 1].legend(loc="upper left", markerscale=5, prop={"size": 25})
# ax[1, 1].tick_params(axis="both", which="major", labelsize=25)
# ax[1, 1].tick_params(axis="both", which="minor", labelsize=25)


plt.tight_layout()
# # stupid hack to resize the one remaining axis, declare cax or use Gridspec in future
# box00 = ax[0, 0].get_position()
# box01 = ax[0, 1].get_position()
# box10 = ax[1, 0].get_position()
# from matplotlib.transforms import Bbox

# box11 = Bbox([[box01.x0, box10.y0], [box01.x1, box10.y1]])
# ax[1, 1].set_position(box11)

fig.text(0.37, -0.03, r"UMAP Dimension-1", fontsize=40)
fig.text(-0.03, 0.37, r"UMAP Dimension-2", rotation=90, fontsize=40)

fig.savefig("./figs/UMAP_projection_dense_low_res.pdf", dpi=100, bbox_inches="tight")

In [None]:
err = (z_spec_test - z_phot_test) / (1 + z_spec_test)

# err = err[morpho==0]

In [None]:
sigma_mad = 1.4826*np.median(np.abs(err-np.median(err)))
print(sigma_mad)

print(sigma_mad/np.sqrt(2*len(err)))

In [None]:
f_outlier = np.sum((np.abs(err)>0.05))*100/len(err)
print(f_outlier)
f_outlier=f_outlier/100
print(np.sqrt(len(err)*f_outlier*(1-f_outlier))*100/len(err))

In [None]:
out = (np.abs(err)>0.05)

In [None]:
np.sum(out&(morpho))/np.sum(out)