In [1]:
import numpy as np
import pandas as pd

from scipy.spatial.distance import jensenshannon
from scipy.stats import linregress, kstest

# Parameters

In [2]:
DLL_COLUMNS = ['RichDLLe', 'RichDLLk', 'RichDLLmu', 'RichDLLp', 'RichDLLbt']
PARTICLE = 'pion'
output_dir = 'results'
run_date = '2024-dec-14'
output_dir_base = f'{output_dir}/{run_date}/layer'

# !unzip -qq '/content/drive/MyDrive/cern/data/results/30x30/dp_0.01/2024-oct-04/pion_sample_30x30.zip'
# y_sample = np.load('/content/results/pion_y_real.npy')
# x_sample = np.load('/content/results/pion_x_real.npy')
# t_generated = np.load('/content/results/t_generated.npy')

# Functions

In [3]:
def estimate_distances(y_real, y_generated, uncertainty_scores, uncertainty_type = None, bin_type = 'linear',
                                                 particle_index = 0, metric='JS', n_rows = 2, n_cols = 5, dll_columns=DLL_COLUMNS):
  n_bins = n_rows * n_cols

  targets = np.array(y_real[:, particle_index])
  predictions = np.array(y_generated[:, particle_index])
  uncertainty_scores = np.array(uncertainty_scores)

  if uncertainty_type == 'MCD':
    uncertainty_scores = uncertainty_scores[:, particle_index]

  if bin_type == 'linear':
    bin_edges = np.linspace(uncertainty_scores.min(), uncertainty_scores.max(), n_bins + 1)
  else: # Quantiles
    bin_edges = np.quantile(uncertainty_scores, np.linspace(0, 1, n_bins + 1))

  # Digitize returns sample indices per bin
  bin_indices = np.digitize(uncertainty_scores, bin_edges)
  distances = []

  for i in range(10):
    indices = bin_indices == i + 1

    mins = targets[indices].min(), predictions[indices].min()
    maxs = targets[indices].max(), predictions[indices].max()

    hist_range = min(mins), max(maxs)

    targets_hist = np.histogram(targets[indices], 25, hist_range, True)[0]
    predictions_hist = np.histogram(
        predictions[indices], 25, hist_range, True)[0]

    if metric == 'JS':
      dist = jensenshannon(predictions_hist, targets_hist)
    else:
      dist = kstest(predictions[indices], targets[indices]).statistic

    distances += [dist]

  #print(f"{metric} Distances:\n" + ", ".join([str(dist) for dist in distances]))
  return bin_edges, distances


def estimate_correlation(all_bin_ranges, all_distances, dll_columns=DLL_COLUMNS):

  correlation_coefficient = []
  for i in range(5):
    bin_ranges = np.mean([all_bin_ranges[i][1:], all_bin_ranges[i][:-1]], 0)
    regress = linregress(bin_ranges, all_distances[i])
    correlation_coefficient += [regress.rvalue]
    #print(f'Correlation coefficient for {dll_columns[i]}:', regress.rvalue)

  return correlation_coefficient

In [4]:
def calculate_stats(all_correlations, columns):
    df = pd.DataFrame(all_correlations, columns=columns)

    means = df.mean(axis=0)
    stds = df.std(axis=0)
    df.loc['Mean'] = means
    df.loc['Std'] = stds
    print(df)

In [5]:
def calculate_correlations(metric, uncertainty_type, uncertainty_data, y_sample, t_generated, N = 30):
    all_correlations = []

    for j in range(N):
        all_bin_edges, all_distances = [], []
        for i in range(5):
            bin_edges, distances = estimate_distances(
                y_sample, t_generated, uncertainty_data[j],
                uncertainty_type=uncertainty_type, bin_type='quantiles',
                particle_index=i, metric=metric
            )

            all_bin_edges += [bin_edges]
            all_distances += [distances]

        all_correlations.append(estimate_correlation(all_bin_edges, all_distances))

    return all_correlations

# FD

## Load data

In [19]:
layer = 14
dir = f'{output_dir_base}{layer}/'
fd_uncertainty_all = np.load(f'{dir}{PARTICLE}_fd_uncertainty.npy')
y_sample = np.load(f'{dir}{PARTICLE}_y_real.npy')
x_sample = np.load(f'{dir}{PARTICLE}_x_real.npy')
t_generated = np.load(f'{dir}{PARTICLE}_t_generated.npy')

## Features Densities with JS

In [20]:
all_correlations = calculate_correlations('JS', 'FD', fd_uncertainty_all, y_sample, t_generated, len(fd_uncertainty_all))
calculate_stats(all_correlations, DLL_COLUMNS)

      RichDLLe  RichDLLk  RichDLLmu  RichDLLp  RichDLLbt
