In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import os
import glob
import pandas as pd
import random
import numpy as np
import imageio
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from scipy.stats import norm

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import backend
from keras import layers, Input, Model, ops
from tensorflow.keras.utils import register_keras_serializable

import seaborn as sns
import pickle

# Loading the data

We take the pickle file (`.pkl`) containing the images (in this workbook, from the A1 tile) and the `csv` file containing the redshifts.

In [None]:
# if you make pickle files from the data generator
# DATA_PATH = "pickles/A1_compiled_cutouts_3arcsec_30mas.pkl"

# if you download directly from the drive
DATA_PATH = "A1_compiled_cutouts_3arcsec_30mas.pkl"

In [None]:
# Load data
# rows: ID, image info
with open(DATA_PATH, "rb") as f:
    data = pickle.load(f) 

In [None]:
# IDs
ids = data[0]

# images
images = np.stack(data[1]).astype("float32")
images = np.log10(images - np.min(images)+0.007) # log transform to reduce differences in scale
images = images / np.max(images)  # normalize

Match and merge the datasets based on the object's ID.

In [None]:
attrs = pd.read_csv("cosmos_cut.csv", sep=",")

df_images = pd.DataFrame({'id': ids}) # from ID numbers
df_merged = pd.merge(df_images, attrs, on="id", how="inner") # only when IDs match

images = images[df_merged.index]
redshifts = df_merged["z"].to_numpy(dtype="float32")
redshifts = redshifts / np.max(redshifts) # normalize

Augment data to extend AGN and LRD parts of the database so that the dataset is more balanced:

In [None]:
classes = df_merged["classification"].to_numpy()
obj_ids = df_merged["id"].to_numpy()

aug_images = []
aug_ids = []
aug_redshifts = []

for img, obj_id, z, cls in zip(images, obj_ids, redshifts, classes):
    if cls == "Galaxy":
        # keep as-is, no augmentation
        aug_images.append(img)
        aug_ids.append(obj_id)
        aug_redshifts.append(z)

    elif cls == "AGN":
        # rotations -> 4 total (0°, 90°, 180°, 270°)
        variants = [
            img,
            np.rot90(img, 1, axes=(0, 1)),
            np.rot90(img, 2, axes=(0, 1)),
            np.rot90(img, 3, axes=(0, 1)),
        ]
        for v in variants:
            aug_images.append(v)
            aug_ids.append(obj_id)
            aug_redshifts.append(z)

    elif cls == "LRD":
        # rotations + flips (more aggressive augmentation)
        base_rots = [
            img,
            np.rot90(img, 1, axes=(0, 1)),
            np.rot90(img, 2, axes=(0, 1)),
            np.rot90(img, 3, axes=(0, 1)),
        ]
        variants = base_rots + [
            np.fliplr(img),
            np.flipud(img),
            np.fliplr(base_rots[1]),  # flipped 90°
            np.flipud(base_rots[1]),  # flipped 90°
        ]
        for v in variants:
            aug_images.append(v)
            aug_ids.append(obj_id)
            aug_redshifts.append(z)

    else:
        # any other class: just keep original
        aug_images.append(img)
        aug_ids.append(obj_id)
        aug_redshifts.append(z)

# -------------------------
# final augmented datasets
# -------------------------
aug_images = np.stack(aug_images).astype("float32")
aug_ids = np.array(aug_ids)
aug_redshifts = np.array(aug_redshifts, dtype="float32")

print("Original size:", len(images))
print("Augmented size:", len(aug_images))

In [None]:
# expand metadata to include augmented samples
# takes all instances of an ID
df_aug = pd.DataFrame({"id": aug_ids})
df_aug = df_aug.merge(df_merged, on="id", how="left")

# Load the pre-trained autoencoder

In [None]:
# custom sampling layer
@register_keras_serializable(package="Custom") # decorator
class Sampling(layers.Layer):
    def __init__(self, beta=1.0, **kwargs):
        super().__init__(**kwargs)
        self.beta = beta

    def build(self, input_shape):
        self.rng = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        mean, logvar = inputs
        eps = keras.random.normal(
            shape=ops.shape(mean),
            dtype=mean.dtype,
            seed=self.rng
        )
        z = mean + ops.exp(0.5 * logvar) * eps

        kl = -0.5 * ops.sum(
            1 + logvar - ops.square(mean) - ops.exp(logvar),
            axis=-1
        )
        self.add_loss(self.beta * ops.mean(kl))
        return z

In [None]:
autoencoder = keras.models.load_model(
    "Best_autoencoder.keras",
    custom_objects={"Sampling": Sampling},
    safe_mode=False)

# Visualize samples

