In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import auc, roc_curve
import matplotlib.pyplot as plt
import matplotlib

import h5py
import os
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.models import Sequential
from tensorflow.keras.layers import Reshape,Conv1D,Flatten,Dense,TimeDistributed, Lambda
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras import mixed_precision
AUTOTUNE = tf.data.AUTOTUNE
from sklearn.preprocessing import StandardScaler
import tqdm
import gc
import random
from tqdm import tqdm
cols=["#DB4437", "#4285F4", "#F4B400", "#0F9D58", "purple", "goldenrod", "peru", "coral","turquoise",'gray','navy','m','darkgreen','fuchsia','steelblue'] 

## Styling the plots

In [None]:
ROOT = {
    "font.sans-serif": ["TeX Gyre Heros", "Helvetica", "Arial"],
    "font.family": "sans-serif",
    "mathtext.fontset": "custom",
    "mathtext.rm": "TeX Gyre Heros",
    "mathtext.bf": "TeX Gyre Heros:bold",
    "mathtext.sf": "TeX Gyre Heros",
    "mathtext.it": "TeX Gyre Heros:italic",
    "mathtext.tt": "TeX Gyre Heros",
    "mathtext.cal": "TeX Gyre Heros",
    "mathtext.default": "regular",
    "figure.figsize": (10.0, 10.0),
    "font.size": 26,
    #"text.usetex": True,
    "axes.labelsize": "medium",
    "axes.unicode_minus": False,
    "xtick.labelsize": "small",
    "ytick.labelsize": "small",
    "legend.fontsize": "small",
    "legend.handlelength": 1.5,
    "legend.borderpad": 0.5,
    "xtick.direction": "in",
    "xtick.major.size": 12,
    "xtick.minor.size": 6,
    "xtick.major.pad": 6,
    "xtick.top": True,
    "xtick.major.top": True,
    "xtick.major.bottom": True,
    "xtick.minor.top": True,
    "xtick.minor.bottom": True,
    "xtick.minor.visible": True,
    "ytick.direction": "in",
    "ytick.major.size": 12,
    "ytick.minor.size": 6.0,
    "ytick.right": True,
    "ytick.major.left": True,
    "ytick.major.right": True,
    "ytick.minor.left": True,
    "ytick.minor.right": True,
    "ytick.minor.visible": True,
    "grid.alpha": 0.8,
    "grid.linestyle": ":",
    "axes.linewidth": 2,
    "savefig.transparent": False,
}
plt.style.use(ROOT)


In [None]:
##configs
batch_size=16384
learning_rate=0.001
epochs=10
critic_epocs = 1
critic_steps = 12
lambda_info = 1

## Pre processing the Data

In [None]:
"""
This code loads data from an HDF5 file, preprocesses the data, and creates TensorFlow datasets for training, validation, and testing.

The code performs the following steps:
1. Opens an HDF5 file and retrieves the 'jets' (high level variables) and 'images' datasets.
2. Preprocesses the data by filtering based on certain conditions.
3. Splits the data into training, validation, and testing sets.
4. Calculates weights for the training data based on a histogram of a nuisance variable.
5. Creates TensorFlow datasets from the numpy arrays.
6. Prints the element specifications of the created datasets.
7. Performs memory cleanup.

Note: The code assumes that the necessary libraries (h5py, numpy, sklearn, tensorflow) have been imported before this code block.
"""
with h5py.File('/uscms/home/abhijith/nobackup/protopyte/DarkShower/himages_lowres.h5', 'r') as file:
    print(file.keys())
    jet_vars = np.array(file['jets'][::2])
    images = np.array(file['images'][::2])

print(images.shape, jet_vars.shape)

one = (jet_vars[:, 53] > 0) + (jet_vars[:, 54] > 0)
three = (jet_vars[:, 57] > 0)
two = (jet_vars[:, 56] > 0) + (jet_vars[:, 55] > 0)


two_vars = jet_vars[two]
two_images = images[two][:, :, :, None]
two_labels = np.ones(two_images.shape[0],dtype=np.float16)

