In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools

import numpy as np
import pandas as pd
import os

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

In [None]:
root_path = global_setup.DATA_path

In [None]:
# Load DESI data from file 'jpas_idr_classification_xmatch_desi_dr1.fits.gz'
DESI_split = {}
DESI_split = loading_tools.load_DESI_dsets(DESI_split, root_path, global_setup.load_DESI_data)
DESI = loading_tools.concatenate_DESI_splits(DESI_split, merged_pd_key="DESI_pd", merged_np_key="DESI_np")

DESI_Lilianne_split = {}
DESI_Lilianne_split = loading_tools.load_DESI_Lilianne_dsets(
    DESI_Lilianne_split, root_path, global_setup.load_DESI_data_Lilianne, pd_keys=DESI['DESI_pd'].keys()
)
DESI_Lilianne = loading_tools.concatenate_DESI_splits(DESI_Lilianne_split, merged_pd_key="DESI_pd", merged_np_key="DESI_np")

In [None]:
labels_uniques_DESI, counts_uniques_DESI = np.unique(DESI['DESI_pd']["TARGETID"], return_counts=True)
labels_uniques_DESI_Lilianne, counts_uniques_DESI_Lilianne = np.unique(DESI_Lilianne['DESI_pd']["TARGETID"], return_counts=True)

print()
print("Number of unique TARGETIDs DESI:", len(labels_uniques_DESI))
print("Total number of objects with unique TARGETIDs DESI (TARGETIDs might be repeated):", np.sum(counts_uniques_DESI))
print()
print("Number of unique TARGETIDs DESI-Lilianne:", len(labels_uniques_DESI_Lilianne))
print("Total number of objects with unique TARGETIDs DESI-Lilianne (TARGETIDs might be repeated):", np.sum(counts_uniques_DESI_Lilianne))

In [None]:
IDs_only_1, IDs_only_2, IDs_both, idxs_only_1, idxs_only_2, idxs_both_1, idxs_both_2 = crossmatch_tools.crossmatch_IDs_two_datasets(
    DESI['DESI_pd']["TARGETID"], DESI_Lilianne['DESI_pd']["TARGETID"]
)
print("Number of unique TARGETIDs only in 1:", len(IDs_only_1))
print("Number of unique TARGETIDs only in 2:", len(IDs_only_2))
print("Number of unique TARGETIDs both in 1 & 2:", len(IDs_both))
if len(IDs_only_1) > 0:
    print("Total number of objects from 1, TARGETID only in 1 (TARGETIDs might be repeated):", len(np.concatenate(idxs_only_1)))
if len(IDs_only_2) > 0:
    print("Total number of objects from 2, TARGETID only in 2 (TARGETIDs might be repeated):", len(IDs_only_2))
if len(IDs_both) > 0:
    print("Total number of objects from 1, TARGETID both in 1 & 2 (TARGETIDs might be repeated):", len(IDs_both))
    print("Total number of objects from 2, TARGETID both in 1 & 2 (TARGETIDs might be repeated):", len(IDs_both))

In [None]:
DESI_Lilianne_flux = 10**((DESI_Lilianne['DESI_np'] - 22.50)/-2.5)

In [None]:
NN_plot_both = 10
LoA_indices_plot = np.random.choice(np.arange(len(idxs_both_1)), NN_plot_both, replace=False)

fig, ax = plt.subplots(figsize=(9, 6))
ax.set_xlabel(r'$\mathrm{Filter~Index}$', fontsize=20)
ax.set_ylabel(r'Flux [arb. units]', fontsize=20)

colors = plotting_utils.get_N_colors(NN_plot_both, colormap=plt.cm.tab10)  
for ii, LoA_idx in enumerate(LoA_indices_plot):

    tmp_TARGETID = IDs_both[LoA_idx]
    print("TARGETID", tmp_TARGETID)

    tmp_indices = idxs_both_1[LoA_idx]
    for jj, tmp_idx in enumerate(tmp_indices):
        assert DESI['DESI_pd']["TARGETID"][tmp_idx] == tmp_TARGETID, "TARGETID mismatch DESI"
        tmp_obs = DESI['DESI_np'][tmp_idx][..., 0]
        ax.plot(np.arange(len(tmp_obs)), tmp_obs, color=colors[ii], ls='-', lw=1, alpha=0.85)

    tmp_indices = idxs_both_2[LoA_idx]
    for jj, tmp_idx in enumerate(tmp_indices):
        assert DESI_Lilianne['DESI_pd']["TARGETID"][tmp_idx] == tmp_TARGETID, "TARGETID mismatch DESI-Lillianne"
        # tmp_obs = DESI_Lilianne['DESI_np'][tmp_idx][..., 0]
        tmp_obs = DESI_Lilianne_flux[tmp_idx][..., 0]
        print(tmp_obs)
        ax.plot(np.arange(len(tmp_obs)), tmp_obs, color=colors[ii], ls='--', lw=3, alpha=0.85)

