## Functionality to combine visualizations

In [None]:
import warnings
warnings.filterwarnings('ignore')
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=DeprecationWarning)
%matplotlib inline  

import numpy as np  
from matplotlib import rc
import matplotlib as mpl
import matplotlib.pyplot as plt
import pickle

from models.binary_model.binary_model import BinaryModel
from models.ind_model.ind_model import OvAModel
from models.multi_model.multi_model import MultiModel

from thex_data.data_consts import *
from mainmodel.helper_compute import *
from mainmodel.helper_plotting import *
from utilities import utilities as thex_utils


mags = ["g_mag",  "r_mag", "i_mag", "z_mag", "y_mag",
        "W1_mag", "W2_mag",
        "J_mag", "K_mag", "H_mag"]

EXPS_DIR = ROOT_DIR + "/../../exps/v8_db_runs/reg_runs/"

codes= ["A1", "F1", "B1", "G1"]
 
# rc('text', usetex=True)
mpl.rcParams['font.serif'] = ['times', 'times new roman']
mpl.rcParams['font.family'] = 'serif'

In [None]:
# Multi 
multi_model = MultiModel(cols = mags,
                       folds = 10,
                       transform_features = True, 
                       case_code = codes,
                       balanced_purity = True) 
multi_model=load_prev_exp(EXPS_DIR, "Multiclass_Classifier1/", multi_model)


# Binary
model2 = BinaryModel(cols = mags,
                       folds = 10,
                       transform_features = True, 
                       case_code = codes,
                       balanced_purity = True)
binary_model = load_prev_exp(EXPS_DIR, "Binary_Classifiers2/", model2)


# OvA 
model = OvAModel(cols = mags,
                 folds = 10,
                 transform_features = True, 
                 case_code = codes,
                 balanced_purity = True)

ova_model = load_prev_exp(EXPS_DIR, "OVA_Classifier1/", model) 

## Merge purity/comp average plots

Below is the code to merge the OVA and multi KDE purity/comp plots; and to visualize the binary classifiers' average balanced purity and completeness in a similar fashion, separately.  

In [None]:
def get_metrics_ax(model, pur_ax, comp_ax):
    c_baselines, p_baselines = compute_baselines(
        model.class_counts, 
        model.class_labels,
        model.get_num_classes(), 
        model.balanced_purity,  
        model.class_priors)
    pc_per_trial = model.get_pc_per_trial(model.results)
    ps, cs = model.get_pc_performance(pc_per_trial)
    p_intvls, c_intvls =compute_confintvls(pc_per_trial, model.class_labels, model.balanced_purity)
    
    y_indices, class_names=model.plot_metrics_ax(pur_ax, ps, "Balanced Purity", p_baselines, p_intvls)

    y_indices, class_names=model.plot_metrics_ax(comp_ax, cs, "Completeness", c_baselines, c_intvls)
    return y_indices, class_names


In [None]:
# Plot Multi & OVA
f, ax = plt.subplots(nrows=2, ncols=2,
                     sharex=True, sharey=True,
                     figsize=(8,8),  dpi=600)
rc('text', usetex=True)
mpl.rcParams['font.serif'] = ['times', 'times new roman']
mpl.rcParams['font.family'] = 'serif'
print("\n --------------------------- ova_model--------------------------- \n\n")
get_metrics_ax(ova_model, ax[0][0], ax[0][1]) 

print("\n --------------------------- multi_model--------------------------- \n\n")
y_indices, class_names=get_metrics_ax(multi_model, ax[1][0], ax[1][1]) 

ax[0][0].tick_params(direction="in")
ax[0][1].tick_params(direction="in")
ax[0][1].text(1.05,2.1, "OVA", fontsize=20)
ax[1][1].text(1.05,2.1, "Multiclass\nKDE", fontsize=20)
plt.subplots_adjust(wspace=0, hspace=0)
ax[0][0].set_yticks(y_indices)
ax[0][0].set_yticklabels(clean_class_names(class_names),
           fontsize=16,  horizontalalignment='right')
ax[1][0].set_yticks(y_indices)
ax[1][0].set_yticklabels(clean_class_names(class_names),
           fontsize=16,  horizontalalignment='right')
plt.xticks(np.linspace(0,1,10, endpoint=False))

ax[1][0].set_xlabel("Balanced Purity (\%)", fontsize=TICK_S)
ax[1][1].set_xlabel("Completeness (\%)", fontsize=TICK_S)
plt.savefig("../output/custom_figures/multis_metrics.pdf", bbox_inches='tight')
plt.show()
#  Binary
f, ax = plt.subplots(nrows=1, ncols=2,
                     sharex=True, sharey=True,
                     figsize=(6,3),  dpi=600)
rc('text', usetex=True)
mpl.rcParams['font.serif'] = ['times', 'times new roman']
mpl.rcParams['font.family'] = 'serif'
 
y_indices, class_names=get_metrics_ax(model = binary_model, pur_ax = ax[0], comp_ax = ax[1] )

ax[0].tick_params(direction="in")
ax[1].tick_params(direction="in")

plt.subplots_adjust(wspace=0, hspace=0)

ax[0].set_yticks(y_indices)
ax[0].set_yticklabels(clean_class_names(class_names),
           fontsize=14,  horizontalalignment='right')
plt.xticks(np.linspace(0,1,10, endpoint=False))
ax[0].set_xlabel("Balanced Purity (\%)", fontsize=14)
ax[1].set_xlabel("Completeness (\%)", fontsize=14)
plt.savefig("../output/custom_figures/binary_metrics.pdf", bbox_inches='tight')
plt.show()

## Plot probability plots together

### Empirical Probabilities

In [None]:
# These calls are for the old way of generating empirical prob plots - when they were not balanced.
# Now, I call the new functions from within plot_rates_together

# ova_model.range_metrics = ova_model.compute_probability_range_metrics(
#         ova_model.results, bin_size=0.2)
# binary_model.range_metrics = binary_model.compute_probability_range_metrics(
#         binary_model.results, bin_size=0.2)
# multi_model.range_metrics = multi_model.compute_probability_range_metrics(
#         multi_model.results, bin_size=0.2)

