In [7]:
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import multiprocessing as mp
from scipy import interp
import numpy as np
import logging
import pickle
import glob
import sys

np.seterr(divide="ignore")
sys.path.append("..")




from recnn.preprocessing import sequentialize_by_pt
from recnn.preprocessing import rewrite_content
from recnn.preprocessing import multithreadmap
from recnn.preprocessing import permute_by_pt
from recnn.preprocessing import randomize
from recnn.preprocessing import extract


from recnn.recnn import grnn_transform_simple
from recnn.recnn import grnn_predict_simple
from recnn.recnn import grnn_predict_gated

%matplotlib inline
plt.rcParams["figure.figsize"] = (6, 6)

# Loading functions

In [8]:
def extractcontent(jet):
    return(jet["content"])

def tftransform(jet,tf):
    jet["content"] = tf.transform(jet["content"])
    return(jet)

def load_tf(filename_train, preprocess=None, n_events_train=-1):
    # Make training data
    print("Loading training data...")

    X, y = np.load(filename_train)
    X=np.array(X).astype(dict)
    y = np.array(y).astype(int)

    if n_events_train > 0:
        indices = np.random.permutation(len(X))[:n_events_train]
        X = X[indices]
        y = y[indices]

    print("\tfilename = " + filename_train)
    print("\tX size = " + len(X))
    print("\ty size = " + len(y))

    # Preprocessing 
    print("Preprocessing...")
    X = multithreadmap(rewrite_content,X)

    if preprocess:
        X = multithreadmap(preprocess,X)

    X = multithreadmap(permute_by_pt,multithreadmap(extract,X))
    Xcontent=multithreadmap(extractcontent,X)
    tf = RobustScaler().fit(np.vstack(Xcontent))

    return(tf)

def load_test(tf, filename_test, preprocess=None):
    # Make test data 
    print("Loading test data...")

    X, y = np.load(filename_test)
    X = np.array(X).astype(dict)
    y = np.array(y).astype(int)

    print("\tfilename = " + filename_test)
    print("\tX size = " + len(X))
    print("\ty size = " + len(y))

    # Preprocessing 
    print("Preprocessing...")
    X = multithreadmap(rewrite_content,X)
    
    if preprocess:
        X = multithreadmap(preprocess,X)
        
    X = multithreadmap(permute_by_pt,X)
    X = multithreadmap(extract,X)

    X=multithreadmap(tftransform,X,tf=tf)

    return(X, y)

In [None]:
def predict(X, filename, func=grnn_predict_simple):
    """
    gives the prediction function
    """
    fd = open(filename, "rb")
    params = pickle.load(fd)
    fd.close()
    y_pred = func(params, X)
    return(y_pred)

def evaluate_models(X, y, filename_list, func=grnn_predict_simple):
    rocs = []
    fprs = []
    tprs = []
    
    for filename in filename_list:
        print("Loading " + filename),
            
        y_pred = predict(X, filename, func=func)
        
        # Roc
        rocs.append(roc_auc_score(y, y_pred))
        fpr, tpr, _ = roc_curve(y, y_pred)
        
        fprs.append(fpr)
        tprs.append(tpr)
        
        print("ROC AUC = %.4f" % rocs[-1])
         
    print("Mean ROC AUC = %.4f" % np.mean(rocs))
        
    return(rocs, fprs, tprs)

def build_rocs(prefix_train, prefix_test, model_pattern, preprocess=None, gated=False):
    tf = load_tf("/data/conda/recnn/data/npyfiles/MoreStat_testtype.npy", preprocess=preprocess)
    X, y, w = load_test(tf, "/data/conda/recnn/data/npyfiles/MoreStat_testtype.npy", preprocess=preprocess) 
    
    if not gated:
        rocs, fprs, tprs = evaluate_models(X, y, w, 
                                           "/data/conda/recnn/models/model_anti-kt.pickle")
    else:
        rocs, fprs, tprs = evaluate_models(X, y, w, 
                                           "/data/conda/recnn/models/model_anti-kt.pickle", func=grnn_predict_gated)
        
    return(rocs, fprs, tprs)