In [12]:
import numpy as np
import pandas as pd

from scipy.spatial.distance import jensenshannon
from scipy.stats import linregress, kstest

# Parameters

In [13]:
DLL_COLUMNS = ['RichDLLe', 'RichDLLk', 'RichDLLmu', 'RichDLLp', 'RichDLLbt']
PARTICLE = 'pion'
output_dir = 'results'
run_date = '2024-dec-14'
output_dir_base = f'{output_dir}/{run_date}/layer'

# !unzip -qq '/content/drive/MyDrive/cern/data/results/30x30/dp_0.01/2024-oct-04/pion_sample_30x30.zip'
# y_sample = np.load('/content/results/pion_y_real.npy')
# x_sample = np.load('/content/results/pion_x_real.npy')
# t_generated = np.load('/content/results/t_generated.npy')

# Functions

In [14]:
def estimate_distances(y_real, y_generated, uncertainty_scores, uncertainty_type = None, bin_type = 'linear',
                                                 particle_index = 0, metric='JS', n_rows = 2, n_cols = 5, dll_columns=DLL_COLUMNS):
  n_bins = n_rows * n_cols

  targets = np.array(y_real[:, particle_index])
  predictions = np.array(y_generated[:, particle_index])
  uncertainty_scores = np.array(uncertainty_scores)

  if uncertainty_type == 'MCD':
    uncertainty_scores = uncertainty_scores[:, particle_index]

  if bin_type == 'linear':
    bin_edges = np.linspace(uncertainty_scores.min(), uncertainty_scores.max(), n_bins + 1)
  else: # Quantiles
    bin_edges = np.quantile(uncertainty_scores, np.linspace(0, 1, n_bins + 1))

  # Digitize returns sample indices per bin
  bin_indices = np.digitize(uncertainty_scores, bin_edges)
  distances = []

  for i in range(10):
    indices = bin_indices == i + 1

    mins = targets[indices].min(), predictions[indices].min()
    maxs = targets[indices].max(), predictions[indices].max()

    hist_range = min(mins), max(maxs)

    targets_hist = np.histogram(targets[indices], 25, hist_range, True)[0]
    predictions_hist = np.histogram(
        predictions[indices], 25, hist_range, True)[0]

    if metric == 'JS':
      dist = jensenshannon(predictions_hist, targets_hist)
    else:
      dist = kstest(predictions[indices], targets[indices]).statistic

    distances += [dist]

  #print(f"{metric} Distances:\n" + ", ".join([str(dist) for dist in distances]))
  return bin_edges, distances


def estimate_correlation(all_bin_ranges, all_distances, dll_columns=DLL_COLUMNS):

  correlation_coefficient = []
  for i in range(5):
    bin_ranges = np.mean([all_bin_ranges[i][1:], all_bin_ranges[i][:-1]], 0)
    regress = linregress(bin_ranges, all_distances[i])
    correlation_coefficient += [regress.rvalue]
    #print(f'Correlation coefficient for {dll_columns[i]}:', regress.rvalue)

  return correlation_coefficient

In [15]:
def calculate_stats(all_correlations, columns):
    df = pd.DataFrame(all_correlations, columns=columns)

    means = df.mean(axis=0)
    stds = df.std(axis=0)
    df.loc['Mean'] = means
    df.loc['Std'] = stds
    print(df)

In [16]:
def calculate_correlations(metric, uncertainty_type, uncertainty_data, y_sample, t_generated, N = 30):
    all_correlations = []

    for j in range(N):
        all_bin_edges, all_distances = [], []
        for i in range(5):
            bin_edges, distances = estimate_distances(
                y_sample, t_generated, uncertainty_data[j],
                uncertainty_type=uncertainty_type, bin_type='quantiles',
                particle_index=i, metric=metric
            )

            all_bin_edges += [bin_edges]
            all_distances += [distances]

        all_correlations.append(estimate_correlation(all_bin_edges, all_distances))

    return all_correlations

# FD

## Load data

In [17]:
layer = 14
dir = f'{output_dir_base}{layer}/'
fd_uncertainty_all = np.load(f'{dir}{PARTICLE}_fd_uncertainty.npy')
y_sample = np.load(f'{dir}{PARTICLE}_y_real.npy')
x_sample = np.load(f'{dir}{PARTICLE}_x_real.npy')
t_generated = np.load(f'{dir}{PARTICLE}_t_generated.npy')

## Features Densities with JS

In [18]:
all_correlations = calculate_correlations('JS', 'FD', fd_uncertainty_all, y_sample, t_generated, len(fd_uncertainty_all))
calculate_stats(all_correlations, DLL_COLUMNS)

      RichDLLe  RichDLLk  RichDLLmu  RichDLLp  RichDLLbt
0     0.964998  0.828683   0.957796  0.719227   0.924053
1     0.962970  0.833798   0.955700  0.716298   0.927429
2     0.964155  0.827681   0.957481  0.709216   0.910402
3     0.960547  0.827948   0.962187  0.715606   0.897411
4     0.962745  0.826722   0.964088  0.712928   0.903402
5     0.962162  0.823687   0.955628  0.712516   0.900312
6     0.962422  0.826518   0.960789  0.717771   0.908824
7     0.963515  0.830457   0.959503  0.711259   0.917277
8     0.963370  0.828275   0.958104  0.705265   0.912214
9     0.957619  0.830407   0.954176  0.724982   0.928754
10    0.959431  0.816098   0.959749  0.712633   0.899848
11    0.964809  0.827028   0.959149  0.718842   0.914598
12    0.965831  0.815356   0.963090  0.721774   0.890789
13    0.961931  0.824603   0.955755  0.698396   0.905629
14    0.965178  0.823605   0.951506  0.719985   0.904464
15    0.967790  0.831956   0.962054  0.719260   0.924090
16    0.955517  0.825910   0.96

