## Functionality to combine visualizations

### Combine curves/rates

In [None]:
import warnings
warnings.filterwarnings('ignore')
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=DeprecationWarning)
%matplotlib inline  
from models.binary_model.binary_model import BinaryModel
from models.ind_model.ind_model import OvAModel
from models.multi_model.multi_model import MultiModel

from mainmodel.helper_plotting import *

mags = ["g_mag",  "r_mag", "i_mag", "z_mag", "y_mag",
        "W1_mag", "W2_mag",
        "J_mag", "K_mag", "H_mag"]


exp_dir = "/Users/marina/Documents/PhD/research/astro_research/experiments/paper_set/"

In [None]:

  

# Multi 
multi_model = MultiModel(folds = 10, min_class_size = 40,  
                   transform_features = True, cols = mags) 
multi_model=load_prev_exp(exp_dir, "multi/Multiclass_Classifier1/", multi_model)




In [None]:

# Binary
model2 = BinaryModel(folds = 10, min_class_size = 40,  
                     transform_features = True, cols = mags) 
binary_model = load_prev_exp(exp_dir, "binary/Binary_Classifiers2/", model2)


In [None]:
# OvA 
ensemble_model = OvAModel(folds = 10, 
                  init_from_binary = binary_model,
                  min_class_size = 40,  
                   transform_features = True, 
                  cols = mags) 
import time
ensemble_model.results = ensemble_model.run_cfv(time.time())
# ensemble_model.range_metrics = ensemble_model.compute_probability_range_metrics(
#         ensemble_model.results, bin_size=0.2)
# ensemble_model.range_metrics_10 = ensemble_model.compute_probability_range_metrics(
#         ensemble_model.results, bin_size=0.1)

In [None]:
ensemble_model.visualize_performance()

## Plot prob metrics

In [None]:
# import matplotlib
# matplotlib.font_manager.fontManager.ttflist

In [None]:

m = ensemble_model

# Replot metrics
N = m.num_runs if m.num_runs is not None else m.num_folds
pc_per_trial = m.get_pc_per_trial(m.results)
ps, cs = m.get_avg_pc(pc_per_trial, N)

m.plot_all_metrics(ps, cs, pc_per_trial, m.y)

## Plot prob rates

In [None]:
# binary_model.class_prob_rates["Unspecified II"]
model = multi_model
for class_name in model.class_prob_rates.keys():
    rates=model.class_prob_rates[class_name]
    a=(rates[0]+rates[1])/2
    b=(rates[2]+rates[3])/2
    c=(rates[4]+rates[5])/2
    d=(rates[6]+rates[7])/2
    e=(rates[8]+rates[9])/2
    new_rates = [a,b,c,d,e]
    model.class_prob_rates[class_name] = new_rates

In [None]:
plot_rates_together(binary_model, ensemble_model, multi_model, indices=[0,1,2,3,4,5,6])
plot_rates_together(binary_model, ensemble_model, multi_model, indices=[7,8,9,10,11,12,13])

In [None]:
## GOOD PERFORMING
# Ia, Ia-91bg, II, II P, TDE, GRB

wanted_classes = ["Unspecified Ia", "Ia-91bg", "Unspecified II", "II P", "TDE", "GRB"]
indices= []
for index, cn in enumerate(ensemble_model.class_labels):
    if cn in wanted_classes:
        indices.append(index)
indices

plot_rates_together(binary_model, ensemble_model, multi_model, indices=indices)

In [None]:
plot_pc_curves_together(binary_model, ensemble_model, multi_model, indices=indices)

In [None]:
plot_pc_curves_together(binary_model, ensemble_model, multi_model, indices=[0,1,2,3,4,5,6])
plot_pc_curves_together(binary_model, ensemble_model, multi_model, indices=[7,8,9,10,11,12,13])


## Combine purity/comp plots

In [None]:

from mainmodel.helper_compute import *
from thex_data.data_consts import * 

def get_mets_for_model(model, met_type):
    
    pc_per_trial = model.get_pc_per_trial(model.results)
    ps, cs = model.get_avg_pc(pc_per_trial, model.num_folds)
    c_baselines, p_baselines = compute_baselines(
            model.class_counts, model.class_labels, model.y, len(model.class_labels), model.class_priors)

    p_intvls, c_intvls = compute_confintvls(pc_per_trial, model.class_labels)
    if met_type == "Purity":
#         list(d.values())
        return model.class_labels, list(ps.values()), list(p_baselines.values()), list(p_intvls.values())
    else:
        return model.class_labels, list(cs.values()), list(c_baselines.values()), list(c_intvls.values())
         