In [None]:
# create training and validation sets
X_train, X_val, z_train, z_val = train_test_split(aug_images, aug_redshifts, test_size=0.2, random_state=42)

z_train = z_train.reshape(-1, 1)
z_val = z_val.reshape(-1, 1)

In [None]:
# helper functions
def _asinh_stretch(x, scale=None):
    """
    Asinh stretch commonly used for astro images.
    If scale is None, use a robust estimate based on the 90th percentile of |x|.
    """
    x = np.asarray(x)
    if scale is None:
        p = np.nanpercentile(np.abs(x), 90)
        scale = p if p > 0 else np.nanmax(np.abs(x)) + 1e-8
    return np.arcsinh(x / (scale + 1e-12))

def _auto_vmin_vmax(x, pct=(1, 99)):
    """
    Robust display limits from percentiles.

    Parameters
    ----------
    x : array-like
        Image (2D) or any array to compute limits from.
    pct : (low, high)
        Percentiles to use, e.g. (1, 99).

    Returns
    -------
    vmin, vmax : float
        Suggested display limits.
    """
    x = np.asarray(x)
    lo, hi = np.nanpercentile(x, pct)
    if hi <= lo:  # fallback if data are weird/constant
        m, s = np.nanmean(x), np.nanstd(x)
        lo, hi = m - 2 * s, m + 2 * s
    return float(lo), float(hi)

def show_input_vs_output(
    model,
    X,
    z,
    labels,
    idx=0,
    channel_names=None,
    stretch="asinh",      # 'asinh' or 'linear'
    percentiles=(1, 99),  # for linear scaling
    figsize=(12, 6),
    cmap="gray",
    show_residuals=True,
    print_redshift=True):
    """
    Visualize original vs reconstructed image for a single (image, redshift) pair.
    """

    # Prepare the specific sample
    x_true = X[idx]
    z_true = z[idx]
    if z_true.ndim == 0:        # scalar -> (1,)
        z_true = np.array([z_true], dtype=np.float32)
    if z_true.ndim == 1:        # (1,) -> (1,1) when batching
        z_true_b = z_true.reshape(1, 1)
    else:
        z_true_b = z_true.reshape(1, *z_true.shape)

    x_true_b = x_true[np.newaxis, ...]  # (1,H,W,C)

    # Forward pass
    pred = model.predict([x_true_b, z_true_b], verbose=0)
    if isinstance(pred, (list, tuple)):
        x_recon_b = pred[0]
        z_pred_b = pred[1] if len(pred) > 1 else None
    else:
        x_recon_b = pred
        z_pred_b = None

    x_recon = np.squeeze(x_recon_b, axis=0)  # (H,W,C)
    residual = x_recon - x_true

    # --- Redshift scalars (for printing and plotting) ----------------------
    try:
        zt_scalar = float(z_true.ravel()[0])
    except Exception:
        zt_scalar = np.nan

    if z_pred_b is not None:
        zp_scalar = float(np.squeeze(z_pred_b))
    else:
        zp_scalar = None
    # ----------------------------------------------------------------------

    # Print redshifts if desired
    if print_redshift:
        if zp_scalar is not None:
            print(f"z_true = {zt_scalar:.5f} | z_pred = {zp_scalar:.5f}")
        else:
            print(f"z_true = {zt_scalar:.5f} | (model has no z_pred head)")

    H, W, C = x_true.shape
    cols = C
    rows = 3 if show_residuals else 2

    fig, axes = plt.subplots(rows, cols, figsize=figsize, constrained_layout=True)
    if rows == 1:
        axes = np.expand_dims(axes, 0)
    if cols == 1:
        axes = np.expand_dims(axes, 1)

    # <<< NEW: put z_true / z_pred at top of the figure
    if zp_scalar is not None:
        suptitle_str = f" Object type is {labels[idx]} | z_true = {zt_scalar:.5f} | z_pred = {zp_scalar:.5f}"
    else:
        suptitle_str = f"z_true = {zt_scalar:.5f} | (no z_pred)"
    fig.suptitle(suptitle_str, fontsize=14)
    # >>>

    def format_title(base, c):
        if channel_names and c < len(channel_names):
            return f"{base} – {channel_names[c]}"
        return f"{base} – ch{c}"

    for c in range(C):
        # Choose scaling for each panel
        if stretch == "asinh":
            a_true = _asinh_stretch(x_true[..., c])
            a_reco = _asinh_stretch(x_recon[..., c])
            vmin_t, vmax_t = _auto_vmin_vmax(a_true)
            vmin_r, vmax_r = _auto_vmin_vmax(a_reco)
            # Keep separate vmin/vmax for truth vs recon to show structure clearly
            axes[0, c].imshow(a_true, cmap=cmap, vmin=vmin_t, vmax=vmax_t)
            axes[1, c].imshow(a_reco, cmap=cmap, vmin=vmin_r, vmax=vmax_r)
        else:  # linear
            vmin, vmax = _auto_vmin_vmax(x_true[..., c], percentiles)
            axes[0, c].imshow(x_true[..., c], cmap=cmap, vmin=vmin, vmax=vmax)
            axes[1, c].imshow(x_recon[..., c], cmap=cmap, vmin=vmin, vmax=vmax)

        axes[0, c].set_title(format_title("Input", c))
        axes[1, c].set_title(format_title("Reconstruction", c))
        axes[0, c].axis("off")
        axes[1, c].axis("off")

        if show_residuals:
            # Residuals in linear scale centered at 0 with symmetric range
            res = residual[..., c]
            m = np.nanmax(np.abs(res)) + 1e-12
            axes[2, c].imshow(res, cmap=cmap, vmin=-m, vmax=m)
            axes[2, c].set_title(format_title("Residual (recon - input)", c))
            axes[2, c].axis("off")

    plt.show()

    return {
        "x_true": x_true,
        "x_recon": x_recon,
        "z_true": z_true,
        "z_pred": zp_scalar,
        "residual": residual,
    }

