In [1]:
import os,sys
import numpy as np
from collections import namedtuple
import tqdm
import glob
import math
import random
import inspect
import os.path as osp
from pathlib import Path
import itertools
from itertools import chain
import numpy as np
import pandas as pd
import multiprocessing
import h5py
import matplotlib.pyplot as plt
import numpy as np
import sys, os
from importlib import reload

import torch
import torch.nn as nn
from torch.utils.data import random_split
from torch_geometric.data import Data, DataLoader, DataListLoader
from torch_geometric.nn import EdgeConv, global_mean_pool, DataParallel
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import Data,Dataset
from torch_scatter import scatter_mean, scatter
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MetaLayer, EdgeConv, global_mean_pool, DynamicEdgeConv


In [74]:
deta_jj = 1.4
jPt = 400

def xyze_to_eppt(constituents):
    ''' converts an array [N x 100, 4] of particles
from px, py, pz, E to eta, phi, pt (mass omitted)
    '''
    PX, PY, PZ, E = range(4)
    pt = np.sqrt(np.float_power(constituents[:,PX], 2) + np.float_power(constituents[:,PY], 2), dtype='float32') # numpy.float16 dtype -> float power to avoid overflow
    eta = np.arcsinh(np.divide(constituents[:,PZ], pt, out=np.zeros_like(pt), where=pt!=0.), dtype='float32')
    phi = np.arctan2(constituents[:,PY], constituents[:,PX], dtype='float32')

    return np.stack([pt, eta, phi], axis=1)

side = True
to_train = True

datas = []
for i_e in range(1000):
    if to_train: 
        if file_bg['truth_label'][i_e]!=0 : #train only on QCD
            continue 
    if side :
        if not (file_bg["jet_kinematics"][i_e,1] > deta_jj):
            continue
    else : 
        if not (file_bg["jet_kinematics"][i_e,1] < deta_jj):
            continue
    for i_j in range(2): #each event has 2 jets
        pf_cands = np.array(file_bg["jet{}_PFCands".format(i_j+1)][i_e])
        pf_pt_eta_phi = xyze_to_eppt(pf_cands)
        n_particles = int(np.sum(pf_pt_eta_phi[:,0]!=0)) #if pt!=0
        particles = np.zeros((n_particles, 7)) #px,py,pz,E, pt, eta, phi = 7
        #particles = np.dstack((pf_cands[0:n_particles,:],np.array(pf_pt_eta_phi[0:n_particles,:])))
        particles = np.hstack((pf_cands[0:n_particles,:],np.array(pf_pt_eta_phi[0:n_particles,:])))
        pairs = np.stack([[m, n] for (m, n) in itertools.product(range(n_particles),range(n_particles)) if m!=n])
        edge_index = torch.tensor(pairs, dtype=torch.long)
        edge_index=edge_index.t().contiguous()
        # save particles as node attributes and target
        x = torch.tensor(particles, dtype=torch.float)
        data = Data(x=x, edge_index=edge_index)
        datas.append([data])
datas = sum(datas,[])

In [3]:
def get_present_constit(x,n):
    return x[0:n,:] 

def concat_features(feats_1,feats_2):
    return np.hstack((feats_1[:,:],feats_2[:,:]))

