In [1]:
import numpy as np
import torch
import h5py
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as PyG
from torch_geometric.transforms import Distance
from torch_geometric.data import DataLoader
from torch_geometric.data import Data as PyGData
from torch_geometric.data import Data
import sys, os
import subprocess
import csv, yaml
import math
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import torch.optim as optim
import argparse
import pandas as pd


In [2]:
sys.path.append("./python")
from model.allModel import *

In [3]:

parser = argparse.ArgumentParser()
parser.add_argument('--config', action='store', type=str, help='Configration file with sample information')
parser.add_argument('-o', '--output', action='store', type=str, required=True, help='Path to output directory')
parser.add_argument('--device', action='store', type=int, default=0, help='device name')
parser.add_argument('--epoch', action='store', type=int, default=400,help='Number of epochs')
parser.add_argument('--batch', action='store', type=int, default=32, help='Batch size')
parser.add_argument('--lr', action='store', type=float, default=1e-4,help='Learning rate')
parser.add_argument('--seed', action='store', type=int, default=12345,help='random seed')

parser.add_argument('--fea', action='store', type=int, default=6, help='# fea')
parser.add_argument('--cla', action='store', type=int, default=3, help='# class')

#parser.add_argument('--r', action='store', type=float, default=0, help='device name')
#parser.add_argument('--k', action='store', type=int, default=0, help='device name')


models = ['GNN1layer', 'GNN2layer', 'GNN3layer','WF1DCNN3FC1Model']
parser.add_argument('--model', choices=models, default=models[0], help='model name')

_StoreAction(option_strings=['--model'], dest='model', nargs=None, const=None, default='GNN1layer', type=None, choices=['GNN1layer', 'GNN2layer', 'GNN3layer', 'WF1DCNN3FC1Model'], help='model name', metavar=None)

In [4]:
# args = parser.parse_args() ## not jupyter
import easydict
args = easydict.EasyDict({
    "config" : 'config_test.yaml' ,
    "output" : '20210721_test',
    "epoch" : 10,
    "seed" : 12345,
    "lr" : 1e-4,
    "batch" : 32,
    "model" : 'GNN1layer',
    "fea" : 6,
    "cla" : 1,
    "device" : 3

   
})

In [5]:
config = yaml.load(open(args.config).read(), Loader=yaml.FullLoader)
config['training']['learningRate'] = float(config['training']['learningRate'])
if args.seed: config['training']['randomSeed1'] = args.seed
if args.epoch: config['training']['epoch'] = args.epoch
if args.lr: config['training']['learningRate'] = args.lr

In [6]:
torch.set_num_threads(os.cpu_count())
if torch.cuda.is_available() and args.device >= 0: torch.cuda.set_device(args.device)
if not os.path.exists('result/' + args.output): os.makedirs('result/' + args.output)

In [7]:
import time
start = time.time()
##### Define dataset instance #####
from dataset.HEPGNNDataset import *
dset = HEPGNNDataset()
for sampleInfo in config['samples']:
    if 'ignore' in sampleInfo and sampleInfo['ignore']: continue
    name = sampleInfo['name']
    dset.addSample(name, sampleInfo['path'], weight=sampleInfo['xsec']/sampleInfo['ngen'])
    dset.setProcessLabel(name, sampleInfo['label'])
dset.initialize()

