In [17]:
import uproot
import awkward as ak
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Define file locations
FILE_LOC = "/fast_scratch_1/jbohm/cell_particle_deposit_learning/delta/delta_root_files/delta_full.root"
CELL_GEO_FILE_LOC = "/data/atlas/data/rho_delta/rho_small.root"

# Open the ROOT files
events = uproot.open(FILE_LOC + ":EventTree")
cell_geo_tree = uproot.open(CELL_GEO_FILE_LOC + ":CellGeo")

# Load the number of truth particles per event
n_truth_part = events["nTruthPart"].array()

# Find the first event with 8 or more truth particles
num_particles = 8
first_index = ak.firsts(ak.where(n_truth_part >= num_particles))

if ak.count_nonzero(first_index) == 0:
    print(f"No event with {num_particles} or more truth particles was found.")
else:
    # Get the actual index of the first event that meets the condition
    event_index = first_index[0]
    print(f"Event {event_index} has {n_truth_part[event_index]} truth particles.")

    # Load data for the specific event
    event_data = events.arrays(entry_start=event_index, entry_stop=event_index + 1)
    cell_geo_data = cell_geo_tree.arrays()

    # Extract the geometric data for cells
    cell_ids = cell_geo_data["cell_geo_ID"]
    cell_eta = cell_geo_data["cell_geo_eta"]
    cell_phi = cell_geo_data["cell_geo_phi"]
    cell_rPerp = cell_geo_data["cell_geo_rPerp"]

    # Print the lengths and types of relevant arrays
    print('Length of cell_ids:', len(cell_ids))
    print('Type of cell_ids:', cell_ids.type)
    print('Length of cell_eta:', len(cell_eta))
    print('Type of cell_eta:', cell_eta.type)
    print('Length of cell_phi:', len(cell_phi))
    print('Type of cell_phi:', cell_phi.type)
    print('Length of cell_rPerp:', len(cell_rPerp))
    print('Type of cell_rPerp:', cell_rPerp.type)

    # Attempt to access nested data within cell_geo_data fields
    if len(cell_ids) == 1 and isinstance(cell_ids[0], ak.Array):
        print('Nested data found in cell_geo_data fields')
        cell_ids = cell_ids[0]
        cell_eta = cell_eta[0]
        cell_phi = cell_phi[0]
        cell_rPerp = cell_rPerp[0]
        print('Re-accessed cell_ids:', cell_ids)
        print('Re-accessed cell_eta:', cell_eta)
        print('Re-accessed cell_phi:', cell_phi)
        print('Re-accessed cell_rPerp:', cell_rPerp)

    # Print keys and a few entries of cell_geo_data to understand its structure
    print('Fields in cell_geo_data:', cell_geo_data.fields)
    for key in cell_geo_data.fields:
        print(f'First few entries of {key}:', cell_geo_data[key][:3])

    # Get the cell IDs for each cluster
    cluster_cell_ids = event_data["cluster_cell_ID"]
    # Print the structure of cluster_cell_ids
    print('Structure of cluster_cell_ids:', cluster_cell_ids.type)

    # Print the first few entries of cluster_cell_ids to see the nested structure
    print('First few entries of cluster_cell_ids:', cluster_cell_ids[:3])

    # Get the energy deposited in each cell for each cluster
    cluster_cell_energies = event_data["cluster_cell_E"]
    # Print the structure of cluster_cell_energies
    print('Structure of cluster_cell_energies:', cluster_cell_energies.type)

    # Print the first few entries of cluster_cell_energies to see the nested structure
    print('First few entries of cluster_cell_energies:', cluster_cell_energies[:3])

    # Calculate the average energy deposited per cell for each cluster
    cluster_energy = event_data["cluster_ENG_CALIB_TOT"]
    cluster_nCells = event_data["cluster_nCells"]

    # Compute the average energy per cell for each cluster
    average_cluster_energy = cluster_energy / cluster_nCells
    # Print the average energy for each cluster
    print('Average cluster energy:', average_cluster_energy)

    # Print a few entries of cluster_cell_ids to inspect their values
    print('Example cluster_cell_ids:', cluster_cell_ids[0][:10])
    # Extract the geometric data for cells
    cell_ids = ak.to_numpy(cell_geo_data["cell_geo_ID"]).flatten()

    # Mapping cell_geo_ID to indices
    cell_id_to_index = {id: idx for idx, id in enumerate(cell_ids)}

    # Convert cluster_cell_ids to a numpy array and flatten it
    cluster_cell_ids = ak.to_numpy(cluster_cell_ids).flatten()
    # Diagnostic prints to check shapes and sizes
    print('Shape of cell_ids:', cell_ids.shape)
    print('Shape of cluster_cell_ids:', cluster_cell_ids.shape)
    print('Length of cell_geo_data["cell_geo_eta"]:', len(cell_geo_data["cell_geo_eta"]))
    print('Maximum index in cluster_cell_ids:', np.max(cluster_cell_ids))

    # Verify that the maximum index in cluster_cell_ids does not exceed the length of cell_geo_data arrays
    if np.max(cluster_cell_ids) >= len(cell_geo_data["cell_geo_eta"]):
        print('Index out of bounds error likely to occur.')

    # Convert cell IDs in each cluster to their corresponding indices
    cluster_cell_indices = np.array([cell_id_to_index[id] for id in cluster_cell_ids])

    # Extract x, y, z coordinates for cells in each cluster
    x_cells = cell_rPerp[cluster_cell_indices] * np.cos(cell_phi[cluster_cell_indices])
    y_cells = cell_rPerp[cluster_cell_indices] * np.sin(cell_phi[cluster_cell_indices])
    z_cells = cell_rPerp[cluster_cell_indices] / np.tan(2 * np.arctan(np.exp(-cell_eta[cluster_cell_indices])))

    # Calculate the energy for each cell based on its cluster
    cell_energy = np.repeat(average_cluster_energy, cluster_nCells)

    # Calculate the logarithm of the cell energy
    min_positive_energy = np.min(cell_energy[cell_energy > 0])
    cell_energy_log = np.log(np.maximum(cell_energy, min_positive_energy))

    # Create a 3D plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Scatter plot using x, y, z coordinates and color by the log of cell energy
    scatter = ax.scatter(x_cells, y_cells, z_cells, c=cell_energy_log, cmap='viridis')

    # Add a color bar
    color_bar = plt.colorbar(scatter)
    color_bar.set_label('Log of Cell-Average Energy Deposits')

    # Set labels
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')

    # Show the plot
    plt.show()


IndentationError: unexpected indent (3147777866.py, line 94)