In [None]:
plot_rates_together(binary_model, ova_model, multi_model, indices=[0,1,2,3,4,5])

In [None]:
plot_rates_together(binary_model, ova_model, multi_model, indices=[6,7,8,9,10,11])

The code below has been commented out becaues we no longer use this empirical-probability callout figure in the paper. Since these plots have been updated to be 'balanced' they lose their sense of calibration.

In [None]:
# cur_model.class_labels
# # ia, ic, iib, iin, tde
# multi_model_preds = np.concatenate(multi_model.results)
# multi_model.class_prob_rates = get_multi_emp_prob_rates(multi_model_preds,
#                                                             multi_model.class_labels,
#                                                             0.2,
#                                                             multi_model.class_counts)
# cur_model = multi_model
# # cur_model.range_metrics = cur_model.compute_probability_range_metrics(
# #         cur_model.results, bin_size=0.2)

# # call outs for multi only
# indices = [0, 4, 7, 10,11]
# rc('text', usetex=True)
# cur_model = multi_model


# class_labels = cur_model.class_labels 

# num_classes = len(indices)

# f, ax = plt.subplots(nrows=1,
#                      ncols=len(indices),
#                      sharex=True, sharey=True,
#                      figsize=(8, 1.9),
#                      dpi=280)
# row_index = 0
# for index in indices: 
#     cn = cur_model.class_labels[index]
#     plot_model_rates(cn, cur_model, ax[row_index])  
#     ax[row_index].text(-0.45, 0.81, clean_class_name(cn), fontsize=14)
# #     plot_model_rates(cn, ova_model, ax[row_index][1])
#     row_index+=1
    
# y_indices = [0.1, 0.3, 0.5, 0.7, 0.9]
# y_ticks = ["10", "30", "50", "70", "90"]
# # x and y indices/ticks are the same
# plt.xticks(np.arange(5), y_ticks)
# plt.yticks(y_indices, y_ticks)
# plt.rc('xtick', labelsize=10)
# plt.rc('ytick', labelsize=10)

# mpl.rcParams['font.serif'] = ['times', 'times new roman']
# mpl.rcParams['font.family'] = 'serif'

# f.text(0.5, -0.07, 'Assigned Probability ' + r' $\pm10\%$', fontsize=14, ha='center')
# f.text(0.05, .5, r'Empirical Prob. ($\%$)',
#        fontsize=14, va='center', rotation='vertical')
# plt.subplots_adjust(wspace=0, hspace=0)
# f.savefig(ROOT_DIR + "/output/custom_figures/prob_callouts.pdf", bbox_inches='tight')
# plt.show()

### Purity /completeness curves vs probability threshold

In [None]:
plot_pc_curves_together(binary_model, ova_model, multi_model, indices=[0,1,2,3,4,5])
plot_pc_curves_together(binary_model, ova_model, multi_model, indices=[6,7,8,9,10,11])


In [None]:
plot_pc_curves_together(binary_model, ova_model, multi_model, indices=[1,2,6,8,9,11])

# Combining with vs without priors


### For Trials of K-Fold Cross Validation
Combine X trials for with and without priors, get average meaure and confidence intervals.
Currently run 6-fold cross validation 10 times and save results & y of each experiment in a list, agg_results.pickle.

*******Stuff below here isn't working great.

In [None]:
GET_DIR = ROOT_DIR + "/../../exps/v8_db_runs/new_lsst_tests/"

In [None]:
multi_no_priors = MultiModel(cols = mags,
                       folds = 6,
                       transform_features = True, 
                       case_code = codes,
                       balanced_purity = True,
                            lsst_test=True) 
multi_no_priors=load_prev_exp(GET_DIR, "Multiclass_Classifier1/", multi_no_priors)

In [None]:
multi_priors = MultiModel(cols = mags,
                       folds = 6,
                       transform_features = True, 
                       case_code = codes,
                       balanced_purity = True,
                            lsst_test=True,
                            priors=True) 
multi_priors=load_prev_exp(GET_DIR, "Multiclass_Classifier2/", multi_no_priors)

In [None]:
plot_EPs_Priors(Model_NP=multi_no_priors, Model_WP=multi_priors)

In [None]:
def plot_EPs_Priors(Model_NP, Model_WP):
    """
    Plot Empirical Probability Plots (EPs) for with priors vs no-priors
    """
    rc('text', usetex=True) 
    mpl.rcParams['font.serif'] = ['times', 'times new roman']
    mpl.rcParams['font.family'] = 'serif'
    
    class_labels = ['Unspecified Ia', 'Ia-91bg', 'Ibc', 'II', 'SLSN-I', 'TDE'] 
    f, ax = plt.subplots(nrows=len(class_labels),
                         ncols=2,
                         sharex=True, sharey=True,
                         figsize=(FIG_WIDTH, 9),
                         dpi=DPI)
    plot_index = 0
    for class_index in range(len(class_labels)):
        if plot_index == 0:
            # Add titles to top of plots
            ax[plot_index][0].set_title("Uniform Priors", fontsize=14)
            ax[plot_index][1].set_title("Frequency-based Priors", fontsize=14)

    
    
        ModelMets_NP_preds = np.concatenate(Model_NP.results)
        ModelMets_NP.class_prob_rates = get_multi_emp_prob_rates(ModelMets_NP_preds,
                                                              Model_NP.class_labels,
                                                              0.2,
                                                              Model_NP.class_counts)
        
        ModelMets_WP_preds = np.concatenate(Model_WP.results)
        ModelMets_WP.class_prob_rates = get_multi_emp_prob_rates(ModelMets_WP_preds,
                                                              Model_WP.class_labels,
                                                              0.2,
                                                              Model_WP.class_counts)
        
        class_name = class_labels[class_index]  
        plot_model_rates(class_name, ModelMets_NP, ax[plot_index][0])
        plot_model_rates(class_name, ModelMets_WP, ax[plot_index][1])

        pretty_class_name = clean_class_name(class_name)
        ax[plot_index][0].text(-0.2, 0.8, pretty_class_name, fontsize=16) 
        plot_index += 1

    y_indices = [0.1, 0.3, 0.5, 0.7, 0.9]
    y_ticks = ["10", "30", "50", "70", "90"]
    # x and y indices/ticks are the same
    plt.xticks(np.arange(5), y_ticks)
    plt.yticks(y_indices, y_ticks)
    plt.rc('xtick', labelsize=14)
    plt.rc('ytick', labelsize=14)

    f.text(0.5, 0.06, 'Assigned Probability' + r' $\pm$10\%', fontsize=14, ha='center')
    f.text(0.02, .5, r'Empirical Probability $\equiv$ P/Total ($\%$)',
               fontsize=14, va='center', rotation='vertical')

    plt.subplots_adjust(wspace=0, hspace=0)

    f.savefig(ROOT_DIR + "/output/custom_figures/merged_metrics_priors_comp_AGG.pdf", bbox_inches='tight')
    plt.show()



