Imports and File Setup

In [46]:
import numpy as np
import sys
import awkward as ak
import uproot as ur
import matplotlib.pyplot as plt
from numpy import genfromtxt
import math
import matplotlib as mpl
import vector
from tqdm import tqdm

# Custom utility functions
from utils.particle_data_visualize_plot_utils import *
from utils.track_metadata import *
from utils.data_processing_util import *

def setup_environment():
    # Append custom module paths to system path
    sys.path.append('~/Work/LCStudies')
    sys.path.append('~/Work/PointNet_Segmentation/')
    
    # Define and return base directory for all delta files
    delta_dir = "/data/mjovanovic/cell_particle_deposit_learning/delta/delta_processed_test_files/"
    return delta_dir

# Usage
delta_dir = setup_environment()



In [47]:
def initialize_parameters():
    params = {
        'files_name': "delta_full",
        'include_delta_p_pi0': True,
        'include_delta_n_pi0': False,
        'include_delta_p_pipm': False,
        'include_delta_n_pipm': True,
        'niche_case': "1_track_1_n_3_pi0",  # only one niche case available 1_track_1_n_3_pi0, otherwise set to anything else ie. "None"
        'len_file': 6000,
        'i_low': 0,
        'i_high': 4,
        'BATCH_SIZE': 100,  # 80 for two files, 100 for other tests
        'LOG_ENERGY_MEAN': -1,  # unrounded mean is ~ -0.93, used to normalize the log energy
        'LOG_MEAN_TRACK_MOMETUM': 2
    }
    return params

# Usage
params = initialize_parameters()


In [48]:
# LOAD PREDICTIONS 
epoch = 99
model_file_path =  "/data/mjovanovic/cell_particle_deposit_learning/delta_train/tr_100_val_10_tst_5_delta_1_track_1_n_3_pi0_lr_1e-2_BS_100_no_tnets_add_min_dist"

labels_unmasked = ak.Array(np.load(model_file_path + "/tests/labels.npy"))
preds_unmasked = ak.Array(np.load(model_file_path + "/tests/preds_" + str(epoch) + ".npy"))

labels = labels_unmasked[labels_unmasked[:,:,0] != -1]
preds = preds_unmasked[labels_unmasked[:,:,0] != -1]

In [None]:
def load_dataset_and_geometry(delta_dir, params):
    # Load dataset from preprocessed test data files
    dataset = "delta"
    file_names = [delta_dir + params['files_name'] + "_len_" + str(params['len_file']) + "_i_" + str(i) + ".npy" for i in range(params['i_low'], params['i_high'] + 1)]

    # Load cell geo tree dict
    cell_geo_tree_file_path = "/data/atlas/data/rho_delta/rho_small.root"  # define the path to the ROOT file
    file = ur.open(cell_geo_tree_file_path)
    cell_geo_tree = file["CellGeo"]
    
    # Extract node feature names
    node_feature_names = cell_geo_tree.keys()[1:7]  # example to get some keys: 'cell_geo_sampling', 'cell_geo_eta', etc.
    
    # Load cell geo data
    cell_geo_data = cell_geo_tree.arrays(library='np')
    cell_geo_ID = cell_geo_data['cell_geo_ID'][0]
    sorter = np.argsort(cell_geo_ID)
    
    return file_names, node_feature_names, cell_geo_data, sorter

# Usage
file_names, node_feature_names, cell_geo_data, sorter = load_dataset_and_geometry(delta_dir, params)

: 

: 

In [None]:
def should_process_event(event_data, params, event_idx, dataset):
    num_tracks = event_data["nTrack"][event_idx]
    decay_group = event_data["decay_group"][event_idx]

    # Conditions for including events based on number of tracks and decay group
    include_conditions = (
        (dataset == "delta" and 
            (((num_tracks == 1) and (
                (params['include_delta_p_pi0'] and decay_group == decay_group["delta+_p"]) or
                (params['include_delta_n_pipm'] and (decay_group == decay_group["delta+_n"] or decay_group == decay_group["delta-"])))) or
            ((num_tracks == 0) and params['include_delta_n_pi0'] and decay_group == decay_group["delta0_n"]) or
            ((num_tracks == 2) and params['include_delta_p_pipm'] and (decay_group == decay_group["delta0_p"] or decay_group == decay_group["delta++"]))))
        or (dataset == "rho" and num_tracks == 1)
    )

    return include_conditions