top_vars = jet_vars[three]
top_images = images[three][:, :, :, None]
top_labels = 2*np.ones(top_images.shape[0],dtype=np.float16)

qcd_vars = jet_vars[one]
qcd_images = images[one][:, :, :, None]
qcd_labels = np.zeros(qcd_images.shape[0],dtype=np.float16)

print(two_images.shape, top_images.shape, qcd_images.shape)
print(two_labels.shape, top_labels.shape, qcd_labels.shape)


x_train, x_test, vars_train, vars_test, y_train, y_test, z_tain, z_test = train_test_split(np.concatenate([qcd_images,two_images]), np.concatenate([qcd_vars,two_vars]), np.concatenate([qcd_labels,two_labels]), 
                                                                                           np.concatenate([qcd_vars[:,48],two_vars[:,48]]),test_size=0.2, random_state=42)
x_train, x_val, vars_train, vars_val, y_train, y_val, z_train, z_val = train_test_split(x_train, vars_train, y_train, z_tain, test_size=0.2, random_state=42)

# Weights for NN
bins = np.arange(0, np.max(z_train)+5, 5)
hist, bin_edges = np.histogram(z_train[y_train == 0], bins=bins)
hist2, bin_edges = np.histogram(z_train[y_train == 1], bins=bins)
weight_index = np.digitize(z_train, bins=bins)-1
weights = np.array([hist[weight_index[i]]/(np.sum(hist)+np.sum(hist2)) if y_train[i] ==
                   0 else hist2[weight_index[i]]/(np.sum(hist)+np.sum(hist2)) for i in range(len(y_train))])
weights = 1/weights
weights[np.isinf(weights)]=0
weights[np.isnan(weights)]=0

print(x_train.shape, vars_train.shape, y_train.shape, z_train.shape, weights.shape)

# Create a dataset from the numpy arrays

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, z_train, weights))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size).prefetch(buffer_size=AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val, z_val))
val_dataset = val_dataset.batch(batch_size).prefetch(buffer_size=AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test, z_test))
test_dataset = test_dataset.batch(batch_size).prefetch(buffer_size=AUTOTUNE)

print(train_dataset.element_spec)
print(val_dataset.element_spec)
print(test_dataset.element_spec)

del jet_vars, images
gc.collect()


## Defining the model

In [None]:
"""
This code defines a neural network model for image classification and a critic model.

The image classification model consists of several convolutional and dense layers. It takes an input image of shape (50, 50, 1) and produces two outputs: 'out' and 'activations'. The 'out' output has a shape of (None, 2) and represents the predicted class probabilities. The 'activations' output has a shape of (None, 20) and represents the intermediate activations of the model.

The critic model takes an input of shape (20+2,) and consists of several dense layers. It produces an output 'out_ji' with a shape of (None, 2).

Both models use the Adam optimizer with a specified learning rate.
"""
input_img = keras.Input(shape=(50, 50, 1), name='input')
layer=input_img
layer=layers.Conv2D(10, kernel_size=(3, 3),activation='relu',padding='same')(layer)
layer=layers.Conv2D(10, kernel_size=(3, 3),activation='relu',padding='same')(layer)
layer=layers.AveragePooling2D(pool_size=(2, 2),padding='same')(layer)
layer=layers.Conv2D(10, kernel_size=(3, 3), activation='relu',padding='same')(layer) 
layer=layers.Conv2D(5, kernel_size=(3, 3),activation='relu',padding='same')(layer)
layer=layers.Conv2D(5, kernel_size=(3, 3),activation='relu',padding='same')(layer)
layer=layers.Flatten()(layer)
layer=layers.Dense(400, activation='relu')(layer)
layer=layers.Dense(100, activation='relu')(layer)
activations = layers.Dense(20)(layer)  # , activation='relu'
out=layers.Dense(2)(activations)
model = Model(inputs=input_img, outputs=[out,activations])
model.summary()

input_ji = keras.Input(shape=(20+2,), name='input_ji')
layer=layers.Dense(256, activation='relu')(input_ji)
layer=layers.Dense(128, activation='relu')(layer)
layer=layers.Dense(64, activation='relu')(layer)
out_ji=layers.Dense(2)(layer)
critc = Model(inputs=input_ji, outputs=out_ji)
critc.summary()


optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
critic_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)


# Code to Train the model

In [None]:
# Checkpoinitng for model training

checkpoint_path = "./checkpoints_norm_latch/train"

ckpt = tf.train.Checkpoint(net=model, optimizer=optimizer, critc=critc, critic_optimizer=critic_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

In [None]:
@tf.function
def train_step(features, labels, nuisance, weight, apply_critic=True):
    """
    Performs a single training step for the model.

    Args:
        features (tf.Tensor): Input features.
        labels (tf.Tensor): Target labels.
        nuisance (tf.Tensor): Nuisance variables.
        weight (tf.Tensor): Sample weights.
        apply_critic (bool, optional): Whether to apply the critic. Defaults to True.

    Returns:
        tuple: A tuple containing the loss, info_loss, and total_loss.
    """
    with tf.GradientTape() as tape:
        # print(features.shape, labels.shape, nuisance.shape)
        logits, activations = model(features, training=True)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True,reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
        if(apply_critic):
            critc_inputs = tf.concat([activations, nuisance[:,None], labels[:,None]], axis=1)
            output = critc(critc_inputs, training=False)
            critic_lables = tf.ones_like((labels), dtype=tf.float32)
            # print(output.shape, labels.shape)
            loss_critct = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(critic_lables, output)
            # final_critic_loss = loss_critct+tf.math.log(1-tf.exp(-loss_critct))
            output_sm = tf.nn.log_softmax(output)
            info_loss =  output_sm[:,1]-output_sm[:,0]
            total_loss = loss + lambda_info * info_loss
        else: 
            total_loss = loss
            info_loss = tf.zeros_like(loss)
        loss = tf.reduce_sum(loss*weight)/tf.reduce_sum(weight)
        info_loss = tf.reduce_sum(info_loss*weight)/tf.reduce_sum(weight)
        total_loss = tf.reduce_sum(total_loss*weight)/tf.reduce_sum(weight)
    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss, info_loss, total_loss

@tf.function
def train_critic_step(features, labels, nuisance, critic_sample_labels, critic_weigth):
    """
    Trains the critic model for one step.

    Args:
        features (tf.Tensor): Input features.
        labels (tf.Tensor): True labels.
        nuisance (tf.Tensor): Nuisance values.
        critic_sample_labels (tf.Tensor): Sample labels for the critic.
        critic_weigth (tf.Tensor): Weights for the critic loss.

    Returns:
        tf.Tensor: The loss value for the critic model.
    """
    with tf.GradientTape() as tape:
        _, activations = model(features, training=False)
        critc_inputs = tf.concat([activations, nuisance[:,None], labels[:,None]], axis=1)
        output = critc(critc_inputs, training=True)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, 
            reduction=tf.keras.losses.Reduction.NONE)(critic_sample_labels, output)
        # print(tf.reduce_sum(loss), tf.reduce_sum(critic_weigth))
        loss = tf.reduce_sum(loss*critic_weigth)/tf.reduce_sum(critic_weigth)
    critic_gradients = tape.gradient(loss, critc.trainable_variables)
    critic_optimizer.apply_gradients(zip(critic_gradients, critc.trainable_variables))
    return loss

@tf.function
def test_step(features, labels):
    logits, activations = model(features, training=False)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)(labels, logits)
    loss = tf.reduce_mean(loss)
    return loss

@tf.function
def test_critic_step(features, labels, nuisance):
    _, activations = model(features, training=False)
    output = critc(activations, training=False)
    loss = tf.keras.losses.BinaryCrossentropy(
        from_logits=True)(labels, output)
    loss = tf.reduce_mean(loss)
    return loss

# Do the actual training