0     0.964998  0.828683   0.957796  0.719227   0.924053
1     0.962970  0.833798   0.955700  0.716298   0.927429
2     0.964155  0.827681   0.957481  0.709216   0.910402
3     0.960547  0.827948   0.962187  0.715606   0.897411
4     0.962745  0.826722   0.964088  0.712928   0.903402
5     0.962162  0.823687   0.955628  0.712516   0.900312
6     0.962422  0.826518   0.960789  0.717771   0.908824
7     0.963515  0.830457   0.959503  0.711259   0.917277
8     0.963370  0.828275   0.958104  0.705265   0.912214
9     0.957619  0.830407   0.954176  0.724982   0.928754
10    0.959431  0.816098   0.959749  0.712633   0.899848
11    0.964809  0.827028   0.959149  0.718842   0.914598
12    0.965831  0.815356   0.963090  0.721774   0.890789
13    0.961931  0.824603   0.955755  0.698396   0.905629
14    0.965178  0.823605   0.951506  0.719985   0.904464
15    0.967790  0.831956   0.962054  0.719260   0.924090
16    0.955517  0.825910   0.96

## Features Densities with KS

In [None]:
all_correlations = calculate_correlations('KS', 'FD', fd_uncertainty_all, y_sample, t_generated, len(fd_uncertainty_all))
calculate_stats(all_correlations, DLL_COLUMNS)

# MCD

## Load data

### DROPOUT 0.05

In [None]:
dir = '/content/drive/MyDrive/cern/data/results/30x30/dp_0.05/2024-oct-20/'

mcd_all_uncertainties  = np.load(dir + f'{PARTICLE}_mcd_uncertainty.npy')

### DROPOUT 0.01

In [None]:
!unzip -qq '/content/drive/MyDrive/cern/data/results/30x30/dp_0.01/2024-oct-04/pion_uncertainty_30x30_reps.zip'
mcd_all_uncertainties  = np.load('/content/' + f'{PARTICLE}_mcd_uncertainty_30_reps.npy')

replace pion_mcd_uncertainty_30_reps.npy? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


### DROPOUT 0.1

In [None]:
dir = '/content/drive/MyDrive/cern/data/results/30x30/dp_0.1/2024-oct-18/'

mcd_all_uncertainties  = np.load(dir + f'{PARTICLE}_mcd_uncertainty.npy')

## MCD with JS

In [None]:
all_correlations = calculate_correlations('JS', 'MCD', mcd_all_uncertainties)
calculate_stats(all_correlations, DLL_COLUMNS)

      RichDLLe  RichDLLk  RichDLLmu  RichDLLp  RichDLLbt
0     0.931712  0.892428   0.927162  0.981013   0.752428
1     0.932801  0.913138   0.942275  0.985214   0.747845
2     0.915555  0.909380   0.965178  0.971393   0.769795
3     0.923141  0.897347   0.977387  0.991940   0.742450
4     0.922550  0.913731   0.960804  0.968792   0.809786
5     0.908788  0.906154   0.946164  0.971405   0.776485
6     0.926276  0.917019   0.972417  0.986013   0.809021
7     0.930464  0.901918   0.980704  0.984468   0.736220
8     0.926601  0.899143   0.927712  0.976093   0.775351
9     0.932433  0.907352   0.971200  0.986352   0.768836
10    0.930012  0.897800   0.954017  0.978606   0.804160
11    0.933869  0.906648   0.933330  0.987977   0.795520
12    0.948315  0.905112   0.848564  0.987393   0.743769
13    0.952713  0.906916   0.927898  0.976741   0.862072
14    0.943223  0.911184   0.976087  0.992929   0.723222
15    0.932740  0.910928   0.898221  0.984948   0.801496
16    0.949656  0.903927   0.97

## MCD with KS

In [None]:
all_correlations = calculate_correlations('KS', 'MCD', mcd_all_uncertainties)
calculate_stats(all_correlations, DLL_COLUMNS)

      RichDLLe  RichDLLk  RichDLLmu  RichDLLp  RichDLLbt
0     0.900017  0.894222   0.897238  0.967978   0.616516
1     0.897940  0.917895   0.990227  0.978043   0.658763
2     0.878441  0.907950   0.983142  0.956923   0.631530
3     0.884357  0.901341   0.950946  0.979883   0.606102
4     0.895306  0.915214   0.995029  0.948828   0.681325
5     0.843513  0.907032   0.988540  0.956222   0.641202
6     0.878221  0.922935   0.970313  0.976142   0.672499
7     0.870327  0.908059   0.985194  0.972894   0.638253
8     0.829179  0.903771   0.969962  0.963994   0.664239
9     0.842508  0.907928   0.994250  0.973793   0.667476
10    0.893678  0.898061   0.901402  0.960095   0.657557
11    0.842354  0.914866   0.900312  0.978566   0.640179
12    0.938817  0.911381   0.888018  0.973989   0.605006
13    0.922875  0.908213   0.990911  0.959045   0.767105
14    0.862864  0.915870   0.967733  0.983107   0.597293
15    0.890918  0.913653   0.942021  0.970654   0.667908
16    0.929453  0.906547   0.94