In [1]:
import numpy as np
import h5py
import os, sys
import argparse
import matplotlib
matplotlib.rc('xtick', labelsize=20)
matplotlib.rc('ytick', labelsize=20)
matplotlib.rcParams['agg.path.chunksize'] = 10000
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score, recall_score
from sklearn.metrics import confusion_matrix
from PlottingFunctions import plot_bin_slices

In [2]:
from sklearn.metrics import confusion_matrix

In [None]:
def find_contours_2D(x_values,y_values,xbins,weights=None,c1=16,c2=84):
    """
    Find upper and lower contours and median
    x_values = array, input for hist2d for x axis (typically truth)
    y_values = array, input for hist2d for y axis (typically reconstruction)
    xbins = values for the starting edge of the x bins (output from hist2d)
    c1 = percentage for lower contour bound (16% - 84% means a 68% band, so c1 = 16)
    c2 = percentage for upper contour bound (16% - 84% means a 68% band, so c2=84)
    Returns:
        x = values for xbins, repeated for plotting (i.e. [0,0,1,1,2,2,...]
        y_median = values for y value medians per bin, repeated for plotting (i.e. [40,40,20,20,50,50,...]
        y_lower = values for y value lower limits per bin, repeated for plotting (i.e. [30,30,10,10,20,20,...]
        y_upper = values for y value upper limits per bin, repeated for plotting (i.e. [50,50,40,40,60,60,...]
    """
    if weights is not None:
        import wquantiles as wq
    y_values = numpy.array(y_values)
    indices = numpy.digitize(x_values,xbins)
    r1_save = []
    r2_save = []
    median_save = []
    for i in range(1,len(xbins)):
        mask = indices==i
        if len(y_values[mask])>0:
            if weights is None:
                r1, m, r2 = numpy.percentile(y_values[mask],[c1,50,c2])
            else:
                r1 = wq.quantile(y_values[mask],weights[mask],c1/100.)
                r2 = wq.quantile(y_values[mask],weights[mask],c2/100.)
                m = wq.median(y_values[mask],weights[mask])                
        else:
            #print(i,'empty bin')
            r1 = numpy.nan
            m = numpy.nan
            r2 = numpy.nan
        median_save.append(m)
        r1_save.append(r1)
        r2_save.append(r2)
    median = numpy.array(median_save)
    lower = numpy.array(r1_save)
    upper = numpy.array(r2_save)

    x = list(itertools.chain(*zip(xbins[:-1],xbins[1:])))
    y_median = list(itertools.chain(*zip(median,median)))
    y_lower = list(itertools.chain(*zip(lower,lower)))
    y_upper = list(itertools.chain(*zip(upper,upper)))

    return x, y_median, y_lower, y_upper

