In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from nflows.distributions.base import Distribution
from sklearn.model_selection import train_test_split
from torch import optim
from torch.utils.data import DataLoader

from helpers import *


class StandardScaler:

    def __init__(self, mean=None, std=None, epsilon=1e-7):
        """Standard Scaler.
        The class can be used to normalize PyTorch Tensors using native functions. The module does not expect the
        tensors to be of any specific shape; as long as the features are the last dimension in the tensor, the module
        will work fine.
        :param mean: The mean of the features. The property will be set after a call to fit.
        :param std: The standard deviation of the features. The property will be set after a call to fit.
        :param epsilon: Used to avoid a Division-By-Zero exception.
        """
        self.mean = mean
        self.std = std
        self.epsilon = epsilon

    def fit(self, values):
        dims = list(range(values.dim() - 1))
        self.mean = torch.mean(values, dim=dims)
        self.std = torch.std(values, dim=dims)

    def transform(self, values):
        return (values - self.mean) / (self.std + self.epsilon)
    def inverse_transform(self,values):
        return (values *self.std)+self.mean
    def fit_transform(self, values):
        self.fit(values)
        return self.transform(values)
    def to(self,dev):
        self.std=self.std.to(dev)
        self.mean=self.mean.to(dev)
        return self
  
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader


class JetNetDataloader(pl.LightningDataModule):
    '''This is more or less standard boilerplate coded that builds the data loader of the training
       one thing to note is the custom standard scaler that works on tensors
       Currently only jets with 30 particles are used but this maybe changes soon'''
    def __init__(self,config,batch_size): 
        super().__init__()
        self.config=config
        self.n_dim=config["n_dim"]
        self.n_part=config["n_part"]
        self.batch_size=batch_size
    def setup(self,stage):
    # This just sets up the dataloader, nothing particularly important. it reads in a csv, calculates mass and reads out the number particles per jet
    # And adds it to the dataset as variable. The only important thing is that we add noise to zero padded jets
        data_dir=os.environ["HOME"]+"/JetNet_NF/train_{}_jets.csv".format(self.config["parton"])
        data=pd.read_csv(data_dir,sep=" ",header=None)
        jets=[]
        limit=int(self.config["limit"]*1.1)
        for njets in range(1,31):
            masks=np.sum(data.values[:,np.arange(3,120,4)],axis=1)
            df=data.loc[masks==njets,:]
            df=df.drop(np.arange(3,120,4),axis=1)
            df["n"]=njets
            if len(df)>100:
                jets.append(df[:self.config["limit"]])
        #stacking together differnet samples with different number particles per jet
        self.n=torch.empty((0,1))
        self.data=torch.empty((0,self.n_dim*self.n_part))
        for i in range(len(jets)):
            x=torch.tensor(jets[i].values[:,:self.n_dim*self.n_part]).float()
            n=torch.tensor(jets[i]["n"].values).float()
            self.data=torch.vstack((self.data,x))
            self.n=torch.vstack((self.n.reshape(-1,1),n.reshape(-1,1)))        
        
      
        # calculating mass per jet
#         self.m=mass(self.data[:,:self.n_dim]).reshape(-1,1)  
      # Adding noise to zero padded jets.
        for i in torch.unique(self.n):
            i=int(i)
            self.data[self.data[:,-1]==i,3*i:90]=torch.normal(mean=torch.zeros_like(self.data[self.data[:,-1]==i,3*i:90]),std=1).abs()*1e-7
        #standard scaling 
        self.scaler=StandardScaler()
#         self.data=torch.hstack((self.data,self.m))        
        self.scaler.fit(self.data)
        self.data=self.scaler.transform(self.data)
#         self.min_m=self.scaler.transform(torch.zeros((1,self.n_dim+1)))[0,-1]
# #         self.data=torch.hstack((self.data,self.n))
        
#         #calculating mass dist in different bins, this is needed for the testcase where we need to generate the conditoon
#         if self.config["variable"]:
#             self.mdists={}
#             for i in torch.unique(self.n):
#                 self.mdists[int(i)]=F(self.data[self.n[:,0]==i,-2])    
        self.data,self.test_set=train_test_split(self.data.cpu().numpy(),test_size=0.3)
        
#         self.n_train=self.data[:,-1]
#         self.n_test=self.test_set[:,-1]
        
            
        self.test_set=torch.tensor(self.test_set).float()
        self.data=torch.tensor(self.data).float()
        self.num_batches=len(self.data)//self.config["batch_size"]
#         assert self.data.shape[1]==92
        assert (torch.isnan(self.data)).sum()==0
    def train_dataloader(self):
        return DataLoader(self.data, batch_size=self.batch_size,drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.test_set, batch_size=len(self.test_set),drop_last=True)


thx max
good boy


In [58]:

import traceback
import os
import nflows as nf
from nflows.utils.torchutils import create_random_binary_mask
from nflows.transforms.base import CompositeTransform
from nflows.transforms.coupling import *
from nflows.nn import nets
from nflows.flows.base import Flow
from nflows.flows import base
from nflows.transforms.coupling import *
from nflows.transforms.autoregressive import *
from particle_net import ParticleNet
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR,ReduceLROnPlateau,ExponentialLR
import torch
from torch import nn
from torch.nn import functional as FF
import numpy as np
from jetnet.evaluation import w1p, w1efp, w1m, cov_mmd,fpnd
import mplhep as hep
import hist
from hist import Hist
from pytorch_lightning.loggers import TensorBoardLogger
from collections import OrderedDict
from ray import tune
from helpers import *
from plotting import *
import pandas as pd
import os
from helpers import CosineWarmupScheduler
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
import time
from torch.nn.functional import leaky_relu,sigmoid
from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
class Gen(nn.Module):
    
    def __init__(self,n_dim=3,l_dim=10,hidden=300,num_layers=3,num_heads=1,n_part=5,fc=False,dropout=0.5,no_hidden=True):
        super().__init__()
        self.hidden_nodes=hidden
        self.n_dim=n_dim
        self.l_dim=l_dim
        self.n_part=n_part
        self.no_hidden=no_hidden
       
        self.fc=fc
        if fc:
            self.l_dim*=n_part 
            self.embbed_flat=nn.Linear(n_dim*n_part,l_dim)
            self.flat_hidden=nn.Linear(l_dim,hidden)
            self.flat_hidden2=nn.Linear(hidden,hidden)
            self.flat_hidden3=nn.Linear(hidden,hidden)
            self.flat_out=nn.Linear(hidden,n_dim*n_part)
        else:
            self.embbed=nn.Linear(n_dim,l_dim)
            self.encoder=nn.TransformerEncoder(nn.TransformerEncoderLayer(
            d_model=l_dim,nhead=num_heads,batch_first=True,norm_first=False
            ,dim_feedforward=hidden,dropout=dropout) ,num_layers=num_layers)
            self.hidden=nn.Linear(l_dim,hidden)
            self.hidden2=nn.Linear(hidden,hidden)
            self.dropout=nn.Dropout(dropout/2)
            self.out=nn.Linear(hidden, n_dim)
            self.out2=nn.Linear(l_dim, n_dim)

            self.out_flat=nn.Linear(hidden,n_dim*n_part )
        
    def forward(self,x):

        if self.fc:
            x=x.reshape(len(x),self.n_part*self.n_dim)
            x=self.embbed_flat(x)
            x=leaky_relu(self.flat_hidden(x))