In [None]:
best_loss = 1e9
for epoch in tqdm(range(epochs)):
    print("Epoch: ", epoch)
    loss_epoch = 0
    info_loss_epoch = 0
    total_loss_epoch = 0
    for step, (features, labels, nuisance, weight) in enumerate(train_dataset):
        apply_critic = True
        features = tf.cast(features, tf.float32)
        labels = tf.cast(labels, tf.float32)
        nuisance = tf.cast(nuisance, tf.float32)
        weight = tf.cast(weight, tf.float32)

        #Train ciritic in set number of fractional epcs
        if(step%1==0):
            for cepoch in range(critic_epocs):
                critc_batches = list(set(range(len(train_dataset))) - {step})
                random.shuffle(critc_batches)
                critic_batch_steps = 0
                loss_critic_epoch = 0
                for batch_number,(critic_features, critic_labels, critic_nuisance, critic_weights) in enumerate(train_dataset):
                    
                    #Training only on 10 random batches for speed up
                    if batch_number not in critc_batches[:10]:
                        continue
                    critic_features = tf.cast(critic_features, tf.float32)
                    critic_labels = tf.cast(critic_labels, tf.float32)
                    critic_nuisance = tf.cast(critic_nuisance, tf.float32)
                    critic_weights = tf.cast(critic_weights, tf.float32)

                    critic_sample_labels = tf.ones_like(critic_labels, dtype=tf.float32)

                    critic_rfeatures = tf.identity(critic_features)
                    critic_rlabels = tf.identity(critic_labels)
                    critic_rnuisance = tf.identity(critic_nuisance)
                    critic_rnuisance = tf.random.shuffle(critic_rnuisance)
                    critic_rweights = tf.identity(critic_weights)
                    critic_rsample_labels = tf.zeros_like(critic_rlabels, dtype=tf.float32)

                    critic_features = tf.concat([critic_features, critic_rfeatures], axis=0)
                    critic_labels = tf.concat([critic_labels, critic_rlabels], axis=0)
                    critic_nuisance = tf.concat([critic_nuisance, critic_rnuisance], axis=0)
                    critic_weights = tf.concat([critic_weights, critic_rweights], axis=0)
                    critic_sample_labels = tf.concat([critic_sample_labels, critic_rsample_labels], axis=0)

                    index = tf.range(critic_features.shape[0])
                    index = tf.random.shuffle(index)
                    critic_features = tf.gather(critic_features, index)
                    critic_labels = tf.gather(critic_labels, index)
                    critic_nuisance = tf.gather(critic_nuisance, index)
                    critic_weights = tf.gather(critic_weights, index)
                    critic_sample_labels = tf.gather(critic_sample_labels, index)

                    critic_loss = train_critic_step(
                        critic_features, critic_labels, critic_nuisance, critic_sample_labels, critic_weights)
                    loss_critic_epoch += critic_loss
                    critic_batch_steps += 1
                print("critic_loss at step {}: ".format(step), loss_critic_epoch.numpy()/critic_batch_steps)

        loss, info_loss, total_loss = train_step(features, labels, nuisance, weight, apply_critic)
        loss_epoch += loss
        info_loss_epoch += info_loss
        total_loss_epoch += total_loss

        if(step%30==29):
            print("Batch {} in Epoch {}".format(step,epoch), loss.numpy(), info_loss.numpy(), total_loss.numpy())
    print("Epoch {} Loss: ".format(epoch), loss_epoch/len(train_dataset), info_loss_epoch/len(train_dataset), total_loss_epoch/len(train_dataset))
    if(loss_epoch.numpy()/len(train_dataset)<best_loss):
        best_loss = loss_epoch.numpy()/len(train_dataset)
        model.save_weights('model_norm.h5')
        print("Model saved")
        # if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                            ckpt_save_path))

    print("")

### Restore the Best checkpoint

In [None]:
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print('Latest checkpoint restored!!')

# Evaluate the Model on the Test Set

In [None]:
logits = []
logits_sm =[]
activations = []
for step, (features, labels, nuisance) in enumerate(test_dataset):
    features = tf.cast(features, tf.float32)
    labels = tf.cast(labels, tf.float32)
    nuisance = tf.cast(nuisance, tf.float32)
    logits_, activations_ = model(features, training=False)
    logits_sm_ = tf.nn.softmax(logits_)
    logits.append(logits_)
    logits_sm.append(logits_sm_)
    activations.append(activations_)