class GraphDataset(Dataset):  ####inherits from pytorch geometric Dataset (not just pytroch)
    def __init__(self, root, transform=None, pre_transform=None,
                 n_events=-1,n_jets=10e3, side_reg=1, features='xyzeptep',n_proc=1):
        """
        Initialize parameters of graph dataset
        Args:
            root (str): dir path
            n_events (int): how many events to process (-1=all in a file (there is a max))
            n_jets (int) : how many total jets to use
            side_reg (bool):true or false, side region for training, otherwise for testing on signal 
            n_proc (int): number of processes to split into
            features (str): (px, py, pz) or relative (pt, eta, phi)
        """
        max_events = int(1.1e6)
        self.n_events = max_events if n_events==-1 else n_events
        self.n_jets = int(n_jets)
        self.side_reg = side_reg
        self.n_proc = n_proc
        self.chunk_size = self.n_events // self.n_proc
        self.features = features
        self.dEtaJJ = 1.4
        self.jPt = 400
        self.jet_kin_names = ['mJJ', 'DeltaEtaJJ', 'j1Pt', 'j1Eta', 'j1Phi',\
                                        'j1M', 'j2Pt', 'j2Eta', 'j2Phi', 'j2M', 'j3Pt', 'j3Eta', 'j3Phi', 'j3M']
        self.pf_kin_names = ['px','py','pz','E']
        self.pf_cands, self.jet_prop = self.read_events()   

        
        super(GraphDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        self.input_files = sorted(glob.glob(self.raw_dir+'/*.root'))
        return [f.split('/')[-1] for f in self.input_files]

    def __len__(self):
        return self.n_jets
  
    def xyze_to_ptep(self,constituents):
        ''' converts an array [N x 100, 4] of particles
from px, py, pz, E to eta, phi, pt (mass omitted)
    '''
        PX = self.pf_kin_names.index('px')
        PY = self.pf_kin_names.index('py')
        PZ = self.pf_kin_names.index('pz')
        E = self.pf_kin_names.index('E')
        pt = np.sqrt(np.float_power(constituents[:,:,PX], 2) + np.float_power(constituents[:,:,PY], 2), dtype='float32') # numpy.float16 dtype -> float power to avoid overflow
        eta = np.arcsinh(np.divide(constituents[:,:,PZ], pt, out=np.zeros_like(pt), where=pt!=0.), dtype='float32')
        phi = np.arctan2(constituents[:,:,PY], constituents[:,:,PX], dtype='float32')
        return np.stack([pt, eta, phi], axis=2)



    def read_events(self):
        
        #Data Samples
        DATA_PATH = '/eos/cms/store/group/phys_b2g/CASE/h5_files/full_run2/BB_UL_MC_small_v2/'
        TRAIN_NAME = 'BB_batch0.h5'
        filename_bg = DATA_PATH + TRAIN_NAME 
        in_file = h5py.File(filename_bg, 'r') 
        jet_kin = np.array(in_file["jet_kinematics"])
        truth = np.array(in_file["truth_label"])

        j1Pt_mask = (jet_kin[:,self.jet_kin_names.index('j1Pt')] > self.jPt)
        j2Pt_mask = (jet_kin[:,self.jet_kin_names.index('j2Pt')] > self.jPt)
        full_mask = j1Pt_mask & j2Pt_mask
        if self.side_reg : 
            full_mask = full_mask & (jet_kin[:,self.jet_kin_names.index('DeltaEtaJJ')] > self.dEtaJJ)
        else : 
            full_mask = full_mask & (jet_kin[:,self.jet_kin_names.index('DeltaEtaJJ')] < self.dEtaJJ)

        #Apply mask on jet kinematics, truth and pf cands
        jet_kin = jet_kin[full_mask][0:self.n_jets]
        truth = truth[full_mask][0:self.n_jets]
        jet_const = [np.array(in_file["jet1_PFCands"])[full_mask][0:self.n_jets],np.array(in_file["jet2_PFCands"])[full_mask][0:self.n_jets]]
                

        pf_out_list = []
        jet_prop_list = []

        for i_j in range(2): #each event has 2 jets
            pf_xyze = jet_const[i_j]
            pf_ptep = self.xyze_to_ptep(pf_xyze)
            n_particles = np.sum(pf_xyze[:,:,self.pf_kin_names.index('E')]!=0,axis=1) #E is 3rd 
            pf_xyze_out = list(map(get_present_constit,pf_xyze,n_particles))
            pf_ptep_out = list(map(get_present_constit,pf_ptep,n_particles))
            pf_tot_out = list(map(concat_features,pf_xyze_out,pf_ptep_out))
            pf_out_list.append(pf_tot_out)

            n_jet_feats = 6
            jet_prop = np.zeros((len(pf_tot_out),n_jet_feats))
            jet_prop[:,0] = n_particles
            for i_f,f_name in enumerate('M,Pt,Eta,Phi'.split(',')):
                jet_prop[:,i_f+1] = jet_kin[:,self.jet_kin_names.index('j{}{}'.format(i_j+1,f_name))]
            jet_prop[:,n_jet_feats-1] = truth[:,0]
            jet_prop_list.append(jet_prop)
            
        #return list of pf particles, and list of global jet properties
        return sum(pf_out_list, []),np.vstack((jet_prop_list[0],jet_prop_list[1]))      
                 

    def get(self,idx):
        '''Yields one data graph'''
        #pf_cands, jet_prop = self.read_events()  #if done like this, it will process the data each time - insane . Has to be rewritten/rethought with generator.
        
        i_evt = idx
        #for i_evt in range(len(pf_cands)):
        n_particles = self.pf_cands[i_evt].shape[0]
        pairs = np.stack([[m, n] for (m, n) in itertools.product(range(n_particles),range(n_particles)) if m!=n])
        edge_index = torch.tensor(pairs, dtype=torch.long)
        edge_index=edge_index.t().contiguous()
        # save particles as node attributes and target
        x = torch.tensor(self.pf_cands[i_evt], dtype=torch.float)
        u = torch.tensor(self.jet_prop[i_evt,:], dtype=torch.float)
        data = Data(x=x, edge_index=edge_index,u=torch.unsqueeze(u, 0))
        return data
    
    def return_inmemory_data(self):
        datas = []
        for i_evt in range(self.n_jets):
            n_particles = self.pf_cands[i_evt].shape[0]
            pairs = np.stack([[m, n] for (m, n) in itertools.product(range(n_particles),range(n_particles)) if m!=n])
            edge_index = torch.tensor(pairs, dtype=torch.long)
            edge_index=edge_index.t().contiguous()
            # save particles as node attributes and target
            x = torch.tensor(self.pf_cands[i_evt], dtype=torch.float)
            u = torch.tensor(self.jet_prop[i_evt,:], dtype=torch.float)
            data = Data(x=x, edge_index=edge_index,u=torch.unsqueeze(u, 0))
            datas.append(data)
        return datas
        
        
data_dir = '/eos/user/n/nchernya/MLHEP/AnomalyDetection/ADgvae/output_models/pytroch/'
dataset = GraphDataset(root=data_dir,n_jets=1000)
    
use_generator = False
if use_generator:
    validation_split = 0.2
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    if dataset_size > 2:
        split = int(np.floor(validation_split * dataset_size))
    else: 
        split = 1
    print(dataset_size,split)
    random_seed= 1001

    train_subset, val_subset = torch.utils.data.random_split(dataset, [dataset_size - split, split],
                                                             generator=torch.Generator().manual_seed(random_seed))
    print("train subset dim:", len(train_subset))
    print("validation subset dim", len(val_subset))
    dataloaders = {
        'train':  DataLoader(train_subset, batch_size=128, shuffle=True),
        'val':   DataLoader(val_subset, batch_size=128, shuffle=True)
    }
    print("train_dataloader dim:", len(dataloaders['train']))
    print("val dataloader dim:", len(dataloaders['val']))
else : 
    in_memory_datas = dataset.return_inmemory_data() 

In [4]:
class Standardizer:
    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, data):
        """
        :param data: torch tensor
        """
        self.mean = torch.mean(data, dim=0)
        self.std = torch.std(data, dim=0)

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data, log_pt=False):
        """
        :param data: torch tensor
        :param log_pt: undo log transformation on pt
        """
        inverse = (data * self.std) + self.mean
        if log_pt:
            inverse[:,0] = (10 ** inverse[:,0]) - 1
        return inverse