def post_process_events(processed_event_data, processed_event_track_data, params, num_events_saved):
    # Filter to BS multiple num samples/files - to make set even with preds from a trained model
    BS_multiple_num_samples = math.floor(num_events_saved / params['BATCH_SIZE']) * params['BATCH_SIZE']
    if BS_multiple_num_samples == num_events_saved:
        BS_multiple_num_samples -= params['BATCH_SIZE']

    for key in processed_event_data:
        processed_event_data[key] = processed_event_data[key][:BS_multiple_num_samples]
    for key in processed_event_track_data:
        processed_event_track_data[key] = processed_event_track_data[key][:BS_multiple_num_samples]

    return processed_event_data, processed_event_track_data

In [None]:



# cluster data dict to look up data by feature name
processed_event_data = {}
processed_event_track_data = {}
processed_event_track_flags = []

cell_geo_ID = cell_geo_data['cell_geo_ID'][0]

samples_count = 0
max_cells = 0
num_cells = 0

for feature in [*node_feature_names, "cell_eta", "trackP", "trackEta", "trackPhi", "truthPartE", "truthPartPt", "truthPartEta", "truthPartPhi", 'x', 'y', 'z', 'cell_hitsTruthIndex', 'cell_hitsTruthE', 'frac_pi0_energy', 'class_frac_pi0_energy', 'cell_E', 'cell_E_weight', 'sampling_layer', 'truthPartPt', 'truthPartEta', 'truthPartPhi', 'clus_idx', 'clus_em_prob', 'decay_group', 'cell_labels', 'cell_part_deposit_labels', "E_frac_focused", "cell_weights"]:#, 'delta_R']:
    processed_event_data[feature] = []

for feature in ['x', 'y', 'z', 'P', 'min_dist', 'min_eta', 'min_phi', 'sampling_layer', 'track_eta', 'track_phi', 'non_null_tracks', 'track_classes']:
    processed_event_track_data[feature] = []

cell_geo_ID = cell_geo_data['cell_geo_ID'][0]