In [None]:


# Define class to keep track of all relevant model mets and details

class ModelsMets:
    def __init__(self, with_priors):
        """
        Keep track of variables
        :param with_priors: Boolean for whether this model uses frequency-based priors (True) or uniform (False)
        """
        # Pull down all data
        # Lists of 10 elements. Each element is results/y for that run of 6-fold CV. 
        all_results, all_y = self.get_agg_data(WP=with_priors) 
            
        self.all_results = all_results
        self.all_y = all_y
        
        model = MultiModel(cols = mags,
                       folds = 6,
                       min_class_size = 3,
                       max_class_size = 4800,
                       transform_features = True,
                       case_code = codes,
                       lsst_test= True,
                       priors =with_priors)
        
        self.model = model
        
        
#         # Lists of length 10, containing maps of purity/comp for each run
#         self.all_purities, self.all_comps = self.get_all_measures()
# #         Map from class name to list of length for avg purity for each trial
#         self.class_purities = self.collect_mets(self.all_purities)
#         self.class_comps = self.collect_mets(self.all_comps)
        
    
    def get_agg_data(self, WP):
        """
        Get results and y from file 
        """
        rname = "Multiclass_Classifier2/results.pickle" if WP else "Multiclass_Classifier1/results.pickle"
        yname = "Multiclass_Classifier2/y.pickle" if WP else "Multiclass_Classifier1/y.pickle"
        with open(GET_DIR + rname, 'rb') as handle:
            results = pickle.load(handle)
        with open(GET_DIR + yname, 'rb') as handle:
            y = pickle.load(handle)
        return results, y

    def get_all_measures(self):
        trialPs = []
        trialCs = []
        for index, results in enumerate(self.all_results):
            y = self.all_y[index]

            # Compute performance for this 6-fold set.
            pc_per_trial = self.model.get_pc_per_trial(results)
            ps, cs = self.model.get_pc_performance(pc_per_trial)
            # Get average purity per class
            trialPs.append(ps)

            # Get average completeness per class
            trialCs.append(cs)
        return trialPs, trialCs

    def collect_mets(self, metSet):
        """
        Convert metSet from list of maps (where each map is class name to list of values) 
        to map from class name to list of all mets for that class.
        :param metSet: List of maps, where each map is set of metrics from that trial
        """
        collMets = {class_name : [] for class_name in self.model.class_labels}
        for class_name in self.model.class_labels:
            for index, curmetSet in enumerate(metSet):
                collMets[class_name].append(curmetSet[class_name])
        return collMets
    
    
    def get_all_range_metrics(self):
        """
        Get the range metrics for each trial, saved as list
        RMs is [TP, Totals], and all_class_positives is just Ps
        """
        RMs = []
        all_class_positives = []
        for index, results in enumerate(self.all_results):
            # curRM is map {class_name: [tp_range_counts, total_range_counts] }
            curRM = self.model.compute_probability_range_metrics(results, bin_size=0.2, concat=True)
            #  self.model.class_positives is map {class_name:  pos_class_per_range }
            all_class_positives.append(self.model.class_positives)
            RMs.append(curRM)
            
        class_RMs = self.collect_mets(RMs)
        class_Ps = self.collect_mets(all_class_positives)
        return class_RMs, class_Ps
    
    def get_avg_range_metrics(self):
        """
        Using the list of range metrics, over each trial, average them. 
        avg_range_metrics: Ps and Totals in each range.
        """
        class_RMs, class_Ps = self.get_all_range_metrics()
        N=10
        avg_range_metrics = {class_name : [] for class_name in self.model.class_labels}
        class_prob_rates = {class_name : [] for class_name in self.model.class_labels}
        for class_name in self.model.class_labels:
            class_Positives = np.array([0,0,0,0,0])
            class_Totals = np.array([0,0,0,0,0])
            all_class_RMs = np.array(class_RMs[class_name])
            all_class_Ps = np.array(class_Ps[class_name])
            for index, row in enumerate(all_class_RMs):
                class_Totals = np.add(class_Totals, row[1])
                class_Positives = np.add(class_Positives, all_class_Ps[index])

            pos = (class_Positives/ N).astype(int).tolist()
            tots = (class_Totals/ N).astype(int).tolist()
            avg_range_metrics[class_name] = [pos, tots]
            class_prob_rates[class_name] = np.divide(pos, tots)
        return avg_range_metrics, class_prob_rates

    



Average purities and comps across trials and compute conf. intervals


In [None]:

ORDERED_LSST_CLASSES = ['Unspecified Ia', 'Ia-91bg', 'Ibc', 'II', 'SLSN-I', 'TDE']

def get_avg(values):
    values = np.array(values)
    return sum(values) / len(values)

def get_cis(values):
    """
    Calculate confidence intervals [µ − 1.96*SEM, µ + 1.96*SEM] where
    SEM = σ/sqrt(N) 
    σ = sqrt( (1/ N ) ∑_n (a_i − µ)^2 )
    """
    values = np.array(values)
    mean = get_avg(values)
    N = len(values)
    stdev = np.sqrt((sum((np.array(values) - mean) ** 2)) / (N - 1))
    SEM = stdev / np.sqrt(N)
    # 95% confidence intervals, [µ − 1.96σ, µ + 1.96σ]
    low_CI = mean - (1.96 * SEM)
    high_CI = mean + (1.96 * SEM)
    if low_CI < 0:
        low_CI = 0
    if high_CI > 1:
        high_CI = 1
    return [low_CI, high_CI]