# Adjust tick label size
ax.tick_params(axis='both', labelsize=14)

# Y scale and plot style
ax.set_yscale("log")

plt.tight_layout()
plt.show()

In [None]:
root_path = global_setup.DATA_path
load_JPAS_data = global_setup.load_JPAS_data
load_DESI_data = global_setup.load_DESI_data

random_seed_load = 42

dict_clean_data_options = global_setup.dict_clean_data_options

DATA = loading_tools.load_dsets(root_path=root_path, datasets_jpas=load_JPAS_data, datasets_desi=load_DESI_data, random_seed=random_seed_load)

DATA_clean = cleaning_tools.clean_and_mask_data(
    DATA=DATA,
    apply_masks=dict_clean_data_options["apply_masks"],
    mask_indices=dict_clean_data_options["mask_indices"],
    magic_numbers=dict_clean_data_options["magic_numbers"],
    i_band_sn_threshold=dict_clean_data_options["i_band_sn_threshold"],
    magnitude_flux_key=dict_clean_data_options["magnitude_flux_key"],
    magnitude_threshold=dict_clean_data_options["magnitude_threshold"],
    z_lim_QSO_cut=dict_clean_data_options["z_lim_QSO_cut"]
)

In [None]:
print(DATA['JPAS'].keys())
print(DATA['JPAS']["all_observations"].shape)
print(DATA['JPAS']["all_errors"].shape)
print(DATA['JPAS']["all_pd"].keys())
print(DATA['JPAS']["all_pd"]['TARGETID'].shape)
print()

print(DATA['DESI'].keys())
print(DATA['DESI']["all_np"].shape)
print(DATA['DESI']["all_pd"].keys())
print(DATA['DESI']["all_pd"]['TARGETID'].shape)
print()

print(DATA_clean['JPAS'].keys())
print(len(DATA_clean['JPAS']['SPECTYPE']))
print()

print(DATA_clean['DESI'].keys())
print(len(DATA_clean['DESI']['SPECTYPE']))
print()

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(0, 17, 19, 21, 22, 22.5, 100)
output_key="MAG_BIN_ID"

magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

magnitude = DATA_clean['JPAS'][magnitude_key]
spectype = DATA_clean['JPAS']['SPECTYPE_int']

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

tmp_hist = spectype
unique_spectype_names, counts = np.unique(tmp_hist, return_counts=True)

colors = plt.cm.inferno(np.linspace(0., 0.8, len(counts)))
explode = [0.05] * len(counts)

# Function to display count and percentage in two lines
def make_autopct(counts):
    def my_autopct(pct):
        total = sum(counts)
        absolute = int(round(pct * total / 100.0))
        return f"{absolute}\n({pct:.1f}%)"
    return my_autopct

# Create pie chart
wedges, texts, autotexts = ax.pie(
    counts,
    labels=unique_spectype_names,
    autopct=make_autopct(counts),
    startangle=140,
    colors=colors,
    explode=explode,
    wedgeprops={'edgecolor': 'black', 'linewidth': 1},
    textprops={'fontsize': 12}
)

# Customize font color inside pie
for autotext in autotexts:
    autotext.set_color("white")

# ax.set_title(f"{survey_key}", fontsize=16)

# General title
plt.suptitle("Distribution of Spectral Types", fontsize=20)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

In [None]:
for ii in range(len(magnitude_ranges)):

    fig, ax = plt.subplots(1, 1, figsize=(12, 6))

    mask = (magnitude > magnitude_ranges[ii][0]) & (magnitude <= magnitude_ranges[ii][1])
    mask = np.array(mask, dtype=bool)  # ensure dtype=bool
    tmp_hist = spectype[mask]
    unique_spectype_names, counts = np.unique(tmp_hist, return_counts=True)

    colors = plt.cm.inferno(np.linspace(0., 0.8, len(counts)))
    explode = [0.05] * len(counts)

    # Function to display count and percentage in two lines
    def make_autopct(counts):
        def my_autopct(pct):
            total = sum(counts)
            absolute = int(round(pct * total / 100.0))
            return f"{absolute}\n({pct:.1f}%)"
        return my_autopct

    # Create pie chart
    wedges, texts, autotexts = ax.pie(
        counts,
        labels=unique_spectype_names,
        autopct=make_autopct(counts),
        startangle=140,
        colors=colors,
        explode=explode,
        wedgeprops={'edgecolor': 'black', 'linewidth': 1},
        textprops={'fontsize': 12}
    )

    # Customize font color inside pie
    for autotext in autotexts:
        autotext.set_color("white")

    ax.set_title("Magnitude Range: " + f"{magnitude_ranges[ii]}", fontsize=16)

    # General title
    # plt.suptitle("Distribution of Spectral Types", fontsize=20)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter

