In [9]:
import sys, yaml, h5py, random
import gc
import numpy as np
import os, glob
import time
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import *
from dataset_loder import *
import pandas as pd
import pickle
import wandb
import importlib
from sklearn.metrics import roc_curve, auc

In [10]:
def set_random_seeds(random_seed=0):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)
    
set_random_seeds(42)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Correct way to specify GPU index
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.backends.cudnn.benchmark = True



cuda:0


In [11]:

load_epoch= 0

lr_init= 1e-3
lr_factor= .1
new_lr= 1e-3
reslayers= [8,16,32,64]
resblocks= 3
channels= [0,1,2,3,4,5,6,7,8,9,10,11,12] 

loss_func= 'mse'
scheduler_= 'cosine'
optimizer_= 'Adam'
patience= 2
scheduler_mode= 'min'
BATCH_SIZE= 128
VAL_BATCH_SIZE= 128
TEST_BATCH_SIZE= 128
indices= [0,1,2,3,4,5,6,7,8,9,10,11,12] 

epochs= 2
n_train= -1
n_valid= -1
n_test= 40

m0_scale = 14
mass_mean= 9.025205 
mass_std= 5.1880417




random_seed=42
w_iter_freq=50
num_data_workers= 4






In [8]:

load_epoch = 30
epoch = load_epoch
BATCH_SIZE=1000
import torch_resnet_concat as networks

out_dir ="/global/cfs/cdirs/m4392/bbbam/jupyter_notebook_new/classification/ResNet_classifier"
model_dir ='13_ch_classifier_ResNet_mapA'




resnet = networks.ResNet_mapA(in_channels=13, nblocks=3, fmaps=[8,16,32,64], alpha=1)
resnet=resnet.to(device)
load_model = glob.glob(f'{out_dir}/{model_dir}/MODELS/model_epoch{load_epoch}*')[0]#loading  model mannually
print('Loading weights from %s'%load_model)
checkpoint = torch.load(load_model, weights_only=False)
resnet.load_state_dict(checkpoint['model_state_dict'])

input_datasets = ['DYToTauTau_M-50_13TeV_valid.h5', 'IMG_H_AATo4Tau_Hadronic_tauDR0p4_M3p7_signal_v2_2.h5', 'IMG_H_AATo4Tau_Hadronic_tauDR0p4_M5_signal_v2_2.h5',
                 'TTToHadronic_valid.h5', 'GGH_TauTau_valid.h5', 'IMG_H_AATo4Tau_Hadronic_tauDR0p4_M4_signal_v2_2.h5', 'QCD_Pt-15to7000_valid.h5', 'WJetsToLNu_valid.h5']

for input_dataset in input_datasets:
    out_tag = input_dataset.split('.')[0]
    # print("out_tag", out_tag)
    test_dir = f'/global/cfs/cdirs/m4392/bbbam/classifier_signal_background_Run2_valid_h5/{input_dataset}'
    test_dset = ClassifierDataset(test_dir , selected_channels=indices, preload_size=32)
    test_indices = list(range(len(test_dset)))
    test_sampler = ChunkedSampler(test_indices, chunk_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dset, batch_size=BATCH_SIZE, sampler=test_sampler, pin_memory=True, num_workers=num_data_workers)
    
    
    
    
    
    
    
    
    loss_, acc_ = 0., 0.
    y_pred_, y_true_ = [], []
    
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            X, y = data[0].to(device), data[1].to(device)
            iphi, ieta = data[2].to(device), data[3].to(device)
    
            iphi = iphi/360.
            ieta = ieta/140.
            logits = resnet([X, iphi, ieta])
            loss=  F.binary_cross_entropy_with_logits(logits, y)
            loss_ += loss.item()
            pred = logits.ge(0.).byte()
            acc_ += pred.eq(y.byte()).float().mean().item()
            y_pred_.append(torch.sigmoid(logits).detach().cpu().numpy())
            y_true_.append(y.detach().cpu().numpy())
    
    
            if i % 50  == 0:
                
    
               print('Validation (%d/%d): Val loss:%f, acc:%f'%(i+1, len(test_loader), loss_/(i+1), acc_/(i+1) ))
    
    
        y_pred_ = np.concatenate(y_pred_)
        y_true_ = np.concatenate(y_true_)
    
    
    
    
    
        print('%d: Val loss:%f, acc%f'%(epoch, loss_/len(test_loader), np.mean(acc_)))
    
    
        # fpr, tpr, _ = roc_curve(y_true_, y_pred_)
        # roc_auc = auc(fpr, tpr)
       
        output_dict = {}
        output_dict["y_true"] = y_true_
        output_dict["y_pred"] = y_pred_
        # output_dict["fpr"] = fpr
        # output_dict["tpr"] = tpr
    
   
        with open(f'{out_dir}/{model_dir}/{out_tag}.pkl', "wb") as outfile:
            pickle.dump(output_dict, outfile, protocol=2) #protocol=2 for compatibility
        print(f">>>>>>>>>>>>>>>>>> Done for {input_dataset} >>>>>>>>>>>>>")
    
    
    
       