logits = tf.concat(logits, axis=0)
logits_sm = tf.concat(logits_sm, axis=0)
activations = tf.concat(activations, axis=0)

## Evaluate on OOD dataset

In [None]:
top_dataset = tf.data.Dataset.from_tensor_slices(top_images)
logits_ood = []
logits_sm_ood =[]
activations_ood = []
for features in tqdm(top_dataset.batch(1024)):
    features = tf.cast(features, tf.float32)
    logits_, activations_ = model(features, training=False)
    logits_sm_ = tf.nn.softmax(logits_)
    logits_ood.append(logits_)
    logits_sm_ood.append(logits_sm_)
    activations_ood.append(activations_)
logits_ood = tf.concat(logits_ood, axis=0)
logits_sm_ood = tf.concat(logits_sm_ood, axis=0)
activations_ood = tf.concat(activations_ood, axis=0)

## Sanity check plots

In [None]:
# Sanity check
_=plt.hist(logits_sm[y_test==0][:,1], bins=100, alpha=0.5, label='0')
_=plt.hist(logits_sm[y_test==1][:,1], bins=100, alpha=0.5, label='1')
_=plt.hist(logits_sm_ood[:,1], bins=100, alpha=0.5, label='ood')
plt.legend()
plt.show()

In [None]:
_=plt.hist2d(logits[y_test==0][:,0].numpy(), vars_test[y_test==0][:,48], bins=100, alpha=0.5, label='0')
np.corrcoef(logits[y_test==0][:,0].numpy(), vars_test[y_test==0][:,48])

## Sanity check for Logits Space

In [None]:
fig = plt.figure(figsize=(40, 10))
spec = fig.add_gridspec(ncols=4, nrows=1, )

ax3 = fig.add_subplot(spec[0, 2])
temp = logits_ood
_ = ax3.hist2d(temp[:, 0], temp[:, 1], bins=100, cmap='Greys')
ax3.set_title('Top')
# plt.show()

ax1 = fig.add_subplot(spec[0, 0])
temp = logits[y_test==0]
_ = ax1.hist2d(temp[:, 0], temp[:, 1], bins=100, cmap='Blues')
ax1.set_title('QCD')

ax2 = fig.add_subplot(spec[0, 1])
temp = logits[y_test==1]
_ = ax2.hist2d(temp[:, 0], temp[:, 1], bins=100, cmap='Reds')
ax2.set_title('W/Z')

ax3 = fig.add_subplot(spec[0, 2])
temp = logits_ood
_ = ax3.hist2d(temp[:, 0], temp[:, 1], bins=100, cmap='Greys')
ax3.set_title('Top')
# plt.show()

ax3 = fig.add_subplot(spec[0, 3])
mask = (top_vars[:,48]>150) * (top_vars[:,48]<200)
temp = logits_ood[mask]
_ = ax3.hist2d(temp[:, 0], temp[:, 1], bins=100, cmap='Greys')
ax3.set_title('Top')
plt.show()


# Calculate Malanobis distance and intermediate variables - sanity check

In [None]:
qcd_act_cov = np.cov(activations[y_test==0].numpy().T)
qcd_act_cov_inv = np.linalg.inv(qcd_act_cov)
qcd_act_mean = np.mean(activations[y_test==0].numpy(), axis=0)

wz_act_cov = np.cov(activations[y_test==1].numpy().T)
wz_act_cov_inv = np.linalg.inv(wz_act_cov)
wz_act_mean = np.mean(activations[y_test==1].numpy(), axis=0)

top_qdistance = np.array([(activations_ood[i].numpy()-qcd_act_mean)@qcd_act_cov_inv@(activations_ood[i].numpy()-qcd_act_mean).T for i in range(len(activations_ood))])
qcd_qdistance = np.array([(activations[y_test==0][i].numpy()-qcd_act_mean)@qcd_act_cov_inv@(activations[y_test==0][i].numpy()-qcd_act_mean).T for i in range(len(activations[y_test==0]))])
wz_qdistance = np.array([(activations[y_test==1][i].numpy()-qcd_act_mean)@qcd_act_cov_inv@(activations[y_test==1][i].numpy()-qcd_act_mean).T for i in range(len(activations[y_test==1]))])

