### Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt 
import pandas as pd
from sklearn import preprocessing
import glob
import sys
import random
sys.path.append('../src')
from utils import load_dataset, plot_spectrum, save_object, load_object, subset_dataset_from_indices, plot_prediction_hist

pd.options.display.max_rows = 4000

# 1. Generate RANDOM OOD data (ID classification + OOD detection scores)

In [None]:
seed = 40
rng = np.random.default_rng(seed)
use_retention_time = True
data_root_path = "../data"
id_dataset_name = "spectrum_exported"
id_dataset = load_dataset(data_dir=data_root_path + '/id/', dataset_name=id_dataset_name, use_retention_time=use_retention_time, verbose=True)
sze = (10000, id_dataset['data_points'].shape[1])

### Uniform noise

In [None]:
uniform = rng.uniform(0, 1, sze).astype('float32')

### Gaussian noise

In [None]:
mu = 0.5
std = 1.0
gaussian = rng.normal(mu, std, size=sze).astype('float32')

## Compare ID distribution and syntetic OOD distributions

In [None]:
_ = plt.hist(id_dataset['data_points'][0], bins = 20, density = True)

In [None]:
_ = plt.hist(uniform[0], bins = 20, density = True)

In [None]:
_ = plt.hist(gaussian[0], bins = 20, density = True)

In [None]:
print(f"ID - mean: {id_dataset['data_points'].mean():.2f}, std: {id_dataset['data_points'].std():.2f}")
print(f"Uniform - mean: {uniform.mean():.2f}, std: {uniform.std():.2f}")
print(f"Gaussian - mean: {gaussian.mean():.2f}, std: {gaussian.std():.2f}")

In [None]:
fake_metadata = pd.DataFrame({'system' : [-1], 'annotator_ID': [-1], 'measurement_number': [-1], 'compound': ['unknown']})
plot_spectrum(uniform[1], fake_metadata.iloc[0])

In [None]:
plot_spectrum(gaussian[1], fake_metadata.iloc[0])

#### Save to disk

In [None]:
suffix = "_rt" if use_retention_time else ""
save_path = f"{data_root_path}/ood/uniform/"
save_object(uniform, save_path + f"uniform{suffix}.pkl")

In [None]:
save_path = f"{data_root_path}/ood/gaussian/"
save_object(gaussian, save_path + f"gaussian{suffix}.pkl")

# 2. Generate QUANTITATIVE syntetic OOD data (ID classification + OOD detection scores)
* kvatitativni umělý - vyber si subset sloucenin, ktery budes povazovat za ood (redukujes pocet trid), a crossval zmer uspesnost pro ruzny pomery id/ood;
* -> z celkových 70 tříd rozdělím ID/OOD dataset na 65/5, 50/20, 40/30, atd. tříd. Na to hodím 5-fold cross-validaci pro každou kombinaci tříd
  
#### Splits:
| ratio | ID num_classes | ID samples | OOD num_classes | OOD samples | 
| ----- | -------------- | ---------- | -------------- | ------------ |
| 0.05  |     67         |     22451  |        3       |     979      |
| 0.1   |     63         |     21111  |        7       |     2319     |
| 0.2   |     56         |     18694  |        14      |     4736     |

In [None]:
data = load_dataset(data_dir='../data/id/', dataset_name='spectrum_exported', use_retention_time=False, verbose=True)
X, y, metadata = data['data_points'], data['data_labels'], data['metadata']

In [None]:
def split_data_distinct_classes(data, labels, ood_ratio, seed):
    random.seed(seed)
    num_classes = np.unique(y).size
    
    ood_num_classes = int(num_classes * ood_ratio)
    id_num_classes = num_classes - ood_num_classes
    print(f"Number of classes: ID {id_num_classes}, OOD {ood_num_classes}")
    
    ood_labels = np.array(random.sample(np.unique(labels).tolist(), ood_num_classes))
    id_labels = np.unique(labels)[~np.isin(np.unique(labels), ood_labels)]
    return id_labels, ood_labels