def get_plotting_data(metModel, met_type):
    """
    Format metric data into plotting-ready format, which is lists in order of class names. 
    Avg, errors, and baselines as lists.
    """
    m = metModel.model
    c_baselines, p_baselines = compute_baselines(class_counts=m.class_counts,
                                                 class_labels=ORDERED_LSST_CLASSES,
                                                 N=m.get_num_classes(),
                                                 balanced_purity=False, 
                                                 class_priors=m.class_priors)
    if met_type == "Purity":
        METS = metModel.class_purities
        BASELINES = p_baselines
    elif met_type == "Completeness":
        METS = metModel.class_comps
        BASELINES = c_baselines
    plot_Means = []
    plot_Baselines = []
    CIs = []
    for cn in ORDERED_LSST_CLASSES: 
        plot_Means.append(get_avg(METS[cn]))
        CIs.append(get_cis(METS[cn]))
        plot_Baselines.append(BASELINES[cn])
    plot_Errs = prep_err_bars(CIs, plot_Means)
    return plot_Means, plot_Errs, plot_Baselines

def plot_on_ax(ax, color, indices, metModel, met_type):
    """
    Plot measures on this particular axis, using given model and met_type
    :param met_type: Purity or Completeness
    """
    bar_width= 1 / (len(indices)  *2)
    plot_Means, plot_Errs, plot_Baselines = get_plotting_data(metModel, met_type)
    if color == "blue":
        label = "Freq-based Priors"
    else:
        label = "Uniform Priors"
    ax.barh(y=indices, 
            width=plot_Means, 
            height=bar_width, 
            xerr=plot_Errs,
            capsize=7, 
            linewidth=0.5,
            edgecolor=BAR_EDGE_COLOR, 
            ecolor=INTVL_COLOR, 
            color=color,
            label=label)
    for index, baseline in enumerate(plot_Baselines):
        y_val = indices[index]
        ax.vlines(x=baseline,
                   ymin=y_val - (bar_width / 2),
                   ymax=y_val + (bar_width / 2),
                   linewidth=2,
               linestyles=(0, (1, 1)), colors=BSLN_COLOR)
    
    print("\n" + label + " stats:")
    for index, cn in enumerate(ORDERED_LSST_CLASSES):
        d = (plot_Errs[1][index] -plot_Errs[0][index] )
        print(cn + " " + str(round(plot_Means[index],4))) # + " % plus/minus " + str(d))

    
    ax.set_xlim(0, 1)
    ax.set_xticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9, 1])
    ax.set_xticklabels(["0", "", "20", "", "40", "", "60", "", "80", "", ""], fontsize=14)
    ax.set_xlabel(met_type+" (\%)", fontsize=16)  
    

def plot_WP_NP_avg_performance(ModelMets_NP, ModelMets_WP):
    """
    Plot average purity and completeness per class for with priors vs no-priors, using ModelMets objets
    """
    f, ax = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(6,4),  dpi=500)
    rc('text', usetex=True)
    mpl.rcParams['font.serif'] = ['times', 'times new roman']
    mpl.rcParams['font.family'] = 'serif'

    indices = np.linspace(0,1,len(ORDERED_LSST_CLASSES))
    
    bar_width= 1 / (len(indices)  *2)

    m2_indices = np.linspace(0,1,len(ORDERED_LSST_CLASSES)+1)
    indices = m2_indices + bar_width

    m2_indices=m2_indices[:-1]
    indices=indices[:-1]

    

    print("\nData for Purity")
    plot_on_ax(ax[0], "blue", indices, metModel=ModelMets_WP, met_type="Purity")
    plot_on_ax(ax[0], "#EAE7E0", m2_indices, metModel=ModelMets_NP, met_type="Purity")
    print("\nData for Completeness")
    plot_on_ax(ax[1], "blue", indices, metModel=ModelMets_WP, met_type="Completeness")
    plot_on_ax(ax[1], "#EAE7E0", m2_indices, metModel=ModelMets_NP, met_type="Completeness")

    ax[0].set_yticks(indices-(bar_width/2))
    ax[0].set_yticklabels(clean_class_names(ORDERED_LSST_CLASSES),  fontsize=16, horizontalalignment='right')
    ax[1].legend(fontsize=11, loc="best", labelspacing=.2, handlelength=1)

    plt.gca().invert_yaxis()
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig("../output/custom_figures/prior_comp_combined_AGG.pdf", bbox_inches='tight')
    plt.show() 

def plot_WP_NP_EPs(ModelMets_NP, ModelMets_WP):
    """
    Plot Empirical Probability Plots (EPs) for with priors vs no-priors, using ModelMets objets
    """
    rc('text', usetex=True) 
    mpl.rcParams['font.serif'] = ['times', 'times new roman']
    mpl.rcParams['font.family'] = 'serif'

    class_labels = ORDERED_LSST_CLASSES  
    f, ax = plt.subplots(nrows=len(class_labels),
                         ncols=2,
                         sharex=True, sharey=True,
                         figsize=(FIG_WIDTH, 9),
                         dpi=DPI)
    plot_index = 0
    for class_index in range(len(class_labels)):
        if plot_index == 0:
            # Add titles to top of plots
            ax[plot_index][0].set_title("Uniform Priors", fontsize=14)
            ax[plot_index][1].set_title("Frequency-based Priors", fontsize=14)

    
    
        ModelMets_NP_preds = np.concatenate(ModelMets_NP.all_results)
        ModelMets_NP.class_prob_rates = get_multi_emp_prob_rates(ModelMets_NP_preds,
                                                              ModelMets_NP.model.class_labels,
                                                              0.2,
                                                              ModelMets_NP.model.class_counts)
        
        ModelMets_WP_preds = np.concatenate(ModelMets_NP.all_results)
        ModelMets_WP.class_prob_rates = get_multi_emp_prob_rates(ModelMets_WP_preds,
                                                              ModelMets_WP.model.class_labels,
                                                              0.2,
                                                              ModelMets_WP.model.class_counts)
        
        class_name = class_labels[class_index]  
        plot_model_rates(class_name, ModelMets_NP, ax[plot_index][0])
        plot_model_rates(class_name, ModelMets_WP, ax[plot_index][1])

        pretty_class_name = clean_class_name(class_name)
        ax[plot_index][0].text(-0.2, 0.8, pretty_class_name, fontsize=16) 
        plot_index += 1

    y_indices = [0.1, 0.3, 0.5, 0.7, 0.9]
    y_ticks = ["10", "30", "50", "70", "90"]
    # x and y indices/ticks are the same
    plt.xticks(np.arange(5), y_ticks)
    plt.yticks(y_indices, y_ticks)
    plt.rc('xtick', labelsize=14)
    plt.rc('ytick', labelsize=14)

    f.text(0.5, 0.06, 'Assigned Probability' + r' $\pm$10\%', fontsize=14, ha='center')
    f.text(0.02, .5, r'Empirical Probability $\equiv$ P/Total ($\%$)',
               fontsize=14, va='center', rotation='vertical')

    plt.subplots_adjust(wspace=0, hspace=0)

    f.savefig(ROOT_DIR + "/output/custom_figures/merged_metrics_priors_comp_AGG.pdf", bbox_inches='tight')
    plt.show()