In [None]:
# Example channel labels for 4-band data:
chan_names = ["F115W", "F150W", "F277W", "F444W"]

# Extract classification labels
labels = df_aug["classification"].to_numpy()

# Show sample (change ID as desired)
_ = show_input_vs_output(
    autoencoder,
    aug_images,
    aug_redshifts,
    labels,
    31,
    channel_names=chan_names,
    stretch="linear",      
    show_residuals=True,
    cmap="gray",
)

# Visualize the latent space

Since the latent space is set as 100 per the model architecture, we only select the most distinguishing dimensions to construct pairwise plots.

In [None]:
# extract encoder from autoencoder
# inputs to the encoder are just the inputs to the autoencoder
img_in_loaded, z_in_loaded = autoencoder.inputs

# latent output is the output of the "z" layer (Sampling)
z_latent_loaded = autoencoder.get_layer("z").output

encoder = keras.Model(
    inputs=[img_in_loaded, z_in_loaded],
    outputs=z_latent_loaded,
    name="encoder_from_ae")

encoder.summary()

In [None]:
def LatentSpace_pairplot(encoder_model, images, labels, n_dims=4):
    # Encode all images to latent space
    z = encoder_model.predict(images, batch_size=64)
    
    # If latent space > n_dims, only take the first few for clarity
    z = z[:, :n_dims]
    
    # Build a dataframe
    df = pd.DataFrame(z, columns=[f"z{i+1}" for i in range(n_dims)])
    df["classification"] = labels

    # color and transparency maps
    label_to_color = {"LRD": "red", "Galaxy": "dodgerblue", "AGN": "mediumseagreen"}
    label_to_alpha = {"LRD": 1.0, "Galaxy": 0.15, "AGN": 0.45}

    # Plot
    sns.pairplot(df, hue="classification", palette=label_to_color, plot_kws={'alpha':0.4, 's':20})
    plt.suptitle("Latent Space Pairplot by Classification", y=1.02)
    plt.show()

In [None]:
# Extract classification labels
labels = df_aug["classification"].to_numpy()

# Create visualization
LatentSpace_pairplot(encoder, [aug_images, aug_redshifts], labels)

### UMAP

We use Uniform Manifold Approximation and Projection (UMAP) to reduce the dimensionality of the latent space to a 3D projection while preserving both the global structure (overall shape of the data) and local structure (how nearby points relate to each other).

In [None]:
# if you don't have umap
# !pip install umap-learn

There is a known UMAP shadow module in `cuml` which introduces instability and crashes kernels when run on GPU. Run the following block as applicable:

In [None]:
# detect which UMAP package you're actually using
import umap
print("UMAP path:", umap.__file__)

# If the printed path contains "cuml" or "rapids", then:
# !pip uninstall cuml -y
# !pip install umap-learn

In [None]:
# ---------- 1) Get latent space from encoder ----------
latent_space = encoder.predict([aug_images, aug_redshifts],
                                batch_size=64)

# ---------- 2) Run 2D UMAP ----------
reducer_2d = umap.UMAP(
    n_neighbors=50,
    min_dist=0.1,
    n_components=2,      # <-- 2D now
    metric='euclidean'
)

latent_umap_2d = reducer_2d.fit_transform(latent_space)