def standardize(train_dataset,log_pt=False):
    """
    standardize dataset and return scaler for inversion
    :param train_dataset: list of Data objects
    :param valid_dataset: list of Data objects
    :param test_dataset: list of Data objects
    :param log_pt: log pt before standardization
    :return scaler: sklearn StandardScaler
    """
    train_x = torch.cat([d.x for d in train_dataset])
    if log_pt:
        train_x[:,0] = torch.log(train_x[:,0] + 1)

    scaler = Standardizer()
    scaler.fit(train_x)
    for d in train_dataset:
        d.x[:,:] = scaler.transform(d.x)
    return scaler

In [5]:
#scaler = standardize(train_subset) # I dont think that this works for the dataset implementation as it is done now
scaler = standardize(in_memory_datas) 

dataloaders = {
    'train':  DataLoader(in_memory_datas, batch_size=128, shuffle=True)
    }
print("train_dataloader dim:", len(dataloaders['train']))

#dataset.get(0).u

train_dataloader dim: 8


In [6]:
"""
    Model definitions.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_scatter import scatter_mean, scatter
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MetaLayer, EdgeConv, global_mean_pool, DynamicEdgeConv


# GNN AE using EdgeConv (mean aggregation graph operation). Basic GAE model.
class EdgeNet(nn.Module):
    def __init__(self, input_dim=7, output_dim=4, big_dim=32, hidden_dim=2, aggr='mean'):
        super(EdgeNet, self).__init__()
        encoder_nn = nn.Sequential(nn.Linear(2*(input_dim), big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, hidden_dim),
                               nn.ReLU(),
        )
        
        decoder_nn = nn.Sequential(nn.Linear(2*(hidden_dim), big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, output_dim)
        )
        
        self.batchnorm = nn.BatchNorm1d(input_dim)

        self.encoder = EdgeConv(nn=encoder_nn,aggr=aggr)
        self.decoder = EdgeConv(nn=decoder_nn,aggr=aggr)

    def forward(self, data):
        x = self.batchnorm(data.x)
        x = self.encoder(x,data.edge_index)
        x = self.decoder(x,data.edge_index)
        return x

In [7]:
def train(model, optimizer, loader, total, batch_size, loss_ftn_obj):
    model.train()

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader),total=total/batch_size)
    for i,data in t:
        optimizer.zero_grad()

        batch_loss, batch_output = forward_loss(model, data, loss_ftn_obj, device, multi_gpu=False)
        batch_loss.backward()
        optimizer.step()

        batch_loss = batch_loss.item()
        sum_loss += batch_loss
        t.set_description('train loss = %.7f' % batch_loss)
        t.refresh() # to show immediately the update

    return sum_loss / (i+1)


# helper to perform correct loss
def forward_loss(model, data, loss_ftn_obj, device, multi_gpu=False):
    
    if not multi_gpu:
        data = data.to(device)

    if 'emd_loss' in loss_ftn_obj.name or loss_ftn_obj.name == 'chamfer_loss' or loss_ftn_obj.name == 'hungarian_loss':
        batch_output = model(data)
        if multi_gpu:
            data = Batch.from_data_list(data).to(device)
        y = data.x
        batch = data.batch
        batch_loss = loss_ftn_obj.loss_ftn(batch_output, y, batch)

    elif loss_ftn_obj.name == 'emd_in_forward':
        _, batch_loss = model(data)
        batch_loss = batch_loss.mean()

    elif loss_ftn_obj.name == 'vae_loss':
        batch_output, mu, log_var = model(data)
        y = torch.cat([d.x for d in data]).to(device) if multi_gpu else data.x
        y = y.contiguous()
        batch_loss = loss_ftn_obj.loss_ftn(batch_output, y, mu, log_var)

    else:
        batch_output = model(data)
        y = torch.cat([d.x for d in data]).to(device) if multi_gpu else data.x
        y = y.contiguous()
        batch_loss = loss_ftn_obj.loss_ftn(batch_output, y)

    return batch_loss, batch_output

In [8]:
torch.manual_seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
multi_gpu = False #torch.cuda.device_count()>1

In [9]:
def xyze_to_ptetaphi_torch(y):
    ''' converts an array [N x 100, 4] of particles
from px, py, pz, E to pt,eta, phi
    '''
    PX, PY, PZ, E = range(4)
    pt = torch.sqrt(torch.pow(y[:,PX], 2) + torch.pow(y[:,PY], 2)) 
    eta = torch.asinh(torch.where(pt < 10e-5, torch.zeros_like(pt), torch.div(y[:,PZ], pt)))
    phi = torch.atan2(y[:,PY], y[:,PX])

    relu =  m = nn.ReLU() #inplace=True
    y_E_trimmed = relu(y[:,-1]) #trimming E
    y_pt_trimmed = relu(pt) #trimming pt
    full_y = torch.stack((y[:,0],y[:,1],y[:,2],y_E_trimmed,y_pt_trimmed,eta,phi), dim=1)

    return full_y


class LossFunction:
    def __init__(self, lossname, device=torch.device('cuda:0')):
        loss = getattr(self, lossname)
        self.name = lossname
        self.loss_ftn = loss
        self.device = device
        
    def mse(self, x, y):
        return F.mse_loss(x, y, reduction='mean')
    
    def mse_coordinates(self, y,x): #for some reason convension is : out,in
        #From px,py,pz,E get pt, eta, phi (do not predict them)
        #x is px,py,pz,E,pt,eta,phi
        #y is px,py,pz,E
        full_y = xyze_to_ptetaphi_torch(y)
        return self.mse(x,full_y)
        

In [10]:
#loss
#loss_ftn_obj = LossFunction('mse_coordinates', device=device)
loss_ftn_obj = LossFunction('mse', device=device)

# model
input_dim = 7
output_dim = 7#4
big_dim = 32
hidden_dim = 2
model = EdgeNet(input_dim=input_dim,output_dim=output_dim, big_dim=big_dim, hidden_dim=hidden_dim)

optimizer = torch.optim.Adam(model.parameters(), lr = 10e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4, threshold=1e-6)

model.to(device)


EdgeNet(
  (batchnorm): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder): EdgeConv(nn=Sequential(
    (0): Linear(in_features=14, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=2, bias=True)
    (5): ReLU()
  ))
  (decoder): EdgeConv(nn=Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=7, bias=True)
  ))
)

In [11]:
# Training loop
n_epochs = 80
stale_epochs = 0
loss = 999999
train_losses = []
for epoch in range(0, n_epochs):
    #loss = train(model, optimizer, loader, len(datas), 128, loss_ftn_obj)
    loss = train(model, optimizer, dataloaders['train'], len(dataloaders['train'].dataset), dataloaders['train'].batch_size, loss_ftn_obj)
    train_losses.append(loss)
    print('Epoch: {:02d}, Training Loss:   {:.4f}'.format(epoch, loss))

train loss = 0.7475386: : 8it [00:01,  6.43it/s]                          
train loss = 0.7425943:  38%|███▊      | 3/7.8125 [00:00<00:00, 21.96it/s]

Epoch: 00, Training Loss:   0.9179


train loss = 0.7787465: : 8it [00:00, 21.64it/s]                          
train loss = 0.7150579:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.34it/s]

Epoch: 01, Training Loss:   0.7511


train loss = 0.5570742: : 8it [00:00, 24.15it/s]                          
train loss = 0.4899089:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.60it/s]

Epoch: 02, Training Loss:   0.6515


train loss = 0.4190300: : 8it [00:00, 23.99it/s]                          
train loss = 0.3580547:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.01it/s]

Epoch: 03, Training Loss:   0.4946


train loss = 0.2555426: : 8it [00:00, 23.93it/s]                          
train loss = 0.3751278:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.82it/s]

Epoch: 04, Training Loss:   0.3957


train loss = 0.3283700: : 8it [00:00, 23.50it/s]                          
train loss = 0.3020281:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.09it/s]

Epoch: 05, Training Loss:   0.3287


train loss = 0.2315729: : 8it [00:00, 22.49it/s]                          
train loss = 0.1960796:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.07it/s]

Epoch: 06, Training Loss:   0.2926


train loss = 0.2699398: : 8it [00:00, 23.99it/s]                          
train loss = 0.1949106:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.15it/s]

Epoch: 07, Training Loss:   0.2488


train loss = 0.2019417: : 8it [00:00, 22.41it/s]                          
train loss = 0.2105487:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.74it/s]

Epoch: 08, Training Loss:   0.2023


train loss = 0.1574364: : 8it [00:00, 23.71it/s]                          
train loss = 0.1479566:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.21it/s]

Epoch: 09, Training Loss:   0.1830


train loss = 0.1403333: : 8it [00:00, 24.18it/s]                          
train loss = 0.0840258:  38%|███▊      | 3/7.8125 [00:00<00:00, 21.24it/s]

Epoch: 10, Training Loss:   0.1557


train loss = 0.1533939: : 8it [00:00, 22.16it/s]                          
train loss = 0.1158177:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.63it/s]

Epoch: 11, Training Loss:   0.1320


train loss = 0.0904480: : 8it [00:00, 23.95it/s]                          
train loss = 0.1025347:  38%|███▊      | 3/7.8125 [00:00<00:00, 21.55it/s]

Epoch: 12, Training Loss:   0.1196


train loss = 0.1160549: : 8it [00:00, 22.89it/s]                          
train loss = 0.1129461:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.93it/s]

Epoch: 13, Training Loss:   0.1148


train loss = 0.1016909: : 8it [00:00, 24.41it/s]                          
train loss = 0.1011642:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.76it/s]

Epoch: 14, Training Loss:   0.1150


train loss = 0.0907081: : 8it [00:00, 23.06it/s]                          
train loss = 0.0983043:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.69it/s]

Epoch: 15, Training Loss:   0.1068


train loss = 0.0970521: : 8it [00:00, 23.65it/s]                          
train loss = 0.0857799:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.69it/s]

Epoch: 16, Training Loss:   0.1041


train loss = 0.1048030: : 8it [00:00, 23.42it/s]                          
train loss = 0.0797524:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.79it/s]

Epoch: 17, Training Loss:   0.0953


train loss = 0.0893670: : 8it [00:00, 25.19it/s]                          
train loss = 0.0710639:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.20it/s]

Epoch: 18, Training Loss:   0.0882


train loss = 0.1018090: : 8it [00:00, 24.25it/s]                          
train loss = 0.0858047:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.70it/s]

Epoch: 19, Training Loss:   0.0828


train loss = 0.0585775: : 8it [00:00, 25.28it/s]                          
train loss = 0.0708721:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.51it/s]

Epoch: 20, Training Loss:   0.0765


train loss = 0.0677299: : 8it [00:00, 25.21it/s]                          
train loss = 0.0800082:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.92it/s]

Epoch: 21, Training Loss:   0.0786


train loss = 0.0624596: : 8it [00:00, 24.86it/s]                          
train loss = 0.0855809:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.79it/s]

Epoch: 22, Training Loss:   0.0751


train loss = 0.0916473: : 8it [00:00, 25.03it/s]                          
train loss = 0.0642789:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.52it/s]

Epoch: 23, Training Loss:   0.0837


train loss = 0.1198972: : 8it [00:00, 23.96it/s]                          
train loss = 0.0844235:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.43it/s]

Epoch: 24, Training Loss:   0.0838


train loss = 0.0559634: : 8it [00:00, 22.84it/s]                          
train loss = 0.0830481:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.01it/s]

Epoch: 25, Training Loss:   0.0753


train loss = 0.0779901: : 8it [00:00, 25.14it/s]                          
train loss = 0.0699527:  38%|███▊      | 3/7.8125 [00:00<00:00, 21.53it/s]

Epoch: 26, Training Loss:   0.0750


train loss = 0.0624757: : 8it [00:00, 22.68it/s]                          
train loss = 0.0514325:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.31it/s]

Epoch: 27, Training Loss:   0.0766


train loss = 0.0624259: : 8it [00:00, 25.20it/s]                          
train loss = 0.0502170:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.22it/s]

Epoch: 28, Training Loss:   0.0740


train loss = 0.0683164: : 8it [00:00, 23.32it/s]                          
train loss = 0.0808447:  38%|███▊      | 3/7.8125 [00:00<00:00, 26.11it/s]

Epoch: 29, Training Loss:   0.0758


train loss = 0.0935838: : 8it [00:00, 25.16it/s]                          
train loss = 0.0694738:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.59it/s]

Epoch: 30, Training Loss:   0.0810


train loss = 0.0901861: : 8it [00:00, 25.21it/s]                          
train loss = 0.0768575:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.35it/s]

Epoch: 31, Training Loss:   0.0768


train loss = 0.0932597: : 8it [00:00, 25.18it/s]                          
train loss = 0.0678616:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.69it/s]

Epoch: 32, Training Loss:   0.0800


train loss = 0.0943939: : 8it [00:00, 25.18it/s]                          
train loss = 0.0754122:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.38it/s]

Epoch: 33, Training Loss:   0.0753


train loss = 0.0963378: : 8it [00:00, 25.10it/s]                          
train loss = 0.0537468:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.48it/s]

Epoch: 34, Training Loss:   0.0704


train loss = 0.0722131: : 8it [00:00, 25.11it/s]                          
train loss = 0.0610189:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.07it/s]

Epoch: 35, Training Loss:   0.0629


train loss = 0.0501374: : 8it [00:00, 25.02it/s]                          
train loss = 0.0785490:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.15it/s]

Epoch: 36, Training Loss:   0.0637


train loss = 0.0858953: : 8it [00:00, 24.46it/s]                          
train loss = 0.0989793:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.44it/s]

Epoch: 37, Training Loss:   0.0733


train loss = 0.0605864: : 8it [00:00, 21.74it/s]                          
train loss = 0.0528818:  38%|███▊      | 3/7.8125 [00:00<00:00, 21.79it/s]

Epoch: 38, Training Loss:   0.0789


train loss = 0.0605728: : 8it [00:00, 21.91it/s]                          
train loss = 0.0548394:  38%|███▊      | 3/7.8125 [00:00<00:00, 21.40it/s]

Epoch: 39, Training Loss:   0.0670


train loss = 0.0859439: : 8it [00:00, 21.91it/s]                          
train loss = 0.0734868:  38%|███▊      | 3/7.8125 [00:00<00:00, 21.71it/s]

Epoch: 40, Training Loss:   0.0611


train loss = 0.0515118: : 8it [00:00, 22.03it/s]                          
train loss = 0.0520447:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.41it/s]

Epoch: 41, Training Loss:   0.0760


train loss = 0.1043358: : 8it [00:00, 21.97it/s]                          
train loss = 0.0478440:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.91it/s]

Epoch: 42, Training Loss:   0.0717


train loss = 0.0788044: : 8it [00:00, 25.14it/s]                          
train loss = 0.0610387:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.86it/s]

Epoch: 43, Training Loss:   0.0641


train loss = 0.0675502: : 8it [00:00, 25.16it/s]                          
train loss = 0.0473323:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.41it/s]

Epoch: 44, Training Loss:   0.0636


train loss = 0.0469295: : 8it [00:00, 25.03it/s]                          
train loss = 0.0462295:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.73it/s]

Epoch: 45, Training Loss:   0.0533


train loss = 0.0437600: : 8it [00:00, 24.25it/s]                          
train loss = 0.0517046:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.41it/s]

Epoch: 46, Training Loss:   0.0521


train loss = 0.0571039: : 8it [00:00, 22.81it/s]                          
train loss = 0.0823893:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.77it/s]

Epoch: 47, Training Loss:   0.0528


train loss = 0.0577339: : 8it [00:00, 24.16it/s]                          
train loss = 0.0588257:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.12it/s]

Epoch: 48, Training Loss:   0.0571


train loss = 0.0697363: : 8it [00:00, 24.26it/s]                          
train loss = 0.0493295:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.91it/s]

Epoch: 49, Training Loss:   0.0612


train loss = 0.0552601: : 8it [00:00, 22.21it/s]                          
train loss = 0.0676070:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.55it/s]

Epoch: 50, Training Loss:   0.0603


train loss = 0.0514778: : 8it [00:00, 25.12it/s]                          
train loss = 0.0388598:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.77it/s]

Epoch: 51, Training Loss:   0.0621


train loss = 0.0512487: : 8it [00:00, 25.28it/s]                          
train loss = 0.0445996:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.67it/s]

Epoch: 52, Training Loss:   0.0628


train loss = 0.0500326: : 8it [00:00, 25.24it/s]                          
train loss = 0.0370677:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.78it/s]

Epoch: 53, Training Loss:   0.0555


train loss = 0.0513884: : 8it [00:00, 25.32it/s]                          
train loss = 0.0626720:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.37it/s]

Epoch: 54, Training Loss:   0.0465


train loss = 0.0465760: : 8it [00:00, 25.32it/s]                          
train loss = 0.0378897:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.69it/s]

Epoch: 55, Training Loss:   0.0594


train loss = 0.0751434: : 8it [00:00, 23.18it/s]                          
train loss = 0.0597783:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.91it/s]

Epoch: 56, Training Loss:   0.0628


train loss = 0.0511368: : 8it [00:00, 24.51it/s]                          
train loss = 0.0522224:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.42it/s]

Epoch: 57, Training Loss:   0.0567


train loss = 0.0572367: : 8it [00:00, 24.60it/s]                          
train loss = 0.0576026:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.52it/s]

Epoch: 58, Training Loss:   0.0518


train loss = 0.0721417: : 8it [00:00, 23.52it/s]                          
train loss = 0.0508808:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.49it/s]

Epoch: 59, Training Loss:   0.0520


train loss = 0.0688027: : 8it [00:00, 24.26it/s]                          
train loss = 0.0388669:  38%|███▊      | 3/7.8125 [00:00<00:00, 23.89it/s]

Epoch: 60, Training Loss:   0.0521


train loss = 0.0372685: : 8it [00:00, 24.47it/s]                          
train loss = 0.0414363:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.04it/s]

Epoch: 61, Training Loss:   0.0417


train loss = 0.0362583: : 8it [00:00, 25.14it/s]                          
train loss = 0.0717653:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.55it/s]

Epoch: 62, Training Loss:   0.0514


train loss = 0.0442414: : 8it [00:00, 25.28it/s]                          
train loss = 0.0474826:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.15it/s]

Epoch: 63, Training Loss:   0.0492


train loss = 0.0471001: : 8it [00:00, 25.27it/s]                          
train loss = 0.0422874:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.77it/s]

Epoch: 64, Training Loss:   0.0477


train loss = 0.0549597: : 8it [00:00, 25.32it/s]                          
train loss = 0.0499493:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.24it/s]

Epoch: 65, Training Loss:   0.0481


train loss = 0.1526406: : 8it [00:00, 25.31it/s]                          
train loss = 0.1190098:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.37it/s]

Epoch: 66, Training Loss:   0.0784


train loss = 0.0679528: : 8it [00:00, 25.27it/s]                          
train loss = 0.0481311:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.90it/s]

Epoch: 67, Training Loss:   0.0639


train loss = 0.0472118: : 8it [00:00, 25.35it/s]                          
train loss = 0.0410708:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.83it/s]

Epoch: 68, Training Loss:   0.0630


train loss = 0.0416030: : 8it [00:00, 25.31it/s]                          
train loss = 0.0450964:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.00it/s]

Epoch: 69, Training Loss:   0.0445


train loss = 0.0320823: : 8it [00:00, 25.30it/s]                          
train loss = 0.0451299:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.14it/s]

Epoch: 70, Training Loss:   0.0461


train loss = 0.0461104: : 8it [00:00, 25.43it/s]                          
train loss = 0.0544953:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.36it/s]

Epoch: 71, Training Loss:   0.0454


train loss = 0.0361238: : 8it [00:00, 25.35it/s]                          
train loss = 0.0383699:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.91it/s]

Epoch: 72, Training Loss:   0.0493


train loss = 0.0359280: : 8it [00:00, 25.36it/s]                          
train loss = 0.0463924:  38%|███▊      | 3/7.8125 [00:00<00:00, 25.59it/s]

Epoch: 73, Training Loss:   0.0432


train loss = 0.0551582: : 8it [00:00, 25.32it/s]                          
train loss = 0.0575165:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.64it/s]

Epoch: 74, Training Loss:   0.0418


train loss = 0.0424429: : 8it [00:00, 25.33it/s]                          
train loss = 0.0484901:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.93it/s]

Epoch: 75, Training Loss:   0.0379


train loss = 0.0347292: : 8it [00:00, 25.26it/s]                          
train loss = 0.0275179:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.41it/s]

Epoch: 76, Training Loss:   0.0407


train loss = 0.0486112: : 8it [00:00, 24.60it/s]                          
train loss = 0.0529840:  38%|███▊      | 3/7.8125 [00:00<00:00, 22.54it/s]

Epoch: 77, Training Loss:   0.0412


train loss = 0.0308493: : 8it [00:00, 24.09it/s]                          
train loss = 0.0658534:  38%|███▊      | 3/7.8125 [00:00<00:00, 24.50it/s]

Epoch: 78, Training Loss:   0.0455


train loss = 0.0401981: : 8it [00:00, 25.08it/s]                          

Epoch: 79, Training Loss:   0.0486





8000

In [12]:
@torch.no_grad()
def gen_in_out(model, loader, device):
    model.eval()
    input_fts = []
    reco_fts = []

    for t in loader:
        if isinstance(t, list):
            for d in t:
                input_fts.append(d.x)
        else:
            input_fts.append(t.x)
            t.to(device)

        reco_out = model(t)
        if isinstance(reco_out, tuple):
            reco_out = reco_out[0]
        reco_fts.append(reco_out.cpu().detach())

    input_fts = torch.cat(input_fts)
    reco_fts = torch.cat(reco_fts)
    return input_fts, reco_fts

def plot_reco_for_loader(model, loader, device, scaler, inverse_scale, model_fname, save_dir, feature_format):
    input_fts, reco_fts = gen_in_out(model, loader, device)
    if inverse_scale:
        input_fts = scaler.inverse_transform(input_fts)
        reco_fts = scaler.inverse_transform(reco_fts)
    plot_reco_difference(input_fts, reco_fts, model_fname, save_dir, feature_format)

    
def plot_reco_difference(input_fts, reco_fts, model_fname, save_path, feature='hadronic'):
    """
    Plot the difference between the autoencoder's reconstruction and the original input
    Args:
        input_fts (numpy array): the original features of the particles
        reco_fts (numpy array): the reconstructed features
        model_fname (str): name of saved model
    """
    
    if isinstance(input_fts, torch.Tensor):
        input_fts = input_fts.numpy()
    if isinstance(reco_fts, torch.Tensor):
       # if feature == 'all':
       #     reco_fts = xyze_to_ptetaphi_torch(reco_fts)
        reco_fts = reco_fts.numpy()

        
    Path(save_path).mkdir(parents=True, exist_ok=True)
  #  label = ['$p_x~[GeV]$', '$p_y~[GeV]$', '$p_z~[GeV]$']
   # feat = ['px', 'py', 'pz']
    label = ['$p_x~[GeV]$', '$p_y~[GeV]$', '$p_z~[GeV]$']
    feat = ['px', 'py', 'pz']
    if feature == 'hadronic':# or 'standardized':
        label = ['$p_T$', '$eta$', '$phi$']
        feat = ['pt', 'eta', 'phi']
        
    if feature == 'all':# or 'standardized':
        label = ['$p_x~[GeV]$', '$p_y~[GeV]$', '$p_z~[GeV]$', '$E~[GeV]$','$p_T$', '$eta$', '$phi$']
        feat = ['px', 'py', 'pz','E','pt', 'eta', 'phi']
        
    # make a separate plot for each feature
    for i in range(input_fts.shape[1]):
        #plt.style.use(hep.style.CMS)
        plt.figure(figsize=(10,8))
        if feature == 'cartesian':
            bins = np.linspace(-20, 20, 101)
            if i == 3:  # different bin size for E momentum
                bins = np.linspace(-5, 35, 101)
        elif feature == 'hadronic':
            bins = np.linspace(-2, 2, 101)
            if i == 0:  # different bin size for pt rel
                bins = np.linspace(-0.05, 0.1, 101)
        elif feature == 'all':
            bins = np.linspace(-20, 20, 101)
            if i > 3:  # different bin size for hadronic coord
                bins = np.linspace(-2, 2, 101)
            if i == 3:  # different bin size for E momentum
                bins = np.linspace(-5, 35, 101)
            if i == 4:  # different bin size for pt rel
                bins = np.linspace(-2, 10, 101)
        else:
            bins = np.linspace(-1, 1, 101)
        plt.ticklabel_format(useMathText=True)
        plt.hist(input_fts[:,i], bins=bins, alpha=0.5, label='Input', histtype='step', lw=5)
        plt.hist(reco_fts[:,i], bins=bins, alpha=0.5, label='Output', histtype='step', lw=5)
        plt.legend(title='QCD dataset', fontsize='x-large')
        plt.xlabel(label[i], fontsize='x-large')
        plt.ylabel('Particles', fontsize='x-large')
        plt.tight_layout()
        plt.savefig(osp.join(save_path, feat[i] + '.png'))
        plt.close()
    

In [13]:
inverse_standardization = True
save_dir = '/eos/user/n/nchernya/MLHEP/AnomalyDetection/ADgvae/output_models/pytroch/'
plot_reco_for_loader(model, dataloaders['train'], device, scaler, inverse_standardization, 'test_train', osp.join(save_dir, 'reconstruction_post_train', 'train_reco_all_std'), 'all')