for file_i, preprocessed_file_name in enumerate(file_names):
    events_arr = np.load(preprocessed_file_name, allow_pickle=True).item()
    
    num_events_saved = 0
    max_cells = 0
    num_cells = 0

    num_events = len(events_arr["eventNumber"])
    event_data = events_arr

    delta_n_pipm_count = 0
    
    print("LOAD FILE:", file_i, "/", len(file_names))
    for event_idx in tqdm(range(num_events)):
        num_tracks = event_data["nTrack"][event_idx]

        if (len(event_data["cluster_cell_ID"][event_idx]) and \
        (dataset == "delta" and 
            (((num_tracks == 1) and (\
                (include_delta_p_pi0 and event_data["decay_group"][event_idx] == decay_group["delta+_p"]) \
                or (include_delta_n_pipm and (event_data["decay_group"][event_idx] == decay_group["delta+_n"] or event_data["decay_group"][event_idx] == decay_group["delta-"])))) or \
            ((num_tracks == 0) and include_delta_n_pi0 and event_data["decay_group"][event_idx] == decay_group["delta0_n"]) or \
            ( (num_tracks == 2) and include_delta_p_pipm and (event_data["decay_group"][event_idx] == decay_group["delta0_p"] or event_data["decay_group"][event_idx] == decay_group["delta++"])))) \
        or (dataset == "rho" and num_tracks == 1)):


            # if case where 1_track_1_n_3_pi0 need to discard 5 delta->neutron+pipm events for every 1 kept
            if niche_case == "1_track_1_n_3_pi0" and (event_data["decay_group"][event_idx] == decay_group["delta+_n"] or event_data["decay_group"][event_idx] == decay_group["delta-"]):
                if delta_n_pipm_count == 0:
                    delta_n_pipm_count += 1

                else:
                    delta_n_pipm_count = (delta_n_pipm_count + 1) % 6
                    continue # don't save event when delta_n_pipm_count = 1, 2, 3, 4, 5
            
            if num_tracks > 0:
                x_tracks = []
                y_tracks = []
                z_tracks = []
                eta_tracks = []
                phi_tracks = []
                rPerp_track = []
                non_null_tracks = []
                for track_idx in range(num_tracks):
                    x_tracks.append([])
                    y_tracks.append([])
                    z_tracks.append([])
                    eta_tracks.append([])
                    phi_tracks.append([])
                    rPerp_track.append([])
                    non_null_tracks.append([])
                    # get the eta, phi, and rPerp of each layers hit
                    for layer_name in calo_layers:
                        if has_fixed_r[layer_name]:
                            eta_tracks[track_idx].append(event_data['trackEta_' + layer_name][event_idx][track_idx])
                            phi_tracks[track_idx].append(event_data['trackPhi_' + layer_name][event_idx][track_idx])
                            rPerp_track[track_idx].append(fixed_r[layer_name])
                        else:
                            eta = event_data['trackEta_' + layer_name][event_idx][track_idx]
                            eta_tracks[track_idx].append(eta)
                            phi_tracks[track_idx].append(event_data['trackPhi_' + layer_name][event_idx][track_idx])
                            
                            z = fixed_z[layer_name]
                            aeta = np.abs(eta)
                            rPerp = z*2*np.exp(aeta)/(np.exp(2*aeta) - 1)
                            rPerp_track[track_idx].append(rPerp)
                    
                    # convert each hit to the cartesian coords
                    thetas = [2*np.arctan(np.exp(-eta)) for eta in eta_tracks[track_idx]]
                    x_tracks_i, y_tracks_i, z_tracks_i = spherical_to_cartesian(rPerp_track[track_idx], phi_tracks[track_idx], thetas)
                    x_tracks[track_idx] = x_tracks_i
                    y_tracks[track_idx] = y_tracks_i
                    z_tracks[track_idx] = z_tracks_i

                    non_null_tracks_i = np.full(NUM_TRACK_POINTS, True)

                    # if eta > 4.9 or phi > pi mark the track as null (why??)
                    non_null_tracks_i[(np.abs(eta_tracks[track_idx]) >= 2.5) | (np.abs(phi_tracks[track_idx]) > np.pi)] = False

                    non_null_tracks[track_idx] = non_null_tracks_i

                if dataset == "delta":
                    if len(non_null_tracks) != 2:
                        non_null_tracks = np.concatenate((non_null_tracks, np.full((1, NUM_TRACK_POINTS), False)))
                        x_tracks = np.concatenate((x_tracks, np.full((1, NUM_TRACK_POINTS), False)))
                        y_tracks = np.concatenate((y_tracks, np.full((1, NUM_TRACK_POINTS), False)))
                        z_tracks = np.concatenate((z_tracks, np.full((1, NUM_TRACK_POINTS), False)))
            else:
                non_null_tracks = np.full((max_num_tracks, NUM_TRACK_POINTS), False)
                x_tracks, y_tracks, z_tracks = np.zeros((max_num_tracks, NUM_TRACK_POINTS)), np.zeros((max_num_tracks, NUM_TRACK_POINTS)), np.zeros((max_num_tracks, NUM_TRACK_POINTS))
            

            cell_IDs = event_data['cluster_cell_ID'][event_idx]
            cell_IDs = cell_IDs
            cell_ID_map = sorter[np.searchsorted(cell_geo_ID, cell_IDs, sorter=sorter)]

            # get cluster cell energy
            cell_E = event_data["cluster_cell_E"][event_idx]
        
            cell_weights = cell_E / np.sum(cell_E)

            # node features
            node_features = {}
            for feature in node_feature_names:
                node_features[feature] = cell_geo_data[feature][0][cell_ID_map]

            # get cartesian coords
            thetas = [2*np.arctan(np.exp(-eta)) for eta in node_features["cell_geo_eta"]]
            x, y, z = spherical_to_cartesian(node_features["cell_geo_rPerp"], node_features["cell_geo_phi"], thetas)

            # label cells (4 classes)
            # all decay groups only have 2 types particles depositing energy, a proton/neutron and a pion
            if dataset == "delta":
                class_part_idx_1 = 0
                class_part_idx_not_1 = 0

                if event_data["decay_group"][event_idx] == decay_group["delta+_p"]:
                    class_part_idx_1 = part_deposit_type_class["track_of_interest"] # proton
                    class_part_idx_not_1 = part_deposit_type_class["pi0"] # pi0
                elif event_data["decay_group"][event_idx] == decay_group["delta+_n"] or event_data["decay_group"][event_idx] == decay_group["delta-"]:
                    class_part_idx_1 = part_deposit_type_class["pi0"] # neutron # TODO: update this is temp
                    class_part_idx_not_1 = part_deposit_type_class["track_of_interest"] # pi+/-
                elif event_data["decay_group"][event_idx] == decay_group["delta0_n"]:
                    class_part_idx_1 = part_deposit_type_class["other_neutral"] # neutron
                    class_part_idx_not_1 = part_deposit_type_class["pi0"] # pi0
                # if decay is delta++ or delta0_p then set labels elsewhere (since 2 tracks)

                # get all cells with particle idx 1 (proton/neutron)
                cut_part_idx_1_deposits = (ak.Array(event_data["cluster_cell_hitsTruthIndex"][event_idx]) == 1)
                # get fraction of energy from the proton/neutron (always p/n in a rho event - it deposits some energy and the pion deposits the remaining)
                frac_cell_energy_from_part_idx_1 = ak.sum(ak.Array(event_data["cluster_cell_hitsTruthE"][event_idx])[cut_part_idx_1_deposits], axis=1) / ak.sum(event_data["cluster_cell_hitsTruthE"][event_idx], axis=1)
                # if frac_cell_energy_from_part_idx_1 < 0.5 set label class_part_idx_not_1 else set cell label to class_part_idx_1
                cell_part_deposit_labels = [class_part_idx_not_1 if cell_frac_cell_energy_from_part_idx_1 < 0.5 else class_part_idx_1 for cell_frac_cell_energy_from_part_idx_1 in frac_cell_energy_from_part_idx_1]
            # only pi0 and pi+/- depositing energy -> binary classification
            elif dataset == "rho":
                # if the pi0 deposits the majority of the energy label cell 1 else if pi+/- deposits majority label cell 0
                frac_pi0_energy = ak.sum(event_data["cluster_cell_hitsTruthE"][event_idx][event_data["cluster_cell_hitsTruthIndex"][event_idx] != 1], axis=1)/ak.sum(event_data["cluster_cell_hitsTruthE"][event_idx], axis=1)
                cell_part_deposit_labels = [1 if cell_frac_pi0_energy > 0.5 else 0 for cell_frac_pi0_energy in frac_pi0_energy]


            # if the particle has 2 tracks match the track to the particle closest & threshold that they must be close enough together
            # for delta dataset either delta++ -> proton + pi+/- or delta0 -> proton + pi+/-
            track_part_dist_thresh = 1
            if num_tracks == 2:

                part1_idx = 1 # proton
                part2_idx = 2 # charged pion

                part1_phi = event_data["truthPartPhi"][event_idx][part1_idx]
                part1_eta = event_data["truthPartEta"][event_idx][part1_idx]
                part1_pt = event_data["truthPartPt"][event_idx][part1_idx]

                part2_phi = event_data["truthPartPhi"][event_idx][part2_idx]
                part2_eta = event_data["truthPartEta"][event_idx][part2_idx]
                part2_pt = event_data["truthPartPt"][event_idx][part2_idx]

                track1_phi = event_data["trackPhi"][event_idx][0]
                track1_eta = event_data["trackEta"][event_idx][0]
                track1_pt = event_data["trackPt"][event_idx][0]

                track2_phi = event_data["trackPhi"][event_idx][1]
                track2_eta = event_data["trackEta"][event_idx][1]
                track2_pt = event_data["trackPt"][event_idx][1]

                part1_track1_dist = measure_track_part_dists(track1_phi, track1_eta, track1_pt, part1_phi, part1_eta, part1_pt)
                part1_track2_dist = measure_track_part_dists(track2_phi, track2_eta, track2_pt, part1_phi, part1_eta, part1_pt)
                part2_track1_dist = measure_track_part_dists(track1_phi, track1_eta, track1_pt, part2_phi, part2_eta, part2_pt)
                part2_track2_dist = measure_track_part_dists(track2_phi, track2_eta, track2_pt, part2_phi, part2_eta, part2_pt)

                # either pair part1 with track1 and part2 with track2 or part1 with track2 and part2 with track1
                # or discard event if no pairing exists with both track-part dists < thresh
                paring_one_sum_dist = part1_track1_dist + part2_track2_dist if part1_track1_dist < track_part_dist_thresh and part2_track2_dist < track_part_dist_thresh else 2*track_part_dist_thresh
                paring_two_sum_dist = part1_track2_dist + part2_track1_dist if part1_track1_dist < track_part_dist_thresh and part2_track2_dist < track_part_dist_thresh else 2*track_part_dist_thresh
                
                if max(paring_one_sum_dist, paring_two_sum_dist) >= 2*track_part_dist_thresh:
                    num_tracks = 0
                else:
                    if paring_one_sum_dist < paring_two_sum_dist:
                        pairing_one = True
                    else:
                        pairing_one = False

            track_idx = 0
            added_one_sample = False # for each event add one sample to dataset

            non_null_tracks = np.array(non_null_tracks)
            x_tracks = np.array(x_tracks)
            y_tracks = np.array(y_tracks)
            z_tracks = np.array(z_tracks)
            #non_null_tracks = np.array(flatten_one_layer(non_null_tracks))
            #x_tracks = np.array(flatten_one_layer(x_tracks))
            #y_tracks = np.array(flatten_one_layer(y_tracks))
            #z_tracks = np.array(flatten_one_layer(z_tracks))
            
            x_tracks[~non_null_tracks] = 0
            y_tracks[~non_null_tracks] = 0
            z_tracks[~non_null_tracks] = 0

            cell_has_E_deposit = ak.sum(event_data["cluster_cell_hitsTruthE"][event_idx], axis=1) > 0
            num_cells = len(cell_E[cell_has_E_deposit])
            
            # execute once for 0-1 track, and 2 times for 2 tracks
            while not added_one_sample or track_idx < num_tracks:
                processed_event_data["cell_E"].append(cell_E[cell_has_E_deposit])
                processed_event_data["x"].append(x[cell_has_E_deposit])
                processed_event_data["y"].append(y[cell_has_E_deposit])
                processed_event_data["z"].append(z[cell_has_E_deposit])
                processed_event_data["cell_weights"].append(cell_weights[cell_has_E_deposit])  

                # extra features for analysis
                processed_event_data["clus_idx"].append(ak.Array(event_data["clus_idx"][event_idx][cell_has_E_deposit]))
                processed_event_data["clus_em_prob"].append(ak.Array(event_data["clus_em_prob"][event_idx][cell_has_E_deposit]))
                processed_event_data["sampling_layer"].append(ak.Array(node_features["cell_geo_sampling"][cell_has_E_deposit]))
                processed_event_data["cell_hitsTruthIndex"].append(ak.Array(event_data["cluster_cell_hitsTruthIndex"][event_idx][cell_has_E_deposit]))
                processed_event_data["cell_hitsTruthE"].append(ak.Array(event_data["cluster_cell_hitsTruthE"][event_idx][cell_has_E_deposit]))
                processed_event_data["cell_eta"].append(node_features["cell_geo_eta"])

                processed_event_data["truthPartPt"].append(ak.Array(event_data["truthPartPt"][event_idx]))
                processed_event_data["truthPartE"].append(ak.Array(event_data["truthPartE"][event_idx]))
                processed_event_data["truthPartEta"].append(ak.Array(event_data["truthPartEta"][event_idx]))
                processed_event_data["truthPartPhi"].append(ak.Array(event_data["truthPartPhi"][event_idx]))  

                processed_event_data["trackPhi"].append(ak.Array(event_data["trackPhi"][event_idx]))  
                processed_event_data["trackEta"].append(ak.Array(event_data["trackEta"][event_idx]))  
                processed_event_data["trackP"].append(ak.Array(event_data["trackP"][event_idx]))  

                processed_event_data["decay_group"].append(event_data["decay_group"][event_idx])
                processed_event_data["E_frac_focused"].append(np.array(frac_cell_energy_from_part_idx_1)[cell_has_E_deposit])
                processed_event_data["cell_part_deposit_labels"].append(np.array(cell_part_deposit_labels)[cell_has_E_deposit])

                

                if dataset == "delta":
                    track_classes = np.zeros((2, NUM_TRACK_POINTS))
                    track_Ps = np.zeros((2, NUM_TRACK_POINTS))
                else:
                    track_classes = np.zeros((1, NUM_TRACK_POINTS))
                    track_Ps = np.zeros((1, NUM_TRACK_POINTS))
                        
                if num_tracks == 2:
                    # set track_idx to be the track of interest (for track idx 0 track 0 is of interest, else track 1)
                    if track_idx == 0:
                        track_classes[0] = np.ones(NUM_TRACK_POINTS)
                        track_classes[1] = np.full(NUM_TRACK_POINTS, 2)
                    else:
                        track_classes[1] = np.ones(NUM_TRACK_POINTS)
                        track_classes[0] = np.full(NUM_TRACK_POINTS, 2)

                    if (pairing_one and track_idx == 0) or (not pairing_one and track_idx == 1): # pair track 0 and part 1
                        class_part_idx_1 = 0 # track of interest
                        class_part_idx_not_1 = 1 # other tracked charged particle 

                    else: # pairing 1 and track idx == 1 or paring 2 and track idx == 0
                        class_part_idx_1 = 1
                        class_part_idx_not_1 = 0

                    cell_part_deposit_labels = [class_part_idx_not_1 if cell_frac_cell_energy_from_part_idx_1 < 0.5 else class_part_idx_1 for cell_frac_cell_energy_from_part_idx_1 in frac_cell_energy_from_part_idx_1]
                    #print("cell_labels:", cell_labels)
                    track_1_P =  np.log10((event_data["trackP"][event_idx][0])) - LOG_ENERGY_MEAN
                    track_2_P =  np.log10((event_data["trackP"][event_idx][1])) - LOG_ENERGY_MEAN
                    track_Ps[0] = np.full(NUM_TRACK_POINTS, track_1_P)
                    track_Ps[1] = np.full(NUM_TRACK_POINTS, track_2_P)
                
                elif num_tracks == 1:
                    # get tracks momentum readout
                    track_P = np.log10(event_data['trackP'][event_idx][0]) - LOG_MEAN_TRACK_MOMETUM
                    track_Ps[0] = np.full(NUM_TRACK_POINTS, track_P)
                    track_classes[0] = np.ones(NUM_TRACK_POINTS)

                #track_classes = np.array(flatten_one_layer(track_classes))
                #track_Ps = np.array(flatten_one_layer(track_Ps))

                track_classes[~non_null_tracks] = 0
                track_Ps[~non_null_tracks] = 0
                
                # else no tracks => Pt = 0


                processed_event_track_data["x"].append(x_tracks)
                processed_event_track_data["y"].append(y_tracks)
                processed_event_track_data["z"].append(z_tracks)
                processed_event_track_data["P"].append(track_Ps) # don't normalize for analysis

                # track classes - 0 => point, 1 => track of interest, 2 => other track
                processed_event_track_data["track_classes"].append(track_classes)
                
                if num_cells + NUM_TRACK_POINTS*max_num_tracks > max_cells:
                    max_cells = num_cells + NUM_TRACK_POINTS*max_num_tracks

                num_events_saved += 1
                added_one_sample = True
                track_idx += 1

    if BATCH_SIZE:
        # filter to BS multiple num samples/files - to make set even with preds from a trained model
        BS_multiple_num_samples = math.floor(num_events_saved / BATCH_SIZE)*BATCH_SIZE
        print("BS_multiple_num_samples", BS_multiple_num_samples)
        print("num_events_saved", num_events_saved)
        if BS_multiple_num_samples == num_events_saved:
            BS_multiple_num_samples -= BATCH_SIZE
        for key in processed_event_data:
            del processed_event_data[key][BS_multiple_num_samples - num_events_saved:]

        for key in processed_event_track_data:
            del processed_event_track_data[key][BS_multiple_num_samples - num_events_saved:]

        print("num dropped events:", num_events_saved - BS_multiple_num_samples)
        num_events_saved = BS_multiple_num_samples