In [None]:
ModelMets_WP = ModelsMets(True)
ModelMets_NP = ModelsMets(False)

### Combine Empirical Probability Plots (averaged over trials)
Avg empirical probability plots over trials

In [None]:
ModelMets_NP.range_metrics

In [None]:
ModelMets_WP.class_prob_rates

In [None]:
ModelMets_WP.range_metrics, ModelMets_WP.class_prob_rates  = ModelMets_WP.get_avg_range_metrics()
ModelMets_NP.range_metrics, ModelMets_NP.class_prob_rates = ModelMets_NP.get_avg_range_metrics()

In [None]:
# ModelMets_WP.range_metrics   Map from class name to true_positives, totals
# ModelMets_WP.class_prob_rates   Map from class name to pos_class_per_range / total_range_counts (TP/T)



plot_WP_NP_EPs(ModelMets_NP, ModelMets_WP)

In [None]:
ModelMets_WP.class_prob_rates

In [None]:
ModelMets_NP.class_prob_rates

### Combining single run (of cross fold validation) priors and non priors
If only combining a single run of k-fold cross validation (with and without priors) use the following code instead.

In [None]:
# EXPS_DIR = ROOT_DIR + "/../../exps/v8_db_runs/new_lsst_tests/"

EXPS_DIR = ROOT_DIR + "/../../experiments/v8_db/w_wout_priors_4/"

In [None]:
NUM_FOLDS = 6
MIN_CLASS_SIZE = 3
MAX_CLASS_SIZE = 4220

modelwithout = MultiModel(cols = mags,
                       folds = NUM_FOLDS,
                       min_class_size = MIN_CLASS_SIZE,
                       max_class_size = MAX_CLASS_SIZE,
                       transform_features = True,
                       case_code = codes,
                       lsst_test= True)
 
multi_wo = load_prev_exp(EXPS_DIR, 
                         "Multiclass_Classifier1/", 
                         model=modelwithout)

multiwith = MultiModel(cols = mags,
                       folds = NUM_FOLDS,
                       min_class_size = MIN_CLASS_SIZE,
                       max_class_size = MAX_CLASS_SIZE,
                       transform_features = True,
                       case_code = codes,
                       priors=True, 
                       lsst_test= True)
multi_w = load_prev_exp(EXPS_DIR,  
                        "Multiclass_Classifier2/", 
                        model=multiwith)

from mainmodel.helper_compute import *
from thex_data.data_consts import * 

def get_model_stats(model): 
    N = model.num_runs if model.num_runs is not None else model.num_folds
    pc_per_trial = model.get_pc_per_trial(model.results)
    ps, cs = model.get_pc_performance(pc_per_trial)
    c_baselines, p_baselines = compute_baselines(
        model.class_counts, 
        model.class_labels,
        model.get_num_classes(), 
        model.balanced_purity,  
        model.class_priors)
    p_intvls, c_intvls = compute_confintvls(pc_per_trial, model.class_labels, model.balanced_purity)
    
    p_class_names, p_metrics, p_b, p_intvls = get_ordered_metrics(
            ps,
            p_baselines,
            p_intvls)
    
    c_class_names, c_metrics, c_b, c_intvls = get_ordered_metrics(
            cs,
            c_baselines,
            c_intvls)
    
    m_stats = {"Purity":[p_class_names,p_metrics,p_b,p_intvls],
              "Completeness":[c_class_names, c_metrics, c_b, c_intvls]}
    return m_stats
         
def plot_m(ax, indices, errs, baselines, metrics, name, color):

#     bar_width=0.1
    bar_width= 1 / (len(indices)  *2)
    if len(indices)>6:
        capsize=4
    else:
        capsize=7
    ax.barh(y=indices, 
            width=metrics, 
            height=bar_width, 
            xerr=errs,
            capsize=capsize, 
            linewidth=0.5,
            edgecolor=BAR_EDGE_COLOR, 
            ecolor=INTVL_COLOR, 
            color=color,
            label=name)
    for index, baseline in enumerate(baselines):
        y_val = indices[index]
        ax.vlines(x=baseline,
                   ymin=y_val - (bar_width / 2),
                   ymax=y_val + (bar_width / 2),
                   linewidth=2,
               linestyles=(0, (1, 1)), colors=BSLN_COLOR)
            