## ``Cross-validation'' for each ID-to-OOD ratio

In [None]:
seed = 100
ood_ratio = 0.1
k = 10
data_root_path = "../data"
labels = np.unique(sorted(metadata['compound']))

id_labels = []
ood_labels = []
for k in range(k):
    id_labels_, ood_labels_ = split_data_distinct_classes(X, y, ood_ratio, seed+k)
    save_path = f"{data_root_path}/id/synthetic_spectrum/"
    save_object(id_labels_, save_path + f"synthetic_spectrum_{ood_ratio}_{seed+k}.pkl")
    save_path = f"{data_root_path}/ood/synthetic_spectrum/"
    save_object(ood_labels_, save_path + f"synthetic_spectrum_{ood_ratio}_{seed+k}.pkl")
    id_labels.append(id_labels_)
    ood_labels.append(ood_labels_)

## Extension of the Cross-Validation where similar substances are not divided into different splits

In [None]:
labels = sorted(np.unique(metadata['compound']))
similar_classes = labels[16:21]
print(similar_classes)

similar_classes_labels = []
for i in similar_classes:
    similar_classes_labels.append(y[metadata['compound'] == i][0])
print(similar_classes_labels)

In [None]:
data_root_path = "../data"
seed = 42
k = 0
num_classes = len(labels)
ood_ratio = round(len(similar_classes) / num_classes, 2)
ood_num_classes = round(num_classes * ood_ratio)
id_num_classes = num_classes - ood_num_classes
print(f"Number of classes: ID {id_num_classes}, OOD {ood_num_classes}")

ood_labels = similar_classes_labels
id_labels = np.unique(y)[~np.isin(np.unique(y), ood_labels)]

save_path = f"{data_root_path}/id/synthetic_spectrum/"
save_object(id_labels, save_path + f"synthetic_spectrum_{ood_ratio}_{seed+k}.pkl")
save_path = f"{data_root_path}/ood/synthetic_spectrum/"
save_object(ood_labels, save_path + f"synthetic_spectrum_{ood_ratio}_{seed+k}.pkl")
print(id_labels)
print(ood_labels)

# invert the ratio
ood_ratio = round(1 - (len(similar_classes) / num_classes), 2)
ood_num_classes = round(num_classes * ood_ratio)
id_num_classes = num_classes - ood_num_classes
print(f"Number of classes: ID {id_num_classes}, OOD {ood_num_classes}")

id_labels = similar_classes_labels
ood_labels = np.unique(y)[~np.isin(np.unique(y), ood_labels)]

save_path = f"{data_root_path}/id/synthetic_spectrum/"
save_object(id_labels, save_path + f"synthetic_spectrum_{ood_ratio}_{seed+k}.pkl")
save_path = f"{data_root_path}/ood/synthetic_spectrum/"
save_object(ood_labels, save_path + f"synthetic_spectrum_{ood_ratio}_{seed+k}.pkl")
print(id_labels)
print(ood_labels)

### Run experiments using the `main.py' script
* evaluate all OOD detectors on each split and save the metrics to log file and the OOD score pdf curves into folder
* produce results for a specific `ood_to_id_ratio` as a mean of the fold values and save it to file

In [None]:
def load_performance(performance_root_path, model, use_retention_time, ood_to_id_ratio, k):
    suffix = "_w_rt" if use_retention_time else ""
    seed = int(ood_to_id_ratio * 1000 + k)
    load_path = f"{performance_root_path}/seed-{seed}/{model}{suffix}/synthetic_spectrum_{ood_to_id_ratio}_{seed}/synthetic_spectrum_{ood_to_id_ratio}_{seed}/"
    col_names = ["detector_name",	"auroc",	"fpr.95",	"dtacc",	"auin",	"auout",	"id_accuracy",	"id_precision",	"id_recall"]
    performance = pd.read_csv(load_path + f"ood-synthetic_spectrum_{ood_to_id_ratio}_{seed}.csv", names=col_names, header=0)
    return performance
         
