In [None]:
import warnings
warnings.filterwarnings('ignore')
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=DeprecationWarning)
%matplotlib inline  
from models.binary_model.binary_model import BinaryModel
from models.ind_model.ind_model import IndModel
from models.multi_model.multi_model import MultiModel
 
from mainmodel.helper_plotting import *

exp_dir = "/Users/marina/Documents/PhD/research/astro_research/experiments/" 
mags = ["g_mag",  "r_mag", "i_mag", "z_mag", "y_mag",
        "W1_mag", "W2_mag",
        "J_mag", "K_mag", "H_mag"]




## Combine curves/rates

In [None]:

# Multi 103
model = MultiModel(folds = 40, min_class_size = 40,  
                   transform_features = True, cols = mags) 
multi_model = load_prev_exp(exp_dir, "103", model)

# Binary 104
model2 = BinaryModel(folds = 40, min_class_size = 40,  
                     transform_features = True, cols = mags) 
binary_model = load_prev_exp(exp_dir, "104/Binary_Classifiers1", model2)


# Ensemble 108
model3 = IndModel(folds = 40, min_class_size = 40,  
                   transform_features = True, cols = mags) 
ensemble_model = load_prev_exp(exp_dir, "108/Ensemble_Classifier1", model3)




## GOOD PERFORMING
# Ia, Ia-91bg, II, II P, TDE, GRB

# maybe Ia Pec (for binary and ensemble) 

wanted_classes = ["Unspecified Ia", "Ia-91bg", "Unspecified II", "II P", "TDE", "GRB"]
indices= []
for index, cn in enumerate(ensemble_model.class_labels):
    if cn in wanted_classes:
        indices.append(index)
indices

In [None]:
plot_rates_together(binary_model, ensemble_model, multi_model, indices=[0,1,2,3,4,5,6])
plot_rates_together(binary_model, ensemble_model, multi_model, indices=[7,8,9,10,11,12,13])

In [None]:
plot_rates_together(binary_model, ensemble_model, multi_model, indices=indices)

In [None]:
plot_pc_curves_together(binary_model, ensemble_model, multi_model, indices=indices)

In [None]:
plot_pc_curves_together(binary_model, ensemble_model, multi_model, indices=[0,1,2,3,4,5,6])
plot_pc_curves_together(binary_model, ensemble_model, multi_model, indices=[7,8,9,10,11,12,13])


## Combine purity/comp plots

In [None]:

# Rerun performance visualizations on saved output of model. 
def load_prev_exp(expnum, model):
    pickle_dir = "/Users/marina/Documents/PhD/research/astro_research/experiments/"+ expnum + "/"

    # with open(pickle_dir + 'density_results.pickle', 'rb') as handle:
    #     density_results = pickle.load(handle)  

    with open(pickle_dir + 'results.pickle', 'rb') as handle:
        results = pickle.load(handle)
    model.results = results

    with open(pickle_dir + 'y.pickle', 'rb') as handle:
        y = pickle.load(handle)    
    model.y = y
    return model


model = MultiModel(folds = 40,
                   min_class_size = 40,  
                   class_labels = ['Unspecified Ia', 'Unspecified II', 'Ia-91bg', 'TDE'],
                   priors = [0.65, 0.36, 0.01, 0.005],
                   transform_features = True,
                   cols = mags)  
multi_w = load_prev_exp(expnum="106/Multiclass_Classifier13",model=model)


model = MultiModel(folds = 40,
                   min_class_size = 40,  
                   class_labels = ['Unspecified Ia', 'Unspecified II', 'Ia-91bg', 'TDE'], 
                   transform_features = True,
                   cols = mags)  
multi_wo = load_prev_exp(expnum="107/Multiclass_Classifier15",model=model)


In [None]:

from mainmodel.helper_compute import *
from thex_data.data_consts import * 

def get_mets_for_model(model, met_type):
    
    pc_per_trial = model.get_pc_per_trial(model.results)
    ps, cs = model.get_avg_pc(pc_per_trial, model.num_folds)
    c_baselines, p_baselines = compute_baselines(
            model.class_counts, model.class_labels, model.y, len(model.class_labels), model.class_priors)

    p_intvls, c_intvls = compute_confintvls(pc_per_trial, model.class_labels)
    if met_type == "Purity":
#         list(d.values())
        return model.class_labels, list(ps.values()), list(p_baselines.values()), list(p_intvls.values())
    else:
        return model.class_labels, list(cs.values()), list(c_baselines.values()), list(c_intvls.values())
         

def plot_together(m1, m2, m1_name, m2_name, plot_type = "Purity"):
    """
    Plot the metrics of 2 models side by side
    """
    m1_class_names, m1_metrics, m1_b, m1_intvls  = get_mets_for_model(model=m1, met_type=plot_type)
    m2_class_names, m2_metrics, m2_b, m2_intvls  = get_mets_for_model(model=m2, met_type=plot_type)
    fig, ax = plt.subplots(figsize=(FIG_WIDTH, FIG_HEIGHT), dpi=200,
                           tight_layout=True, sharex=True, sharey=True)

 
    BLUE = "#1f77b4"
    GREEN = "#00ffbf"
    def plot_m(ax, indices, errs, baselines, metrics, name, color):
        bar_width=0.1
        ax.barh(y=indices, 
                width=metrics, 
                height=bar_width, 
                xerr=errs,
                capsize=2, 
                edgecolor='black', ecolor='coral', 
                color=color,
                label=name)
        for index, baseline in enumerate(baselines):
            y_val = indices[index]
            plt.vlines(x=baseline,
                       ymin=y_val - (bar_width / 2),
                       ymax=y_val + (bar_width / 2),
                       linestyles='--', colors='red')
    m1_errs = prep_err_bars(m1_intvls, m1_metrics) 
    m1_indices =[0.2, 0.42, 0.64, 0.86]
    # with priors
    plot_m(ax=ax, indices=m1_indices, errs=m1_errs, baselines=m1_b, 
           metrics=m1_metrics, color=GREEN, name=m1_name)
    m2_errs = prep_err_bars(m2_intvls, m2_metrics)
    plot_m(ax=ax, indices=[0.1, 0.32, 0.54, 0.76], errs=m2_errs, baselines=m2_b, 
           metrics=m2_metrics, color=BLUE, name=m2_name)

    # Figure formatting
    ax.set_xlim(0, 1)
    plt.legend(fontsize=LAB_S)
    plt.xticks(list(np.linspace(0, 1, 11)), [
                       str(tick) + "%" for tick in list(range(0, 110, 10))], fontsize=TICK_S)
    plt.yticks(np.array(m1_indices) - 0.05, clean_class_names(m1_class_names),  fontsize=TICK_S,
                       horizontalalignment='right')
    plt.ylabel('Transient Class', fontsize=LAB_S)
    plt.xlabel(plot_type, fontsize=LAB_S)  
    plt.title(plot_type, fontsize=TITLE_S)
    plt.savefig("../output/custom_figures/combined_" + plot_type + ".pdf")




In [None]:
plot_together(m1=multi_w, m2=multi_wo, m1_name="With priors", m2_name="Without priors", plot_type = "Purity")

plot_together(m1=multi_w, m2=multi_wo, m1_name="With priors", m2_name="Without priors", plot_type = "Completeness")