top_wzdistance = np.array([(activations_ood[i].numpy()-wz_act_mean)@wz_act_cov_inv@(activations_ood[i].numpy()-wz_act_mean).T for i in range(len(activations_ood))])
qcd_wzdistance = np.array([(activations[y_test==0][i].numpy()-wz_act_mean)@wz_act_cov_inv@(activations[y_test==0][i].numpy()-wz_act_mean).T for i in range(len(activations[y_test==0]))])
wz_wzdistance = np.array([(activations[y_test==1][i].numpy()-wz_act_mean)@wz_act_cov_inv@(activations[y_test==1][i].numpy()-wz_act_mean).T for i in range(len(activations[y_test==1]))])

_=plt.hist(top_qdistance, bins=100, alpha=0.5, label='top')
_=plt.hist(qcd_qdistance, bins=100, alpha=0.5, label='qcd')
_=plt.hist(wz_qdistance, bins=100, alpha=0.5, label='wz')
plt.legend()
plt.show()

_=plt.hist(top_wzdistance, bins=100, alpha=0.5, label='top')
_=plt.hist(qcd_wzdistance, bins=100, alpha=0.5, label='qcd')
_=plt.hist(wz_wzdistance, bins=100, alpha=0.5, label='wz')
plt.legend()
plt.show()


In [None]:
qd_mean, qd_std = np.mean(qcd_qdistance), np.std(qcd_qdistance)
wd_mean, wd_std = np.mean(wz_wzdistance), np.std(wz_wzdistance)
top_distance = np.concatenate([(top_qdistance[:,None]-qd_mean)/qd_std, (top_wzdistance[:,None]-wd_mean)/wd_std], axis=1)
qcd_distance = np.concatenate([(qcd_qdistance[:,None]-qd_mean)/qd_std, (qcd_wzdistance[:,None]-wd_mean)/wd_std], axis=1)
wz_distance = np.concatenate([(wz_qdistance[:,None]-qd_mean)/qd_std, (wz_wzdistance[:,None]-wd_mean)/wd_std], axis=1)

