In [3]:
class Plot(object):
    def __init__(self):
        pass

    def plot(self, sh):
        """
        Returns:
        matplotlib figure? axes?
        """
        raise NotImplementedError("plot() needs to be defined in derived class")
        
class RejectionPlot(Plot):
    def __init__(self, scores, eff, xvar, working_point=None, bins=10, scale=1.0, label=None,
                 ylim=None,  fraction = 1, frac_test_used_for_train = 0, prongness = None, do_pt_reweight=False, log_x=False, log_y=False):
        super(RejectionPlot, self).__init__()

        if not isinstance(scores, list):
            scores = [scores]

        self.scores = scores

        if label:
            if not isinstance(label, list):
                label = [label]

            assert len(label) == len(scores)
        else:
            label = [s.split("/")[-1] for s in scores]

        self.label = label

        self.eff = eff
        self.xvar = xvar
        self.bins = bins / scale
        self.scale = scale
        self.ylim = ylim
        self.working_point = working_point
        self.fraction = fraction
        self.frac_test_used_for_train = frac_test_used_for_train
        self.prongness = prongness
        self.do_pt_reweight = do_pt_reweight
        self.log_x = log_x
        self.log_y = log_y


    def plot(self, sh):

        # print(self.scores)

        # Flatten on training sample
        sig_train = sh.sig_train.get_variables("TauJets/pt", "TauJets/mu",
                                               *self.scores)

        
        # Variables to determine working points on testing sample
        sig_test = sh.sig_test.get_variables("TauJets/pt","TauJets/beamSpotWeight","TauJets/mu")
        bkg_test = sh.bkg_test.get_variables("TauJets/pt","TauJets/beamSpotWeight","TauJets/mu",
                                             self.xvar, *self.scores)


       

        # Selecting a fraction of the scores for debugging purposes
        if self.fraction != 1:
            sig_train_frac_index = int(len(sig_train["TauJets/pt"])*self.fraction)
        if self.frac_test_used_for_train != 0:
            sig_test_frac_index = int(len(sig_test["TauJets/pt"])*self.frac_test_used_for_train)
            bkg_test_frac_index = int(len(bkg_test["TauJets/pt"])*self.frac_test_used_for_train)


        # print(sig_train.keys())
        # print(sig_test.keys())
        # print(bkg_test.keys())
        

        # sig_train["TauJets/pt"] = sig_train["TauJets/pt"][:sig_train_frac_index]
        # sig_test["TauJets/pt"] = sig_test["TauJets/pt"][:sig_test_frac_index]
        # bkg_test["TauJets/pt"] = bkg_test["TauJets/pt"][:bkg_test_frac_index]

        # sig_train["score"] = sig_train["score"][:sig_train_frac_index]
        # # sig_test["score"] = sig_test["score"][:sig_test_frac_index]
        # bkg_test["score"] = bkg_test["score"][:bkg_test_frac_index] 

        # # sig_train["TauJets/beamSpotWeight"] = sig_train["TauJets/beamSpotWeight"][:sig_train_frac_index]
        # sig_test["TauJets/beamSpotWeight"] = sig_test["TauJets/beamSpotWeight"][:sig_test_frac_index]
        # bkg_test["TauJets/beamSpotWeight"] = bkg_test["TauJets/beamSpotWeight"][:bkg_test_frac_index]

        # sig_train["TauJets/mu"] = sig_train["TauJets/mu"][:sig_train_frac_index]
        # sig_test["TauJets/mu"] = sig_test["TauJets/mu"][:sig_test_frac_index]
        # bkg_test["TauJets/mu"] = bkg_test["TauJets/mu"][:bkg_test_frac_index]
    



        #  # Applying RNN Cut
        # rnn_cut_sig_train = sig_train["score"] > .01
        # rnn_cut_sig_test = sig_test["score"] > .01
        # rnn_cut_bkg_test = bkg_test["score"] > .01

        # sig_train["TauJets/pt"] = sig_train["TauJets/pt"][rnn_cut_sig_train]
        # sig_test["TauJets/pt"] = sig_test["TauJets/pt"][rnn_cut_sig_test]
        # bkg_test["TauJets/pt"] = bkg_test["TauJets/pt"][rnn_cut_bkg_test]

        # sig_train["score"] = sig_train["score"][rnn_cut_sig_train]
        # sig_test["score"] = sig_test["score"][rnn_cut_sig_test]
        # bkg_test["score"] = bkg_test["score"][rnn_cut_bkg_test]   

        # sig_train["TauJets/beamSpotWeight"] = sig_train["TauJets/beamSpotWeight"][rnn_cut_sig_train]
        # sig_test["TauJets/beamSpotWeight"] = sig_test["TauJets/beamSpotWeight"][rnn_cut_sig_test]
        # bkg_test["TauJets/beamSpotWeight"] = bkg_test["TauJets/beamSpotWeight"][rnn_cut_bkg_test]

        # sig_train["TauJets/mu"] = sig_train["TauJets/mu"][rnn_cut_sig_train]
        # sig_test["TauJets/mu"] = sig_test["TauJets/mu"][rnn_cut_sig_test]
        # bkg_test["TauJets/mu"] = bkg_test["TauJets/mu"][rnn_cut_bkg_test]


        if self.prongness == "1p":
            print("using 1p pt binnings")
            pt_binning = pt_bins_R22_log_1p
        elif self.prongness == "2p":
            print("using 2p pt binnings")
            pt_binning = pt_bins_R22_log_2p
        else: # This is the case if we have a 3p dataset
            print("using 3p pt binnings")
            pt_binning = pt_bins_R22_log_3p


        # Determine flattening on training sample for all scores
        flat_dict = {}
        for s in self.scores:
            flat_dict[s] = Flattener(pt_binning, mu_bins_extended, self.eff)
            
            if self.fraction != 1:
                flat_dict[s].fit(sig_train["TauJets/pt"][:sig_train_frac_index], sig_train["TauJets/mu"][:sig_train_frac_index], sig_train[s][:sig_train_frac_index])
            else:    
                flat_dict[s].fit(sig_train["TauJets/pt"], sig_train["TauJets/mu"], sig_train[s])

        

        # Kinematic reweighting
        if self.do_pt_reweight == True:
            if self.frac_test_used_for_train != 0:
                sig_test_weight, bkg_test_weight = pt_reweight(sig_test["TauJets/pt"][sig_test_frac_index:],\
                                                               bkg_test["TauJets/pt"][bkg_test_frac_index:])
            else:
                sig_test_weight, bkg_test_weight = pt_reweight(sig_test["TauJets/pt"], bkg_test["TauJets/pt"])
        else:
            if self.frac_test_used_for_train != 0:
                sig_test_weight, bkg_test_weight = beamSpot_reweight(
                sig_test["TauJets/pt"][sig_test_frac_index:], bkg_test["TauJets/pt"][bkg_test_frac_index:],\
                    sig_test["TauJets/beamSpotWeight"][sig_test_frac_index:],\
                    bkg_test["TauJets/beamSpotWeight"][bkg_test_frac_index:])
            else:
                sig_test_weight, bkg_test_weight = beamSpot_reweight(
                sig_test["TauJets/pt"], bkg_test["TauJets/pt"], sig_test["TauJets/beamSpotWeight"], bkg_test["TauJets/beamSpotWeight"])




        # Check which events pass the working point for each score
        pass_thr = []
        for s in self.scores:
            if self.frac_test_used_for_train != 0:
                pass_thr.append(flat_dict[s].passes_thr(bkg_test["TauJets/pt"][bkg_test_frac_index:],
                                                    bkg_test["TauJets/mu"][bkg_test_frac_index:],
                                                    bkg_test[s][bkg_test_frac_index:]))
            else:
                pass_thr.append(flat_dict[s].passes_thr(bkg_test["TauJets/pt"],
                                                    bkg_test["TauJets/mu"],
                                                    bkg_test[s]))

        if self.frac_test_used_for_train != 0:
            rejections = binned_efficiency_ci(bkg_test[self.xvar][bkg_test_frac_index:], pass_thr,
                                          weight=bkg_test_weight,
                                          bins=self.bins, return_inverse=True)
        else:
            rejections = binned_efficiency_ci(bkg_test[self.xvar], pass_thr,
                                          weight=bkg_test_weight,
                                          bins=self.bins, return_inverse=True)



        # Plot
        fig, ax = plt.subplots()

        bin_center = self.scale * (self.bins[1:] + self.bins[:-1]) / 2.0
        bin_half_width = self.scale * (self.bins[1:] - self.bins[:-1]) / 2.0

        for z, (rej, c, label) in enumerate(
                zip(rejections, colorseq, self.label)):
            ci_lo, ci_hi = rej.ci
            yerr = np.vstack([rej.median - ci_lo, ci_hi - rej.median])

            ax.errorbar(bin_center, rej.median,
                        xerr=bin_half_width,
                        yerr=yerr,
                        fmt="o", color=c, label=label, zorder=z)

        if self.ylim:
            ax.set_ylim(ylim)
        else:
            y_lo, y_hi = ax.get_ylim()
            d = 0.05 * (y_hi - y_lo) + .03
            ax.set_ylim(y_lo - d, y_hi + d)

        if self.log_x:
            ax.set_xscale("log")
        if self.log_y:
            ax.set_yscale("log")

        ax.title.set_text("{} {} Rejection {} wp".format(self.prongness,self.xvar.split("/")[-1], self.working_point))
        ax.set_xlabel(self.xvar.split("/")[-1], x=1, ha="right")
        ax.set_ylabel("Rejection", y=1, ha="right")
        ax.legend(fontsize="small")

        return fig, ax