# RUN THIS CELL ONLY ONCE

In [1]:
import os
# Get the current path of the notebook

notebook_path = os.getcwd()
# Construct the path to the directory containing 'fit'
parent_path = os.path.join(notebook_path, '..')


In [2]:
%matplotlib inline
import torch

os.chdir(parent_path)

import sys
# Add this parent directory to the system path
sys.path.insert(0, parent_path)
from fit.fit_pnf import PNF as PNFModel
from fit.fit_nf import NF as NFModel
from fit.fit_tnf import TNF as TNFModel
from fit.fit_tf import TF as TFModel
import pytorch_lightning as pl
import torch
from utils.dataloader_jetnet import PointCloudDataloader
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches
from utils.helpers import get_hists, mass
from utils.dataloader_jetnet import PointCloudDataloader
import matplotlib as mpl
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MaxNLocator, FuncFormatter
from jetnet.evaluation import w1m
from utils.helpers import mass, plotting_thesis,fit_kde,sample_kde,create_mask


def calculate_data_bounds(dataloader):
    """
    Calculates the minimum and maximum values across the dimensions from the dataloader.

    :param dataloader: The dataloader to process.
    :param n_dim: Number of dimensions.
    :return: Tuple of minimum values, maximum values, and count of non-masked data points.
    """
    mins = torch.ones(3).unsqueeze(0)
    maxs = torch.ones(3).unsqueeze(0)
    n = []
    for i in dataloader:
        mins = torch.min(
            torch.cat((mins, i[0][~i[1]].min(0, keepdim=True)[0]), dim=0), dim=0
        )[0].unsqueeze(0)
        maxs = torch.max(
            torch.cat((maxs, i[0][~i[1]].max(0, keepdim=True)[0]), dim=0), dim=0
        )[0].unsqueeze(0)
        n.append((~i[1]).sum(1))
    # model.maxs = maxs.cuda()
    # model.mins = mins.cuda()
    # model.avg_n = torch.cat(n, dim=0).float().cuda().mean()
    return mins, maxs, n

def setup_model_with_data(model, data_module):
    """
    Sets up the model with the data module and configuration parameters.

    :param model: The model to be set up.
    :param data_module: The data module used for training and validation.
    :param config: Configuration dictionary.
    """
    model.bins = [100, 100, 100, 100]
    model.n_dim = 3
    model.scaler = data_module.scaler[0]
    model.pt_scaler = data_module.scaler[1]
    model.w1m_best = 0.01


    # Calculate the minimum and maximum values from the training data
    mins, maxs, n_counts = calculate_data_bounds(data_module.train_dataloader())
    model.maxs = maxs.cuda()
    model.mins = mins.cuda()
    model.avg_n = torch.cat(n_counts, dim=0).float().cuda().mean()
    model.gen_net.avg_n = torch.cat(n_counts, dim=0).float().cuda().mean()
    model.dis_net.avg_n = torch.cat(n_counts, dim=0).float().cuda().mean()
    # Additional model settings
    model.scaler = model.scaler.to("cuda")
    model.scaler.std = model.scaler.std.cuda()
    model.scaled_mins = torch.tensor(data_module.mins).cuda()
    model.scaled_maxs = torch.tensor(data_module.maxs).cuda()