In [None]:
bins= np.linspace(-4, 10, 100)
_=plt.hist(np.max(top_distance,axis=1), bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
_=plt.hist(np.max(qcd_distance,axis=1), bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(np.max(wz_distance,axis=1), bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
plt.yscale('log')
plt.legend()
plt.show()

bins= np.linspace(-4, 10, 100)
_=plt.hist(np.max(top_distance[mask],axis=1), bins=bins, alpha=0.5, label='Top [OOD Sample]', density=True, histtype='step', lw=2)
_=plt.hist(np.max(qcd_distance,axis=1), bins=bins, alpha=0.5, label='QCD', density=True, histtype='step', lw=2)
_=plt.hist(np.max(wz_distance,axis=1), bins=bins, alpha=0.5, label='WZ', density=True, histtype='step', lw=2)
plt.yscale('log')
plt.legend()
plt.show()

In [None]:
bins= np.linspace(0, 2000, 100)
_=plt.hist(top_distance[:,1], bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
_=plt.hist(qcd_distance[:,1], bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(wz_distance[:,1], bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
plt.yscale('log')
plt.legend()
plt.show()

In [None]:
pred = np.concatenate([np.max(top_distance,axis=1), np.max(qcd_distance,axis=1)])
y = np.concatenate([np.ones(len(top_distance)), np.zeros(len(qcd_distance))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

In [None]:
pred = np.concatenate([np.max(top_distance[mask],axis=1), np.max(qcd_distance,axis=1)])
y = np.concatenate([np.ones(len(top_distance[mask])), np.zeros(len(qcd_distance))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

In [None]:
qcd_mass = vars_test[y_test==0][:,48]
qcd_distance_max = np.max(qcd_distance, axis=1)
th = np.quantile(qcd_distance_max, 0.7)
plt.hist(qcd_mass, bins=100, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
plt.hist(qcd_mass[qcd_distance_max>th], bins=100, alpha=0.5, label='qcd after cut', density=True, histtype='step', lw=2)
plt.legend()
plt.show()

In [None]:
qcd_lg_cov = np.cov(logits[y_test==0].numpy().T)
qcd_lg_cov_inv = np.linalg.inv(qcd_lg_cov)
qcd_lg_mean = np.mean(logits[y_test==0].numpy(), axis=0)

wz_lg_cov = np.cov(logits[y_test==1].numpy().T)
wz_lg_cov_inv = np.linalg.inv(wz_lg_cov)
wz_lg_mean = np.mean(logits[y_test==1].numpy(), axis=0)

top_qdistance_lg = np.array([(logits_ood[i].numpy()-qcd_lg_mean)@qcd_lg_cov_inv@(logits_ood[i].numpy()-qcd_lg_mean).T for i in range(len(logits_ood))])
qcd_qdistance_lg = np.array([(logits[y_test==0][i].numpy()-qcd_lg_mean)@qcd_lg_cov_inv@(logits[y_test==0][i].numpy()-qcd_lg_mean).T for i in range(len(logits[y_test==0]))])
wz_qdistance_lg = np.array([(logits[y_test==1][i].numpy()-qcd_lg_mean)@qcd_lg_cov_inv@(logits[y_test==1][i].numpy()-qcd_lg_mean).T for i in range(len(logits[y_test==1]))])

top_wzdistance_lg = np.array([(logits_ood[i].numpy()-wz_lg_mean)@wz_lg_cov_inv@(logits_ood[i].numpy()-wz_lg_mean).T for i in range(len(logits_ood))])
qcd_wzdistance_lg = np.array([(logits[y_test==0][i].numpy()-wz_lg_mean)@wz_lg_cov_inv@(logits[y_test==0][i].numpy()-wz_lg_mean).T for i in range(len(logits[y_test==0]))])
wz_wzdistance_lg = np.array([(logits[y_test==1][i].numpy()-wz_lg_mean)@wz_lg_cov_inv@(logits[y_test==1][i].numpy()-wz_lg_mean).T for i in range(len(logits[y_test==1]))])

_=plt.hist(top_qdistance_lg, bins=100, alpha=0.5, label='top')
_=plt.hist(qcd_qdistance_lg, bins=100, alpha=0.5, label='qcd')
_=plt.hist(wz_qdistance_lg, bins=100, alpha=0.5, label='wz')
plt.legend()
plt.show()

_=plt.hist(top_wzdistance_lg, bins=100, alpha=0.5, label='top')
_=plt.hist(qcd_wzdistance_lg, bins=100, alpha=0.5, label='qcd')
_=plt.hist(wz_wzdistance_lg, bins=100, alpha=0.5, label='wz')
plt.legend()
plt.show()


# Explore Logit Distance

In [None]:
top_distance_lg = np.concatenate([top_qdistance_lg[:,None], top_wzdistance_lg[:,None]], axis=1)
qcd_distance_lg = np.concatenate([qcd_qdistance_lg[:,None], qcd_wzdistance_lg[:,None]], axis=1)
wz_distance_lg = np.concatenate([wz_qdistance_lg[:,None], wz_wzdistance_lg[:,None]], axis=1)

bins= np.linspace(0, 2000, 100)
_=plt.hist(np.max(top_distance_lg,axis=1), bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
_=plt.hist(np.max(qcd_distance_lg,axis=1), bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(np.max(wz_distance_lg,axis=1), bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
plt.yscale('log')
plt.legend()
plt.show()

bins= np.linspace(0, 2000, 100)
_=plt.hist(np.max(top_distance_lg[mask],axis=1), bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
_=plt.hist(np.max(qcd_distance_lg,axis=1), bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(np.max(wz_distance_lg,axis=1), bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
plt.yscale('log')
plt.legend()
plt.show()

pred = np.concatenate([np.max(top_distance_lg,axis=1), np.max(qcd_distance_lg,axis=1)])
y = np.concatenate([np.ones(len(top_distance_lg)), np.zeros(len(qcd_distance_lg))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

pred = np.concatenate([np.max(top_distance_lg[mask],axis=1), np.max(qcd_distance_lg,axis=1)])
y = np.concatenate([np.ones(len(top_distance_lg[mask])), np.zeros(len(qcd_distance_lg))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

In [None]:
qcd_max_lg = np.max(logits[y_test==0],axis=1)
wz_max_lg = np.max(logits[y_test==1],axis=1)
top_max_lg = np.max(logits_ood,axis=1)

bins= np.linspace(0, 400, 100)
_=plt.hist(qcd_max_lg, bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(wz_max_lg, bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
_=plt.hist(top_max_lg, bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
plt.legend()
plt.show()

bins= np.linspace(0, 400, 100)
_=plt.hist(qcd_max_lg, bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(wz_max_lg, bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
_=plt.hist(top_max_lg[mask], bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
plt.legend()
plt.show()

pred = np.concatenate([top_max_lg, qcd_max_lg])
y = np.concatenate([np.zeros(len(top_max_lg)), np.ones(len(qcd_max_lg))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

pred = np.concatenate([top_max_lg[mask], qcd_max_lg])
y = np.concatenate([np.zeros(len(top_max_lg[mask])), np.ones(len(qcd_max_lg))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

In [None]:
qcd_wz_lg = logits[y_test==0][:,1]
wz_wz_lg = logits[y_test==1][:,1]
top_wz_lg = logits_ood[:,1]

bins= np.linspace(0, 400, 100)
_=plt.hist(qcd_wz_lg, bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(wz_wz_lg, bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
_=plt.hist(top_wz_lg, bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
plt.legend()
plt.show()

bins= np.linspace(0, 400, 100)
_=plt.hist(qcd_wz_lg, bins=bins, alpha=0.5, label='qcd', density=True, histtype='step', lw=2)
_=plt.hist(wz_wz_lg, bins=bins, alpha=0.5, label='wz', density=True, histtype='step', lw=2)
_=plt.hist(top_wz_lg[mask], bins=bins, alpha=0.5, label='top', density=True, histtype='step', lw=2)
plt.legend()
plt.show()

pred = np.concatenate([top_wz_lg, qcd_wz_lg])
y = np.concatenate([np.zeros(len(top_wz_lg)), np.ones(len(qcd_wz_lg))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

pred = np.concatenate([top_wz_lg[mask], qcd_wz_lg])
y = np.concatenate([np.zeros(len(top_wz_lg[mask])), np.ones(len(qcd_wz_lg))])
fpr, tpr, thresholds = roc_curve(y, pred)
auc_roc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc_roc)
plt.legend(loc='lower right')
plt.show()

# Save the results

In [None]:
import h5py as h5

with h5.File('results_ji_norm_50epochs.h5', 'w') as f:
    f.create_dataset('vars_test', data=vars_test,compression='gzip')
    f.create_dataset('labels', data=y_test,compression='gzip')
    f.create_dataset('activations', data=activations.numpy(),compression='gzip')
    f.create_dataset('activations_ood', data=activations_ood.numpy(),compression='gzip')
    f.create_dataset('top_ml_distance', data=top_distance,compression='gzip')
    f.create_dataset('qcd_ml_distance', data=qcd_distance,compression='gzip')
    f.create_dataset('wz_ml_distance', data=wz_distance,compression='gzip')
    f.create_dataset('logits', data=logits.numpy(),compression='gzip')
    f.create_dataset('logits_ood', data=logits_ood.numpy(),compression='gzip')
    f.create_dataset('logits_sm', data=logits_sm.numpy(),compression='gzip')
    f.create_dataset('logits_sm_ood', data=logits_sm_ood.numpy(),compression='gzip')
    f.create_dataset('images', data=x_test,compression='gzip')
    f.create_dataset('images_ood', data=top_images,compression='gzip')
    f.create_dataset('top_distance_lg', data=top_distance_lg,compression='gzip')
    f.create_dataset('qcd_distance_lg', data=qcd_distance_lg,compression='gzip')
    f.create_dataset('wz_distance_lg', data=wz_distance_lg,compression='gzip')