#             x = self.dropout(x)
            x=self.flat_out(x)
            x=x.reshape(len(x),self.n_part,self.n_dim)
        else:
            x=self.embbed(x)
            x=self.encoder(x)
            if not self.no_hidden:
                
                x=leaky_relu(self.hidden(x))
                x=self.dropout(x)
                x=leaky_relu(self.hidden2(x))
                x=self.dropout(x)
                x=self.out(x)
            else:
                x=leaky_relu(x)
                x=self.out2(x)
        return x

class Disc(nn.Module):
    def __init__(self,n_dim=3,l_dim=10,hidden=300,num_layers=3,num_heads=1,n_part=2,fc=False,dropout=0.5,mass=False,clf=False):
        super().__init__()
        self.hidden_nodes=hidden
        self.n_dim=n_dim
#         l_dim=n_dim
        self.l_dim=l_dim
        self.n_part=n_part
        self.fc=fc
        self.clf=clf
        
        if fc:
            self.l_dim*=n_part 
            self.embbed_flat=nn.Linear(n_dim*n_part,l_dim)
            self.flat_hidden=nn.Linear(l_dim,hidden)
            self.flat_hidden2=nn.Linear(hidden,hidden)
            self.flat_hidden3=nn.Linear(hidden,hidden)
            self.flat_out=nn.Linear(hidden,1)
        else:
            self.embbed=nn.Linear(n_dim,l_dim)
            self.encoder=nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=self.l_dim,nhead=num_heads,dim_feedforward=hidden,dropout=dropout,norm_first=False,
                                       activation=lambda x:leaky_relu(x,0.2),batch_first=True) ,num_layers=num_layers)
            self.hidden=nn.Linear(l_dim+int(mass),2*hidden)
            self.hidden2=nn.Linear(2*hidden,hidden)
            self.out=nn.Linear(hidden,1)

    def forward(self,x,m=None):

        if self.fc==True:
            x=x.reshape(len(x),self.n_dim*self.n_part)
            x=self.embbed_flat(x)
            x=leaky_relu(self.flat_hidden(x),0.2)
            x=leaky_relu(self.flat_hidden2(x),0.2)
            x=self.flat_out(x)
        else:
            x=self.embbed(x)
            if self.clf:
                x=torch.concat((torch.ones_like(x[:,0,:]).reshape(len(x),1,-1),x),axis=1)
                x=self.encoder(x)
                x=x[:,0,:]
            else:
                x=self.encoder(x)
                x=torch.sum(x,axis=1)
            if m is not None:
                x=torch.concat((m.reshape(len(x),1),x),axis=1)

            x=leaky_relu(self.hidden(x),0.2)
            
            x=leaky_relu(self.hidden2(x),0.2)
            
            x=self.out(x)
            x=x
        return x
      
