## Uncertainty Quantification

In traditional *supervised learning*, we have data features and target data. The result is to have a neural network that maps input data that maps to predicted target that maps the true targets.

This gives us an expected target, but we want to go beyond this and learn the uncertainties associated with it.

<div align="center">
    <img src="images/probs.png" alt="probs" style="width:30%; height:auto;">
</div>

A big theme with this tutorial is: **Do not mistake probability for model confidence**. I will explain this shortly.

Let's say we train a neural network to discriminate between cats and dogs based on images. This network learns features in these images, such as whiskers for cats or floppy ears for dogs. This networks creates an internal representation to discern features from these images and then makes a decision whether it is a cat or a dog. The last layer is `softmax` in order to create probabilities that sum to 1.

<div align="center">
    <img src="images/pet_model.png" alt="probs" style="width:35%; height:auto; object-fit:cover; object-position:top; margin-top:-4%; clip-path:inset(20% 0 0 0);">
</div>

The main challenge in uncertainty quantification and anomaly detection is that models are typically trained on clean, well-curated data. However, real-world data is often messy and unpredictable.

When we deploy our models, it's not enough to simply classify an image—we also need to know how confident the model is in its prediction.

For example, if we train a network to distinguish between cats and dogs, and then test it on an image of a boat (something it has never seen before), the model will still output a probability for "cat" or "dog." However, it won't indicate how uncertain it is about this prediction.

We need a way for a model to say **hmm...I don't know** <img src="images/hmm.png" alt="hmm" style="height:1em; vertical-align:middle;">

<div align="center" style="display: flex; justify-content: center; gap: 2em;">
    <img src="images/reality.png" alt="reality" style="width:35%; height:auto;">
    <img src="images/new.png" alt="new" style="width:35%; height:auto;">
</div>

There are two main types of uncertainty:
 - **Aleatoric Uncertainty**
    - Relates to the inherent noise or randomness in the input data
    - Becomes significant when the data is noisy or ambiguous
    - Cannot be minimized by collecting more data
 - **Epistemic Uncertainty**
    - Reflects uncertainty in the model’s predictions due to limited knowledge
    - Is high when the model has not seen enough diverse training examples
    - Can be reduced by gathering additional data or improving the model

<div align="center">
    <img src="images/uncertainty.png" alt="probs" style="width:30%; height:auto;">
</div>

Aleatoric uncertainty can be learned directly using neural networks, but epistemic uncertainty is much more challenging to estimate. 

Q: *How can a model understand when it doesn't know the answer?*

A: Instead of learning a fixed set of weight numbers, we replace those numbers with distributions. This is called Bayesian neural networks. However, this problem becomes intractable, so we use approximations.

These approximations involve running $T$ forward passes on the same input data using different weights. For example:
 - **Dropout** (left image) randomly zeros out some number of nodes, and the spread of those predictions gives you an estimate of the epistemic uncertainty.
 - **Ensemble** (right image) involves independently training $T$ different models, and the variance of that gives you the epistemic uncertainty.

<div align="center" style="display: flex; justify-content: center; gap: 2em;">
    <img src="images/dropout.png" alt="Dropout (left)" style="width:25%; height:auto;">
    <img src="images/ens.png" alt="Ensemble (right)" style="width:25%; height:auto;">
</div>

There are downsides to Bayesian Deep Learning, which include:
 - **Slow**: Requires running the network $T$ times for each input
 - **Memory**: Stores $T$ copies of the network in parallel
 - **Efficiency**: Sampling hinders real-time on edge devices
 - **Calibration**: Sensitive to prior and often over-confident

Let's say that each model has a mean and a data uncertainty, like the top left image. If you have an ensemble of these models (bottom left image), We can see that each of these ensemble puts an X on the rightmost graph. The x-axis is the mean $\mu$ while the y-axis is the data uncertainty $\sigma^2$. 

Instead of this brute force sampling, can we instead **directly** learn the parameters defining this underlying likelihood distribution?

<div style="max-width:700px; margin:auto;">
    <div style="display: flex; align-items: flex-start; gap: 2em;">
            <div style="display: flex; flex-direction: column; justify-content: space-between; height: 300px;">
                    <img src="images/im1.png" alt="im1" style="max-height:145px; ">
                    <img src="images/im2.png" alt="im2" style="max-height:145px; margin-top:2%; object-fit:contain; aspect-ratio:1/1;">
            </div>
            <img src="images/im3.png" alt="im3" style="max-height:310px; object-fit:contain;">
    </div>