def make_plots(model_name, disco=False):

    ckptdir = "./ckpts/"
    ckpt = "t_{}.ckpt".format(model_name)
    # ckpt = "t_{}.ckpt".format(model_name)
    ckpt = ckptdir + ckpt
    print(ckpt)

    # Load state dictionary from checkpoint
    # state_dict = torch.load(ckpt)
    # config = state_dict["hyper_parameters"]
    # config["model_name"] = model_name

    # print(config)
    # Choose the model class based on the model name
    if model_name in ["ipf", "pf", "apf"]:
        # config["pf"] = True
        # config["adversarial"] = False
        # config["norm"]=False
        # config["fast"]=False

        model_class = PNFModel
    else:
        # config["pf"] = False
        model_class = NFModel if model_name.find("t")==-1 else TNFModel if model_name.find("tnf")>-1 else TFModel

    torch.set_float32_matmul_precision('medium' )
    model = model_class.load_from_checkpoint(ckpt,ema=False)

    # Initialize data module and set up model
    data_module = PointCloudDataloader(parton="t",n_dim=3,n_part=30,batch_size=1024,sampler=False)
    data_module.setup("fit")


    # Assuming `model` is defined elsewhere in your code
    setup_model_with_data(model, data_module)
    train=data_module.train_dataloader().dataset.cuda()
    test=data_module.test_dataloader().dataset.cuda()
    pt=model.pt_scaler.inverse_transform(train[:,:,-2])
    std=model.scaler.inverse_transform(train[:,:,:-2])
    train=torch.cat((std,pt.unsqueeze(2),train[:,:,-1:]),dim=2)
    data=torch.cat((train,test[:,:,:]),dim=0)
    m=mass(data.cuda()).cpu()
    n=(~(torch.cat((train,test),dim=0)[:,:,-1]).bool()).float().sum(1).cpu()
    n_kde,m_kde=fit_kde(n,m)
    n,m=sample_kde(len(data),n_kde,m_kde)
    # Trainer setup and model validation
    trainer = pl.Trainer(devices=1, accelerator="gpu")
    model.eval_metrics=False
    model.batch=[]
    model.masks=[]
    model.fake=[]
    model.conds=[]
    model=model.cuda()
    model.load_datamodule(data_module)
    with torch.no_grad():
        trainer.test(model, data_module.val_dataloader())

    # concatenate all batches
    fake = torch.cat(model.fake)
    true = torch.cat(model.batch)
    # sorted_indices = torch.argsort(fake[:,:,2], dim=1, descending=True)
    # fake = torch.gather(fake, 1, sorted_indices.unsqueeze(-1).expand(-1, -1, fake.shape[2]))

    m_f, m_t = mass(fake), mass(true)

    # Apply clamping based on quantiles
    mins = torch.quantile(true.reshape(-1, 3), 0.001, dim=0)
    maxs = torch.quantile(true.reshape(-1, 3), 0.999, dim=0)
    fake = torch.clamp(fake, min=mins, max=maxs)
    true = torch.clamp(true, min=mins, max=maxs)
    m_f = torch.clamp(m_f, min=torch.quantile(m_t, 0.001), max=torch.quantile(m_t, 0.999))
    m_t = torch.clamp(m_t, min=torch.quantile(m_t, 0.001), max=torch.quantile(m_t, 0.999))
    mins=torch.cat([mins,torch.tensor([torch.quantile(m_t, 0.001)])])-0.01
    maxs=torch.cat([maxs,torch.tensor([torch.quantile(m_t, 0.999)])])*1.01
    for i in range(5):
        w1m_=w1m(fake,true,num_batches=16,num_eval_samples=250000)
        print(w1m_)
    # Prepare histograms
    hists=get_hists([30,30,30,30],mins*1.1,maxs*1.1,calo=model.name=="calo")

    masks=torch.cat(model.masks)
    # Fill histograms
    for var in range(3):
        hists["hists_real"][var].fill(true.reshape(-1, 3)[(true.reshape(-1, 3) != 0).all(1)][:, var].cpu().numpy())
        hists["hists_fake"][var].fill(fake.reshape(-1, 3)[(fake.reshape(-1, 3) != 0).all(1)][:, var].cpu().numpy())


    hists["hists_real"][3].fill(m_t.cpu().numpy())
    hists["hists_fake"][3].fill(m_f.cpu().numpy())

    # Plotting
    plotter = plotting_thesis()
    plotter.plot_ratio(hists["hists_real"], hists["hists_fake"], weighted=False, leg=2, model_name=model_name)
    # plotter.plot_corr(true.numpy(), fake.numpy(), model_name, disco=disco,leg=-1)
    return model
pointflows=[ "ipf", "apf", "pf",]
flows=["nf","cnf","ccnf"]
gans=["tnf","tnf_fast","tf_slow","tf_slow_ema"]

for model_name in  ["tf"]:#
    model=make_plots(model_name)

  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)


./ckpts/t_tf.ckpt




In [None]:
torch.cat(model.fake).shape

torch.Size([50000, 30, 3])

In [None]:
import torch
import numpy as np
ckpt=torch.load("/beegfs/desy/user/kaechben/thesis/eval_jetnet150/generated_data_50000.pt")
with open("/beegfs/desy/user/kaechben/thesis/eval_jetnet150/EPiC-FM.npy","wb") as f:
    np.save(f,ckpt)


In [None]:
import pandas as pd
import numpy as np
import ast
from numpy import array
import re
# Convert string representation of lists back to lists
def convert_str_to_tuple(s):
    # Extract lists from the string using regex
    s=s.replace("\n ","")
    lists = re.findall(r'\[.*?\]', s)
    arrays = [np.array(ast.literal_eval(l)) for l in lists]
    return tuple(arrays)
results=pd.read_csv("bravobenno.csv")
def weighted_mean(x):
        x,w=np.array(x[0]),np.array(x[1])
        weights=1/w**2
        return np.sum(x*weights)/np.sum(weights)
def weighted_std(x):
    w=(1/np.array(x[1]))**2
    sigma=np.sqrt(1/np.sum(w))
    return sigma