performance_root_path = "../logs"
model = "svm"
num_folds = 10
ood_to_id_ratios = [0.1, 0.9]

final_perf = {}
for use_retention_time in [False, True]:
    suffix = "w_rt" if use_retention_time else "wo_rt"
    final_perf[suffix] = {}
    for ratio in ood_to_id_ratios:
        perfs = []
        for k in range(num_folds):
            perfs.append(load_performance(performance_root_path, model, use_retention_time=use_retention_time, ood_to_id_ratio=ratio, k=k))
        final_perf[suffix][ratio] = pd.concat(perfs)
        final_perf[suffix][ratio] = final_perf[suffix][ratio].drop('dtacc', axis=1)
        final_perf[suffix][ratio] = final_perf[suffix][ratio].drop(final_perf[suffix][ratio][final_perf[suffix][ratio]['detector_name'] == 'vim'].index, axis=0)
        final_perf[suffix][ratio] = final_perf[suffix][ratio].groupby(["detector_name"]).mean()
        final_perf[suffix][ratio] = final_perf[suffix][ratio].round(4)

In [None]:
final_perf['wo_rt'][0.1]

In [None]:
final_perf['w_rt'][0.1][final_perf['w_rt'][0.1]['detector_name'] == 'msp']

In [None]:
final_perf['w_rt'][0.1].groupby(['detector_name']).agg(list)

# 3. Process QUALITATIVE OOD data (ID classification + OOD detection scores)
* ID dataset = `spectrum_exported`
* OOD dataset = `M29_9_system2` (GCxGC containing 2000 x 460 spectrums)

## Evaluation pipeline:
1. export `M29_9_system2` as a OOD dataset with shape (2000 * 460, 801)
2. create classification predictions and compute OOD scores
3. select threshold for OOD detection so that we have 95% TPR on ID dataset
4. produce final predictions by running OOD detector with found threshold, reject classifier's predictions for those that marked as OOD samples
5. vizualize the prediction on GCxGC by coloring spectrum according to the final predictions

### 0. Load and visualize OOD data as a GCxGC diagram 

In [None]:
data_root_path = "../data"
ood_dataset_name = "M29_9_system2"
gcxgc = np.load(f"{data_root_path}/ood/{ood_dataset_name}/{ood_dataset_name}.npy").astype('float32')
print(gcxgc.shape)

#### TIC GCxGC by summing the first dimension

In [None]:
import matplotlib
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap

cmap_name = 'jet'
cmap = matplotlib.colormaps[cmap_name]
colors_lst = [(0.99, 0.99, 0.99)]
start_i = 18
for i in range(start_i, 255-start_i, 1):
    (r, g, b, _) = cmap(int(i))
    colors_lst.append((r, g, b))

# set OOD as white and invalid measurements as black
custom_cmap = LinearSegmentedColormap.from_list('jet+white', colors_lst)
custom_cmap

In [None]:
gcxgc_normalized = np.sum(gcxgc, axis=0)

t1 = np.array([6*(i+1) for i in range(0, 200, 1)] + [6*200 + 8*(i+1) for i, _ in enumerate(range(200, 311, 1))] + [6*200 + 8*111 + 10*(i+1) for i, _ in enumerate(range(311, gcxgc.shape[2], 1))])
t2 = np.array([0.005*(i+1) for i in range(gcxgc.shape[1])])
xs1 = [1, 100, 200, 311, 460]
xs2 = [1, 500, 1000, 1500, 2000]
t1_sampled = [t1[int(j-1)] for j in xs1]
t2_sampled = [t2[int(j-1)] for j in xs2]