In [None]:
def plot_2D_prediction(truth, nn_reco, \
                        save=False,savefolder=None,weights=None,syst_set="",\
                        bins=60,minval=None,maxval=None, switch_axis=False,\
                        cut_truth = False, axis_square =False,
                        zmin = None, zmax=None,log=True,
                        variable="Energy", units = "(GeV)", epochs=None,\
                        flavor="NuMu", sample=None,\
                        variable_type="True", reco_name="CNN",new_labels=None,
                        new_units=None,save_name=None,no_contours=False,
                        xline=None,yline=None):
    """
    Plot testing set reconstruction vs truth
    Recieves:
        truth = array, Y_test truth
        nn_reco = array, neural network prediction output
        save = optional, bool to save plot
        savefolder = optional, output folder to save to, if not in current dir
        syst_set = string, name of the systematic set (for title and saving)
        bins = int, number of bins plot (will use for both the x and y direction)
        minval = float, minimum value to cut nn_reco results
        maxval = float, maximum value to cut nn_reco results
        cut_truth = bool, true if you want to make the value cut on truth rather than nn results
        axis_square = bool, cut axis to be square based on minval and maxval inputs
        variable = string, name of the variable you are plotting
        units = string, units for the variable you are plotting
    Returns:
        2D plot of True vs Reco
    """

    maxplotline = min([max(nn_reco),max(truth)])
    minplotline = max([min(nn_reco),min(truth)])

    truth = truth #[mask]
    nn_reco = nn_reco #[mask]

        #Cut axis
    if axis_square:
        xmin = minval
        ymin = minval
        xmax = maxval
        ymax = maxval
    else:
        xmin = min(truth)
        ymin = min(nn_reco)
        xmax = max(truth)
        ymax = max(nn_reco)
    if switch_axis:
        xmin, ymin = ymin, xmin
        xmax, ymax = ymax, xmax


    if weights is None:
        cmin = 1
    else:
        cmin = 1e-12
    if zmin is not None:
        cmin = zmin

    plt.figure(figsize=(10,7))
    if log:
        cts,xbin,ybin,img = plt.hist2d(truth, nn_reco, bins=bins,range=[[xmin,xmax],[ymin,ymax]],cmap='viridis_r', norm=colors.LogNorm(), weights=weights, cmax=zmax, cmin=cmin)
    else:
        cts,xbin,ybin,img = plt.hist2d(truth, nn_reco, bins=bins,range=[[xmin,xmax],[ymin,ymax]],cmap='viridis_r', weights=weights, cmax=zmax, cmin=cmin)
    cbar = plt.colorbar()
    if weights is None:
        cbar.ax.set_ylabel('counts', rotation=90)
    else:
        cbar.ax.set_ylabel('Rate (Hz)', rotation=90)
    plt.xlabel("%s %s %s"%(variable_type,variable,units),fontsize=20)
    plt.ylabel("%s Reconstructed %s %s"%(reco_name,variable,units),fontsize=20)

    #NAMING
    title = "%s vs %s for %s %s"%(reco_name,variable_type,variable,syst_set)
    if flavor == "NuMu" or flavor == "numu":
        title += r' for $\nu_\mu$ '
    elif flavor == "NuE" or flavor == "nue":
        title += r' for $\nu_e$ '
    elif flavor == "NuTau" or flavor == "nutau":
        title += r' for $\nu_\tau$ '
    elif flavor == "Mu" or flavor == "mu":
        title += r' for $\mu$ '
    elif flavor == "Nu" or flavor == "nu":
        title += r' for $\nu$ '
    else:
        title += flavor
    if sample is not None:
        title += sample
    if epochs:
        title += " at %i Epochs"%epochs
    plt.suptitle(title,fontsize=25)

    #Plot 1:1 line
    if axis_square:
        plt.plot([minval,maxval],[minval,maxval],'k:',label="1:1")
    else:
        plt.plot([minplotline,maxplotline],[minplotline,maxplotline],'k:',label="1:1")

    if switch_axis:
        x, y, y_l, y_u = find_contours_2D(nn_reco,truth,xbin,weights=weights)
    else:
        x, y, y_l, y_u = find_contours_2D(truth,nn_reco,xbin,weights=weights)

    if not no_contours:
        plt.plot(x,y,color='r',label='Median')
        plt.plot(x,y_l,color='r',label='68% band',linestyle='dashed')
        plt.plot(x,y_u,color='r',linestyle='dashed')
        plt.legend(fontsize=20)
    if yline is not None:
        if type(yline) is list:
            for y_val in yline:
                plt.axhline(y_val,linewidth=3,color='red',label="Cut")
        else:
            plt.axhline(yline,linewidth=3,color='red',label="Cut")
        plt.legend(fontsize=20)
    if xline is not None:
        if type(xline) is list:
            for x_val in xline:
                plt.axvline(x_val,linewidth=3,color='magenta',linestyle="dashed",label="Cut")
        else:
            plt.axvline(xline,linewidth=3,color='magenta',linestyle="dashed",label="Cut")
        plt.legend(fontsize=20)


