In [1]:
import numpy as np
import glob
import os
from models.JetPointNet import PointNetSegmentation


def load_data_from_npz(npz_file):
    data = np.load(npz_file)
    feats = data['feats']  # Shape: (num_samples, 859, 6)
    frac_labels = data['frac_labels']  # Shape: (num_samples, 859)
    tot_labels = data['tot_labels']  # Shape: (num_samples, 859)
    tot_truth_e = data['tot_truth_e']  # Shape: (num_samples, 859) (This is the true total energy deposited by particles into this cell)
    return feats, frac_labels, tot_labels, tot_truth_e

# Setup
os.environ['CUDA_VISIBLE_DEVICES'] = ""  # Disable GPU
model_path = "saved_model/PointNetModel.keras"

TEST_DIR = '/data/mjovanovic/jets/processed_files/2000_events_w_fixed_hits/SavedNpz/test'

model = PointNetSegmentation(num_points=278, num_classes=1)
model.load_weights(model_path)

npz_files = glob.glob(os.path.join(TEST_DIR, '*.npz'))



2024-04-24 12:54:01.526518: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9360] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-24 12:54:01.526843: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-24 12:54:01.526967: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1537] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-24 12:54:01.537163: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-04-24 12:54:04.144442: E tensor

In [20]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def plot_energies(predicted_energies, true_energies, coords, types, sample_index):
    fig = plt.figure(figsize=(14, 7))
    ax1 = fig.add_subplot(121, projection='3d')
    ax2 = fig.add_subplot(122, projection='3d')

    ax1.set_title('Predicted Classifications')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    ax2.set_title('Actual Classifications')
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax2.set_zlabel('Z')

    vmax = 1  # Fixed as we now use binary classification

    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=0, vmax=vmax))
    sm.set_array([])

    for i, t in enumerate(types):
        if t == -1:
            continue

        color = 'r' if predicted_energies[i] > 0.5 else 'b'
        marker = 'x' if true_energies[i] > 0.5 else 'o'
        
        ax1.scatter(coords[i, 0], coords[i, 1], coords[i, 2], color=color, marker=marker)
        ax2.scatter(coords[i, 0], coords[i, 1], coords[i, 2], color=color, marker=marker)

    plt.suptitle(f'Sample {sample_index}')
    cbar = fig.colorbar(sm, ax=[ax1, ax2], shrink=0.5, aspect=20)
    cbar.set_label('Classification (0 or 1)')

    plt.show()
accs_list = []
# Process each file
for npz_file_idx, npz_file in enumerate(npz_files):
    #if npz_file_idx > 1:  # Only process the first two files for demonstration
    #    break

    feats, frac_labels, tot_labels, tot_truth_e = load_data_from_npz(npz_file)
    segmentation_logits, energies = model.predict(feats)
    segmentation_logits = np.squeeze(segmentation_logits, axis=-1)

    for sample_idx in range(len(feats)):
        valid_types = feats[sample_idx][:, -1]

        energies = tot_labels[sample_idx]
        

        valid_indices = valid_types != 1  # Exclude -1 types for plotting
        if sum(energies[valid_indices]) < 35:
            continue
        energy_types = feats[sample_idx][:, 6]
        coords = feats[sample_idx][valid_indices, :3]

        # Filtering only energy_type == 1
        energy_type_mask = energy_types == 1

        # Apply the same mask to all data arrays
        energies_filtered = energies[energy_type_mask]
        predicted_classes = np.where(segmentation_logits[sample_idx] > 0, 1, 0)[energy_type_mask]
        true_classes = np.where(frac_labels[sample_idx] > 0.5, 1, 0)[energy_type_mask]

        # Calculate energy-weighted accuracy
        correct_predictions = predicted_classes == true_classes
        correct_energy = np.sum(energies_filtered[correct_predictions])
        total_energy = np.sum(energies_filtered)

        energy_weighted_accuracy = correct_energy / (total_energy + 1e-5)

        # Count predictions
        predicted_count = np.sum(predicted_classes)
        true_count = np.sum(true_classes)

        print("Predicted Count: ", predicted_count)
        print("True Count: ", true_count)
        print("Energy Weighted Accuracy: ", energy_weighted_accuracy)
        if energy_weighted_accuracy != np.nan:
            accs_list.append(energy_weighted_accuracy)
        print()

print("Average energy weighted acc: ",np.mean(accs_list))
        # Uncomment below to plot adjusted classes
        # plot_energies(predicted_classes, true_classes, coords, types_for_plotting, sample_idx)


Predicted Count:  17
True Count:  37
Energy Weighted Accuracy:  0.520162050280483

Predicted Count:  105
True Count:  102
Energy Weighted Accuracy:  0.9995713512928054

Predicted Count:  0
True Count:  128
Energy Weighted Accuracy:  0.0

Predicted Count:  99
True Count:  98
Energy Weighted Accuracy:  0.9999998242437083

Predicted Count:  97
True Count:  96
Energy Weighted Accuracy:  0.9999998382278387

Predicted Count:  101
True Count:  86
Energy Weighted Accuracy:  0.9904885937387574

Predicted Count:  1
True Count:  13
Energy Weighted Accuracy:  0.5992764077552647

Predicted Count:  0
True Count:  39
Energy Weighted Accuracy:  0.901479187581768

Predicted Count:  113
True Count:  112
Energy Weighted Accuracy:  0.9999999041395218

Predicted Count:  50
True Count:  71
Energy Weighted Accuracy:  0.7303394679135145

Predicted Count:  98
True Count:  98
Energy Weighted Accuracy:  0.9999998868992003

Predicted Count:  3
True Count:  22
Energy Weighted Accuracy:  0.871671764373442

Predicte