gcxgc_normalized[gcxgc_normalized == 0] = -1
save_path = '../figures'
plt.figure(figsize=(9, 6), dpi=170)
plt.imshow(gcxgc_normalized, cmap=custom_cmap)
plt.gca().invert_yaxis()
plt.gca().set_aspect(aspect=str(0.2))
plt.xlabel('Retention time in the first dimension [s]')
plt.ylabel('Retention time in the second dimension [s]')
plt.xticks(xs1, t1_sampled)
plt.yticks(xs2, t2_sampled)
plt.colorbar(label="Summed m/z values", orientation="vertical") 
plt.savefig(f'{save_path}/GCxGC_tic.png', bbox_inches='tight')
plt.show()

### 1. Export `M29_9_system2` as a OOD dataset with shape (2000 * 460, 801)

In [None]:
print(f"[INFO] Original dataset size: ({gcxgc.shape[1] * gcxgc.shape[2]}, {gcxgc.shape[0]})")

# the machine did not measure on at some times, exclude that from predictions
mask_valid_measurements = (gcxgc_normalized != 0)
qual_ood_dataset = gcxgc[:, mask_valid_measurements].transpose(1, 0)
print(f"[INFO] Dataset size after filtering empty measurements {qual_ood_dataset.shape}")

# Load test dataset and predict labels
data_root_path = "../data"
ood_dataset_name = "M29_9_system2"
gcxgc = np.load(f"{data_root_path}/ood/{ood_dataset_name}/{ood_dataset_name}.npy")

# Calculate normalized retention times
t1 = np.array([6*(i+1) for i in range(0, 200, 1)] + [6*200 + 8*(i+1) for i, _ in enumerate(range(200, 311, 1))] + [6*200 + 8*111 + 10*(i+1) for i, _ in enumerate(range(311, gcxgc.shape[2], 1))])
t2 = np.array([0.005*(i+1) for i in range(gcxgc.shape[1])])
t1 = t1.reshape(1, -1)
t2 = t2.reshape(-1, 1)

gcxgc_normalized = np.sum(gcxgc, axis=0)

# concatenate with the 801-dimensional vector of spectrums
t1_2d = np.tile(t1, (2000, 1))
t2_2d = np.tile(t2, (1, 460))
t1_2d = np.expand_dims(t1_2d, axis=0)
t2_2d = np.expand_dims(t2_2d, axis=0)
gcxgc = np.concatenate((gcxgc, t1_2d, t2_2d), axis=0)

eps = 0
mask_valid_measurements = (gcxgc_normalized > eps)
qual_ood_dataset = gcxgc[:, mask_valid_measurements].transpose(1, 0)
print(f"[INFO] Testing dataset loaded: {qual_ood_dataset.shape}")

In [None]:
from sklearn.preprocessing import Normalizer
normalizer = Normalizer(norm='max')
qual_ood_dataset = normalizer.fit_transform(qual_ood_dataset)

### 2. Create classification predictions and compute OOD scores

In [None]:
visualize_prediction(classified_gcxgc, ood_detector_name, custom_cmap, colors_lst, predicted_ood_ratio_post, 0.95, '')

In [None]:
def detect_ood(scores, threshold):
    predictions = (scores >= threshold).astype(int)
    return predictions

def find_threshold(in_distribution_scores, target_detection_rate=0.95):
    # Sort scores in DESCENDING order
    sorted_scores = np.sort(in_distribution_scores)[::-1]

    # Calculate the index to achieve the target detection rate
    index = int(target_detection_rate * len(sorted_scores))

    # Choose the threshold based on the index
    threshold = sorted_scores[index]

    return threshold

def normalize_scores(in_distribution_scores, scores):
    return (scores - in_distribution_scores.min()) / in_distribution_scores.max()