def plot_priors_together(axis, WP_model, WP_model_stats, NP_model, NP_model_stats, plot_type):
    m1_class_names, m1_metrics, m1_b, m1_intvls  = WP_model_stats[plot_type]
    m2_class_names, m2_metrics, m2_b, m2_intvls  = NP_model_stats[plot_type]
    WITHOUT_COLOR = "#808080" # without 
    WITH_COLOR = "#80b3ff" # with
        
    m1_errs = prep_err_bars(m1_intvls, m1_metrics) 
    m1_indices = np.linspace(0,1,len(multi_w.class_labels))
    bar_width= 1 / (len(m1_indices) *2)
    m2_indices = m1_indices - bar_width
     
    L = len(multi_w.class_labels)
    m2_indices = np.linspace(0,1,L+1)
    m1_indices = m2_indices + bar_width
    
    m2_indices=m2_indices[:-1]
    m1_indices=m1_indices[:-1]
    
    # Print performance per class Plus/minus 95% confidence
    print("\n"+plot_type+" above-random stats:")
    for index, cn in enumerate(m1_class_names):
        print("Class: "+ cn)
        baseline = m1_b[index]
        cur_met_W= m1_metrics[index]
        d_W = (m1_intvls[index][1]-m1_intvls[index][0])/2
        RT = 3
        if cur_met_W-d_W>baseline:
            print(" WP : " + str(round(cur_met_W,RT)*100) + "\% \pm " + str(round(d_W,RT)*100) +"\%")
        
        cur_met_WO = m2_metrics[index]
        d_WO = (m2_intvls[index][1]-m2_intvls[index][0])/2
        if cur_met_WO-d_WO>baseline:
            print(" NP : " + str(round(cur_met_WO,RT)*100) + "\% \pm " + str(round(d_WO,RT)*100) +"\%")
    
    m2_errs = prep_err_bars(m2_intvls, m2_metrics)
    
    plot_m(ax=axis, indices=m1_indices, errs=m2_errs, baselines=m2_b, 
           metrics=m2_metrics, color=WITHOUT_COLOR, name=NP_model.name)
    plot_m(ax=axis, indices=m2_indices, errs=m1_errs, baselines=m1_b, 
           metrics=m1_metrics, color=WITH_COLOR, name=WP_model.name)
    # Figure formatting
    axis.set_xlim(0, 1)
    axis.set_xticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9, 1])
    axis.set_xticklabels(["0", "", "20", "", "40", "", "60", "", "80", "", ""], fontsize=14)
    axis.set_xlabel(plot_type+" (\%)", fontsize=16)  
    if plot_type == "Purity": 
#         ylabel_points = [0.15, 0.35, 0.55, 0.75, 0.95]
        
        ylabel_points = m1_indices-(bar_width/2)
        axis.set_yticks(ylabel_points)
        axis.set_yticklabels(clean_class_names(m1_class_names),  fontsize=16,
                       horizontalalignment='right')
    else:
        axis.legend(fontsize=11, loc="best", labelspacing=.2, handlelength=1)


f, ax = plt.subplots(nrows=1, ncols=2,
                     sharex=True, sharey=True,
                     figsize=(6,4),  dpi=500)

rc('text', usetex=True)
mpl.rcParams['font.serif'] = ['times', 'times new roman']
mpl.rcParams['font.family'] = 'serif'

print("\n\n --------------- WITH Priors -------------------- \n")
WP_stats =  get_model_stats(multi_w) # WP = with priors

print("\n\n --------------- WITHOUT Priors -------------------- \n")
NP_stats = get_model_stats(multi_wo) # NP = no priors
multi_w.name = "Freq-based priors"
multi_wo.name = "Uniform priors"


plot_priors_together(axis=ax[0],
                     WP_model=multi_w, 
                     WP_model_stats=WP_stats, 
                     NP_model=multi_wo, 
                     NP_model_stats=NP_stats, 
                     plot_type="Purity")
plot_priors_together(axis=ax[1],
                     WP_model=multi_w, 
                     WP_model_stats=WP_stats, 
                     NP_model=multi_wo, 
                     NP_model_stats=NP_stats, 
                     plot_type="Completeness")

plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig("../output/custom_figures/prior_comp_combined.pdf", bbox_inches='tight')
plt.show() 

### Combine Empirical Probability Plots 
Just get empirical probability plots for K-F CV and plot together

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from thex_data.data_consts import ROOT_DIR
from mainmodel.helper_plotting import *

multi_w.range_metrics = multi_w.compute_probability_range_metrics(
        multi_w.results, bin_size=0.2)
multi_wo.range_metrics = multi_wo.compute_probability_range_metrics(
        multi_wo.results, bin_size=0.2)


rc('text', usetex=True) 
mpl.rcParams['font.serif'] = ['times', 'times new roman']
mpl.rcParams['font.family'] = 'serif'

# class_labels = multi_w.class_labels
# multi_w.class_labels = ["Unspecified Ia", "Ia-91bg", "Ibc", "Unspecified II",  "TDE" ]
class_labels =ORDERED_LSST_CLASSES # ['Unspecified Ia', 'Ia-91bg', 'Ibc', 'II (cust.)',  'SLSN-I', 'TDE']
num_classes = len(multi_w.class_labels) 
f, ax = plt.subplots(nrows=num_classes,
                     ncols=2,
                     sharex=True, sharey=True,
                     figsize=(FIG_WIDTH, 9),
                     dpi=DPI)
plot_index = 0
for class_index in range(len(class_labels)):
    if plot_index == 0:
        # Add titles to top of plots
        ax[plot_index][0].set_title("Uniform Priors", fontsize=14)
        ax[plot_index][1].set_title("Frequency-based Priors", fontsize=14)
        

    class_name = class_labels[class_index]  
    plot_model_rates(class_name, multi_wo, ax[plot_index][0])
    plot_model_rates(class_name, multi_w, ax[plot_index][1])
    
    pretty_class_name = clean_class_name(class_name)
    ax[plot_index][0].text(-0.2, 0.8, pretty_class_name, fontsize=16) 
    plot_index += 1

y_indices = [0.1, 0.3, 0.5, 0.7, 0.9]
y_ticks = ["10", "30", "50", "70", "90"]
# x and y indices/ticks are the same
plt.xticks(np.arange(5), y_ticks)
plt.yticks(y_indices, y_ticks)
plt.rc('xtick', labelsize=14)
plt.rc('ytick', labelsize=14)

f.text(0.5, 0.06, 'Assigned Probability' + r' $\pm$10\%', fontsize=14, ha='center')
f.text(0.02, .5, r'Empirical Probability $\equiv$ TP/Total ($\%$)',
           fontsize=14, va='center', rotation='vertical')