QCD700 /store/hep/users/yewzzang/4top_QCD_ttbar/data_graph/pt/QCD_weight_210709/HT700*/*.pt
QCD1000 /store/hep/users/yewzzang/4top_QCD_ttbar/data_graph/pt/QCD_weight_210709/HT1000*/*.pt
QCD1500 /store/hep/users/yewzzang/4top_QCD_ttbar/data_graph/pt/QCD_weight_210709/HT1500*/*.pt
QCD2000 /store/hep/users/yewzzang/4top_QCD_ttbar/data_graph/pt/QCD_weight_210709/HT2000*/*.pt
ttbar /store/hep/users/yewzzang/4top_QCD_ttbar/data_graph/pt/ttbar_weight_210719/*.pt
     procName                                           fileName    weight  \
0      QCD700  /store/hep/users/yewzzang/4top_QCD_ttbar/data_...  0.000131   
1      QCD700  /store/hep/users/yewzzang/4top_QCD_ttbar/data_...  0.000131   
2      QCD700  /store/hep/users/yewzzang/4top_QCD_ttbar/data_...  0.000131   
3      QCD700  /store/hep/users/yewzzang/4top_QCD_ttbar/data_...  0.000131   
4      QCD700  /store/hep/users/yewzzang/4top_QCD_ttbar/data_...  0.000131   
...       ...                                                ...       .

In [8]:
lengths = [int(x*len(dset)) for x in config['training']['splitFractions']]
lengths.append(len(dset)-sum(lengths))
torch.manual_seed(config['training']['randomSeed1'])
trnDset, valDset, testDset = torch.utils.data.random_split(dset, lengths)


kwargs = {'num_workers':min(config['training']['nDataLoaders'],os.cpu_count()), 'pin_memory':False}

trnLoader = DataLoader(trnDset, batch_size=args.batch, shuffle=True, **kwargs)
valLoader = DataLoader(valDset, batch_size=args.batch, shuffle=False, **kwargs)
torch.manual_seed(torch.initial_seed())


<torch._C.Generator at 0x7f29d0098430>

In [9]:
##### Define model instance #####
exec('model = '+args.model+'(fea=args.fea, cla=args.cla)')
torch.save(model, os.path.join('result/' + args.output, 'model.pth'))

device = 'cpu'
if args.device >= 0 and torch.cuda.is_available():
    model = model.cuda()
    device = 'cuda'

##### Define optimizer instance #####
optm = optim.Adam(model.parameters(), lr=config['training']['learningRate'])

In [10]:

##### Start training #####
with open('result/' + args.output+'/summary.txt', 'w') as fout:
    fout.write(str(args))
    fout.write('\n\n')
    fout.write(str(model))
    fout.close()

In [11]:
from sklearn.metrics import accuracy_score
from tqdm import tqdm
bestState, bestLoss = {}, 1e9
train = {'loss':[], 'acc':[], 'val_loss':[], 'val_acc':[]}
nEpoch = config['training']['epoch']
for epoch in range(nEpoch):
    model.train()
    trn_loss, trn_acc = 0., 0.
    nProcessed = 0
    optm.zero_grad()
    test = torch.zeros(0).to(device)
    test_l = torch.zeros(0).to(device)
    for i, data in enumerate(tqdm(trnLoader, desc='epoch %d/%d' % (epoch+1, nEpoch))):
        data = data.to(device)
        
        label = data.y.float().to(device=device)
      
            
        scale = data.ss.float().to(device)
        weight = data.ww.float().to(device)
        scaledweight = weight*scale
        scaledweight = torch.abs(scaledweight)
        
#         print(scale, scale.shape,'scale')
#         print(weight, weight.shape, 'weight')
#         print(scaledweight,scaledweight.shape, 'sc')
     
        test = torch.cat((test, scaledweight),0)
        test_l = torch.cat((test_l,label),0)
        pred = model(data)
      
        if args.cla ==3:
            crit = torch.nn.CrossEntropyLoss(reduction='none')
            loss = crit(pred, label)
            loss = loss * scaledweight
            loss.mean().backward()

            optm.step()
            optm.zero_grad()


            ibatch = len(label)
            nProcessed += ibatch

            pred = torch.argmax(pred, 1)
            trn_loss += loss.mean().item()*ibatch
            trn_acc += accuracy_score(label.to('cpu'), pred.to('cpu'), 
                                      sample_weight=scaledweight.to('cpu'))*ibatch
        else:
            crit = torch.nn.BCEWithLogitsLoss(weight=scaledweight) ### sacledweight np.abs()
      
            loss = crit(pred.view(-1), label)
            loss.backward()

            optm.step()
            optm.zero_grad()

            label = label.reshape(-1)
            ibatch = len(label)
            nProcessed += ibatch
            trn_loss += loss.item()*ibatch
            
            trn_acc += accuracy_score(label.to('cpu'), np.where(pred.to('cpu') > 0.5, 1, 0), 
                                      sample_weight=scaledweight.to('cpu'))*ibatch
        
        
        
        
        
    trn_loss /= nProcessed 
    trn_acc  /= nProcessed
    print(trn_loss,'trn_loss')
    print(trn_acc,'trn_acc')
    model.eval()
    val_loss, val_acc = 0., 0.
    nProcessed = 0
    for i, data in enumerate(tqdm(valLoader)):
        
        data = data.to(device)

        label = data.y.float().to(device=device)
        scale = data.ss.float().to(device)
        weight = data.ww.float().to(device)
        scaledweight = weight*scale
        scaledweight = torch.abs(scaledweight)
        test = torch.cat((test, scaledweight),0)
        test_l = torch.cat((test_l,label),0)   
        
        pred = model(data)
        if args.cla == 3:
            crit = nn.CrossEntropyLoss(reduction='none')
            loss = crit(pred, label)
            loss = loss * scaledweight




            ibatch = len(label)
            nProcessed += ibatch

            pred=torch.argmax(pred,1)
            val_loss += loss.mean().item()*ibatch
            val_acc += accuracy_score(label.to('cpu'), pred.to('cpu'), 
                                      sample_weight=scaledweight.to('cpu'))*ibatch
        else:
            crit = torch.nn.BCEWithLogitsLoss(weight=scaledweight)
            loss = crit(pred.view(-1), label)

            label = label.reshape(-1)
            ibatch = len(label)
            nProcessed += ibatch
            val_loss += loss.item()*ibatch
       
            val_acc += accuracy_score(label.to('cpu'), np.where(pred.to('cpu') > 0.5, 1, 0), 
                                      sample_weight=scaledweight.to('cpu'))*ibatch
            
            
            
            
    val_loss /= nProcessed
    val_acc  /= nProcessed
    print(val_loss,'val_loss')
    print(val_acc,'val_acc')
    if bestLoss > val_loss:
        bestState = model.to('cpu').state_dict()
        bestLoss = val_loss
        torch.save(bestState, os.path.join('result/' + args.output, 'weight.pth'))

        model.to(device)

    train['loss'].append(trn_loss)
    train['acc'].append(trn_acc)
    train['val_loss'].append(val_loss)
    train['val_acc'].append(val_acc)

    with open(os.path.join('result/' + args.output, 'train.csv'), 'w') as f:
        writer = csv.writer(f)
        keys = train.keys()
        writer.writerow(keys)
        for row in zip(*[train[key] for key in keys]):
            writer.writerow(row)

bestState = model.to('cpu').state_dict()
torch.save(bestState, os.path.join('result/' + args.output, 'weightFinal.pth'))



epoch 1/10: 100%|██████████| 49740/49740 [06:35<00:00, 125.77it/s]
  0%|          | 0/33160 [00:00<?, ?it/s]

0.7071277023328433 trn_loss
0.796510107851673 trn_acc


 23%|██▎       | 7733/33160 [00:31<01:45, 241.92it/s]


KeyboardInterrupt: 

In [None]:
print(time.time()-start)

In [None]:
scale.shape

In [None]:
weight.shape


In [None]:
test[test_l==1].shape

In [None]:
test[test_l==0].shape

In [None]:
test_l.shape

In [None]:
test[test_l==0].sum()

In [None]:
test[test_l==1].sum()