class TransGan(pl.LightningModule):
    def create_resnet(self,in_features, out_features):
        '''This is the network that outputs the parameters of the invertible transformation
        The only arguments can be the in dimension and the out dimenson, the structure
        of the network is defined over the config which is a class attribute
        Context Features: Amount of features used to condition the flow - in our case 
        this is usually the mass
        num_blocks: How many Resnet blocks should be used, one res net block is are 1 input+ 2 layers
        and an additive skip connection from the first to the third'''
        c=self.config["context_features"]
        return nets.ResidualNet(
                in_features,
                out_features,
                hidden_features=self.config["network_nodes_nf"],
                context_features=c,
                num_blocks=self.config["network_layers_nf"],
                activation=self.config["activation"]  if "activation" in self.config.keys() else FF.relu,
                #dropout_probability=self.config["dropout"] if "dropout" in self.config.keys() else 0,
                use_batch_norm=self.config["batchnorm"] if "batchnorm" in self.config.keys() else 0
        )

    def __init__(self,config,hyperopt,num_batches):
        '''This initializes the model and its hyperparameters'''
        super().__init__()
        self.hyperopt=True
        
        self.start=time.time()
        # self.batch_size=batch_size
        # print(batch_size)
        self.config=config
        self.automatic_optimization=False
        self.freq_d=config["freq"]
        
        self.wgan=config["wgan"]
        #Metrics to track during the training
        self.metrics={"val_w1p":[],"val_w1m":[],"val_w1efp":[],"val_cov":[],"val_mmd":[],"val_fpnd":[],"val_logprob":[],"step":[]}
        #Loss function of the Normalizing flows
        self.logprobs=[]
        self.n_part=config["n_part"]
        # self.hparams.update(config)
        self.save_hyperparameters()
        self.flows = []
        self.n_dim=self.config["n_dim"]
        self.n_part=config["n_part"]
        self.add_corr=config["corr"]
        self.alpha=1
        self.num_batches=int(num_batches)
        K=self.config["coupling_layers"]
        for i in range(K):
            '''This creates the masks for the coupling layers, particle masks are masks
            created such that each feature particle (eta,phi,pt) is masked together or not'''
           
            if self.config["autoreg"]:
                self.flows += [MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
#                         random_mask=True,
                        features=self.n_dim,
                        hidden_features=128,
                        use_residual_blocks=True, 
                        tails='linear',
                        tail_bound=self.config["tail_bound"],
                        num_bins=self.config["bins"] )]
            else:
                mask=create_random_binary_mask(self.n_dim*self.n_part)            
                self.flows += [PiecewiseRationalQuadraticCouplingTransform(
                            mask=mask,
                            transform_net_create_fn= self.create_resnet, 
                            tails='linear',
                            tail_bound=self.config["tail_bound"],
                            num_bins=self.config["bins"] )]

        self.q0 = nf.distributions.normal.StandardNormal([self.n_dim*self.n_part])
        self.q_test =nf.distributions.normal.StandardNormal([self.n_dim*self.n_part])
        #Creates working flow model from the list of layer modules
        self.flows=CompositeTransform(self.flows)
        # Construct flow model
        self.flow_test= base.Flow(distribution=self.q_test, transform=self.flows)
        self.flow = base.Flow(distribution=self.q0, transform=self.flows)
        

        self.gen_net = Gen(n_dim=self.n_dim,hidden=config["hidden"],num_layers=config["num_layers"],dropout=config["dropout"],no_hidden=config["no_hidden"],fc= config["fc"],n_part=config["n_part"],l_dim=config["l_dim"],num_heads=config["heads"]).cuda()
        self.dis_net = Disc(n_dim=self.n_dim,hidden=config["hidden"],l_dim=config["l_dim"],num_layers=config["num_layers"],mass=self.config["mass"],  num_heads=config["heads"],fc=config["fc"],n_part=config["n_part"],dropout=config["dropout"],clf=config["clf"]).cuda()
        self.sig=nn.Sigmoid()
        for p in self.dis_net.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal(p)
        self.nf_train=True
        self.train_nf=config["max_epochs"]//40
    

    def load_datamodule(self,data_module):
        '''needed for lightning training to work, it just sets the dataloader for training and validation'''
        self.data_module=data_module
        
    def on_after_backward(self) -> None:
        '''This is a genious little hook, sometimes my model dies, i have no clue why. This saves the training from crashing and continues'''
        valid_gradients = False
        for name, param in self.named_parameters():
            if param.grad is not None:
                valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
                if not valid_gradients:
                    break
        if not valid_gradients:
            print("not valid grads",self.counter)
            self.zero_grad()
            self.counter+=1
            if self.counter>5:
                raise ValueError('5 nangrads in a row')
        else:
            self.counter=0

    def sampleandscale(self,batch,scale=False):
        '''This is a helper function that samples from the flow (i.e. generates a new sample) 
            and reverses the standard scaling that is done in the preprocessing. This allows to calculate the mass
            on the generative sample and to compare to the simulated one, we need to inverse the scaling before calculating the mass
            because calculating the mass is a non linear transformation and does not commute with the mass calculation''' 
        z=self.flow.sample(len(batch)).reshape(len(batch),self.n_part,self.n_dim)
        if self.add_corr: 
            fake=z+self.gen_net(z)#(1-self.alpha)*
            fake=fake.reshape(len(batch),self.n_part,self.n_dim)
        else:
            fake=self.gen_net(z)
        assert batch.device==fake.device

        if scale:
            self.data_module.scaler=self.data_module.scaler.to(batch.device)

            fake_scaled=self.data_module.scaler.inverse_transform(fake.reshape(len(batch),self.n_dim*self.n_part))
            z_scaled=self.data_module.scaler.inverse_transform(z.reshape(len(batch),self.n_dim*self.n_part))
            true=self.data_module.scaler.inverse_transform(batch)
            return fake,batch,z,fake_scaled,true,z_scaled
        else:
            return fake
        
    def configure_optimizers(self):
        
        self.losses=[]
        
        #mlosses are initialized with None during the time it is not turned on, makes it easier to plot
        
        opt_nf = torch.optim.AdamW(self.flow.parameters(), lr=self.config["lr_nf"] )
        if self.config["opt"]=="Adam":

            opt_g = torch.optim.Adam(self.gen_net.parameters(), lr=self.config["lr_g"],betas=(0,.9))
            opt_d = torch.optim.Adam(self.dis_net.parameters(), lr=self.config["lr_d"],betas=(0,.9))
        elif self.config["opt"]=="AdamW":
            opt_g = torch.optim.AdamW(self.gen_net.parameters(), lr=self.config["lr_g"],betas=(0,.9))
            opt_d = torch.optim.AdamW(self.dis_net.parameters(), lr=self.config["lr_d"],betas=(0,.9))
        elif self.config["opt"]=="SGD":
            opt_g = torch.optim.SGD(self.gen_net.parameters(), lr=self.config["lr_g"])
            opt_d = torch.optim.SGD(self.dis_net.parameters(), lr=self.config["lr_d"])
        else:
            opt_g = torch.optim.RMSprop(self.gen_net.parameters(), lr=self.config["lr_g"])
            opt_d = torch.optim.RMSprop(self.dis_net.parameters(), lr=self.config["lr_d"])
        if self.config["sched"]=="cosine":   
            lr_scheduler_nf=CosineWarmupScheduler(opt_nf,warmup=1,max_iters=10000000*self.config["freq"]) 
            lr_scheduler_d=CosineWarmupScheduler(opt_d,warmup=15*self.num_batches,max_iters=(self.config["max_epochs"]-self.train_nf//2)*self.num_batches)
            lr_scheduler_g=CosineWarmupScheduler(opt_g,warmup=15*self.num_batches,max_iters=(self.config["max_epochs"]-self.train_nf)*self.num_batches//self.freq_d)
        elif self.config["sched"]=="cosine2":   
            lr_scheduler_nf=CosineWarmupScheduler(opt_nf,warmup=1,max_iters=10000000*self.config["freq"]) 
            lr_scheduler_d=CosineWarmupScheduler(opt_d,warmup=15*self.num_batches,max_iters=(self.config["max_epochs"]-self.train_nf//2)*self.num_batches//2)
            lr_scheduler_g=CosineWarmupScheduler(opt_g,warmup=15*self.num_batches,max_iters=(self.config["max_epochs"]-self.train_nf)*self.num_batches//self.freq_d//2)
        else:
            lr_scheduler_nf =None 
            lr_scheduler_d =None    
            lr_scheduler_g =None
        if self.config["sched"] !=None:
            return  [opt_nf,opt_d,opt_g],[lr_scheduler_nf,lr_scheduler_d,lr_scheduler_g]
        else:
            return [opt_nf,opt_d,opt_g] 
    
    
    def compute_gradient_penalty(self,D, real_samples, fake_samples, phi):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake samples
        alpha = torch.Tensor(np.random.random((real_samples.size(0),1, 1))).to(real_samples.device)
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples))
        if self.config["mass"]:
            m=mass(interpolates.reshape(len(real_samples),self.n_part*self.n_dim).detach())
            d_interpolates = D.train()(interpolates.requires_grad_(True),m.requires_grad_(True) )
        else:
            d_interpolates = D.train()(interpolates.requires_grad_(True))
        fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(real_samples.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
        return gradient_penalty


    def _summary(self,temp):
        self.summary_path="/beegfs/desy/user/{}/{}/summary.csv".format(os.environ["USER"],self.config["name"])
        if os.path.isfile(self.summary_path):
            
            summary=pd.read_csv(self.summary_path).set_index(["path_index"])
        else:
            print("summary not found")
            summary=pd.DataFrame()

            
        summary.loc[self.logger.log_dir,self.config.keys()]=self.config.values()
        summary.loc[self.logger.log_dir,temp.keys()]=temp.values()
        summary.loc[self.logger.log_dir,"time"]=time.time()-self.start          
        summary.to_csv(self.summary_path,index_label=["path_index"])  
        return summary
    
    # def _results(self,temp):
    #     self.metrics["step"].append(self.global_step)
    #     self.df=pd.DataFrame.from_dict(temp,index=)
    #     self.df.to_csv(self.logger.log_dir+"result.csv",index_label=["index"])
    
   
    
    def training_step(self, batch, batch_idx):
        """training loop of the model, here all the data is passed forward to a gaussian
            This is the important part what is happening here. This is all the training we do """
        
        opt_nf,opt_d,opt_g=self.optimizers()
        if self.config["sched"]:
            sched_nf,sched_d,sched_g=self.lr_schedulers()
            

        nf_loss=0
        d_loss_avg=0
        gradient_penalty=0
        
        
        ### NF PART
        if self.config["sched"]!=None:
            self.log("lr_g",sched_g.get_last_lr()[-1],logger=True)
            self.log("lr_nf",sched_nf.get_last_lr()[-1],logger=True)
            self.log("lr_d",sched_d.get_last_lr()[-1],logger=True)
            
        if self.current_epoch<self.train_nf:
            if self.config["sched"]!=None:
                sched_nf.step()
            nf_loss -=self.flow.to(self.device).log_prob(batch).mean()#c if self.config["context_features"] else None
            nf_loss/=(self.n_dim*self.n_part) 
            opt_nf.zero_grad()
            self.manual_backward(nf_loss)
            opt_nf.step()
            self.log("logprob", nf_loss, on_step=True, on_epoch=False, prog_bar=True, logger=True) 

        ### GAN PART
        if self.current_epoch>=self.train_nf/2 or self.global_step==1:
            
            batch=batch.reshape(len(batch),self.n_part,self.n_dim)
            fake=self.sampleandscale(batch,scale=False)
#             fake=torch.rand((len(batch),self.n_part,self.n_dim)).cuda()+10
            if self.config["mass"]:
                m_t=mass(batch.reshape(len(batch),self.n_part*self.n_dim),self.config["canonical"])
                m_f=mass(fake.reshape(len(batch),self.n_part*self.n_dim),self.config["canonical"])

            # if self.current_epoch>pretrain/2:
                
            pred_real=self.dis_net(batch,None if not self.config["mass"] else m_t)
            pred_fake=self.dis_net(fake.detach(),None if not self.config["mass"] else m_f.detach())
            if self.wgan:
                gradient_penalty =  self.compute_gradient_penalty(self.dis_net, batch, fake.detach(),1)
                self.log("gradient penalty",gradient_penalty,logger=True)
                d_loss=-torch.mean(pred_real.view(-1))+torch.mean(pred_fake.view(-1))+self.config["lambda"]*gradient_penalty
            else:               
                target_real=torch.ones_like(pred_real)
                target_fake=torch.zeros_like(pred_fake)
                pred=torch.vstack((pred_real,pred_fake))
                target=torch.vstack((target_real,target_fake))
                d_loss=nn.MSELoss()(pred,target).mean()
            opt_d.zero_grad()
            self.manual_backward(d_loss)
            opt_d.step() if self.global_step>10 else opt_d.zero_grad()

            self.log("dloss",d_loss,logger=True,prog_bar=True)
            if self.global_step==2:
                print("passed test disc")
            #self.logger.experiment.add_scalars("d_losses",{"train_disc":d_loss_avg},global_step=self.global_step)
            if self.config["sched"]:
                sched_d.step()
            
            
            if (self.current_epoch>self.train_nf and self.global_step%self.freq_d<2) or self.global_step==2 :
                opt_g.zero_grad()
                if self.config["mass"]:
                    m_t=mass(batch.reshape(len(batch),self.n_part*self.n_dim),self.config["canonical"])
                    m_f=mass(fake.reshape(len(batch),self.n_part*self.n_dim),self.config["canonical"])
                pred_fake=self.dis_net(fake,None if not self.config["mass"] else m_f)
                target_real=torch.ones_like(pred_fake)
                if self.wgan:
                    g_loss=-torch.mean(pred_fake.view(-1))
                else:
                    g_loss=nn.MSELoss()((pred_fake.view(-1)),target_real.view(-1))
                self.manual_backward(g_loss)
#                 if self.global_step>10:
                opt_g.step() if self.global_step>10 else opt_g.zero_grad()
#                 else:
#                     opt_g.zero_grad()
                #self.logger.experiment.add_scalars("g_losses",{"train_gen":g_loss},global_step=self.global_step)
                self.log("g_loss",g_loss,logger=True,prog_bar=True)
                if self.config["sched"]:
                    sched_g.step()
                if self.global_step==3:
                    print("passed test gen")

            # Control plot train
            if self.current_epoch%config["val_check"]==0 and self.current_epoch>self.train_nf/2 :
                fig,ax=plt.subplots()
                ax.hist(pred_fake.detach().cpu().numpy(),label="fake",bins=np.linspace(0,1,30) if not self.wgan else 30,histtype='step')
                ax.hist(pred_real.detach().cpu().numpy(),label="real",bins=np.linspace(0,1,30) if not self.wgan else 30,histtype='step')
                ax.legend()
                plt.ylabel("Counts")
                plt.xlabel("Critic Score")
                self.logger.experiment.add_figure("class_train",fig,global_step=self.current_epoch)  
                plt.close()
        
    
    def validation_step(self, batch, batch_idx):
        '''This calculates some important metrics on the hold out set (checking for overtraining)'''

        self.dis_net.train()
        self.gen_net.train()
        self.data_module.scaler.to("cpu")  
        batch=batch.to("cpu")
        self.flow=self.flow.to("cpu")
        self.dis_net=self.dis_net.cpu()
        self.gen_net=self.gen_net.cpu()
        c=None       

        with torch.no_grad():
            logprob=-self.flow.log_prob(batch).mean()/90

            gen,true,z,fake_scaled,true_scaled,z_scaled=self.sampleandscale(batch,scale=True)
#             gen=torch.rand((len(batch),self.n_part,self.n_dim)).cpu()+10
            
            if self.config["mass"]:
                m_t=mass(batch.reshape(len(batch),self.n_part*self.n_dim),self.config["canonical"])
                m_f=mass(gen.reshape(len(batch),self.n_part*self.n_dim),self.config["canonical"])
            scores_fake=self.dis_net(gen,None if not self.config["mass"] else m_f)
            scores_real=self.dis_net(true.reshape(len(batch),self.n_part,self.n_dim),None if not self.config["mass"] else m_t)
            #scores_nf=self.dis_net(z.reshape(len(batch),self.n_part,self.n_dim))
        bins=50
        # if not self.wgan:
        #     bins=np.linspace(0,1,bins)
        #     #scores_nf=nn.Sigmoid()(scores_nf)
        #     scores_real=nn.Sigmoid()(scores_real)
        #     scores_fake=nn.Sigmoid()(scores_fake)
        #     real_loss=nn.BCELoss()(scores_real,torch.ones_like(scores_nf))
        #     fake_loss_gan=nn.BCELoss()(scores_fake,torch.zeros_like(scores_nf))
        #     #fake_loss_nf=nn.BCELoss()(scores_nf,torch.zeros_like(scores_nf))
        #     g_loss=nn.BCELoss()(scores_fake,torch.ones_like(scores_nf))
        #     num=2
        #     d_loss=(real_loss+fake_loss_gan)/num
        # else:
        #     d_loss=-torch.mean(scores_real.view(-1))+torch.mean(scores_fake.view(-1))#+10*gradient_penalty
        #     g_loss=-torch.mean(scores_fake.view(-1))
        
        fig=plt.figure()

        _,bins,_=plt.hist(scores_real.numpy(),bins=bins,label="MC simulated",alpha=0.5)
        plt.hist(scores_fake.numpy(),bins=bins,label="ML generated",alpha=0.5)
        plt.xlabel("Critic Score")
        plt.ylabel("Counts")
        plt.legend()
        self.logger.experiment.add_figure("class_val",fig,global_step=self.current_epoch)  
        plt.close()
        true_scaled,fake_scaled,z_scaled=true_scaled.reshape(-1,90),fake_scaled.reshape(-1,90),z_scaled.reshape(-1,90)
        # Reverse Standard Scaling (this has nothing to do with flows, it is a standard preprocessing step)
        m_t=mass(true_scaled[:,:self.n_dim*self.n_part].to(self.device),self.config["canonical"]).cpu()
        m_gen=mass(z_scaled[:,:self.n_dim*self.n_part],self.config["canonical"]).cpu()
        m_c=mass(fake_scaled[:,:self.n_dim*self.n_part],self.config["canonical"]).cpu()
        for i in range(30):
            i=2+3*i
            # gen[gen[:,i]<0,i]=0
            fake_scaled[fake_scaled[:,i]<0,i]=0
            true_scaled[true_scaled[:,i]<0,i]=0
        #Some metrics we track
        cov,mmd=cov_mmd(fake_scaled.reshape(-1,self.n_part,self.n_dim),true_scaled.reshape(-1,self.n_part,self.n_dim),use_tqdm=False)
        try:
            fpndv=fpnd(fake_scaled.reshape(-1,self.n_part,self.n_dim).numpy(),use_tqdm=False,jet_type=self.config["parton"])
        except:
            fpndv=1000
        self.metrics["val_fpnd"].append(fpndv)
        self.metrics["val_logprob"].append(logprob)
        self.metrics["val_mmd"].append(mmd)
        self.metrics["val_cov"].append(cov)
        self.metrics["val_w1p"].append(w1p(fake_scaled.reshape(len(batch),self.n_part,self.n_dim),true_scaled.reshape(len(batch),self.n_part,self.n_dim)))
        w1m_=w1m(fake_scaled.reshape(len(batch),self.n_part,self.n_dim),true_scaled.reshape(len(batch),self.n_part,self.n_dim))
        if w1m_[0]>0.01 and self.current_epoch>100 and not self.config["sched"]=="cosine2":
            print("no convergence, stop training")
            raise
        self.metrics["val_w1m"].append(w1m_)
        self.metrics["val_w1efp"].append(w1efp(fake_scaled.reshape(len(batch),self.n_part,self.n_dim),true_scaled.reshape(len(batch),self.n_part,self.n_dim)))
        
        
        temp={"val_logprob":float(logprob.numpy()),"val_fpnd":fpndv,"val_mmd":mmd,"val_cov":cov,"val_w1m":self.metrics["val_w1m"][-1][0],"val_w1efp":self.metrics["val_w1efp"][-1][0],"val_w1p":self.metrics["val_w1p"][-1][0],"step":self.global_step}
        
        print("epoch {}: ".format(self.current_epoch),temp)
        if self.hyperopt and self.global_step>3:
            # self._results(temp)
            summary=self._summary(temp)
        self.log("hp_metric",self.metrics["val_w1m"][-1][0],on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_w1m",self.metrics["val_w1m"][-1][0],on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_w1p",self.metrics["val_w1p"][-1][0],on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_w1efp",self.metrics["val_w1efp"][-1][0],on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_logprob",logprob,prog_bar=True,logger=True)
        self.log("val_cov",cov,prog_bar=True,logger=True,on_step=False, on_epoch=True)
        self.log("val_fpnd",fpndv,prog_bar=True,logger=True,on_step=False, on_epoch=True)
        self.log("val_mmd",mmd,prog_bar=True,logger=True,on_step=False, on_epoch=True)

        self.plot=plotting(model=self,gen=z_scaled,gen_corr=fake_scaled,true=true_scaled,config=self.config,step=self.global_step,logger=self.logger.experiment)  
        try:
            self.plot.plot_mass(m=m_gen.cpu().numpy(),m_t=m_t.cpu().numpy(),m_c=m_c.cpu().numpy(),save=True,bins=50,quantile=True,plot_vline=False)
            # self.plot.plot_2d(save=True)
#             self.plot.var_part(true=true[:,:self.n_dim],gen=gen_corr[:,:self.n_dim],true_n=n_true,gen_n=n_gen_corr,m_true=m_t,m_gen=m_test ,save=True)
        except Exception as e:
            traceback.print_exc() 
        self.flow=self.flow.to("cuda")
        self.gen_net=self.gen_net.to("cuda")
        self.dis_net=self.dis_net.to("cuda")
    


In [None]:
import sys  
hyperopt = True  # This sets to run a hyperparameter optimization with ray or just running the training once

config = {
    "autoreg":False,
    "context_features":0,
   "network_layers": 3,  # sets amount hidden layers in transformation networks -scannable
    "network_layers_nf": 2,  # sets amount hidden layers in transformation networks -scannable
    "network_nodes_nf": 256,  # amount nodes in hidden layers in transformation networks -scannable
    "batch_size": 2000,  # sets batch size -scannable #best one 4000
    "coupling_layers": 15,  # amount of invertible transformations to use -scannable
    "lr": 0.001,  # sets learning rate -scannable
    "batchnorm": False,  # use batchnorm or not -scannable
    "bins":5,  # amount of bins to use in rational quadratic splines -scannable
    "tail_bound": 6,  # splines:max value that is transformed, over this value theree is id  -scannable
    "limit": 150000,  # how many data points to use, test_set is 10% of this -scannable in a sense use 10 k for faster training
    "n_dim": 3,  # how many dimensions to use or equivalently /3 gives the amount of particles to use NEVER EVER CHANGE THIS
    "dropout": 0.2,  # use droput proportion, for 0 there is no dropout -scannable
    "canonical": False,  # transform data coordinates to px,py,pz -scannable
    "max_steps": 100000,  # how many steps to use at max - lower for quicker training
    "lambda": 10,  # balance between massloss and nll -scannable
    "name": "Transflow_final",  # name for logging folder
    "disc": False,  # whether to train gan style discriminator that decides whether point is simulated or generated-semi-scannable
    "variable":1, #use variable amount of particles otherwise only use 30, options are true or False 
    "parton":"t", #choose the dataset you want to train options: t for top,q for quark,g for gluon
    "wgan":False,
    "corr":True,
    "num_layers":5,
    "freq":10,
    "n_part":30,
    "fc":False,
    "hidden":80,
    "heads":3,
    "l_dim":63,
    "lr_g":1e-4,
    "lr_d":1e-4,
    "lr_nf":0.000722,
    "sched":None,
    "opt":"Adam",
    "lambda":1,
    "max_epochs":1600,
    "mass":True,
    "no_hidden":True,
    "clf":True,
    "val_check":50
}     

# if len(sys.argv)>2:
#     root="/beegfs/desy/user/"+os.environ["USER"]+"/"+config["name"]+"/run"+sys.argv[1]+"_"+str(sys.argv[2])
# else:
root="/beegfs/desy/user/"+os.environ["USER"]+"/"+config["name"]

            
data_module = JetNetDataloader(config,config["batch_size"]) #this loads the data
data_module.setup("training")
model = TransGan(config,hyperopt,data_module.num_batches) # the sets up the model,  config are hparams we want to optimize
model.data_module=data_module
# Callbacks to use during the training, we  checkpoint our models

callbacks = [ModelCheckpoint(monitor="val_w1m",save_top_k=2, filename='{epoch}-{val_fpnd:.2f}-{val_w1m:.4f}', dirpath=root,every_n_epochs=10) ]

if False:#load_ckpt:
    model = TransGan.load_from_checkpoint("/beegfs/desy/user/kaechben/Transflow_reloaded2/2022_08_08-18_02-08/epoch=239-val_logprob=0.47-val_w1m=0.0014.ckpt")
    model.data_module=data_module

# pl.seed_everything(model.config["seed"], workers=True)
# model.config["freq"]=20
# model.config["lr_g"]=0.00001
# model.config["lr_d"]=0.00001
# model.config = config #config are our hyperparams, we make this a class property now
logger = TensorBoardLogger(root)
#log every n steps could be important as it decides how often it should log to tensorboard
# Also check val every n epochs, as validation checking takes some time

trainer = pl.Trainer(gpus=1, logger=logger,  log_every_n_steps=5,  # auto_scale_batch_size="binsearch",
                      max_epochs=config["max_epochs"], callbacks=callbacks, progress_bar_refresh_rate=True,
                      check_val_every_n_epoch=config["val_check"] ,num_sanity_val_steps=0,#gradient_clip_val=.02, gradient_clip_algorithm="norm",
                     fast_dev_run=False,default_root_dir=root)
# This calls the fit function which trains the model

trainer.fit(model,datamodule=data_module )

  nn.init.xavier_normal(p)
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | q0        | StandardNormal     | 0     
1 | q_test    | StandardNormal     | 0     
2 | flows     | CompositeTransform | 6.6 M 
3 | flow_test | Flow               | 6.6 M 
4 | flow      | Flow               | 6.6 M 
5 | gen_net   | Gen                | 95.7 K
6 | dis_net   | Disc               | 95.3 K
7 | sig       | Sigmoid            | 0     
-------------------------------------------------
6.8 M     Trainable params
0         Non-trainable params
6.8 M     Total params
27.006    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



epoch 49:  {'val_logprob': 0.5185331106185913, 'val_fpnd': 32.4072516748474, 'val_mmd': 0.08661059372852516, 'val_cov': 0.257, 'val_w1m': 0.015798835202362388, 'val_w1efp': 0.00027832312137666856, 'val_w1p': 0.023252358801423956, 'step': 3763}


Validation: 0it [00:00, ?it/s]



epoch 99:  {'val_logprob': 0.5185331106185913, 'val_fpnd': 19.596721754712007, 'val_mmd': 0.07786583962330493, 'val_cov': 0.391, 'val_w1m': 0.00920265948953107, 'val_w1efp': 0.0001609642511678173, 'val_w1p': 0.010758718508268638, 'step': 6707}


Validation: 0it [00:00, ?it/s]



epoch 149:  {'val_logprob': 0.5185331106185913, 'val_fpnd': 13.972423818117846, 'val_mmd': 0.07885658372865456, 'val_cov': 0.45999999999999996, 'val_w1m': 0.005590532812271268, 'val_w1efp': 0.00015088370942319623, 'val_w1p': 0.014601729460354049, 'step': 9652}


Validation: 0it [00:00, ?it/s]



epoch 199:  {'val_logprob': 0.5185331106185913, 'val_fpnd': 11.485946324436668, 'val_mmd': 0.07700029331935007, 'val_cov': 0.522, 'val_w1m': 0.0047123191112093624, 'val_w1efp': 0.0001761489806831363, 'val_w1p': 0.011896537161695303, 'step': 12596}


## # STOP
 

In [35]:
import datetime
import os
import time
import traceback
from pytorch_lightning.tuner.tuning import Tuner

import pandas as pd
import pytorch_lightning as pl
import ray
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CometLogger, TensorBoardLogger
from ray import tune
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback, TuneReportCheckpointCallback)
from scipy import stats
from torch.nn import functional as FF

from helpers import *
from jetnet_dataloader import JetNetDataloader

from plotting import plotting

# from comet_ml import Experiment

hyperopt = True  # This sets to run a hyperparameter optimization with ray or just running the training once

config = {
    "autoreg":False,
    "context_features":0,
   "network_layers": 3,  # sets amount hidden layers in transformation networks -scannable
    "network_layers_nf": 2,  # sets amount hidden layers in transformation networks -scannable
    "network_nodes_nf": 256,  # amount nodes in hidden layers in transformation networks -scannable
    "batch_size": 1024,  # sets batch size -scannable #best one 4000
    "coupling_layers": 10,  # amount of invertible transformations to use -scannable
    "lr": 0.001,  # sets learning rate -scannable
    "batchnorm": False,  # use batchnorm or not -scannable
    "bins":5,  # amount of bins to use in rational quadratic splines -scannable
    "tail_bound": 6,  # splines:max value that is transformed, over this value theree is id  -scannable
    "limit": 150000,  # how many data points to use, test_set is 10% of this -scannable in a sense use 10 k for faster training
    "n_dim": 3,  # how many dimensions to use or equivalently /3 gives the amount of particles to use NEVER EVER CHANGE THIS
    "dropout": 0.2,  # use droput proportion, for 0 there is no dropout -scannable
    "canonical": False,  # transform data coordinates to px,py,pz -scannable
    "max_steps": 100000,  # how many steps to use at max - lower for quicker training
    "lambda": 10,  # balance between massloss and nll -scannable
    "name": "Transflow_reloaded2",  # name for logging folder
    "disc": False,  # whether to train gan style discriminator that decides whether point is simulated or generated-semi-scannable
    "variable":1, #use variable amount of particles otherwise only use 30, options are true or False 
    "parton":"t", #choose the dataset you want to train options: t for top,q for quark,g for gluon
    "wgan":True,
    "corr":True,
    "num_layers":5,
    "freq":10,
    "n_part":30,
    "fc":False,
    "hidden":16,
    "heads":3,
    "l_dim":63,
    "lr_g":1e-4,
    "lr_d":1e-4,
    "lr_nf":0.000722,
    "sched":False,
    "pretrain":30,
    "opt":"RMSprop",
    "lambda":1,
    "max_epochs":300,
    "mass":True

}     
print(config["name"])
root="/beegfs/desy/user/"+os.environ["USER"]+"/"+config["name"]+"/"+datetime.datetime.now().strftime("%Y_%m_%d-%H_%M-%S")

hyperopt=True

# This function is a wrapper for the hyperparameter optimization module called ray 
# Its parameters hyperopt and load_ckpt are there for convenience
# Config is the only relevant parameter as it sets the trainings hyperparameters
# hyperopt:whether to optimizer hyper parameters - load_ckpt: path to checkpoint if used
data_module = JetNetDataloader(config,config["batch_size"]) #this loads the data

model = TransGan(config,hyperopt) # the sets up the model,  config are hparams we want to optimize
model.data_module=data_module
# Callbacks to use during the training, we  checkpoint our models

callbacks = [ModelCheckpoint(monitor="val_w1m",save_top_k=2, filename='{epoch}-{val_fpnd:.2f}-{val_w1m:.4f}', dirpath=root,every_n_epochs=10) ]
print(model.config)

# if True:#load_ckpt:
#     model = TransGan.load_from_checkpoint("/beegfs/desy/user/kaechben/Transflow_reloaded2/2022_08_08-18_02-08/epoch=239-val_logprob=0.47-val_w1m=0.0014.ckpt")
#     model.data_module=data_module
#     print("loaded mass",model.config["mass"])

# model.load_datamodule(data_module)

# pl.seed_everything(model.config["seed"], workers=True)


# logger = TensorBoardLogger(root)
# #log every n steps could be important as it decides how often it should log to tensorboard
# # Also check val every n epochs, as validation checking takes some time

# trainer = pl.Trainer(gpus=1, logger=logger,  log_every_n_steps=100,  # auto_scale_batch_size="binsearch",
#                       max_epochs=config["max_epochs"], callbacks=callbacks, progress_bar_refresh_rate=int(not hyperopt)*10,
#                       check_val_every_n_epoch=5 ,num_sanity_val_steps=1,#gradient_clip_val=.02, gradient_clip_algorithm="norm",
#                      fast_dev_run=False,default_root_dir=root)
# # This calls the fit function which trains the model

# trainer.fit(model,datamodule=data_module )  




Transflow_reloaded2
{'autoreg': False, 'context_features': 0, 'network_layers': 3, 'network_layers_nf': 2, 'network_nodes_nf': 256, 'batch_size': 1024, 'coupling_layers': 10, 'lr': 0.001, 'batchnorm': False, 'bins': 5, 'tail_bound': 6, 'limit': 150000, 'n_dim': 3, 'dropout': 0.2, 'canonical': False, 'max_steps': 100000, 'lambda': 1, 'name': 'Transflow_reloaded2', 'disc': False, 'variable': 1, 'parton': 't', 'wgan': True, 'corr': True, 'num_layers': 5, 'freq': 10, 'n_part': 30, 'fc': False, 'hidden': 16, 'heads': 3, 'l_dim': 63, 'lr_g': 0.0001, 'lr_d': 0.0001, 'lr_nf': 0.000722, 'sched': False, 'pretrain': 30, 'opt': 'RMSprop', 'max_epochs': 300, 'mass': True}


  nn.init.xavier_normal(p)


In [36]:
# torch.save(model.flow.state_dict(), "./flow.pt")
base.Flow(distribution=model.q_test, transform=model.flows).load_state_dict(torch.load("./flow.pt"))

RuntimeError: Error(s) in loading state_dict for Flow:
	Unexpected key(s) in state_dict: "_transform._transforms.10.identity_features", "_transform._transforms.10.transform_features", "_transform._transforms.10.transform_net.initial_layer.weight", "_transform._transforms.10.transform_net.initial_layer.bias", "_transform._transforms.10.transform_net.blocks.0.context_layer.weight", "_transform._transforms.10.transform_net.blocks.0.context_layer.bias", "_transform._transforms.10.transform_net.blocks.0.linear_layers.0.weight", "_transform._transforms.10.transform_net.blocks.0.linear_layers.0.bias", "_transform._transforms.10.transform_net.blocks.0.linear_layers.1.weight", "_transform._transforms.10.transform_net.blocks.0.linear_layers.1.bias", "_transform._transforms.10.transform_net.blocks.1.context_layer.weight", "_transform._transforms.10.transform_net.blocks.1.context_layer.bias", "_transform._transforms.10.transform_net.blocks.1.linear_layers.0.weight", "_transform._transforms.10.transform_net.blocks.1.linear_layers.0.bias", "_transform._transforms.10.transform_net.blocks.1.linear_layers.1.weight", "_transform._transforms.10.transform_net.blocks.1.linear_layers.1.bias", "_transform._transforms.10.transform_net.final_layer.weight", "_transform._transforms.10.transform_net.final_layer.bias", "_transform._transforms.11.identity_features", "_transform._transforms.11.transform_features", "_transform._transforms.11.transform_net.initial_layer.weight", "_transform._transforms.11.transform_net.initial_layer.bias", "_transform._transforms.11.transform_net.blocks.0.context_layer.weight", "_transform._transforms.11.transform_net.blocks.0.context_layer.bias", "_transform._transforms.11.transform_net.blocks.0.linear_layers.0.weight", "_transform._transforms.11.transform_net.blocks.0.linear_layers.0.bias", "_transform._transforms.11.transform_net.blocks.0.linear_layers.1.weight", "_transform._transforms.11.transform_net.blocks.0.linear_layers.1.bias", "_transform._transforms.11.transform_net.blocks.1.context_layer.weight", "_transform._transforms.11.transform_net.blocks.1.context_layer.bias", "_transform._transforms.11.transform_net.blocks.1.linear_layers.0.weight", "_transform._transforms.11.transform_net.blocks.1.linear_layers.0.bias", "_transform._transforms.11.transform_net.blocks.1.linear_layers.1.weight", "_transform._transforms.11.transform_net.blocks.1.linear_layers.1.bias", "_transform._transforms.11.transform_net.final_layer.weight", "_transform._transforms.11.transform_net.final_layer.bias", "_transform._transforms.12.identity_features", "_transform._transforms.12.transform_features", "_transform._transforms.12.transform_net.initial_layer.weight", "_transform._transforms.12.transform_net.initial_layer.bias", "_transform._transforms.12.transform_net.blocks.0.context_layer.weight", "_transform._transforms.12.transform_net.blocks.0.context_layer.bias", "_transform._transforms.12.transform_net.blocks.0.linear_layers.0.weight", "_transform._transforms.12.transform_net.blocks.0.linear_layers.0.bias", "_transform._transforms.12.transform_net.blocks.0.linear_layers.1.weight", "_transform._transforms.12.transform_net.blocks.0.linear_layers.1.bias", "_transform._transforms.12.transform_net.blocks.1.context_layer.weight", "_transform._transforms.12.transform_net.blocks.1.context_layer.bias", "_transform._transforms.12.transform_net.blocks.1.linear_layers.0.weight", "_transform._transforms.12.transform_net.blocks.1.linear_layers.0.bias", "_transform._transforms.12.transform_net.blocks.1.linear_layers.1.weight", "_transform._transforms.12.transform_net.blocks.1.linear_layers.1.bias", "_transform._transforms.12.transform_net.final_layer.weight", "_transform._transforms.12.transform_net.final_layer.bias", "_transform._transforms.13.identity_features", "_transform._transforms.13.transform_features", "_transform._transforms.13.transform_net.initial_layer.weight", "_transform._transforms.13.transform_net.initial_layer.bias", "_transform._transforms.13.transform_net.blocks.0.context_layer.weight", "_transform._transforms.13.transform_net.blocks.0.context_layer.bias", "_transform._transforms.13.transform_net.blocks.0.linear_layers.0.weight", "_transform._transforms.13.transform_net.blocks.0.linear_layers.0.bias", "_transform._transforms.13.transform_net.blocks.0.linear_layers.1.weight", "_transform._transforms.13.transform_net.blocks.0.linear_layers.1.bias", "_transform._transforms.13.transform_net.blocks.1.context_layer.weight", "_transform._transforms.13.transform_net.blocks.1.context_layer.bias", "_transform._transforms.13.transform_net.blocks.1.linear_layers.0.weight", "_transform._transforms.13.transform_net.blocks.1.linear_layers.0.bias", "_transform._transforms.13.transform_net.blocks.1.linear_layers.1.weight", "_transform._transforms.13.transform_net.blocks.1.linear_layers.1.bias", "_transform._transforms.13.transform_net.final_layer.weight", "_transform._transforms.13.transform_net.final_layer.bias", "_transform._transforms.14.identity_features", "_transform._transforms.14.transform_features", "_transform._transforms.14.transform_net.initial_layer.weight", "_transform._transforms.14.transform_net.initial_layer.bias", "_transform._transforms.14.transform_net.blocks.0.context_layer.weight", "_transform._transforms.14.transform_net.blocks.0.context_layer.bias", "_transform._transforms.14.transform_net.blocks.0.linear_layers.0.weight", "_transform._transforms.14.transform_net.blocks.0.linear_layers.0.bias", "_transform._transforms.14.transform_net.blocks.0.linear_layers.1.weight", "_transform._transforms.14.transform_net.blocks.0.linear_layers.1.bias", "_transform._transforms.14.transform_net.blocks.1.context_layer.weight", "_transform._transforms.14.transform_net.blocks.1.context_layer.bias", "_transform._transforms.14.transform_net.blocks.1.linear_layers.0.weight", "_transform._transforms.14.transform_net.blocks.1.linear_layers.0.bias", "_transform._transforms.14.transform_net.blocks.1.linear_layers.1.weight", "_transform._transforms.14.transform_net.blocks.1.linear_layers.1.bias", "_transform._transforms.14.transform_net.final_layer.weight", "_transform._transforms.14.transform_net.final_layer.bias". 

In [15]:
model.validation_step(torch.ones(10,30,3).cuda(),1)

AttributeError: 'JetNetDataloader' object has no attribute 'scaler'

In [37]:
from helpers import mass
del model
import gc
torch.cuda.empty_cache()
gc.collect()

64950

In [13]:

encoder_layer = nn.TransformerEncoderLayer(d_model=15, nhead=1).cuda()
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6).cuda()
lin=nn.Sequential(*[nn.Linear(99,99) for i in range(6)]).cuda()
import time
error=[]
for i in range(100):
    start=time.time()
    src = torch.rand(5000, 99).cuda()
    out = lin(src).detach()
    error.append(time.time()-start)
    
print(np.average(error))


0.004051065444946289