def format_mean_sd(mean, sd):
    """round mean and standard deviation to most significant digit of sd and apply latex formatting"""
    decimals = -int(np.floor(np.log10(sd)))
    decimals -= int((sd * 10 ** decimals) >= 9.5)

    if decimals < 0:
        ten_to = 10 ** (-decimals)
        if mean > ten_to:
            mean = ten_to * (mean // ten_to)
        else:
            mean_ten_to = 10 ** np.floor(np.log10(mean))
            mean = mean_ten_to * (mean // mean_ten_to)
        sd = ten_to * (sd // ten_to)
        decimals = 0

    if mean >= 1e3 and sd >= 1e3:
        mean = np.round(mean * 1e-3)
        sd = np.round(sd * 1e-3)
        return f"${mean:.{decimals}f}$k $\\pm {sd:.{decimals}f}$k"
    else:
        return f"${mean:.{decimals}f} \\pm {sd:.{decimals}f}$"
df=results
df["pmm"]=df["w1m"].apply(ast.literal_eval).apply(lambda x:x).apply(weighted_std)
df["w1m"]=df["w1m"].apply(ast.literal_eval).apply(lambda x:x).apply(weighted_mean)
df["pmp"]=df["w1p"].apply(convert_str_to_tuple).apply(lambda x:x).apply(weighted_std)
df["w1p"]=df["w1p"].apply(convert_str_to_tuple).apply(lambda x:x).apply(weighted_mean)
df["pme"]=df["w1efp"].apply(convert_str_to_tuple).apply(lambda x:x).apply(weighted_std)
df["w1efp"]=df["w1efp"].apply(convert_str_to_tuple).apply(lambda x:x).apply(weighted_mean)

df["fpd_std"]=df["fpd"].apply(ast.literal_eval).apply(lambda x:x[1]).apply(np.mean)
df["fpd"]=df["fpd"].apply(ast.literal_eval).apply(lambda x:x[0]).apply(np.mean)
df["kpd_std"]=df["kpd"].apply(ast.literal_eval).apply(lambda x:x[1]).apply(np.mean)
df["kpd"]=df["kpd"].apply(ast.literal_eval).apply(lambda x:x[0]).apply(np.mean)

cols=["name","w1m","w1efp","w1m","pmm","pme","pmp","cov","mmd","fpd","kpd","time","parameters"]
replace_dict={"MPGAN":"MPGAN","t_cpflow":"PF","t_ipflow":"IPF","t_apflow":"APF","t_nflow":"NF","t_ccnflow":"NF(cc)","t_cnflow":"NF(c)","t_tnflow":"TNF","IN":"IN"}
df.loc[:,"model"]=df["model"].apply(lambda x:replace_dict[x])
df=df.set_index("model",drop=True)
df.loc[:,"w1m"]*=1000
df.loc[:,"w1p"]*=1000
df["w1p"]=df["w1p"]
df.loc[:,"w1efp"]*=100000
df["w1efp"]=df["w1efp"]
df.loc[:,"pmm"]*=1000
df.loc[:,"pmp"]*=1000
df.loc[:,"pme"]*=100000
df.loc[:,"fpd"]*=10000
df.loc[:,"fpd_std"]*=10000
df.loc[:,"kpd"]*=10000

df.loc[:,"kpd_std"]*=10000
df.loc[:,"w1m"]=df.apply(lambda x:format_mean_sd(float(x["w1m"]),float(x["pmm"])),axis=1)
df.loc[:,"w1p"]=df.apply(lambda x:format_mean_sd(float(x["w1p"]),float(x["pmp"])),axis=1)
df.loc[:,"w1efp"]=df.apply(lambda x:format_mean_sd(float(x["w1efp"]),float(x["pme"])),axis=1)
df.loc[:,"kpd"]=df.apply(lambda x:format_mean_sd(float(x["kpd"]),float(x["kpd_std"])),axis=1)
df.loc[:,"fpd"]=df.apply(lambda x:format_mean_sd(float(x["fpd"]),float(x["fpd_std"])),axis=1)
df.loc[:,"time"]*=1e6
df.loc[:,"time"]=np.round(df["time"],decimals=1)
order=["PF","IPF","APF","NF","NF(c)","NF(cc)","TNF","MPGAN","IN"]
df=df.loc[order,:]
print(df)
# def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print (count_parameters(model.gen_net))
tex=""
for p in ["t"]:
    temp=df
    for col in df.columns:
        if col not in ["w1m","w1p","w1efp","fpd","kpd","time"]:
            continue
        temp_index=temp[col].astype(str).str.replace("$","").str.replace("k","").str.split("\\").str[0].astype(float)
        mins=temp_index==temp_index.drop("IN").min()
        temp.loc[mins,col]="$\mathbf{"+temp.loc[mins,col].astype(str).str.replace("$","")+"}$"
    temp=temp.reset_index()[["model","w1m","w1p","w1efp","kpd","fpd","time"]]
    temp.columns=["model","$W_1^M (\times 10^{3})$","$W_1^P (\times 10^{3})$","$W_1^{EFP}(\times 10^{5})$","$KPD(\times 10^{4})$","$FPD$","Time $\mu s$"]
    text=temp.to_latex(index=False,escape=False)
    parton="Gluon" if p=="g" else "Light Quark" if p=="q" else "Top Quark"
    tex+="\multirow{9}{*}{"+parton+"} & "+text.split("FPD")[1].split("\\bottomrule")[0].replace("\\\\","\\\\&").replace("\\midrule","").replace("  ","")[:-2]+"\cline{1-8}"
    tex+="\n"
print(tex)


FileNotFoundError: [Errno 2] No such file or directory: 'bravobenno.csv'

In [None]:
results["w1efp"][0]

'(array([0.00025178, 0.00022063, 0.00026026, 0.00043311, 0.00019063]), array([6.04386555e-07, 1.94986961e-06, 2.46654965e-06, 1.54915805e-06,\n       1.93143100e-06]))'