def visualize_prediction(classified, ood_detector_name, custom_cmap, colors_lst, predicted_ood_ratio, fn_detection_rate, save_path):
    classified = classified[ood_detector_name]
    plt.figure(figsize=(11, 9), dpi=150)
    plt.imshow(classified, cmap=custom_cmap)
    im_values = np.ravel(np.unique(classified)).astype(int)
    patches = [mpatches.Patch(color=colors_lst[val], label=f"Compound {val}") for val in im_values if
               val not in [70, 71]]
    patches.append(mpatches.Patch(color=colors_lst[70], label=f"OOD"))
    patches.append(mpatches.Patch(color=colors_lst[71], label=f"Invalid measurement"))
    # plt.legend(handles=patches, bbox_to_anchor=(1.1, 1), loc=2, borderaxespad=0., ncol=2)
    plt.title \
        (f"{ood_detector_name}: {predicted_ood_ratio[ood_detector_name]:.3f}% "
         f"samples classified as OOD (total {len(im_values ) -2} ID classes detected)")
    plt.gca().invert_yaxis()
    plt.gca().set_aspect(aspect=str(0.2))
    plt.axis('off')
    plt.savefig(f'{save_path}/GCxGC_{ood_detector_name}_{fn_detection_rate}.png', bbox_inches='tight')
    plt.show()

import matplotlib.patches as mpatches
import matplotlib
from matplotlib.colors import LinearSegmentedColormap

cmap_name = 'hsv'
cmap = matplotlib.colormaps[cmap_name]
colors_lst = []
for i in range(70):
    (r, g, b, _) = cmap(int(i * 3.5))
    colors_lst.append((r, g, b))

# set OOD as white and invalid measurements as black
colors_lst.append((0.99, 0.99, 0.99))
colors_lst.append((0., 0., 0.))
custom_cmap = LinearSegmentedColormap.from_list('diverge_72', colors_lst, N=72)
custom_cmap

In [None]:
seed = 1111
root_save_dir_path = "../saved_model_outputs"
ood_dataset_name = "M29_9_system2"
class_model_name = "svm"
use_retention_time = False
suffix = "_w_rt" if use_retention_time else ""
ood_save_dir_path = f"{root_save_dir_path}/seed-{seed}/{class_model_name}{suffix}/{ood_dataset_name}"

# get ID accuracy, precision and recall
# id_class_perf = load_object(f"{ood_save_dir_path}/classification_perf_id.pt")

# # get predicted labels by the model
# classification_predictions = load_object(f"{ood_save_dir_path}/model_predictions_ood.pt")#.numpy()

# get OOD detector scores 
detector_scores = {}
for fold in ['id', 'ood']:
    detector_scores[fold] = {}
    ood_detectors_files = glob.glob(f'{ood_save_dir_path}/*scores_{fold}.pt')
    for ood_detector_fname in ood_detectors_files:
        detector_scores[fold][ood_detector_fname.split('_')[-3].split('/')[1]] = load_object(ood_detector_fname)

In [None]:
detector_scores

In [None]:
plot_prediction_hist(classification_predictions, use_logspace=True, use_percentage=True)

In [None]:
from scipy.stats import gaussian_kde