def plot_spectype_fractions(magnitude, spectype, magnitude_ranges, *,
                            title="Spectral-type fractions vs. magnitude",
                            ymin=0.0, ymax=1.0):
    """
    Plot per-bin fractions of each spectral type across magnitude ranges.

    Parameters
    ----------
    magnitude : array-like of shape (N,)
        Magnitudes for all objects (numeric).
    spectype : array-like of shape (N,)
        Spectral type labels (strings or ints). Lists are OK.
    magnitude_ranges : list of (low, high) tuples
        Bins are interpreted as (low, high] (i.e., right-closed).
    title : str
        Plot title.
    ymin, ymax : float
        y-limits (set to 0..1; axis displays percentages).
    """

    # Ensure arrays
    magnitude = np.asarray(magnitude)
    spectype = np.asarray(spectype)  # convert list-of-strings to ndarray for boolean indexing

    # Unique spectral classes from the whole sample (stable order by overall frequency)
    all_classes, all_counts = np.unique(spectype, return_counts=True)
    sort_idx = np.argsort(-all_counts)  # most frequent first
    classes = all_classes[sort_idx]

    n_bins = len(magnitude_ranges)
    n_classes = len(classes)

    # Count matrix: rows=classes, cols=bins
    counts = np.zeros((n_classes, n_bins), dtype=int)
    bin_totals = np.zeros(n_bins, dtype=int)

    for j, (low, high) in enumerate(magnitude_ranges):
        # (low, high] as in your pie code
        mask = (magnitude > low) & (magnitude <= high)
        bin_totals[j] = int(mask.sum())
        if bin_totals[j] == 0:
            continue

        # Count per class in this bin
        sp_bin = spectype[mask]
        u, c = np.unique(sp_bin, return_counts=True)

        # Map into the rows of `counts`
        # (only update rows corresponding to classes present in this bin)
        cls_to_row = {cls: i for i, cls in enumerate(classes)}
        for cls, cnt in zip(u, c):
            if cls in cls_to_row:
                counts[cls_to_row[cls], j] = cnt

    # Fractions per bin (handle empty bins)
    with np.errstate(invalid='ignore', divide='ignore'):
        fractions = counts / np.maximum(bin_totals, 1)  # avoid div-by-zero; empty bins -> 0

    # X positions and tick labels
    x = np.arange(n_bins)
    tick_labels = [f"({lo:.2f}, {hi:.2f}]" for (lo, hi) in magnitude_ranges]

    # Plot
    fig, ax = plt.subplots(figsize=(10, 6))

    # Choose a qualitative colormap with enough distinct colors
    cmap = plt.get_cmap("tab20" if n_classes > 10 else "tab10")

    for i, cls in enumerate(classes):
        ax.plot(x, counts[i], marker='o', lw=2, label=str(cls), color=cmap(i % cmap.N))

    # # Annotate number of objects per bin on top
    # for j in range(n_bins):
    #     ax.text(j, min(1.02, ymax + 0.02), f"n={bin_totals[j]}",
    #             ha='center', va='bottom', fontsize=9)

    ax.set_xticks(x)
    ax.set_xticklabels(tick_labels, rotation=30, ha='right')
    # ax.set_ylim(ymin, ymax)
    # ax.yaxis.set_major_formatter(PercentFormatter(1.0))
    ax.set_xlabel("Magnitude range")
    ax.set_ylabel("Counts")
    ax.set_yscale("log")
    ax.set_title(title)

    # Move legend outside if many classes
    ncol = 1 if n_classes <= 8 else 2
    ax.legend(title="Spectral type", bbox_to_anchor=(1.02, 1), loc="upper left", ncol=ncol)
    plt.tight_layout()
    plt.show()


In [None]:
# Example: magnitude_ranges = [(16,18), (18,19), (19,20), (20,21), (21,22)]
plot_spectype_fractions(magnitude, spectype, magnitude_ranges, title=None)


In [None]:
    DATA = loading_tools.load_dsets(
        root_path=root_path,
        datasets_jpas=load_JPAS_data,
        datasets_desi=load_DESI_data,
        random_seed=random_seed_load
    )