In [1]:
def plot_bin_slices(truth, nn_reco, energy_truth=None, weights=None,\
                       use_fraction = False, old_reco=None,old_reco_truth=None,\
                       reco_energy_truth=None,old_reco_weights=None,\
                       bins=10,min_val=0.,max_val=60., ylim = None,\
                       save=False,savefolder=None,vs_predict=False,\
                       flavor="NuMu", sample="CC",style="contours",\
                       variable="Energy",units="(GeV)",xlog=False,
                       xvariable="Energy",xunits="(GeV)",notebook=False,
                       specific_bins = None,xline=None,xline_name="DeepCore",
                       epochs=None,reco_name="Retro",cnn_name="CNN",legend="upper center"):
    """Plots different variable slices vs each other (systematic set arrays)
    Receives:
        truth= array with truth labels for this one variable
        nn_reco = array that has NN predicted reco results for one variable (same size of truth)
        energy_truth = optional (will use if given), array that has true energy information (same size of truth)
        use_fraction = bool, use fractional resolution instead of absolute, where (reco - truth)/truth
        old_reco = optional (will use if given), array of pegleg labels for one variable
        bins = integer number of data points you want (range/bins = width)
        min_val = minimum value for variable to start cut at (default = 0.)
        max_val = maximum value for variable to end cut at (default = 60.)
        ylim = List with two entries of ymin and ymax for plot [min, max], leave as None for no ylim applied
        style= contours or errorbars
    Returns:
        Scatter plot with energy values on x axis (median of bin width)
        y axis has median of resolution with error bars containing 68% of resolution
    """
    reco_weights = old_reco_weights
    if weights is not None:
        import wquantiles as wq
        if reco_weights is None:
            reco_weights = numpy.array(weights)

    nn_reco = numpy.array(nn_reco)
    truth = numpy.array(truth)
     ## Assume old_reco truth is the same as test sample, option to set it otherwise
    if old_reco_truth is None:
        truth2 = numpy.array(truth)
    else:
        truth2 = numpy.array(old_reco_truth)
    if reco_energy_truth is None:
        energy_truth2 = numpy.array(energy_truth)
    else:
        energy_truth2 = numpy.array(reco_energy_truth)
    #Check nan
    if old_reco is not None:
        is_nan = numpy.isnan(old_reco)
        assert sum(is_nan) == 0, "OLD RECO HAS NAN"
    is_nan = numpy.isnan(nn_reco)
    assert sum(is_nan) == 0, "CNN RECO HAS NAN"

    if use_fraction:
        resolution = ((nn_reco-truth)/truth) # in fraction
    else:
        resolution = (nn_reco-truth)
    resolution = numpy.array(resolution)
    percentile_in_peak = 68.27

    left_tail_percentile  = (100.-percentile_in_peak)/2
    right_tail_percentile = 100.-left_tail_percentile

    if specific_bins is None:
        variable_ranges  = numpy.linspace(min_val,max_val, num=bins+1)
        variable_centers = (variable_ranges[1:] + variable_ranges[:-1])/2.
    else:
        max_val = specific_bins[-1]
        min_val = specific_bins[0]
        variable_ranges = specific_bins
        variable_centers = []
        for i in range(len(specific_bins)-1):
            variable_centers.append(specific_bins[i] + ((specific_bins[i+1] - specific_bins[i])/2.))

    medians  = numpy.zeros(len(variable_centers))
    err_from = numpy.zeros(len(variable_centers))
    err_to   = numpy.zeros(len(variable_centers))

    if old_reco is not None:
        if use_fraction:
            resolution_reco = ((old_reco-truth2)/truth2)
        else:
            resolution_reco = (old_reco-truth2)
        resolution_reco = numpy.array(resolution_reco)
        medians_reco  = numpy.zeros(len(variable_centers))
        err_from_reco = numpy.zeros(len(variable_centers))
        err_to_reco   = numpy.zeros(len(variable_centers))

    for i in range(len(variable_ranges)-1):
        var_from = variable_ranges[i]
        var_to   = variable_ranges[i+1]

        if vs_predict:
            x_axis_array = nn_reco
            x_axis_array2 = old_reco #nn_reco
            title="%s Resolution Dependence"%(variable)
        else:
            if energy_truth is None:
                title="%s Resolution Dependence"%(variable)
                x_axis_array = truth
                x_axis_array2 = truth2
            else:
                title="%s Resolution %s Dependence"%(variable,xvariable)
                energy_truth = numpy.array(energy_truth)
                x_axis_array = energy_truth
                x_axis_array2 = energy_truth2

        cut = (x_axis_array >= var_from) & (x_axis_array < var_to)
        #print("Events in ", var_from, " to ", var_to, sum(cut))
        if old_reco is not None:
            cut2 = (x_axis_array2 >= var_from) & (x_axis_array2 < var_to)
            #print("Events in ", var_from, " to ", var_to, sum(cut2))

        if weights is not None:
            lower_lim = wq.quantile(resolution[cut],weights[cut],left_tail_percentile/100.)
            upper_lim = wq.quantile(resolution[cut],weights[cut], right_tail_percentile/100.)
            median = wq.median(resolution[cut],weights[cut])
        else:
            lower_lim = numpy.percentile(resolution[cut], left_tail_percentile/100.)
            upper_lim = numpy.percentile(resolution[cut], right_tail_percentile/100.)
            median = numpy.percentile(resolution[cut], 0.50)
            
        medians[i] = median
        err_from[i] = lower_lim
        err_to[i] = upper_lim

        if old_reco is not None:
            if reco_weights is not None:
                lower_lim_reco = wq.quantile(resolution_reco[cut2],reco_weights[cut2],left_tail_percentile/100.)
                upper_lim_reco = wq.quantile(resolution_reco[cut2],reco_weights[cut2],right_tail_percentile/100.)
                median_reco = wq.median(resolution_reco[cut2],reco_weights[cut2])
            else:
                lower_lim_reco = numpy.percentile(resolution_reco[cut2], left_tail_percentile/100.)
                upper_lim_reco = numpy.percentile(resolution_reco[cut2], right_tail_percentile/100.)
                median_reco = numpy.percentile(resolution_reco[cut2], 0.50)

            medians_reco[i] = median_reco
            err_from_reco[i] = lower_lim_reco
            err_to_reco[i] = upper_lim_reco

    plt.figure(figsize=(10,7))
    plt.plot([min_val,max_val], [0,0], color='k')
    if style == "errorbars":
        if old_reco is not None:
            (_, caps_reco, _) = plt.errorbar(variable_centers, medians_reco, yerr=[medians_reco-err_from_reco, err_to_reco-medians_reco], xerr=[ variable_centers-variable_ranges[:-1], variable_ranges[1:]-variable_centers ], capsize=3.0, fmt='o',label="%s"%reco_name)
            for cap in caps_reco:
                cap.set_markeredgewidth(5)
        (_, caps, _) = plt.errorbar(variable_centers, medians, yerr=[medians-err_from, err_to-medians], xerr=[ variable_centers-variable_ranges[:-1], variable_ranges[1:]-variable_centers ], capsize=3.0, fmt='o',label=cnn_name)
        for cap in caps:
            cap.set_markeredgewidth(5)
        plt.legend(loc=legend)

        if xline is not None:
            if type(xline) is list:
                for x_val in xline:
                    plt.axvline(x_val,linewidth=3,color='k',linestyle="dashed",label="%s"%xline_name)
            else:
                plt.axvline(xline,linewidth=3,color='k',linestyle="dashed",label="%s"%xline_name)

            plt.legend(loc=legend)

    else: #countours
        alpha=0.5
        lwid=3
        cmap = plt.get_cmap('Blues')
        colors = cmap(numpy.linspace(0, 1, 2 + 2))[2:]
        color=colors[0]
        cmap = plt.get_cmap('Oranges')
        rcolors = cmap(numpy.linspace(0, 1, 2 + 2))[2:]
        rcolor=rcolors[0]
        ax = plt.gca()
        if old_reco is not None:
            ax.plot(variable_centers,medians_reco, color=rcolor, linestyle='-', label="%s median"%reco_name, linewidth=lwid)
            ax.fill_between(variable_centers,medians_reco,err_from_reco, color=rcolor, alpha=alpha)
            ax.fill_between(variable_centers,medians_reco,err_to_reco, color=rcolor,alpha=alpha,label=reco_name + ' 68%')
        ax.plot(variable_centers, medians,linestyle='-',label="%s median"%(cnn_name), color=color, linewidth=lwid)
        ax.fill_between(variable_centers,medians, err_from,color=color, alpha=alpha)
        ax.fill_between(variable_centers,medians, err_to, color=color, alpha=alpha,label=cnn_name + ' 68%')
        if xline is not None:
            if type(xline) is list:
                for x_val in xline:
                    plt.axvline(x_val,linewidth=3,color='k',linestyle="dashed",label="%s"%xline_name)
            else:
                plt.axvline(xline,linewidth=3,color='k',linestyle="dashed",label="%s"%xline_name)
        plt.legend(loc=legend)
    plt.xlim(min_val,max_val)
    if ylim is not None:
        plt.ylim(ylim)
    if vs_predict:
        plt.xlabel("Reconstructed %s %s"%(variable,units),fontsize=20)
    elif energy_truth is not None:
        plt.xlabel("%s %s"%(xvariable,units),fontsize=20)
    else:
        plt.xlabel("True %s %s"%(variable,units),fontsize=20)
    if use_fraction:
        plt.ylabel(r'Fractional Resolution: $\frac{reconstruction - truth}{truth}$',fontsize=20)
    else:
         plt.ylabel("Resolution: \n reconstruction - truth %s"%units,fontsize=20)
    if xlog:
        plt.xscale('log')
        
    if flavor == "NuMu" or flavor == "numu":
        title += r' for $\nu_\mu$ '
    elif flavor == "NuE" or flavor == "nue":
        title += r' for $\nu_e$ '
    elif flavor == "NuTau" or flavor == "nutau":
        title += r' for $\nu_\tau$ '
    elif flavor == "Mu" or flavor == "mu":
        title += r' for $\mu$ '
    elif flavor == "Nu" or flavor == "nu":
        title += r' for $\nu$ '
    else:
        title += flavor
    title += sample
    plt.title(title,fontsize=25)

    reco_name = reco_name.replace(" ","")
    variable = variable.replace(" ","")
    savename = "%s%sResolutionSlices"%(variable,cnn_name)
    if vs_predict:
        savename +="VsPredict"
    if use_fraction:
        savename += "Frac"
    if weights is not None:
        savename += "Weighted"
    if flavor is not None:
        savename += "%s"%flavor.replace(" ","")
    if energy_truth is not None:
        xvar_no_space = xvariable.replace(" ","")
        savename += "_%sBinned"%xvar_no_space
    if style == "errorbars":
        savename += "ErrorBars"
    if xlog:
        savename +="_xlog"
    if old_reco is not None:
        savename += "_Compare%sReco"%reco_name
    if ylim is not None:
        savename += "_ylim"
    if save == True:
        plt.savefig("%s%s.png"%(savefolder,savename),bbox_inches='tight')