Loading weights from /global/cfs/cdirs/m4392/bbbam/jupyter_notebook_new/classification/ResNet_classifier/13_ch_classifier_ResNet_mapA/MODELS/model_epoch30_auc0.9476.pkl
Validation (1/85): Val loss:0.388102, acc:0.855000
Validation (51/85): Val loss:0.383801, acc:0.844235
30: Val loss:0.382681, acc71.868004
>>>>>>>>>>>>>>>>>> Done for DYToTauTau_M-50_13TeV_valid.h5 >>>>>>>>>>>>>




Validation (1/709): Val loss:0.382357, acc:0.817000
Validation (51/709): Val loss:0.358794, acc:0.832177
Validation (101/709): Val loss:0.358927, acc:0.832832
Validation (151/709): Val loss:0.358925, acc:0.833291
Validation (201/709): Val loss:0.358191, acc:0.833478
Validation (251/709): Val loss:0.357948, acc:0.833737
Validation (301/709): Val loss:0.357960, acc:0.833581
Validation (351/709): Val loss:0.357647, acc:0.833972
Validation (401/709): Val loss:0.357959, acc:0.833830
Validation (451/709): Val loss:0.357671, acc:0.834093
Validation (501/709): Val loss:0.357916, acc:0.833994
Validation (551/709): Val loss:0.358205, acc:0.833784
Validation (601/709): Val loss:0.358072, acc:0.834025
Validation (651/709): Val loss:0.357750, acc:0.834264
Validation (701/709): Val loss:0.357538, acc:0.834298
30: Val loss:0.357459, acc591.562839




>>>>>>>>>>>>>>>>>> Done for IMG_H_AATo4Tau_Hadronic_tauDR0p4_M3p7_signal_v2_2.h5 >>>>>>>>>>>>>
Validation (1/521): Val loss:0.263100, acc:0.887000
Validation (51/521): Val loss:0.297944, acc:0.865333
Validation (101/521): Val loss:0.298160, acc:0.865178
Validation (151/521): Val loss:0.298552, acc:0.864702
Validation (201/521): Val loss:0.299789, acc:0.863761
Validation (251/521): Val loss:0.299309, acc:0.863677
Validation (301/521): Val loss:0.299532, acc:0.863488
Validation (351/521): Val loss:0.300260, acc:0.863051
Validation (401/521): Val loss:0.299952, acc:0.863105
Validation (451/521): Val loss:0.299931, acc:0.863142
Validation (501/521): Val loss:0.299875, acc:0.863066
30: Val loss:0.299579, acc449.749909
>>>>>>>>>>>>>>>>>> Done for IMG_H_AATo4Tau_Hadronic_tauDR0p4_M5_signal_v2_2.h5 >>>>>>>>>>>>>




Validation (1/85): Val loss:0.320898, acc:0.853000
Validation (51/85): Val loss:0.374212, acc:0.828157
30: Val loss:0.376644, acc70.335003
>>>>>>>>>>>>>>>>>> Done for TTToHadronic_valid.h5 >>>>>>>>>>>>>




Validation (1/85): Val loss:0.391073, acc:0.838000
Validation (51/85): Val loss:0.377137, acc:0.848549
30: Val loss:0.375663, acc72.132004
>>>>>>>>>>>>>>>>>> Done for GGH_TauTau_valid.h5 >>>>>>>>>>>>>




Validation (1/679): Val loss:0.339238, acc:0.857000
Validation (51/679): Val loss:0.339915, acc:0.845059
Validation (101/679): Val loss:0.338835, acc:0.844525
Validation (151/679): Val loss:0.338592, acc:0.843947
Validation (201/679): Val loss:0.338382, acc:0.844124
Validation (251/679): Val loss:0.338723, acc:0.843582
Validation (301/679): Val loss:0.339849, acc:0.842947
Validation (351/679): Val loss:0.339828, acc:0.842969
Validation (401/679): Val loss:0.339367, acc:0.843476
Validation (451/679): Val loss:0.339670, acc:0.843100
Validation (501/679): Val loss:0.340199, acc:0.842920
Validation (551/679): Val loss:0.340367, acc:0.843004
Validation (601/679): Val loss:0.340410, acc:0.842995
Validation (651/679): Val loss:0.340519, acc:0.843068
30: Val loss:0.340613, acc572.410694




>>>>>>>>>>>>>>>>>> Done for IMG_H_AATo4Tau_Hadronic_tauDR0p4_M4_signal_v2_2.h5 >>>>>>>>>>>>>
Validation (1/85): Val loss:0.167179, acc:0.930000
Validation (51/85): Val loss:0.143603, acc:0.934314
30: Val loss:0.144922, acc79.379004
>>>>>>>>>>>>>>>>>> Done for QCD_Pt-15to7000_valid.h5 >>>>>>>>>>>>>




Validation (1/85): Val loss:0.083031, acc:0.976000
Validation (51/85): Val loss:0.103703, acc:0.960255
30: Val loss:0.103204, acc81.656004
>>>>>>>>>>>>>>>>>> Done for WJetsToLNu_valid.h5 >>>>>>>>>>>>>


