In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools

import numpy as np

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

In [None]:
root_path = global_setup.DATA_path
load_JPAS_data = global_setup.load_JPAS_data
load_DESI_data = global_setup.load_DESI_data

random_seed_load = 42

In [None]:
DATA = loading_tools.load_dsets(root_path=root_path, datasets_jpas=load_JPAS_data, datasets_desi=load_DESI_data, random_seed=random_seed_load)

In [None]:
encoded, labels = cleaning_tools.encode_strings_to_integers(DATA['DESI']['all_pd']['MORPHTYPE'])
print(labels)
np.unique(encoded, return_counts=True)

In [None]:
cleaning_tools.encode_strings_to_integers(DATA['JPAS']['all_pd']['MORPHTYPE'])

In [None]:
dict_clean_data_options = global_setup.dict_clean_data_options

In [None]:
DATA_clean = cleaning_tools.clean_and_mask_data(
    DATA=DATA,
    apply_masks=dict_clean_data_options["apply_masks"],
    mask_indices=dict_clean_data_options["mask_indices"],
    magic_numbers=dict_clean_data_options["magic_numbers"],
    i_band_sn_threshold=dict_clean_data_options["i_band_sn_threshold"],
    z_lim_QSO_cut=dict_clean_data_options["z_lim_QSO_cut"]
)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

for ii, survey_key in enumerate(["JPAS", "DESI"]):

    unique_spectype_names, counts = np.unique(DATA_clean[survey_key]['SPECTYPE'], return_counts=True)

    ratios = counts / np.sum(counts)
    print(survey_key, "- ratios:", ratios)
    inverse_ratios = 1.0 / ratios
    class_weights = inverse_ratios / inverse_ratios.mean()
    print(survey_key, "- Class_weights:", class_weights)

    ax = axs[ii]
    colors = plt.cm.inferno(np.linspace(0., 0.8, len(counts)))
    explode = [0.05] * len(counts)

    # Function to display count and percentage in two lines
    def make_autopct(counts):
        def my_autopct(pct):
            total = sum(counts)
            absolute = int(round(pct * total / 100.0))
            return f"{absolute}\n({pct:.1f}%)"
        return my_autopct

    # Create pie chart
    wedges, texts, autotexts = ax.pie(
        counts,
        labels=unique_spectype_names,
        autopct=make_autopct(counts),
        startangle=140,
        colors=colors,
        explode=explode,
        wedgeprops={'edgecolor': 'black', 'linewidth': 1},
        textprops={'fontsize': 12}
    )

    # Customize font color inside pie
    for autotext in autotexts:
        autotext.set_color("white")

    ax.set_title(f"{survey_key}", fontsize=16)

# General title
plt.suptitle("Distribution of Spectral Types", fontsize=20)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()