# ---------- 3) Put into DataFrame ----------
df_latent_2d = pd.DataFrame(latent_umap_2d, columns=['UMAP1', 'UMAP2'])
df_latent_2d['classification'] = df_aug['classification'].values

# ---------- 4) Plot 2D scatter ----------
label_to_color = {"LRD": "red", "Galaxy": "dodgerblue", "AGN": "mediumseagreen"}
label_to_alpha = {"LRD": 1.0, "Galaxy": 0.15, "AGN": 0.45}

plt.figure(figsize=(9, 7))

for cls in df_latent_2d['classification'].unique():
    subset = df_latent_2d[df_latent_2d['classification'] == cls]
    plt.scatter(
        subset['UMAP1'],
        subset['UMAP2'],
        color=label_to_color[cls],
        alpha=label_to_alpha[cls],
        label=cls,
        s=20
    )

plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title('2D UMAP of latent space')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
reducer = umap.UMAP(
    n_neighbors=15,      # local neighborhood size
    min_dist=0.1,        # how tightly UMAP packs points
    n_components=3,      # output dimensions (2D or 3D)
    metric='euclidean')   # distance metric

latent_space = encoder.predict([aug_images, aug_redshifts])

In [None]:
# create reduced dimension latent space
latent_umap = reducer.fit_transform(latent_space)
latent_dims = ['UMAP1', 'UMAP2', 'UMAP3']

df_latent = pd.DataFrame(latent_umap, columns=latent_dims)
df_latent['classification'] = df_aug['classification']

# color + alpha mapping to highlight LRDs
label_to_color = {"LRD": "red", "Galaxy": "dodgerblue", "AGN": "mediumseagreen"}
label_to_alpha = {"LRD": 1.0, "Galaxy": 0.15, "AGN": 0.45}

# pairwise plots 
plot_pairs = [('UMAP1', 'UMAP2'), ('UMAP2', 'UMAP3'), ('UMAP1', 'UMAP3')]

for x_dim, y_dim in plot_pairs:
    plt.figure(figsize=(9, 7))
    
    for cls in df_latent['classification'].unique():
        subset = df_latent[df_latent['classification'] == cls]
        plt.scatter(subset[x_dim], subset[y_dim],
                    color=label_to_color[cls],
                    alpha=label_to_alpha[cls],
                    label=cls,
                    s=20)
    
    plt.xlabel(x_dim)
    plt.ylabel(y_dim)
    plt.legend()
    plt.title(f"{x_dim} vs {y_dim}")
    plt.show()

In [None]:
# 3D plot to see latent dimensions at once
from mpl_toolkits.mplot3d import Axes3D

# create figure and 3D axes
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

# plot for each classification to maintain different alpha levels
for cls in df_latent['classification'].unique():
    subset = df_latent[df_latent['classification'] == cls]
    ax.scatter(subset['UMAP1'], subset['UMAP2'], subset['UMAP3'],
                   c=label_to_color[cls],
                   alpha=label_to_alpha[cls],
                   label=cls,
                   s=20)
    
ax.set_xlabel('UMAP1')
ax.set_ylabel('UMAP2')
ax.set_zlabel('UMAP3')
ax.legend(title = 'Classification')
ax.set_title('3D UMAP of Latent Space')

plt.show()

## Compare latent space distances

In [None]:
from scipy.spatial.distance import cdist

# Make data frames for separate classes
latent_LRD = df_latent[df_latent['classification'] == 'LRD'][['UMAP1','UMAP2','UMAP3']].values
latent_Galaxy = df_latent[df_latent['classification'] == 'Galaxy'][['UMAP1','UMAP2','UMAP3']].values
latent_AGN = df_latent[df_latent['classification'] == 'AGN'][['UMAP1','UMAP2','UMAP3']].values

# Euclidean distances from each LRD to every point in the other groups
dist_LRD_Galaxy = cdist(latent_LRD, latent_Galaxy)
dist_LRD_AGN = cdist(latent_LRD, latent_AGN)

# For each LRD, keep the nearest distance to each group
df_LRD = df_latent[df_latent['classification'] == 'LRD'].copy()
df_LRD['nearest_galaxy'] = dist_LRD_Galaxy.min(axis=1)
df_LRD['nearest_agn'] = dist_LRD_AGN.min(axis=1)

# Summary statistics
summary_panel = df_LRD[['nearest_galaxy', 'nearest_agn']].describe(percentiles=[.25, .5, .75]).T
summary_panel.rename(columns={
    '25%': 'Q1',
    '50%': 'Median',
    '75%': 'Q3'}, inplace=True)
summary_panel = summary_panel[['min','Q1','Median','Q3','max','mean','std']]

print(summary_panel)