def plot_together(m1, m2, m1_name, m2_name, plot_type = "Purity"):
    """
    Plot the metrics of 2 models side by side
    """
    m1_class_names, m1_metrics, m1_b, m1_intvls  = get_mets_for_model(model=m1, met_type=plot_type)
    m2_class_names, m2_metrics, m2_b, m2_intvls  = get_mets_for_model(model=m2, met_type=plot_type)
    fig, ax = plt.subplots(figsize=(FIG_WIDTH, FIG_HEIGHT), dpi=200,
                           tight_layout=True, sharex=True, sharey=True)

 
    BLUE = "#1f77b4"
    GREEN = "#00ffbf"
    def plot_m(ax, indices, errs, baselines, metrics, name, color):
        bar_width=0.1
        ax.barh(y=indices, 
                width=metrics, 
                height=bar_width, 
                xerr=errs,
                capsize=2, 
                edgecolor='black', ecolor='coral', 
                color=color,
                label=name)
        for index, baseline in enumerate(baselines):
            y_val = indices[index]
            plt.vlines(x=baseline,
                       ymin=y_val - (bar_width / 2),
                       ymax=y_val + (bar_width / 2),
                       linestyles='--', colors='red')
    m1_errs = prep_err_bars(m1_intvls, m1_metrics) 
    m1_indices =[0.2, 0.42, 0.64, 0.86]
    # with priors
    plot_m(ax=ax, indices=m1_indices, errs=m1_errs, baselines=m1_b, 
           metrics=m1_metrics, color=GREEN, name=m1_name)
    m2_errs = prep_err_bars(m2_intvls, m2_metrics)
    plot_m(ax=ax, indices=[0.1, 0.32, 0.54, 0.76], errs=m2_errs, baselines=m2_b, 
           metrics=m2_metrics, color=BLUE, name=m2_name)

    # Figure formatting
    ax.set_xlim(0, 1)
    plt.legend(fontsize=LAB_S)
    plt.xticks(list(np.linspace(0, 1, 11)), [
                       str(tick) + "%" for tick in list(range(0, 110, 10))], fontsize=TICK_S)
    plt.yticks(np.array(m1_indices) - 0.05, clean_class_names(m1_class_names),  fontsize=TICK_S+3,
                       horizontalalignment='right')
#     plt.ylabel('Transient Class', fontsize=LAB_S)
    plt.xlabel(plot_type, fontsize=LAB_S+1)  
#     plt.title(plot_type, fontsize=TITLE_S)
    plt.savefig("../output/custom_figures/prior_comp_combined_" + plot_type + ".pdf")




In [None]:
def load_prev_exp(expnum, model):
    exp_dir = "/Users/marina/Documents/PhD/research/astro_research/experiments/paper_set/" 
    pickle_dir = exp_dir + expnum + "/"

    with open(pickle_dir + 'results.pickle', 'rb') as handle:
        results = pickle.load(handle)
    model.results = results

    with open(pickle_dir + 'y.pickle', 'rb') as handle:
        y = pickle.load(handle)
    model.y = y
    return model

model = MultiModel(folds = 10,
                   min_class_size = 40,  
                   class_labels = ['Unspecified Ia', 'Unspecified II', 'Ia-91bg', 'TDE'],
                   priors = [0.65, 0.36, 0.01, 0.005],
                   transform_features = True,
                   cols = mags)  
multi_w = load_prev_exp(expnum="multi_priors/Multiclass_Classifier2/",model=model)


modelwithout = MultiModel(folds = 10,
                   min_class_size = 40,  
                   class_labels = ['Unspecified Ia', 'Unspecified II', 'Ia-91bg', 'TDE'], 
                   transform_features = True,
                   cols = mags)  
multi_wo = load_prev_exp(expnum="multi_no_priors/Multiclass_Classifier3/",model=modelwithout)


In [None]:
plot_together(m1=multi_w, m2=multi_wo, m1_name="With priors", m2_name="Without priors", plot_type = "Purity")

plot_together(m1=multi_w, m2=multi_wo, m1_name="With priors", m2_name="Without priors", plot_type = "Completeness")


In [None]:
multi_w.range_metrics = multi_w.compute_probability_range_metrics(
        multi_w.results, bin_size=0.2)
multi_wo.range_metrics = multi_wo.compute_probability_range_metrics(
        multi_wo.results, bin_size=0.2)

In [None]:
num_classes

In [None]:
multi_w.class_labels

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from thex_data.data_consts import ROOT_DIR
from mainmodel.helper_plotting import *

class_labels = multi_w.class_labels
num_classes = len(multi_w.class_labels) 
f, ax = plt.subplots(nrows=num_classes,
                     ncols=2,
                     sharex=True, sharey=True,
                     figsize=(FIG_WIDTH, 8),
                     dpi=DPI)
plot_index = 0
for class_index in range(len(class_labels)):
    if plot_index == 0:
        # Add titles to top of plots
        ax[plot_index][0].set_title("With priors", fontsize=11)
        ax[plot_index][1].set_title("Without priors", fontsize=11)

    class_name = class_labels[class_index]
    print(class_name)
    plot_model_rates(class_name, multi_w, ax[plot_index][0])
    plot_model_rates(class_name, multi_wo, ax[plot_index][1])

    pretty_class_name = clean_class_name(class_name)
    ax[plot_index][0].set_ylabel(pretty_class_name, fontsize=9)
    plot_index += 1

y_indices = [0.1, 0.3, 0.5, 0.7, 0.9]
y_ticks = ["10%", "30%", "50%", "70%", "90%"]
# x and y indices/ticks are the same
plt.xticks(np.arange(5), y_ticks)
plt.yticks(y_indices, y_ticks)
plt.rc('xtick', labelsize=7)
plt.rc('ytick', labelsize=7)

f.text(0.5, 0.08, 'Assigned Probability' + r' $\pm$10%', fontsize=10, ha='center')
f.text(0.0, .5, 'Empirical Probability',
       fontsize=10, va='center', rotation='vertical')

plt.subplots_adjust(wspace=0, hspace=0.1)

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']

f.savefig(ROOT_DIR + "/output/custom_figures/merged_metrics_priors_comp.pdf", bbox_inches='tight')
plt.show()