In [None]:
input_file_cnn = ""
f = h5py.File(input_file_cnn, "r")
truth = f["Y_test_use"][:]
predict = f["Y_predicted"][:]
try:
    info = f["additional_info"][:]
except:
    info = None
if no_old_reco:
    reco = None
else:
    try:
        reco = f["reco_test"][:]
    except:
        reco = None
try:
    weights = f["weights_test"][:]
except:
    weights = None
f.close()
del f

In [None]:
cnn_prob_mu = np.array(predict[:,-1])
cnn_prob_nu = 1-cnn_prob_mu
cnn_zenith = np.array(predict[:,1])
cnn_prob_track = np.array(predict[:,2])
cnn_x = np.array(predict[:,3])
cnn_y = np.array(predict[:,4])
cnn_z = np.array(predict[:,5])
cnn_coszen = np.cos(cnn_zenith)
cnn_energy = np.array(predict[:,0])
cnn_isTrack = np.array(cnn_prob_track,dtype=bool)

reco_energy = np.array(reco[:,0])
reco_zenith = np.array(reco[:,1])
reco_azimuth = np.array(reco[:,2])
reco_coszenith = np.cos(reco_zenith)
reco_x = np.array(reco[:,4])
reco_y = np.array(reco[:,5])
reco_z = np.array(reco[:,6])
reco_CC = np.array(reco[:,11])
reco_isCC = true_CC2 == 1
reco_track = np.array(reco[:,8])
reco_isTrack = np.array(reco_track,dtype=bool)
reco_PID = reco[:,9]