plt.subplots_adjust(wspace=0, hspace=0)

f.savefig(ROOT_DIR + "/output/custom_figures/merged_metrics_priors_comp.pdf", bbox_inches='tight')
plt.show()

In [None]:
# f, ax = plt.figure(figsize=(4, 4), dpi=200)

f, ax = plt.subplots(nrows=1,
                     ncols=1,
#                      sharex=True, sharey=True,
                     figsize=(4,4),
                     dpi=280)

multi_range_metrics = multi_w.compute_probability_range_metrics(
        multi_w.results)
mirror_ax = plot_model_curves(
            "TDE", multi_w, multi_range_metrics, ax)

In [None]:
# for class_index, class_name in enumerate(multi_w.class_labels): 
class_name = "TDE"
results = np.concatenate(multi_w.results)
label_index = len(multi_w.class_labels)
TP_counts = [0,0,0,0,0,0,0,0,0,0]
total_counts = [0,0,0,0,0,0,0,0,0,0]
for row in results:
    labels = row[label_index]
    is_class = multi_w.is_class(class_name, labels)

    # Get class index of max prob; exclude last column since it is label
    max_class_prob = np.max(row[: len(row) - 1])
    max_class_index = np.argmax(row[: len(row) - 1])
    max_class_name = multi_w.class_labels[max_class_index]
    
    bins = np.arange(0, 1.01, 0.1)
    counts, ranges = np.histogram([max_class_prob], bins=bins) 

    if max_class_name == "TDE":
        total_counts=np.add(counts, total_counts)
        if is_class:
            TP_counts=np.add(counts, TP_counts) 

## Example output plots

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mainmodel.helper_compute import *
from thex_data.data_consts import *
import utilities.utilities as thex_utils

model = MultiModel(folds=3,
                   min_class_size = 40,   
                   priors = True,
                   transform_features = True,
                   cols = mags, 
                   lsst_test=True)  
model.run_model()

Want to find 16 quality examples

8: good & correct examples with wide dist
    correct examples for II P, IIn, TDE, GRB, Ib (unspec.)

4: not 'good' but correct examples (high prob on Ia or something)