## Features Densities with KS

In [19]:
all_correlations = calculate_correlations('KS', 'FD', fd_uncertainty_all, y_sample, t_generated, len(fd_uncertainty_all))
calculate_stats(all_correlations, DLL_COLUMNS)

      RichDLLe  RichDLLk  RichDLLmu  RichDLLp  RichDLLbt
0     0.922542  0.769144   0.877365  0.651841   0.740512
1     0.923546  0.769065   0.871600  0.665569   0.748514
2     0.921210  0.767971   0.867018  0.651421   0.746251
3     0.923725  0.773321   0.873112  0.646618   0.760194
4     0.918703  0.768359   0.871497  0.651562   0.734220
5     0.918830  0.761066   0.872742  0.645511   0.724752
6     0.925732  0.762774   0.865273  0.651702   0.720942
7     0.923746  0.772517   0.869374  0.663329   0.744077
8     0.917762  0.771504   0.871476  0.654244   0.738877
9     0.923704  0.768276   0.869120  0.653000   0.738782
10    0.923630  0.762850   0.861909  0.654097   0.747004
11    0.924658  0.773738   0.870839  0.652349   0.753925
12    0.924652  0.771384   0.875302  0.650035   0.716724
13    0.924116  0.763536   0.869784  0.648006   0.725396
14    0.925129  0.769977   0.878200  0.657403   0.753944
15    0.922737  0.768951   0.884172  0.654313   0.750047
16    0.921081  0.768106   0.86

# MCD

## Load data

In [20]:
output_dir_base = f'{output_dir}/{run_date}/dp'
dp = 0.005
dir = f'{output_dir_base}{dp}/'

mcd_all_uncertainties  = np.load(f'{dir}{PARTICLE}_mcd_uncertainty.npy')
y_sample = np.load(f'{dir}{PARTICLE}_y_real.npy')
x_sample = np.load(f'{dir}{PARTICLE}_x_real.npy')
t_generated = np.load(f'{dir}{PARTICLE}_t_generated.npy')

## MCD with JS

In [21]:
all_correlations = calculate_correlations('JS', 'MCD', mcd_all_uncertainties, y_sample, t_generated, len(mcd_all_uncertainties))
calculate_stats(all_correlations, DLL_COLUMNS)

      RichDLLe  RichDLLk  RichDLLmu  RichDLLp  RichDLLbt
0     0.973253  0.963154   0.791120  0.994416   0.652639
1     0.881841  0.973327   0.688996  0.985435   0.849597
2     0.927235  0.972034   0.695727  0.988791   0.711342
3     0.831457  0.979970   0.454474  0.979293   0.594883
4     0.919326  0.971196   0.262863  0.933239   0.128607
5     0.837797  0.974454   0.448531  0.990109   0.879149
6     0.756633  0.970568   0.590062  0.988438   0.839090
7     0.893355  0.970251  -0.404310  0.991083   0.830400
8     0.756830  0.985200   0.315504  0.993381   0.957548
9     0.917053  0.980872   0.931814  0.994931   0.062460
10    0.968043  0.973316   0.793065  0.989291   0.706634
11    0.903435  0.972637   0.418876  0.987633   0.208458
12    0.787884  0.972268   0.710490  0.983281   0.472637
13    0.880449  0.973242  -0.311953  0.979620   0.917198
14    0.985255  0.985513   0.409213  0.987021   0.809063
15    0.788872  0.970303   0.491362  0.986855   0.869733
16    0.699160  0.974687  -0.73

## MCD with KS

In [22]:
all_correlations = calculate_correlations('KS', 'MCD', mcd_all_uncertainties, y_sample, t_generated, len(mcd_all_uncertainties))
calculate_stats(all_correlations, DLL_COLUMNS)

      RichDLLe  RichDLLk  RichDLLmu  RichDLLp  RichDLLbt
0     0.955755  0.974492   0.787496  0.993640   0.467527
1     0.960322  0.980320   0.815629  0.976826   0.793889
2     0.919801  0.977731   0.770223  0.985655   0.509864
3     0.616106  0.982070   0.548455  0.986697   0.515587
4     0.968864  0.978648   0.694839  0.907667   0.020475
5     0.427455  0.981011   0.670471  0.993651   0.674428
6     0.845947  0.979401   0.582704  0.989381   0.708151
7     0.908087  0.978553   0.690598  0.990659   0.723877
8     0.917245  0.986376   0.378330  0.996272   0.947498
9     0.958040  0.979361   0.958014  0.987374  -0.624392
10    0.964212  0.980245   0.905931  0.991920   0.728027
11    0.905119  0.979500   0.411069  0.981594   0.117017
12    0.760294  0.975994   0.544271  0.988746   0.269091
13    0.740507  0.985134   0.378129  0.986501   0.920456
14    0.981447  0.981425   0.621426  0.990853   0.526248
15    0.150911  0.974731   0.617243  0.989443   0.563450
16    0.817843  0.980399   0.14