def plot_distributions(scores, detector_name, id_thresh=None, fn_detection_rate=None, save_plot_dir=None, bw=0.1, nbins=200, use_logspace=False,
                       normalize_density=False, save_plot=False):
    id_scores = scores['id'][detector_name]
    test_scores = scores['ood'][detector_name]

    min_val = min(min(id_scores), min(test_scores))
    max_val = max(max(id_scores), max(test_scores))
    xs = np.linspace(min_val, max_val, nbins)
    density_id = gaussian_kde(id_scores, bw)
    density_test = gaussian_kde(test_scores, bw)
    density_id_estimate = density_id(xs)
    density_test_estimate = density_test(xs)

    plt.figure(figsize=(6, 4), dpi=300)
    if normalize_density:
        # Make it a pdf so it integrates to 1
        bin_width = np.diff(xs)[0]
        num_observations_id = density_id_estimate.sum()
        num_observations_test = density_test_estimate.sum()
        density_id_estimate = density_id_estimate / (num_observations_id * bin_width)
        density_test_estimate = density_test_estimate / (num_observations_test * bin_width)
        assert (density_id_estimate * bin_width).sum().round() == 1, "PDF does not integrate to 1!"
        assert (density_test_estimate * bin_width).sum().round() == 1, "PDF does not integrate to 1!"

    # Quantitative experiments
    if id_thresh is None:
        plt.plot(xs, density_id_estimate, "--", color='blue', label='ID set')
        plt.plot(xs, density_test_estimate, "--", color='orange', label='OOD set')
        idx = np.argwhere(np.diff(np.signbit(density_test_estimate - density_id_estimate))).flatten()
        plt.plot(xs[idx], density_id_estimate[idx], 'rx')
        for intr in idx:
            plt.annotate(f'{xs[intr]:.3f}', (xs[intr] + 2 * (xs[1] - xs[0]), density_id_estimate[intr]))
    else:   # Qualitative experiments
        plt.plot(xs, density_id_estimate, "--", color='blue', label='Compounds dataset')
        plt.plot(xs, density_test_estimate, "--", color='orange', label='Test dataset')
        colors = ['green', 'red', 'purple']
        for i, th in enumerate(id_thresh):
            plt.axvline(th, color=colors[i], linestyle ="-", linewidth=1, label=f'{fn_detection_rate[i]} OOD training samples')

    plt.legend(loc='upper right')
    plt.xlabel(f'{detector_name} scores')
    plt.grid()
    if use_logspace:
        plt.gca().set_yscale("log")
    if save_plot:
        suffix = f"0.9589"
        plt.savefig(f"{save_plot_dir}/distribution_plot_{detector_name}_{suffix}.pdf", format="pdf", bbox_inches='tight')
    plt.show()

In [None]:
save_plot_dir = '../figures/'
msp_thresh = [0.1927411535762995, 0.09413525961366254, 0.07741693337568636]
plot_distributions(detector_scores, 'msp', id_thresh=msp_thresh, fn_detection_rate=[1200, 480, 240], save_plot_dir=save_plot_dir, nbins=200, use_logspace=False, normalize_density=True, save_plot=True)


In [None]:
save_plot_dir = '../figures/'
mahalanobis = [0.18167925245956645, 0.09413525961366254, 0.07741693337568636]
plot_distributions(detector_scores, 'mahalanobis', id_thresh=msp_thresh, fn_detection_rate=[1200, 480, 240], save_plot_dir=save_plot_dir, nbins=200, use_logspace=False, normalize_density=True, save_plot=True)


In [None]:
save_plot_dir = '../figures/'
plot_distributions(detector_scores, 'msp', save_plot_dir, nbins=200, use_logspace=False, normalize_density=True, save_plot=True)
plot_distributions(detector_scores, 'maxlogit', save_plot_dir, nbins=200, use_logspace=False, normalize_density=True, save_plot=True)
plot_distributions(detector_scores, 'kl', save_plot_dir, nbins=200, use_logspace=False, normalize_density=True, save_plot=True)
plot_distributions(detector_scores, 'energy', save_plot_dir, nbins=200, use_logspace=False, normalize_density=True, save_plot=True)

In [None]:
save_plot_dir = '../figures/'
plot_distributions(detector_scores, 'mahalanobis', save_plot_dir, nbins=50, use_logspace=False, normalize_density=False, save_plot=True)
plot_distributions(detector_scores, 'knn', save_plot_dir, nbins=50, use_logspace=False, normalize_density=False, save_plot=True)
plot_distributions(detector_scores, 'ssd', save_plot_dir, nbins=50, use_logspace=False, normalize_density=False, save_plot=True)
plot_distributions(detector_scores, 'nnguide', save_plot_dir, nbins=50, use_logspace=False, normalize_density=False, save_plot=True)

### 3. Select threshold for OOD detection so that we have 95% TPR on ID dataset
* We choose the threshold $\tau$ using ID data so that a high fraction of inputs are correctly classified by the OOD detector g(x).

*  ID samples are considered positive (and have higher scores) so we get the prediction from the detector $g$ as:

