In [1]:
import os,sys
import numpy as np
from collections import namedtuple
import tqdm
import glob
import math
import random
import inspect
import os.path as osp
from pathlib import Path
import itertools
from itertools import chain
import numpy as np
import pandas as pd
import multiprocessing
import h5py
import matplotlib.pyplot as plt
import numpy as np
import sys, os
from importlib import reload

import torch
import torch.nn as nn
from torch.utils.data import random_split
from torch_geometric.data import Data, DataLoader, DataListLoader
from torch_geometric.nn import EdgeConv, global_mean_pool, DataParallel
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import Data,Dataset
from torch_scatter import scatter_mean, scatter
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MetaLayer, EdgeConv, global_mean_pool, DynamicEdgeConv


In [2]:
#Data Samples
DATA_PATH = '/eos/cms/store/group/phys_b2g/CASE/h5_files/full_run2/BB_UL_MC_small_v2/'

TRAIN_NAME = 'BB_batch0.h5'
filename_bg = DATA_PATH + TRAIN_NAME 
batch_size = 128
train_set_size = int((5*10e3//batch_size)*batch_size)
file_bg = h5py.File(filename_bg, 'r') 

In [74]:
deta_jj = 1.4
jPt = 400

def xyze_to_eppt(constituents):
    ''' converts an array [N x 100, 4] of particles
from px, py, pz, E to eta, phi, pt (mass omitted)
    '''
    PX, PY, PZ, E = range(4)
    pt = np.sqrt(np.float_power(constituents[:,PX], 2) + np.float_power(constituents[:,PY], 2), dtype='float32') # numpy.float16 dtype -> float power to avoid overflow
    eta = np.arcsinh(np.divide(constituents[:,PZ], pt, out=np.zeros_like(pt), where=pt!=0.), dtype='float32')
    phi = np.arctan2(constituents[:,PY], constituents[:,PX], dtype='float32')

    return np.stack([pt, eta, phi], axis=1)

side = True
to_train = True

datas = []
for i_e in range(1000):
    if to_train: 
        if file_bg['truth_label'][i_e]!=0 : #train only on QCD
            continue 
    if side :
        if not (file_bg["jet_kinematics"][i_e,1] > deta_jj):
            continue
    else : 
        if not (file_bg["jet_kinematics"][i_e,1] < deta_jj):
            continue
    for i_j in range(2): #each event has 2 jets
        pf_cands = np.array(file_bg["jet{}_PFCands".format(i_j+1)][i_e])
        pf_pt_eta_phi = xyze_to_eppt(pf_cands)
        n_particles = int(np.sum(pf_pt_eta_phi[:,0]!=0)) #if pt!=0
        particles = np.zeros((n_particles, 7)) #px,py,pz,E, pt, eta, phi = 7
        #particles = np.dstack((pf_cands[0:n_particles,:],np.array(pf_pt_eta_phi[0:n_particles,:])))
        particles = np.hstack((pf_cands[0:n_particles,:],np.array(pf_pt_eta_phi[0:n_particles,:])))
        pairs = np.stack([[m, n] for (m, n) in itertools.product(range(n_particles),range(n_particles)) if m!=n])
        edge_index = torch.tensor(pairs, dtype=torch.long)
        edge_index=edge_index.t().contiguous()
        # save particles as node attributes and target
        x = torch.tensor(particles, dtype=torch.float)
        data = Data(x=x, edge_index=edge_index)
        datas.append([data])
datas = sum(datas,[])

In [75]:
"""
    Model definitions.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_scatter import scatter_mean, scatter
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MetaLayer, EdgeConv, global_mean_pool, DynamicEdgeConv


# GNN AE using EdgeConv (mean aggregation graph operation). Basic GAE model.
class EdgeNet(nn.Module):
    def __init__(self, input_dim=7, output_dim=4, big_dim=32, hidden_dim=2, aggr='mean'):
        super(EdgeNet, self).__init__()
        encoder_nn = nn.Sequential(nn.Linear(2*(input_dim), big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, hidden_dim),
                               nn.ReLU(),
        )
        
        decoder_nn = nn.Sequential(nn.Linear(2*(hidden_dim), big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, output_dim)
        )
        
        self.batchnorm = nn.BatchNorm1d(input_dim)

        self.encoder = EdgeConv(nn=encoder_nn,aggr=aggr)
        self.decoder = EdgeConv(nn=decoder_nn,aggr=aggr)

    def forward(self, data):
        x = self.batchnorm(data.x)
        x = self.encoder(x,data.edge_index)
        x = self.decoder(x,data.edge_index)
        return x

In [76]:
def train(model, optimizer, loader, total, batch_size, loss_ftn_obj):
    model.train()

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader),total=total/batch_size)
    for i,data in t:
        optimizer.zero_grad()

        batch_loss, batch_output = forward_loss(model, data, loss_ftn_obj, device, multi_gpu=False)
        batch_loss.backward()
        optimizer.step()

        batch_loss = batch_loss.item()
        sum_loss += batch_loss
        t.set_description('train loss = %.7f' % batch_loss)
        t.refresh() # to show immediately the update

    return sum_loss / (i+1)


# helper to perform correct loss
def forward_loss(model, data, loss_ftn_obj, device, multi_gpu=False):
    
    if not multi_gpu:
        data = data.to(device)

    if 'emd_loss' in loss_ftn_obj.name or loss_ftn_obj.name == 'chamfer_loss' or loss_ftn_obj.name == 'hungarian_loss':
        batch_output = model(data)
        if multi_gpu:
            data = Batch.from_data_list(data).to(device)
        y = data.x
        batch = data.batch
        batch_loss = loss_ftn_obj.loss_ftn(batch_output, y, batch)

    elif loss_ftn_obj.name == 'emd_in_forward':
        _, batch_loss = model(data)
        batch_loss = batch_loss.mean()

    elif loss_ftn_obj.name == 'vae_loss':
        batch_output, mu, log_var = model(data)
        y = torch.cat([d.x for d in data]).to(device) if multi_gpu else data.x
        y = y.contiguous()
        batch_loss = loss_ftn_obj.loss_ftn(batch_output, y, mu, log_var)

    else:
        batch_output = model(data)
        print(batch_output.shape)
        y = torch.cat([d.x for d in data]).to(device) if multi_gpu else data.x
        print(y.shape)
        y = y.contiguous()
        print(y.shape)
        batch_loss = loss_ftn_obj.loss_ftn(batch_output, y)

    return batch_loss, batch_output

In [77]:
torch.manual_seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
multi_gpu = False #torch.cuda.device_count()>1

In [7]:
class Standardizer:
    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, data):
        """
        :param data: torch tensor
        """
        self.mean = torch.mean(data, dim=0)
        self.std = torch.std(data, dim=0)

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data, log_pt=False):
        """
        :param data: torch tensor
        :param log_pt: undo log transformation on pt
        """
        inverse = (data * self.std) + self.mean
        if log_pt:
            inverse[:,0] = (10 ** inverse[:,0]) - 1
        return inverse

def standardize(train_dataset,log_pt=False):
    """
    standardize dataset and return scaler for inversion
    :param train_dataset: list of Data objects
    :param valid_dataset: list of Data objects
    :param test_dataset: list of Data objects
    :param log_pt: log pt before standardization
    :return scaler: sklearn StandardScaler
    """
    train_x = torch.cat([d.x for d in train_dataset])
    if log_pt:
        train_x[:,0] = torch.log(train_x[:,0] + 1)

    scaler = Standardizer()
    scaler.fit(train_x)
    for d in train_dataset:
        d.x[:,:] = scaler.transform(d.x)
    return scaler

In [78]:
loader = DataLoader(datas, batch_size=128)
#scaler = standardize(datas)


In [79]:
def xyze_to_ptetaphi_torch(y):
    ''' converts an array [N x 100, 4] of particles
from px, py, pz, E to pt,eta, phi
    '''
    PX, PY, PZ, E = range(4)
    pt = torch.sqrt(torch.pow(y[:,PX], 2) + torch.pow(y[:,PY], 2)) 
    eta = torch.asinh(torch.where(pt < 10e-5, torch.zeros_like(pt), torch.div(y[:,PZ], pt)))
    phi = torch.atan2(y[:,PY], y[:,PX])

    relu =  m = nn.ReLU() #inplace=True
    y_E_trimmed = relu(y[:,-1]) #trimming E
    y_pt_trimmed = relu(pt) #trimming pt
    full_y = torch.stack((y[:,0],y[:,1],y[:,2],y_E_trimmed,y_pt_trimmed,eta,phi), dim=1)

    return full_y


class LossFunction:
    def __init__(self, lossname, device=torch.device('cuda:0')):
        loss = getattr(self, lossname)
        self.name = lossname
        self.loss_ftn = loss
        self.device = device
    def mse(self, x, y):
        return F.mse_loss(x, y, reduction='mean')
    
    def mse_coordinates(self, y,x): #for some reason convension is : out,in
        #From px,py,pz,E get pt, eta, phi (do not predict them)
        #x is px,py,pz,E,pt,eta,phi
        #y is px,py,pz,E
        full_y = xyze_to_ptetaphi_torch(y)
        return self.mse(x,full_y)
        

In [80]:
#loss
loss_ftn_obj = LossFunction('mse_coordinates', device=device)

# model
input_dim = 7
output_dim = 4
big_dim = 32
hidden_dim = 2
model = EdgeNet(input_dim=input_dim,output_dim=output_dim, big_dim=big_dim, hidden_dim=hidden_dim)

optimizer = torch.optim.Adam(model.parameters(), lr = 10e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4, threshold=1e-6)

model.to(device)


EdgeNet(
  (batchnorm): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder): EdgeConv(nn=Sequential(
    (0): Linear(in_features=14, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=2, bias=True)
    (5): ReLU()
  ))
  (decoder): EdgeConv(nn=Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=4, bias=True)
  ))
)

In [81]:
# Training loop
n_epochs = 10
stale_epochs = 0
loss = 999999
train_losses = []
for epoch in range(0, n_epochs):
    loss = train(model, optimizer, loader, len(datas), 128, loss_ftn_obj)
    train_losses.append(loss)
    print('Epoch: {:02d}, Training Loss:   {:.4f}'.format(epoch, loss))
















  0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1848.0806885:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1848.0806885:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1256.9459229:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1256.9459229:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1724.9876709:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1724.9876709:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1724.9876709:  24%|██▍       | 3/12.421875 [00:00<00:00, 20.28it/s][A[A[

torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])

















train loss = 1580.2336426:  24%|██▍       | 3/12.421875 [00:00<00:00, 20.28it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1580.2336426:  24%|██▍       | 3/12.421875 [00:00<00:00, 20.28it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1946.8026123:  24%|██▍       | 3/12.421875 [00:00<00:00, 20.28it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1946.8026123:  24%|██▍       | 3/12.421875 [00:00<00:00, 20.28it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1946.8026123:  48%|████▊     | 6/12.421875 [00:00<00:00, 20.49it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1421.6152344:  48%|████▊     | 6/12.421875 [00:00<00:00, 20.49it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1421.6152344:  48%|████▊     | 6/12.421875 [00:00<00:00, 20.49it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train

torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])

















train loss = 1761.4467773:  64%|██████▍   | 8/12.421875 [00:00<00:00, 20.24it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1761.4467773:  64%|██████▍   | 8/12.421875 [00:00<00:00, 20.24it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1202.4370117:  64%|██████▍   | 8/12.421875 [00:00<00:00, 20.24it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1202.4370117:  64%|██████▍   | 8/12.421875 [00:00<00:00, 20.24it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1202.4370117:  89%|████████▊ | 11/12.421875 [00:00<00:00, 20.09it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1467.0510254:  89%|████████▊ | 11/12.421875 [00:00<00:00, 20.09it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1467.0510254:  89%|████████▊ | 11/12.421875 [00:00<00:00, 20.09it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














tr

torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 00, Training Loss:   1570.3243
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])


train loss = 1656.0998535:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1656.0998535:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1057.6140137:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1057.6140137:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1425.3988037:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1425.3988037:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1425.3988037:  24%|██▍       | 3/12.421875 [00:00<00:00, 21.48it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1308.1002197:  24%|██▍       | 3/12.421875 [00:00<00:00

torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])

















train loss = 1573.3800049:  24%|██▍       | 3/12.421875 [00:00<00:00, 21.48it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1573.3800049:  24%|██▍       | 3/12.421875 [00:00<00:00, 21.48it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1573.3800049:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 957.7961426:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.31it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 957.7961426:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1206.6854248:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1206.6854248:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train 

torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])
torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])


train loss = 786.6390991:  72%|███████▏  | 9/12.421875 [00:00<00:00, 21.26it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 786.6390991:  72%|███████▏  | 9/12.421875 [00:00<00:00, 21.26it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1019.4121094:  72%|███████▏  | 9/12.421875 [00:00<00:00, 21.26it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1019.4121094:  72%|███████▏  | 9/12.421875 [00:00<00:00, 21.26it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1019.4121094:  97%|█████████▋| 12/12.421875 [00:00<00:00, 21.52it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 922.4370117:  97%|█████████▋| 12/12.421875 [00:00<00:00, 21.52it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 922.4370117:  97%|█████████▋| 12/12.421875 [00:00<00:00, 21.52it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 922.43

torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 01, Training Loss:   1170.9129
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])

















train loss = 923.0444946:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.47it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 923.0444946:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.47it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 708.1799927:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.47it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 708.1799927:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.47it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1260.7274170:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.47it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1260.7274170:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.47it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1260.7274170:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.02it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train los

torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])


[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1014.9131470:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.02it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 882.0192261:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.02it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 882.0192261:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.02it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 882.0192261:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.61it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 874.6762085:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.61it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 874.6762085:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.61it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 724.6170654:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.61it/s][A[A[A[A[A[A[A[A

torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])
torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 02, Training Loss:   883.9467

















train loss = 1024.3964844:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1024.3964844:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 699.8557129:   0%|          | 0/12.421875 [00:00<?, ?it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 699.8557129:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 859.7333374:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 859.7333374:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 859.7333374:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.77it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 908.3608398:  24%|██▍       | 3/12.421875 [0

torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])

















train loss = 1227.2080078:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.77it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1227.2080078:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.77it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1227.2080078:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.69it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 632.7911987:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.69it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 632.7911987:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.69it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 999.5585938:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.69it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 999.5585938:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.69it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train lo

torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])
torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])

















train loss = 714.5312500:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.43it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 714.5312500:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.43it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 977.7271118:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.43it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 977.7271118:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.43it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 977.7271118:  97%|█████████▋| 12/12.421875 [00:00<00:00, 22.25it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 854.9405518:  97%|█████████▋| 12/12.421875 [00:00<00:00, 22.25it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 854.9405518:  97%|█████████▋| 12/12.421875 [00:00<00:00, 22.25it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train los

torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 03, Training Loss:   870.1319
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])

















train loss = 849.1717529:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 849.1717529:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 849.1717529:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.12it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 896.6448364:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.12it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 896.6448364:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.12it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 661.3188477:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.12it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 661.3188477:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.12it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1222.8542480:  

torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])

















train loss = 998.5823364:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 998.5823364:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 863.6795044:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 863.6795044:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.31it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 863.6795044:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.09it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 863.0319214:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.09it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 863.0319214:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.09it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss =

torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])
torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])

















train loss = 850.4773560:  97%|█████████▋| 12/12.421875 [00:00<00:00, 23.03it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 850.4773560:  97%|█████████▋| 12/12.421875 [00:00<00:00, 23.03it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 850.4773560: : 13it [00:00, 23.71it/s]                             [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














  0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1003.6890869:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1003.6890869:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 684.5257568:   0%|          | 0/12.421875 [00:00<?, ?it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 684.5257568:   0%|          | 0/12.421875 [00:00<?,

torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 04, Training Loss:   862.1106
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])

















train loss = 654.7463989:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.85it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 654.7463989:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.85it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1211.7498779:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.85it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1211.7498779:  24%|██▍       | 3/12.421875 [00:00<00:00, 23.85it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1211.7498779:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.62it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 620.6051636:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.62it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 620.6051636:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.62it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train lo

torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])

















train loss = 859.1433105:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.36it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 859.1433105:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.36it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 704.8620605:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.36it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 704.8620605:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.36it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 963.8912964:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.36it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 963.8912964:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.36it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 963.8912964:  97%|█████████▋| 12/12.421875 [00:00<00:00, 24.68it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss 

torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 05, Training Loss:   856.4478
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])

















train loss = 675.4831543:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 675.4831543:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 836.1130371:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 836.1130371:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 836.1130371:  24%|██▍       | 3/12.421875 [00:00<00:00, 24.97it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 880.2844238:  24%|██▍       | 3/12.421875 [00:00<00:00, 24.97it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 880.2844238:  24%|██▍       | 3/12.421875 [00:00<00:00, 24.97it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 647.9483032:  24%|██▍       | 3

torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])

















train loss = 974.7557983:  48%|████▊     | 6/12.421875 [00:00<00:00, 24.55it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 974.7557983:  48%|████▊     | 6/12.421875 [00:00<00:00, 24.55it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 846.1830444:  48%|████▊     | 6/12.421875 [00:00<00:00, 24.55it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 846.1830444:  48%|████▊     | 6/12.421875 [00:00<00:00, 24.55it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 846.1830444:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.53it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 843.9046021:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.53it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 843.9046021:  72%|███████▏  | 9/12.421875 [00:00<00:00, 23.53it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss =

torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])
torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])

















train loss = 815.9417725:  97%|█████████▋| 12/12.421875 [00:00<00:00, 23.63it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 815.9417725:  97%|█████████▋| 12/12.421875 [00:00<00:00, 23.63it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 815.9417725: : 13it [00:00, 24.07it/s]                             [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














  0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 972.6502075:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 972.6502075:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 656.4061890:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 656.4061890:   0%|          | 0/12.421875 [00:00<?, ?i

torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 06, Training Loss:   841.5685
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])

















train loss = 649.8989868:  24%|██▍       | 3/12.421875 [00:00<00:00, 24.25it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 649.8989868:  24%|██▍       | 3/12.421875 [00:00<00:00, 24.25it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1135.5708008:  24%|██▍       | 3/12.421875 [00:00<00:00, 24.25it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1135.5708008:  24%|██▍       | 3/12.421875 [00:00<00:00, 24.25it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1135.5708008:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.54it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 616.8605347:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.54it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 616.8605347:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.54it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train lo

torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])


[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 832.4316406:  48%|████▊     | 6/12.421875 [00:00<00:00, 23.54it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 832.4316406:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.89it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 837.4322510:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.89it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 837.4322510:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.89it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 665.6111450:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.89it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 665.6111450:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.89it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 888.2235718:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.89it/s][A[A[A[A[A[A[A[A[A

torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 07, Training Loss:   819.5740
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])

















train loss = 640.3950806:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 640.3950806:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 828.4776611:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 828.4776611:   0%|          | 0/12.421875 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 828.4776611:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 836.1513672:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 836.1513672:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 686.2998657:  24%|██▍       | 3

torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])
torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])

















train loss = 1069.1783447:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1069.1783447:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1069.1783447:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.70it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 645.4926758:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.70it/s] [A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 645.4926758:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.70it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 896.6610107:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.70it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 896.6610107:  48%|████▊     | 6/12.421875 [00:00<00:00, 21.70it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train lo

torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])
torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])
torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])

















train loss = 650.2779541:  64%|██████▍   | 8/12.421875 [00:00<00:00, 21.08it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 650.2779541:  64%|██████▍   | 8/12.421875 [00:00<00:00, 21.08it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 650.2779541:  89%|████████▊ | 11/12.421875 [00:00<00:00, 21.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 848.3253784:  89%|████████▊ | 11/12.421875 [00:00<00:00, 21.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 848.3253784:  89%|████████▊ | 11/12.421875 [00:00<00:00, 21.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 758.0457764:  89%|████████▊ | 11/12.421875 [00:00<00:00, 21.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 758.0457764:  89%|████████▊ | 11/12.421875 [00:00<00:00, 21.04it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train l

torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 08, Training Loss:   807.6007
torch.Size([5075, 4])
torch.Size([5075, 7])
torch.Size([5075, 7])
torch.Size([5542, 4])
torch.Size([5542, 7])
torch.Size([5542, 7])
torch.Size([5388, 4])
torch.Size([5388, 7])
torch.Size([5388, 7])

















train loss = 829.9236450:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.88it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 829.9236450:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.88it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 712.2864990:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.88it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 712.2864990:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.88it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1043.4344482:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.88it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1043.4344482:  24%|██▍       | 3/12.421875 [00:00<00:00, 22.88it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 1043.4344482:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.72it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train los

torch.Size([5178, 4])
torch.Size([5178, 7])
torch.Size([5178, 7])
torch.Size([5572, 4])
torch.Size([5572, 7])
torch.Size([5572, 7])
torch.Size([5211, 4])
torch.Size([5211, 7])
torch.Size([5211, 7])
torch.Size([5489, 4])
torch.Size([5489, 7])
torch.Size([5489, 7])
torch.Size([5469, 4])
torch.Size([5469, 7])
torch.Size([5469, 7])

















train loss = 834.7162476:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.72it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 834.7162476:  48%|████▊     | 6/12.421875 [00:00<00:00, 22.72it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 834.7162476:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.71it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 861.1276855:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.71it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 861.1276855:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.71it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 647.5491333:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.71it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss = 647.5491333:  72%|███████▏  | 9/12.421875 [00:00<00:00, 22.71it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A














train loss =

torch.Size([5537, 4])
torch.Size([5537, 7])
torch.Size([5537, 7])
torch.Size([5205, 4])
torch.Size([5205, 7])
torch.Size([5205, 7])
torch.Size([5463, 4])
torch.Size([5463, 7])
torch.Size([5463, 7])
torch.Size([5633, 4])
torch.Size([5633, 7])
torch.Size([5633, 7])
torch.Size([2227, 4])
torch.Size([2227, 7])
torch.Size([2227, 7])
Epoch: 09, Training Loss:   806.2351


In [82]:
@torch.no_grad()
def gen_in_out(model, loader, device):
    model.eval()
    input_fts = []
    reco_fts = []

    for t in loader:
        if isinstance(t, list):
            for d in t:
                input_fts.append(d.x)
        else:
            input_fts.append(t.x)
            t.to(device)

        reco_out = model(t)
        if isinstance(reco_out, tuple):
            reco_out = reco_out[0]
        reco_fts.append(reco_out.cpu().detach())

    input_fts = torch.cat(input_fts)
    reco_fts = torch.cat(reco_fts)
    return input_fts, reco_fts

def plot_reco_for_loader(model, loader, device, scaler, inverse_scale, model_fname, save_dir, feature_format):
    input_fts, reco_fts = gen_in_out(model, loader, device)
    if inverse_scale:
        input_fts = scaler.inverse_transform(input_fts)
        reco_fts = scaler.inverse_transform(reco_fts)
    plot_reco_difference(input_fts, reco_fts, model_fname, save_dir, feature_format)

    
def plot_reco_difference(input_fts, reco_fts, model_fname, save_path, feature='hadronic'):
    """
    Plot the difference between the autoencoder's reconstruction and the original input
    Args:
        input_fts (numpy array): the original features of the particles
        reco_fts (numpy array): the reconstructed features
        model_fname (str): name of saved model
    """
    
    if isinstance(input_fts, torch.Tensor):
        input_fts = input_fts.numpy()
    if isinstance(reco_fts, torch.Tensor):
        if feature == 'all':
            reco_fts = xyze_to_ptetaphi_torch(reco_fts)
        reco_fts = reco_fts.numpy()

        
    Path(save_path).mkdir(parents=True, exist_ok=True)
  #  label = ['$p_x~[GeV]$', '$p_y~[GeV]$', '$p_z~[GeV]$']
   # feat = ['px', 'py', 'pz']
    label = ['$p_x~[GeV]$', '$p_y~[GeV]$', '$p_z~[GeV]$']
    feat = ['px', 'py', 'pz']
    if feature == 'hadronic':# or 'standardized':
        label = ['$p_T$', '$eta$', '$phi$']
        feat = ['pt', 'eta', 'phi']
        
    if feature == 'all':# or 'standardized':
        label = ['$p_x~[GeV]$', '$p_y~[GeV]$', '$p_z~[GeV]$', '$E~[GeV]$','$p_T$', '$eta$', '$phi$']
        feat = ['px', 'py', 'pz','E','pt', 'eta', 'phi']
        
    # make a separate plot for each feature
    for i in range(input_fts.shape[1]):
        #plt.style.use(hep.style.CMS)
        plt.figure(figsize=(10,8))
        if feature == 'cartesian':
            bins = np.linspace(-20, 20, 101)
            if i == 3:  # different bin size for E momentum
                bins = np.linspace(-5, 35, 101)
        elif feature == 'hadronic':
            bins = np.linspace(-2, 2, 101)
            if i == 0:  # different bin size for pt rel
                bins = np.linspace(-0.05, 0.1, 101)
        elif feature == 'all':
            bins = np.linspace(-20, 20, 101)
            if i > 3:  # different bin size for hadronic coord
                bins = np.linspace(-2, 2, 101)
            if i == 3:  # different bin size for E momentum
                bins = np.linspace(-5, 35, 101)
            if i == 4:  # different bin size for pt rel
                bins = np.linspace(-2, 10, 101)
        else:
            bins = np.linspace(-1, 1, 101)
        plt.ticklabel_format(useMathText=True)
        plt.hist(input_fts[:,i], bins=bins, alpha=0.5, label='Input', histtype='step', lw=5)
        plt.hist(reco_fts[:,i], bins=bins, alpha=0.5, label='Output', histtype='step', lw=5)
        plt.legend(title='QCD dataset', fontsize='x-large')
        plt.xlabel(label[i], fontsize='x-large')
        plt.ylabel('Particles', fontsize='x-large')
        plt.tight_layout()
        plt.savefig(osp.join(save_path, feat[i] + '.pdf'))
        plt.close()
    

In [83]:
inverse_standardization = False
save_dir = '/eos/user/n/nchernya/MLHEP/AnomalyDetection/ADgvae/output_models/pytroch/'
plot_reco_for_loader(model, loader, device, scaler, inverse_standardization, 'test_train', osp.join(save_dir, 'reconstruction_post_train', 'train'), 'all')