true_energy = np.array(truth[:,0])
true_zenith = np.array(truth[:,1])
true_azimuth = np.array(truth[:,2])
true_coszenith = np.cos(true_zenith)
true_x = np.array(truth[:,4])
true_y = np.array(truth[:,5])
true_z = np.array(truth[:,6])
true_CC = np.array(truth[:,11])
true_isCC = true_CC == 1
true_track = np.array(truth[:,8])
true_isTrack = np.array(true_track,dtype=bool)
true_PID = truth[:,9]

muon_mask_test = (true_PID) == 13
true_isMuon = np.array(muon_mask_test,dtype=bool)
numu_mask_test = (true_PID) == 14
true_isNuMu = np.array(numu_mask_test,dtype=bool)
nue_mask_test = (true_PID) == 12
true_isNuE = np.array(nue_mask_test,dtype=bool)
nutau_mask_test = (true_PID) == 16
true_isNuTau = np.array(nutau_mask_test,dtype=bool)
nu_mask = np.logical_or(np.logical_or(numu_mask_test, nue_mask_test), nutau_mask_test)
true_isNu = np.array(nu_mask,dtype=bool)

numu_files2 = 97
nue_files2 = 91
muon_files2 = 1999
nutau_files2 = 45
if weights is not None:
    weights = weights[:,8]
    if sum(true_isNuMu) > 1:
        print("NuMu:",sum(true_isNuMu),sum(weights[true_isNuMu]))
        weights2[true_isNuMu] = weights[true_isNuMu]/numu_files
        print(sum(weights[true_isNuMu]))
    if sum(true_isNuE) > 1:
        print("NuE:",sum(true_isNuE),sum(weights2[true_isNuE]))
        weights2[true_isNuE2] = weights2[true_isNuE]/nue_files
        print(sum(weights2[true_isNuE2]))
    if sum(true_isMuon2) > 1:
        print("Muon:",sum(true_isMuon),sum(weights[true_isMuon]))
        weights[true_isMuon] = weights[true_isMuon]/muon_files
        print(sum(weights[true_isMuon]))
    if sum(nutau_mask_test) > 1:
        print("NuTau:",sum(true_isNuTau),sum(weights[true_isNuTau]))
        weights[true_isNuTau] = weights[true_isNuTau]/nutau_files
        print(sum(weights[true_isNuTau]))

In [None]:
true_CC = np.array(truth[:,11])
true_isCC = true_CC == 1

muon_mask_test = (true_PID) == 13
true_isMuon = np.array(muon_mask_test,dtype=bool)
numu_mask_test = (true_PID) == 14
true_isNuMu = np.array(numu_mask_test,dtype=bool)
nue_mask_test = (true_PID) == 12
true_isNuE = np.array(nue_mask_test,dtype=bool)
nutau_mask_test = (true_PID) == 16
true_isNuTau = np.array(nutau_mask_test,dtype=bool)
nu_mask = np.logical_or(np.logical_or(numu_mask_test, nue_mask_test), nutau_mask_test)
true_isNu = np.array(nu_mask,dtype=bool)