$g(\bf{x}, \tau, f) = 0$  if $score(\bf{x},f) <= \tau$

$g(\bf{x}, \tau, f) = 1$  if $score(\bf{x},f) > \tau$

### 4. Produce final predictions by running OOD detector with found threshold, reject classifier's predictions for those that marked as OOD samples

In [None]:
from copy import deepcopy
# post_hoc_methods = ['msp', 'kl', 'maxlogit', 'energy']
post_hoc_methods = ['msp']
post_hoc_dr = 0.99

thresholds = {}
for ood_detector_name in post_hoc_methods:
    id_scores = detector_scores['id'][ood_detector_name]
    thresholds[ood_detector_name] = find_threshold(id_scores, target_detection_rate=post_hoc_dr)
    print(f"[INFO] {ood_detector_name} threshold: {thresholds[ood_detector_name]}")

print("===================")
predictions_dist = {}
predicted_ood_ratio_dist = {}
ood_label = 70
for ood_detector_name in post_hoc_methods:
    ood_scores = detector_scores['ood'][ood_detector_name]
    current_predictions = detect_ood(ood_scores, thresholds[ood_detector_name])
    ood_mask = (current_predictions == 0)
    corrected_predictions = deepcopy(classification_predictions)
    corrected_predictions[ood_mask] = ood_label  # we have labels from 0 to 69, 70 is ood
    predictions_dist[ood_detector_name] = corrected_predictions
    predicted_ood_ratio_dist[ood_detector_name] = 100*(len(current_predictions[current_predictions == 0]) / len(current_predictions))
    print(f"[INFO] {ood_detector_name} predicts {predicted_ood_ratio_dist[ood_detector_name]:.3f}% samples as OOD")

In [None]:
from copy import deepcopy
# distance_methods = ['mahalanobis', 'knn', 'ssd', 'nnguide']
distance_methods = ['mahalanobis']

distance_dr = 0.98
thresholds = {}
for ood_detector_name in distance_methods:
    id_scores = detector_scores['id'][ood_detector_name]
    thresholds[ood_detector_name] = find_threshold(id_scores, target_detection_rate=distance_dr)
    print(f"[INFO] {ood_detector_name} threshold: {thresholds[ood_detector_name]}")

print("===================")
predictions_post = {}
predicted_ood_ratio_post = {}
ood_label = 70
for ood_detector_name in distance_methods:
    ood_scores = detector_scores['ood'][ood_detector_name]
    current_predictions = detect_ood(ood_scores, thresholds[ood_detector_name])
    ood_mask = (current_predictions == 0)
    corrected_predictions = deepcopy(classification_predictions)
    corrected_predictions[ood_mask] = ood_label  # we have labels from 0 to 69, 70 is ood
    predictions_post[ood_detector_name] = corrected_predictions
    predicted_ood_ratio_post[ood_detector_name] = 100*(len(current_predictions[current_predictions == 0]) / len(current_predictions))
    print(f"[INFO] {ood_detector_name} predicts {predicted_ood_ratio_post[ood_detector_name]:.3f}% samples as OOD")

### 5. vizualize the prediction on GCxGC by coloring spectrum according to the final predictions

In [None]:
d, h, w = gcxgc.shape
classified_gcxgc = {}
predictions = {**predictions_post, **predictions_dist}
predicted_ood_ratio = {**predicted_ood_ratio_post, **predicted_ood_ratio_dist}
for ood_detector_name in post_hoc_methods + distance_methods:
    gcxgc_classified = -1 * np.ones((h, w)) 
    gcxgc_classified[mask_valid_measurements] = predictions[ood_detector_name]
    classified_gcxgc[ood_detector_name] = gcxgc_classified

In [None]:
for ood_detector_name in post_hoc_methods:
    visualize_prediction(classified_gcxgc, ood_detector_name)

In [None]:
for ood_detector_name in distance_methods:
    visualize_prediction(classified_gcxgc, ood_detector_name)