</div>

## Evidential Deep Learning

Evidential Deep Learning (EDL) recasts the learning as an evidence acquisition process. The more evidence we have, the more confidence we have. It takes a *Theory of Evidence* perspective: $\operatorname{softmax}$ is interpreted as the parameter set of a categorical distribution which is replaced by the parameters of a **Dirichlet density** (a factory of softmax point estimates). From left to right, we have:
 - Low uncertainties,, which indicates high confidence
 - High aleatoric (data) uncertainty
 - High epistemic (model) uncertainty

 **Goal: train a neural network to learn these type of evidential distributions**

<div align="center">
    <img src="images/dists.png" alt="dists" style="width:50%; height:auto;">
</div>

In this tutorial, we focus on classification, though it is also applicable to regression. More information on EDL for regression is available [here](https://arxiv.org/abs/1910.02600).

**Sampling from an evidential distribution yields individual new distributions over the data.** So you can think of this as a second-order learning problem. You want to learn the $\alpha_k$ parameters of a Dirichlet distribution ($\alpha_k > 0 $). In the image below, we can see an example of a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) with three classes, making a triangle. More classes creates a [simplex](https://en.wikipedia.org/wiki/Simplex). These model parameters determine the density inside of this classification simplex. We replace the Categorical distribution that we've been using to get probabilities with a Dirichlet distribution. Sampling from this gives us our probabilities. For exampling, the middle of the Dirichlet distribution will result in equal probabilities for all classes. 

We choose the Dirichlet distribution as our *evidential* distribution because it is the conjugate prior in the context of Bayesian inference. The proofs for why these are chosen are available in the original [paper](https://papers.nips.cc/paper_files/paper/2018/hash/a981f2b708044d6fb4a71a1463242520-Abstract.html).

<div align="center">
    <img src="images/dirichlet.png" alt="dists" style="width:30%; height:auto;">
</div>

For our network with model weights $\Theta$, instead of outputting the probabilities using a `softmax` layer, we use a `ReLU` layer to ensure $\alpha_k > 0$. Then, we perform multi-objective training by modifying the loss function, shown in the images below. We include a reconstruction loss to ensure the accuracy is high, and we add a regularization term to penalize large confidence to uncertain samples. You achieve this through KL divergence between you predicted Dirichlet distribution and the unit Dirichlet distribution ($\alpha_k=1$) which is completely uncertain. The $\lambda_t$ term controls the strength of the regularization term. Similar to the original paper, we increase $\lambda_t$ over epochs.
$
\begin{align}
    \mathcal{L}(\Theta)=\sum_{i=1}^{N}\mathcal{L}_{MSE}(\Theta)_i+\lambda_t\sum_{i=1}^{N}\mathcal{L}_{KL}(\Theta)_i
\end{align}
$

<div align="center" style="display: flex; justify-content: center; gap: 2em;">
    <img src="images/model.png" alt="Model (left)" style="width:25%; height:auto;">
    <img src="images/diagram.png" alt="Diagram (right)" style="width:25%; height:auto;">
</div>

Specifically, the outputs of the network, denoted as $f_k(\mathbf{x}|\Theta)$, directly provide the evidence for the anticipated Dirichlet distribution through.
$$
\begin{equation*}
e_k = f_k(\mathbf{x}|\Theta) ~~~\textrm{and}~~~ \alpha_k = f_k(\mathbf{x}|\Theta) + 1
\end{equation*}
$$

Once the network learns the parameters $\alpha$, its mean, can be taken as an estimate of the class probabilities: $$p_k = \mathbb{E}(x_k) = \frac{\alpha_k}{S}$$ where the sum $S = \sum_{k=1}^{K}\alpha_k = \sum_{k=1}^{K}(e_k + 1)$ represents the Dirichlet strength. The epistemic uncertainty can then be estimated as $$u = \frac{K}{S}$$ where $K$ is the number of classes. More information is available in our paper [here](https://iopscience.iop.org/article/10.1088/2632-2153/ade51b).


## Getting started

If you're running within Google Colab:
- Go to Runtime -> Change runtime type and select GPU

If you're running locally  
- Run `pip install -r requirements.txt` to install the necessary requirements

#### Import all necessary packages

In [None]:
!rm -rf GDSVirtualTutorials
!git clone https://github.com/butler-julie/GDSVirtualTutorials.git
%cd GDSVirtualTutorials/100325_UncertaintyQuantification/

In [None]:
!pip install torchinfo

In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
import os, json, sys
import matplotlib.pyplot as plt
import json
import seaborn as sns
import h5py
from matplotlib.colors import LogNorm
from EvalTools import *
import gc


#### Downloading and preprocessing data

We apply EDL to identify jets at LHC using an explainable AI inspired $t$-quark tagger called the Particle Flow Interaction Network (PFIN). The first dataset (`TopData`) has 2M simulated jets and two classes: top quark jets (label=1) and QCD jets (label=0).

In [10]:
orig_dir = os.getcwd()

try:
    os.chdir('datasets/topdata')
    !rm -rf processed
    !rm download
    !rm -rf raw
    !mkdir -p processed
    !wget https://desycloud.desy.de/index.php/s/llbX3zpLhazgPJ6/download
    !unzip download
    !mv v0/ raw/
    !rm download
    !python topdata_preprocess.py
finally:
    os.chdir(orig_dir)

'rm' is not recognized as an internal or external command,
operable program or batch file.
'rm' is not recognized as an internal or external command,
operable program or batch file.
'rm' is not recognized as an internal or external command,
operable program or batch file.
A subdirectory or file processed already exists.
Error occurred while processing: processed.
'wget' is not recognized as an internal or external command,
operable program or batch file.
'unzip' is not recognized as an internal or external command,
operable program or batch file.
'mv' is not recognized as an internal or external command,
operable program or batch file.
'rm' is not recognized as an internal or external command,
operable program or batch file.
c:\Users\khota\miniconda3\envs\robot\python.exe: can't open file 'c:\\Users\\khota\\GDSVirtualTutorials\\100325_UncertaintyQuantification\\datasets\\topdata\\topdata_prepocess.py': [Errno 2] No such file or directory


<div align="center">
    <img src="images/top.png" alt="top" style="width:50%; height:auto;">
</div>

We also use another dataset called `JetNet`. This dataset consists of 880k particle jets originating from gluons ($g$), light quarks ($q$), top quarks ($t$), and bosons ($W$ and $Z$).

In [None]:
orig_dir = os.getcwd()

try:
    os.chdir('datasets/jetnet')
    !mkdir -p raw
    !mkdir -p processed
    !wget -O raw/g.hdf5 "https://zenodo.org/records/6975118/files/g.hdf5?download=1"
    !wget -O raw/q.hdf5 "https://zenodo.org/records/6975118/files/q.hdf5?download=1"
    !wget -O raw/t.hdf5 "https://zenodo.org/records/6975118/files/t.hdf5?download=1"
    !wget -O raw/w.hdf5 "https://zenodo.org/records/6975118/files/w.hdf5?download=1"
    !wget -O raw/z.hdf5 "https://zenodo.org/records/6975118/files/z.hdf5?download=1"
    !python jetnet_preprocess.py
finally:
    os.chdir(orig_dir)

We specify the reconstruction loss below as

$$
\begin{align}
    \mathcal{L}_{MSE}(\Theta)_i 
    &=
    \sum_{k=1}^K\mathbb{E} [ \left(y_{ik} - { \hat{y}}_{ik}\right)^2 ] { \,\,= \sum_{k=1}^K\mathbb{E} \left[ \left(y_{ik} - \frac{f_i(x_{k}|\Theta) + 1}{\sum_{j=1}^{K}(f_j(x_{k}|\Theta) + 1)}\right)^2 \right]} 
\end{align}
$$

In [None]:
def LossMSE(labels, outs):
    # labels size: (Nb, nclasses) [true values]
    # outs size: (Nb, nclasses) [NN predictions]
    alphas = outs + 1
    S = torch.sum(alphas, 1).reshape(-1,1)
    probs = alphas / S
    return ((labels - probs)**2 + probs * (1 - probs) / (1 + S)).sum(1).mean()

The second component of the loss function is a KL Divergence term defined as

$$
\begin{align}
    \mathcal{L}_{KL}(\Theta)_i&=KL[D(\mathbf{{\hat{y}}}_i|\mathbf{\tilde{\alpha}_i})\|D(\mathbf{{ \hat{y}}}_i|\left \langle 1,\cdots,1 \right \rangle)] \\
    &=\log \left (\frac{\Gamma(\sum_{k=1}^K\tilde{\alpha}_{ik})}{\Gamma(K)\prod_{k=1}^K\Gamma(\tilde{\alpha}_{ik})} \right ) + \sum_{k=1}^K (\tilde{\alpha}_{ik}-1) \left [ \psi(\tilde{\alpha}_{ik}) - \psi\left(\sum_{j=1}^K \tilde{\alpha}_{ij}\right) \right] \nonumber
\end{align}
$$
where 
$$
\begin{equation*}
    \mathbf{\tilde{\alpha}_i} = \mathbf{y_i} + (1 - \mathbf{y_i}) \odot \mathbf{\alpha_i} { ~~~\textrm{and}~~~ \alpha_i = f_i(\mathbf{x}|\Theta) + 1}
\end{equation*}
$$
and $\psi(\cdot)$ is the digamma function.

In [None]:
def KLDiv(labels, outs):
    K = torch.tensor(labels.shape[-1]).float()
    alphas = outs + 1
    _alphas = labels + (1-labels)*alphas
    _S = torch.sum(_alphas, 1).reshape(-1,1)
    lognum = torch.lgamma(_S)
    logden = torch.lgamma(K*1.0) + torch.lgamma(_alphas).sum(1).reshape(-1,1)
    t2 = ((_alphas - 1) * (torch.digamma(_alphas) - torch.digamma(_S) )).sum(1).reshape(-1,1)
    return (lognum - logden + t2).mean()

## `TopData` Model

### Training

First, we load the pre-loaded parameters for our model. The name of the file corresponds to the $\lambda_t$ used in the loss. `nominal` means increasing over epochs with slope 0.1. `1.0` refers to the maximum $\lambda_t$. You can change the name of your file to use different strengths for the KL Divergence term.

In [7]:
file_name = 'UQPFIN_topdata_nominal_1.0_baseline.json'
parameters = json.load(open(f'json_files/{file_name}'))
parameters

{'outdir': './trained_models/',
 'outdictdir': './trained_model_dicts/',
 'Np': 60,
 'n_phiI': 128,
 'x_mode': 'sum',
 'phi_nodes': '100,100,64',
 'f_nodes': '64,100,100',
 'epochs': 50,
 'label': 'topdata_nominal_1.0_baseline',
 'batch_size': 250,
 'data_loc': '../datasets/',
 'data_type': 'topdata',
 'preload': False,
 'preload_file': '',
 'klcoef': 'nominal',
 'massrange': 'AND:0,10000.',
 'ptrange': 'AND:0,10000',
 'etarange': 'AND:-6,6',
 'skiplabels': '',
 'batchmode': True,
 'use_softmax': False,
 'use_dropout': False,
 'ndata': 0}

Then, we can train the model and save the checkpoints based on the highest accuracy

In [9]:
!python train.py --load-json json_files/UQPFIN_topdata_nominal_1.0_baseline.json

Traceback (most recent call last):
  File "c:\Users\khota\GDSVirtualTutorials\100325_UncertaintyQuantification\train.py", line 115, in <module>
    with open(args.load_json, 'rt') as f:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'UQPFIN_topdata_nominal_1.0_baseline.json'


### Evaluation

We can evaluate the results on a test set and visualize the uncertainties.

In [None]:
!python evaluate_model.py --data topdata --make-file --type edl

In [None]:
#optional parameters to only evaluate certain models
optional_dataset = 'topdata'
optional_tag = ''
results_dir = 'results/'
saved_model_loc = "./trained_models/"
saved_model_dict_loc = "./trained_model_dicts/"

result_files = sorted([f for f in os.listdir(results_dir) if optional_dataset in f and optional_tag in f
                      and "Ensemble" not in f and "MCDO" not in f and '.h5' in f and 'slope' not in f and 'mask' not in f])
print('\n'.join(result_files))

In [None]:
for modelname in result_files:
    model_results = {}
    mname = modelname[20:-3]
    filename = os.path.join(results_dir, modelname)
    
    f = h5py.File(filename, "r")
    labels, preds, oods, probs, uncs, sums = f['labels'][:], f['preds'][:], f['oods'][:], f['probs'][:], f['uncs'][:], f['sums'][:]
    f.close()
    acc = accuracy_score(labels[~oods], preds[~oods])*100
    if "topdata" in mname:
        probs2=probs
    else:
        skiplabels = np.unique(labels[oods])
        probs2=np.delete(probs, skiplabels, 1)

    if probs2.shape[1] == 2:
        probs2 = probs2[:, 1]
    
    # misclassification detection vs out-of-distribution detection
    if "baseline" in mname:
        oods = labels != preds
    
    this_file = os.path.join(saved_model_loc, modelname[8:-3])
    evaluator = ModelEvaluator(this_file)
    nparams = sum(p.numel() for p in evaluator.model.parameters())
    del evaluator
    gc.collect()
    
    # Sensory et al EDL Uncertainty
    auc = roc_auc_score(oods, uncs) * 100
    
    # D-STD Uncertainty
    sums = torch.from_numpy(sums).reshape(-1,1)
    probs = torch.from_numpy(probs)
    uncs = torch.sqrt(((probs*(1 - probs))/(sums + 1))).sum(1).numpy()
    uncs = torch.sqrt(((probs*(1 - probs))/(sums + 1)).sum(1)).numpy()
    auc_std = roc_auc_score(oods, uncs) * 100

    print("{} \t\t Params: {}\t Accuracy: {:.2f}% \t AUROC: {:.2f}% \t AUROC-STD: {:.2f}%".format(mname, nparams, acc, auc, auc_std))    


In [None]:
for modelname in result_files:
    key = modelname[20:-3] 
    print(key)
    
    if "jetnet" in key:
        l_max = 5
        fsize=18
        groups = ['QCD', 'QCD', 'Top', 'Bosons', 'Bosons']
    elif "jetclass" in key:
        l_max = 10
        fsize=10
        groups = ['QCD', 'Higgs', 'Higgs', 'Higgs', 'Higgs', 'Higgs', 'Bosons', 'Bosons', 'Top', 'Top']
    elif "JNqgmerged" in key:
        l_max = 4
    else:
        l_max = 2
        fsize=24
        groups = ["background", "signal"]
    groups = np.array(groups)
        
    uniq_names, ind = np.unique(groups, return_index=True)
    uniq_names = uniq_names[np.argsort(ind)]
        
        
    filename = os.path.join(results_dir, modelname)
    
    f = h5py.File(filename, "r")    
    labels, preds, maxprobs, sums, oods, uncs, probs = f['labels'][:], f['preds'][:], \
                                                f['maxprobs'][:], f['sums'][:], \
                                                f['oods'][:], f['uncs'][:], f['probs'][:]
    f.close()
    
    split = key.split('_')
    coeff = "_".join(split[1:-1])
    # Choose to use D-STD Uncertainty or not
    for use_std in [False]:
        high_unc = 0.8
        if use_std:
            sums = torch.from_numpy(sums).reshape(-1,1)
            probs = torch.from_numpy(probs)
            uncs = torch.sqrt(((probs*(1 - probs))/(sums + 1))).sum(1).numpy()
            coeff += "_std"
            high_unc = 0.5
        
        savefolder = "figures/{}/{}/{}/".format(split[0], split[-1], coeff)
        os.makedirs(savefolder, exist_ok=True)
        
        if 'baseline' in key:
            f1 = 'correct'
            f2 = 'incorrect'
            oods = labels != preds
        else:
            f1 = 'id'
            f2 = 'ood'
        
        
        mult_label = "Unc"
        if "_0_" in key:
            if "topdata" in key:
                multiple = 300
                if use_std:
                    multiple = 30
            elif "jetnet" in key:
                if "skiptop" in key:
                    multiple = 5
                elif "skiptwz" in key:
                    multiple = 25
                else:
                    multiple = 5*2.5
                    if "skipwz" in key and not use_std:
                        multiple *= 4
            elif "jetclass" in key:
                if "baseline" in key:
                    multiple = 15
                elif "skipwz" in key:
                    multiple = 5*2.5
                else:
                    multiple = 25
            uncs = uncs * multiple
            mult_label = str(multiple) + r' $\times$ Unc'
        
        for types in ['total', 'separate', 'total_log', 'separate_log']:
            # Uncertainty distribution
            plt.figure(figsize=(10,10))
            ax = plt.gca()
                
            if 'total' in types:
                ax.hist(uncs, bins=np.arange(0.,1.01,0.04), alpha = 0.7, color='blue', weights = (1./len(uncs)) * np.ones_like(uncs))
            else:
                ax.hist(uncs[~oods], bins=np.arange(0.,1.01,0.04), label=f1, alpha = 0.5, histtype = 'step', linewidth = 5, weights = (1./len(uncs[~oods])) * np.ones_like(uncs[~oods]))
                ax.hist(uncs[oods], bins=np.arange(0.,1.01,0.04), label=f2, alpha = 0.5, histtype = 'step', linewidth = 5, weights = (1./len(uncs[oods])) * np.ones_like(uncs[oods]))
                ax.legend(fontsize = 22)
            # ax.set_aspect((ax.get_xlim()[1] - ax.get_xlim()[0]) / (ax.get_ylim()[1] - ax.get_ylim()[0]))
            ax.set_xlabel("Uncertainty".replace("Unc", mult_label), fontsize=30)
            ax.set_ylabel("Fractional Number of Events", fontsize = 24)
            if 'log' in types:
                ax.set_ylabel("Log10(Fractional Number of Events)", fontsize = 24)
                ax.set_yscale('log')
            ax.tick_params(axis='both', which='major', labelsize=24)
            ax.yaxis.get_offset_text().set_fontsize(24)
            plt.tight_layout()
            # plt.savefig("{}/unc_normal_{}.pdf".format(savefolder, types), dpi = 150, bbox_inches='tight')
            
            
        for types in ['total', 'separate', 'total_log', 'separate_log']:
            # Uncertainty distribution
            plt.figure(figsize=(10,10))
            ax = plt.gca()
                
            if 'total' in types:
                ax.hist(uncs, bins=np.arange(0.,1.01,0.04), alpha = 0.7, color='blue', weights = (1./len(uncs)) * np.ones_like(uncs))
            else:
                ax.hist(uncs[~oods], bins=np.arange(0.,1.01,0.04), label=f1, alpha = 0.5, histtype = 'step', linewidth = 5)
                ax.hist(uncs[oods], bins=np.arange(0.,1.01,0.04), label=f2, alpha = 0.5, histtype = 'step', linewidth = 5)
                ax.legend(fontsize = 22)
            # ax.set_aspect((ax.get_xlim()[1] - ax.get_xlim()[0]) / (ax.get_ylim()[1] - ax.get_ylim()[0]))
            ax.set_xlabel("Uncertainty".replace("Unc", mult_label), fontsize=30)
            ax.set_ylabel("Number of Events", fontsize = 24)
            if 'log' in types:
                ax.set_ylabel("Log10(Number of Events)", fontsize = 24)
                ax.set_yscale('log')
            ax.tick_params(axis='both', which='major', labelsize=24)
            ax.yaxis.get_offset_text().set_fontsize(24)
            plt.tight_layout()
            # plt.savefig("{}/unc_{}.pdf".format(savefolder, types), dpi = 150, bbox_inches='tight')
            
        for types in ['normal', 'log']:
            # Uncertainty distribution
            plt.figure(figsize=(10,10))
            ax = plt.gca()
            
            for i in range(len(uniq_names)):
                indices = np.in1d(labels, np.where(groups == uniq_names[i])[0])
                if np.any(indices & oods) and "skip" in key:
                    lstyle = '--'
                    lbel = f"{uniq_names[i]} (OOD)"
                else:
                    lstyle = '-'
                    lbel = uniq_names[i]
                ax.hist(uncs[indices], bins=np.arange(0.,1.01,0.04), label=lbel, alpha = 0.5, histtype = 'step', linewidth = 5, linestyle = lstyle, weights = (1./len(uncs[indices])) * np.ones_like(uncs[indices]))
            ax.legend(fontsize = 22)
            # ax.set_aspect((ax.get_xlim()[1] - ax.get_xlim()[0]) / (ax.get_ylim()[1] - ax.get_ylim()[0]))
            ax.set_xlabel("Uncertainty".replace("Unc", mult_label), fontsize=30)
            ax.set_ylabel("Fractional Number of Events", fontsize = 24)
            if 'log' in types:
                ax.set_ylabel("Log10(Fractional Number of Events)", fontsize = 24)
                ax.set_yscale('log')
            ax.tick_params(axis='both', which='major', labelsize=24)
            ax.yaxis.get_offset_text().set_fontsize(24)
            plt.tight_layout()
            # plt.savefig("{}/unc_class_{}.pdf".format(savefolder, types), dpi = 150, bbox_inches='tight')
            
        if "baseline" in key:
            for types in ['normal', 'log']:
                # Uncertainty distribution
                plt.figure(figsize=(10,10))
                ax = plt.gca()

                for i in range(len(uniq_names)):
                    indices = np.in1d(labels, np.where(groups == uniq_names[i])[0])
                    if np.any(indices & oods) and "skip" in key:
                        lstyle = '--'
                        lbel = f"{uniq_names[i]} (OOD)"
                    else:
                        lstyle = '-'
                        lbel = uniq_names[i]
                    ax.hist(uncs[(labels == preds) & indices], bins=np.arange(0.,1.01,0.04), label=lbel, alpha = 0.5, histtype = 'step', linewidth = 5, linestyle = lstyle, weights = (1./len(uncs[(labels == preds) & indices])) * np.ones_like(uncs[(labels == preds) & indices]))
                ax.legend(fontsize = 22)
                # ax.set_aspect((ax.get_xlim()[1] - ax.get_xlim()[0]) / (ax.get_ylim()[1] - ax.get_ylim()[0]))
                ax.set_xlabel("Uncertainty".replace("Unc", mult_label), fontsize=30)
                ax.set_ylabel("Fractional Number of Events", fontsize = 24)
                if 'log' in types:
                    ax.set_ylabel("Log10(Fractional Number of Events)", fontsize = 24)
                    ax.set_yscale('log')
                ax.tick_params(axis='both', which='major', labelsize=24)
                ax.yaxis.get_offset_text().set_fontsize(24)
                plt.tight_layout()
                # plt.savefig("{}/unc_correct_class_{}.pdf".format(savefolder, types), dpi = 150, bbox_inches='tight')
            
        
        filetypes = ['total', f1, f2]
        indices = [oods | ~oods, ~oods, oods]
        
        for filetype, idx in zip(filetypes, indices):
            # Max Prob. vs Uncertainty distribution
            hist, _, _ = np.histogram2d(maxprobs[idx], uncs[idx], bins = [np.arange(0.,1.01,0.04), np.arange(0.,1.01,0.04)])
            
            fig, ax  = plt.subplots(1, 2, figsize=(10,10), gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1})
            heatmap = sns.heatmap(hist.T, annot=False, cmap='winter', ax=ax[0], cbar_ax=ax[1])
            ax[0].invert_yaxis()
            ax[0].set_xlabel("Max. Prob.", fontsize=30)
            ax[0].set_ylabel("Uncertainty".replace("Unc", mult_label), fontsize=30)
            ax[0].set_yticks(np.arange(0, 26, 5))
            ax[0].set_yticklabels([0, 0.2, 0.4, 0.6, 0.8, 1.0], rotation=0)
            ax[0].set_xticks(np.arange(0, 26, 5))
            ax[0].set_xticklabels([0, 0.2, 0.4, 0.6, 0.8, 1.0])
            ax[0].tick_params(axis='both', which='major', labelsize=20)
            cbar = heatmap.collections[0].colorbar
            cbar.ax.tick_params(labelsize=24)
            cbar.ax.yaxis.get_offset_text().set_fontsize(24)
            if 'topdata' in key or 'jetnet' in key:
                tick_values = cbar.get_ticks()
                cbar.set_ticklabels([f'{int(tick / 1000)}k' for tick in tick_values])
            plt.tight_layout()
            # plt.savefig("{}/unc_prob_{}.pdf".format(savefolder, filetype), dpi = 150, bbox_inches='tight')
        
        for filetype, idx in zip(filetypes, indices):
            # Max Prob. vs Uncertainty distribution (Log Scale)
            hist, _, _ = np.histogram2d(maxprobs[idx], uncs[idx], bins = [np.arange(0.,1.01,0.04), np.arange(0.,1.01,0.04)])
            
            fig, ax  = plt.subplots(1, 2, figsize=(10,10), gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1})
            heatmap = sns.heatmap(hist.T+1, annot=False, cmap='winter', ax=ax[0], cbar_ax=ax[1], norm=LogNorm())
            ax[0].invert_yaxis()
            ax[0].set_xlabel("Max. Prob.", fontsize=30)
            ax[0].set_ylabel("Uncertainty".replace("Unc", mult_label), fontsize=30)
            ax[0].set_yticks(np.arange(0, 26, 5))
            ax[0].set_yticklabels([0, 0.2, 0.4, 0.6, 0.8, 1.0], rotation=0)
            ax[0].set_xticks(np.arange(0, 26, 5))
            ax[0].set_xticklabels([0, 0.2, 0.4, 0.6, 0.8, 1.0])
            ax[0].tick_params(axis='both', which='major', labelsize=20)
            cbar = heatmap.collections[0].colorbar
            cbar.ax.tick_params(labelsize=24)
            plt.tight_layout()
            # plt.savefig("{}/unc_prob_{}_log.pdf".format(savefolder, filetype), dpi = 150, bbox_inches='tight')
   
        for filetype, idx in zip(filetypes, indices):
            # Labels vs Preds
            arr = np.zeros((l_max,l_max))
            for i in range(l_max):
                for j in range(l_max):
                    arr[j, i] = np.sum((labels[idx] == i) & (preds[idx] == j))
            
            fig, ax  = plt.subplots(1, 2, figsize=(10,10), gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1})
            heatmap = sns.heatmap(arr, annot=True, cmap='winter', annot_kws={"size": fsize, "weight": "bold"}, ax=ax[0], cbar_ax=ax[1])
            ax[0].invert_yaxis()
            ax[0].set_xlabel("Labels", fontsize=30)
            ax[0].set_ylabel("Preds", fontsize=30)
            ax[0].tick_params(axis='both', which='major', labelsize=24)
            ax[0].tick_params(axis='y', labelrotation=0)
            cbar = heatmap.collections[0].colorbar
            cbar.ax.tick_params(labelsize=24)
            cbar.ax.yaxis.get_offset_text().set_fontsize(24)
            if 'topdata' in key or 'jetnet' in key:
                tick_values = cbar.get_ticks()
                cbar.set_ticklabels([f'{int(tick / 1000)}k' for tick in tick_values])
            plt.tight_layout()
            # plt.savefig("{}/labels_preds_{}.pdf".format(savefolder, filetype), dpi = 150, bbox_inches='tight')
            
        for types in ['normal', 'log']:
            for filetype, idx in zip(filetypes, indices):
                # Labels vs Preds + Uncertainty
                uncertainty_bins = np.zeros((5*l_max, l_max))  # 5 bins for uncertainty
                for i in range(l_max):
                    for j in range(l_max):
                        filtered_indices = (labels[idx] == i) & (preds[idx] == j)
                        if np.sum(filtered_indices) > 0:
                            # Create histogram for uncertainty values in 5 bins
                            hist, _ = np.histogram(uncs[idx][filtered_indices], bins=5, range=(0, 1))
                            for k in range(5):
                                uncertainty_bins[5*j+k, i] = hist[k]

                fig, ax  = plt.subplots(1, 2, figsize=(10,10), gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1})
                if 'log' in types:
                    norm = LogNorm()
                else:
                    norm = None
                uncertainty_bins_clean = np.where(uncertainty_bins == 0, 1, uncertainty_bins)
                heatmap = sns.heatmap(uncertainty_bins_clean, annot=False, cmap='Spectral', annot_kws={"size": fsize, "weight": "bold"}, ax=ax[0], cbar_ax=ax[1], norm=norm)
                ax[0].invert_yaxis()
                uacm_axsize = 30
                if "jetclass" in key:
                    uacm_axsize = 20
                ax[0].set_xlabel("True Label", fontsize=uacm_axsize)
                ax[0].set_ylabel("Predicted Label + Uncertainty".replace("Unc", mult_label), fontsize=uacm_axsize)
                for jj in np.arange(5,5*l_max, 5):
                    ax[0].axhline(jj, linewidth=3, color='white', zorder=1)
                    ax[0].axvline(jj//5, linewidth=3, color='white', zorder=1)
                for i in range(l_max):
                    ax[0].add_patch(
                        plt.Rectangle((i, 5*i), 1, 5,
                                      fill=False, edgecolor='black', linewidth=3, zorder=3, clip_on=False)
                    )
    
                ax[0].set_yticks(np.arange(0, 5*l_max, 5))
                ax[0].set_yticklabels(np.arange(0, l_max), rotation=0)
                uacm_fsize = 24
                if "jetclass" in key:
                    uacm_fsize = 16
                ax[0].tick_params(axis='both', which='major', labelsize=uacm_fsize)
                cbar = heatmap.collections[0].colorbar
                cbar.ax.tick_params(labelsize=uacm_fsize)
                cbar.ax.yaxis.get_offset_text().set_fontsize(uacm_fsize)
                if ('topdata' in key or 'jetnet' in key) and 'normal' in types:
                    tick_values = cbar.get_ticks()
                    cbar.set_ticklabels([f'{int(tick / 1000)}k' for tick in tick_values])
                plt.tight_layout()
                # plt.savefig("{}/labels_preds_unc_{}.pdf".format(savefolder, filetype), dpi = 150, bbox_inches='tight')

visualize in altent space and compare and then do same with jetnet

## Training

**References**

https://introtodeeplearning.com/2021/slides/6S191_MIT_DeepLearning_L7.pdf

https://iopscience.iop.org/article/10.1088/2632-2153/ade51b

https://arxiv.org/abs/1905.09638

https://proceedings.mlr.press/v48/gal16.html

https://arxiv.org/abs/1612.01474

https://papers.nips.cc/paper_files/paper/2018/hash/a981f2b708044d6fb4a71a1463242520-Abstract.html