In [None]:
import numpy as np
import glob
import os
from models.JetPointNet import PointNetSegmentation


def load_data_from_npz(npz_file):
    data = np.load(npz_file)
    feats = data['feats']  # Shape: (num_samples, 859, 6)
    frac_labels = data['frac_labels']  # Shape: (num_samples, 859)
    tot_labels = data['tot_labels']  # Shape: (num_samples, 859)
    tot_truth_e = data['tot_truth_e']  # Shape: (num_samples, 859) (This is the true total energy deposited by particles into this cell)
    return feats, frac_labels, tot_labels, tot_truth_e

# Setup
os.environ['CUDA_VISIBLE_DEVICES'] = ""  # Disable GPU
model_path = "saved_model/PointNetModel.keras"

TEST_DIR = '/data/mjovanovic/jets/processed_files/2000_events_w_fixed_hits/SavedNpz/test'

model = PointNetSegmentation(num_points=859, num_classes=1)
model.load_weights(model_path)

npz_files = glob.glob(os.path.join(TEST_DIR, '*.npz'))



In [11]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def plot_energies(predicted_energies, true_energies, coords, types, sample_index):
    fig = plt.figure(figsize=(14, 7))
    ax1 = fig.add_subplot(121, projection='3d')
    ax2 = fig.add_subplot(122, projection='3d')

    ax1.set_title('Predicted Energies')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    ax2.set_title('Actual Energies')
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax2.set_zlabel('Z')

    vmax = np.max([predicted_energies, true_energies])

    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=0, vmax=vmax))
    sm.set_array([])

    for i, t in enumerate(types):
        color = 'r' if t == 2 else 'b'
        marker = 'x' if t == 2 else 'o'
        if t == -1:
            continue
        scatter1 = ax1.scatter(coords[i, 0], coords[i, 1], coords[i, 2], c=[predicted_energies[i]], cmap='viridis', vmin=0, vmax=vmax, marker=marker)
        scatter2 = ax2.scatter(coords[i, 0], coords[i, 1], coords[i, 2], c=[true_energies[i]], cmap='viridis', vmin=0, vmax=vmax, marker=marker)

    plt.suptitle(f'Sample {sample_index}')
    cbar = fig.colorbar(sm, ax=[ax1, ax2], shrink=0.5, aspect=20)
    cbar.set_label('Energy (MeV)')
    plt.show()

for npz_file_idx, npz_file in enumerate(npz_files):
    if npz_file_idx > 1:
        break
    feats, frac_labels, tot_labels, tot_truth_e = load_data_from_npz(npz_file)
    predictions = model.predict(feats)
    reconstruction_output, segmentation_output = predictions[0], predictions[1]
    
    abs_predicted = np.abs(np.squeeze(reconstruction_output, axis=-1))
    

    for sample_idx in range(min(10, len(feats))):
        types = feats[sample_idx][:, -1]
        valid_indices = types != -1

        segmentation_mask = (segmentation_output[sample_idx].squeeze() > np.median(segmentation_output[sample_idx][valid_indices])).astype(int)

        coords = feats[sample_idx][valid_indices, :3]
        types_for_plotting = types[valid_indices]

        energy_indices = types == 0
        energy_mask = segmentation_mask * energy_indices
        
        predicted_energies_eval = abs_predicted[sample_idx] * energy_mask
        true_energies_eval = tot_labels[sample_idx] * energy_indices

        tot_predicted = np.sum(predicted_energies_eval)
        tot_true = np.sum(true_energies_eval)
        
        print("Predicted Energy: ", tot_predicted)
        print("True Energy: ", tot_true)
        print("Percentage Difference: ", 100 * np.abs(tot_predicted - tot_true) / tot_true, "%")

        predicted_energies_plot = abs_predicted[sample_idx][valid_indices] * segmentation_mask[valid_indices]
        true_energies_plot = tot_labels[sample_idx][valid_indices]

        print(np.median(segmentation_output[sample_idx][valid_indices]))
        #print(segmentation_output[sample_idx][valid_indices])
        #plot_energies(predicted_energies_plot, true_energies_plot, coords, types_for_plotting, sample_idx)


Predicted Energy:  0.0
True Energy:  7583.123
Percentage Difference:  100.0 %
0.13583921
Predicted Energy:  7.528671369887888
True Energy:  6367.1475
Percentage Difference:  99.88175754659247 %
0.034045845
Predicted Energy:  65.21424984931946
True Energy:  29393.334
Percentage Difference:  99.77813251846835 %
0.4574094
Predicted Energy:  0.0
True Energy:  15073.28
Percentage Difference:  100.0 %
0.19780605
Predicted Energy:  0.0
True Energy:  65.23077
Percentage Difference:  100.0 %
0.044190902
Predicted Energy:  14.41098044347018
True Energy:  4411.5977
Percentage Difference:  99.67333874105555 %
0.33921048
Predicted Energy:  0.0
True Energy:  817.6396
Percentage Difference:  100.0 %
0.44343713
Predicted Energy:  0.0
True Energy:  507.77164
Percentage Difference:  100.0 %
0.08759536
Predicted Energy:  58.9533970952034
True Energy:  396.6472
Percentage Difference:  85.13706912982065 %
0.06875372
Predicted Energy:  6.263466137461364
True Energy:  1142.1328
Percentage Difference:  99.451

  print("Percentage Difference: ", 100 * np.abs(tot_predicted - tot_true) / tot_true, "%")