4: incorrect, but benefit from empirical prob plots (so look for maybe Ia with probs 20-49%, where Ia IS the correct answer but something else won; or could show 4 GRB plots with probs >80%, and half are correct but half are not- this exemplifies its 50% empriical prob at the range. 



In [None]:
model = MultiModel(folds=3,
                   min_class_size = 40,   
                   priors = True,
                   transform_features = True,
                   cols = mags, 
                   lsst_test=True)  


In [None]:
model.run_model()

In [None]:
def plot_on_ax(axis, model, do_index, preds):
    """
    Plot example on given ax
    """
#     if do_index == -1:
        # Custom sample ; put Ia-91bg prob at range 
    row = preds[do_index] 
    labels = row[len(row) - 1]
    true_class_index = None

    for class_index, class_name in enumerate(model.class_labels):
        if class_name in thex_utils.convert_str_to_list(labels):
            true_class_index = class_index


    ACC = "#ccccff"# actual class color, light blue
    DCC = "#0000cc" # default class color, dark blue

    colors = [DCC] * len(model.class_labels)
    colors[true_class_index] = ACC
    probabilities = row[0:len(row) - 1]  
    bar_size = 0.05
    x_indices = np.linspace(0,
                            len(model.class_labels) * bar_size,
                            len(model.class_labels))
    axis.bar(x=x_indices, height=probabilities,  
             width=bar_size, 
             color=colors, edgecolor='black')
    axis.tick_params(axis="y",direction="in")
    axis.set_ylim([0, 1]) 
    
    



In [None]:
preds = np.concatenate(model.results)

In [None]:
#  use following to find good samples.
use_indices = []
USE_CLASS = "Unspecified Ia"
for index, row in enumerate(preds):
    labels = row[len(row) - 1]
    if USE_CLASS in labels:
        if row[0] > 0.5 and row[0] <0.65:
            use_indices.append(index)

for do_index in use_indices:
    f, ax = plt.subplots(nrows=1, ncols=1,  figsize=(4,2), dpi=DPI)
    plot_on_ax(ax, model, do_index, do_index, preds) 
    bar_size = 0.05
    x_indices = np.linspace(0, len(model.class_labels) * bar_size, len(model.class_labels))
    ax.set_xticks(ticks=x_indices)
    xticksize = 11
    pretty_class_names = clean_class_names(model.class_labels) 
    ax.set_xticklabels(labels=pretty_class_names, fontsize=xticksize, rotation = -90)


    plt.savefig(model.dir + "/examples/"+USE_CLASS+"/sample_" + str(do_index) + ".pdf", 
                bbox_inches='tight')

In [None]:
indices = [[2964, 2965, 6032, 9136],  # Good , 9140
          [2400, 2394, 9103, 9140], # good too 
          [2974, 3031, 2378, 2461]]  # wrong, low probs
#           [xx, xx, xx, xx]] #wrong but empirical probs would help a lot.
rows, cols = np.array(indices).shape

f, ax = plt.subplots(nrows=rows,
                     ncols=cols,
                     sharex=True, sharey=True,
                     figsize=(6,5),
                     dpi=DPI)
do_index = 0
example_num = 1
for row_index, index_set in enumerate(indices): 
    cur_ax = ax[row_index]
    for plot_index, do_index in enumerate(index_set): 
        plot_on_ax(cur_ax[plot_index], model, do_index, preds)
        if str(example_num) == "1":
            cur_ax[plot_index].text(.17,.8,  str(example_num), style='italic',fontsize=15)
        else:
            cur_ax[plot_index].text(.23,.8,  str(example_num), style='italic',fontsize=15)
        example_num+=1
        
        if plot_index==0:
            yticks = np.arange(0,1.2,.2)
            cur_ax[plot_index].set_yticks(ticks=yticks)
            cur_ax[plot_index].set_yticklabels(labels=[str(int(i*100)) for i in yticks],
                                               fontsize=12)


f.text(.04, 0.4, 'Probability (\%)', fontsize=14, ha='center', rotation =90)
bar_size = 0.05
x_indices = np.linspace(0, len(model.class_labels) * bar_size,
                            len(model.class_labels))


pretty_class_names = clean_class_names(model.class_labels) 
xticksize = 13
for i in range(cols):
    ax[rows-1][i].set_xticks(ticks=x_indices)
    ax[rows-1][i].set_xticklabels(labels=pretty_class_names, fontsize=xticksize, rotation = -90)
plt.subplots_adjust(wspace=0, hspace=0, left=0.1)

plt.savefig("../output/custom_figures/lsst_examples.pdf", bbox_inches='tight')
plt.show()




In [None]:
import pickle
with open('../output/custom_figures/full_model_data.pickle', 'wb') as handle:
    pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Archive
### Probability purity/completeness plots 
Old way of using normal purity and completeness

In [None]:


def plot_model_curves(class_name, model, ax):
    """
    Plots rates for this model/class on axis, with annotations
    """
    purities, comps = get_pc_per_range(model, class_name)

    def plot_axis(ax, data, color):
        """
        Plot data on axis in certain color
        """
        x_indices = np.linspace(0, 1, 11)[:-1]

        
        print("Data: " + str(data))
        # Do not plot points whose data is 0; so that we may distinguish between true 0 purity and having no data. 
        # If both purity and completeness are 0 we do not plot.
        total_range_counts=model.range_metrics[class_name][1]
        keep_indices = []
        keep_data = []
        for i, t in enumerate(total_range_counts):
            if t != 0:
                keep_indices.append(x_indices[i])
                keep_data.append(data[i])
                
        ax.scatter(keep_indices, keep_data, color=color, s=4)
        ax.plot(keep_indices, keep_data, color=color, linewidth=2)
        ax.set_yticks([])  # same for y ticks
        ax.set_ylim([0, 1])

    print("\n\n P-C metrics for : " + class_name)
    plot_axis(ax, comps, C_BAR_COLOR)
    ax2 = ax.twinx()  # instantiate a second axes that shares the same x-axis
    ax2.set_ylim([0, 1])
    plot_axis(ax2, purities, P_BAR_COLOR)
    for axis in ['top', 'bottom', 'left', 'right']:
        ax.spines[axis].set_linewidth(1.5)
    return ax2


def plot_pc_curves_together(binary_model, ova_model, multi_model, indices=None):
    """
    Plot class versus probability rates of all three classifiers together
    :param indices: class indices to plot
    """
    class_labels = ova_model.class_labels
    num_classes = len(ova_model.class_labels)
    if indices is not None:
        num_classes = len(indices)
    f, ax = plt.subplots(nrows=num_classes,
                         ncols=3,
                         sharex=True, sharey=True,
                         figsize=(FIG_WIDTH, 10),
                         dpi=DPI)

    y_indices = [0, 0.2, 0.4, 0.6, 0.8, 1]
    y_ticks = ["0", "20", "40", "60", "80", ""]

    plot_index = 0
    for class_index in range(len(class_labels)):
        if indices is not None and class_index not in indices:
            continue

        if plot_index == 0:
            # Add titles to top of plots
            ax[plot_index][0].set_title("Binary", fontsize=TICK_S)
            ax[plot_index][1].set_title("OVA", fontsize=TICK_S)
            ax[plot_index][2].set_title("Multi", fontsize=TICK_S)

        class_name = class_labels[class_index]
        print("Binary model")
        plot_model_curves(class_name, binary_model, ax[plot_index][0])
        print("OVA model")
        plot_model_curves(class_name, ova_model, ax[plot_index][1])
        print("KDE Multi model")
        mirror_ax = plot_model_curves(class_name, multi_model, ax[plot_index][2])

        ax[plot_index][0].set_yticks(ticks=y_indices)
        ax[plot_index][0].set_yticklabels(labels=y_ticks, color=P_BAR_COLOR)
        mirror_ax.set_yticks(ticks=y_indices)
        mirror_ax.set_yticklabels(labels=y_ticks, color=C_BAR_COLOR)
        ax[plot_index][0].tick_params(axis='both', direction='in', labelsize=10)
        ax[plot_index][1].tick_params(axis='both', direction='in')
        ax[plot_index][2].tick_params(axis='both', direction='in', labelsize=10)

        mpl.rcParams['font.serif'] = ['times', 'times new roman']
        mpl.rcParams['font.family'] = 'serif'
        pretty_class_name = clean_class_name(class_name)
        ax[plot_index][0].text(0, 0.85, pretty_class_name, fontsize=14)
        plot_index += 1

    x_indices = np.linspace(0, 1, 11)[:-1]

    plt.xticks(x_indices, ["", "10", "", "30", "", "50", "", "70", "", "90"])
    rc('text', usetex=True)
    f.text(0.5, 0.08, r'Probability $\geq$X\%', fontsize=TICK_S, ha='center')
    f.text(0.03, .5, 'Purity (\%)',
           fontsize=TICK_S, va='center', rotation='vertical', color=P_BAR_COLOR)
    f.text(0.98, .5, 'Completeness (\%)',
           fontsize=TICK_S, va='center', rotation='vertical', color=C_BAR_COLOR)

    plt.subplots_adjust(wspace=0, hspace=0)

    f.savefig("../output/custom_figures/merged_pc_curves_" +
              str(indices) + ".pdf", bbox_inches='tight')
    plt.show()
    
    


ova_model.range_metrics = ova_model.compute_probability_range_metrics(
        ova_model.results, bin_size=0.1)
binary_model.range_metrics = binary_model.compute_probability_range_metrics(
        binary_model.results, bin_size=0.1)
multi_model.range_metrics = multi_model.compute_probability_range_metrics(
        multi_model.results, bin_size=0.1)

plot_pc_curves_together(binary_model, ova_model, multi_model, indices=[0,1,2,3,4,5])
plot_pc_curves_together(binary_model, ova_model, multi_model, indices=[6,7,8,9,10,11])
plot_pc_curves_together(binary_model, ova_model, multi_model, indices=[0,7,8,9,10,11])
