In [None]:
import uproot
import glob
import numpy as np
import pandas as pd
import time
import os
import re

import pyvista as pv
pv.set_jupyter_backend('trame')  # or 'panel' if using panel

from scipy.constants import epsilon_0, e as q_e
from scipy.interpolate import griddata
from scipy.optimize import curve_fit
from scipy.spatial import cKDTree
from scipy.interpolate import interp1d

# Import interpolate for numerical method
from scipy.interpolate import CubicSpline
import matplotlib.gridspec as gridspec

from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib as mpl
import matplotlib.ticker as ticker

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed

import trimesh
import h5py
from trimesh.points import PointCloud

from common_functions import *


## Load in data to make figures

In [None]:
## READ IN GEOMETRY ## 
geometry = trimesh.load_mesh('geometry/isolated_grains_interpolated.stl') 

In [None]:
iteration_SW, iteration_PE = 80, 42

configIN = "onlysolarwind"  # "both", "onlyphotoemission", "onlysolarwind"
directory = "raw-files/WangInitial8max0.7final11"

if configIN == "both":

    filenames = sorted(glob.glob(f"{directory}/*{iteration_PE}*onlyphotoemission*.txt")) #{iteration}
    print(filenames)
    df_PE  = read_data_format_efficient(filenames,scaling=True)

    filenames = sorted(glob.glob(f"{directory}/*{iteration_SW}*onlysolarwind*.txt")) #{iteration}
    print(filenames)
    df_SW = read_data_format_efficient(filenames,scaling=True)

elif configIN == "onlyphotoemission":

    filenames = sorted(glob.glob(f"{directory}/*{iteration_PE}*onlyphotoemission*.txt")) #{iteration}
    print(filenames)
    df_PE  = read_data_format_efficient(filenames,scaling=True)

elif configIN == "onlysolarwind":

    filenames = sorted(glob.glob(f"{directory}/*{iteration_SW}*onlysolarwind*.txt")) #{iteration}
    print(filenames)
    df_SW = read_data_format_efficient(filenames,scaling=True)

In [None]:
configIN = "onlysolarwind"
directory_path = "raw-files/WangInitial8max0.7final11/"

fileIN = sorted(glob.glob(f"{directory_path}/*iteration*{configIN}*1000000.root"))[0]

print(fileIN.split("/")[-1])
number_str = fileIN.split("/")[-1].split("_")[1]
iterationNUM = int(''.join(filter(str.isdigit, number_str)))
configIN = fileIN.split("/")[-1].split("_")[2]

vars()[f"dfOutput_num{iterationNUM}"]  = calculate_stats(read_rootfile(fileIN.split("/")[-1], directory_path=directory_path),config=configIN)

print(78*"-")

#solar wind returns: protons_inside, electrons_inside 
#photoemission returns: gamma_initial_leading_e_creation, electrons_stopped, gamma_initial_leading_to_e_ejection
#all particles returns: gamma_initial_leading_e_creation, electrons_stopped, protons_inside, electrons_inside                                                                                                             

In [None]:
# save the e- and proton positions
e_pos = np.array(dfOutput_num0[1]["Post_Step_Position_mm"].tolist())
p_pos = np.array(dfOutput_num0[0]["Post_Step_Position_mm"].tolist())

## Calculate quantities for 2D & 3D plotting

In [None]:
# ----------------------------------------------------
# Step 0: Dimension of World from Fieldmap
# ----------------------------------------------------
start_time = time.time()

fieldIN = df_SW[iteration_SW]
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]

center = geometry.centroid
 
initial_mask = (magnitudes > 0) 
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]
 
# Create field cloud
field_cloud = pv.PolyData(points)
field_cloud["E_x"] = vectors[:, 0]
field_cloud["E_y"] = vectors[:, 1]
field_cloud["E_z"] = vectors[:, 2]

# Create points cloud
point_cloud = pv.PolyData(points)
point_cloud["E_mag"] = magnitudes
point_cloud["Ex_val"] = vectors[:,0]
point_cloud["Ey_val"] = vectors[:,1]
point_cloud["Ez_val"] = vectors[:,2]

bbox_bounds = field_cloud.bounds
print(f"Starting points (filtered by mag > 0): {len(points)}")

In [None]:
# ============================================================
# Interpolate field to face centers
# ============================================================
# Step 2: Convert geometry to PyVista if needed
if not isinstance(geometry, pv.PolyData):
    # Convert from trimesh to PyVista
    pv_mesh = pv.PolyData(
        geometry.vertices,
        np.hstack([np.full((len(geometry.faces), 1), 3), geometry.faces])
    )
else:
    pv_mesh = geometry.copy()

face_centers = pv_mesh.cell_centers().points
 
face_field_cloud = pv.PolyData(face_centers)
face_field_interp = face_field_cloud.interpolate(
    field_cloud,
    radius=0.002,
    strategy='closest_point',
    sharpness=3.0,
    null_value=0.0
)
 
# Extract face-centered E-fields
E_x_faces = face_field_interp["E_x"]
E_y_faces = face_field_interp["E_y"]
E_z_faces = face_field_interp["E_z"]
E_vec_faces = np.stack([E_x_faces, E_y_faces, E_z_faces], axis=1)
 
# Face normals
face_normals = pv_mesh.cell_normals
nx = face_normals[:, 0]   # <-- x-direction component
ny = face_normals[:, 1]   # <-- y-direction component
nz = face_normals[:, 2]   # <-- z-direction component
 
# ============================================================
# Compute Maxwell electric pressure (normal)
# ============================================================
E_dot_n = np.einsum('ij,ij->i', E_vec_faces, face_normals)
E_mag_sq = np.einsum('ij,ij->i', E_vec_faces, E_vec_faces)
 
# Normal pressure (scalar)
P_normal_faces = epsilon_0 * (E_dot_n**2 - 0.5 * E_mag_sq)
# Compute X-directed electric pressure
P_x_faces = P_normal_faces * nx   
P_y_faces = P_normal_faces * ny  
P_z_faces = P_normal_faces * nz 

# Save to cell data
pv_mesh.cell_data["electric_pressure"] = P_normal_faces #P_x_faces
pv_mesh.cell_data["E_x"] = E_x_faces #P_x_faces
pv_mesh.cell_data["E_y"] = E_y_faces #P_x_faces
pv_mesh.cell_data["E_z"] = E_z_faces #P_x_faces

In [None]:
# ----------------------------------------------------
# Step 1: Geometry Setup
# ----------------------------------------------------
 
print(f"Starting points (filtered): {len(points)}")

# Step 1: Get particle positions as Nx3 array
points = np.array(dfOutput_num0[0]["Pre_Step_Position_mm"].tolist())

# Step 3: Find nearest face for each particle using PyVista
particle_cloud = pv.PolyData(points)
closest_points = pv_mesh.find_closest_cell(points, return_closest_point=False)
face_ids = closest_points

# Step 4: Accumulate per-face "illumination"
n_faces = pv_mesh.n_cells
face_illum = np.zeros(n_faces)
np.add.at(face_illum, face_ids, 1)

# Step 5: Add illumination values as cell data (face data)
pv_mesh.cell_data['proton_illumination'] = face_illum

In [None]:
# ============================================================
# Step 2: Calculate surface charge on FULL geometry
# ============================================================

face_charges_full = np.zeros(n_faces)

# Face areas (m^2) for full geometry
face_areas_m2_full = pv_mesh.compute_cell_sizes()['Area'] * 1e-6 + 1e-20

q_proton = +1.602e-19
q_electron = -1.602e-19

# Get face centers for full geometry
face_centers_full = pv_mesh.cell_centers().points

# Build KDTree for fast nearest neighbor search
tree = cKDTree(face_centers_full)

# --- Bin electrons to closest face ---
if len(e_pos) > 0:
    _, face_id_e = tree.query(e_pos, k=1)
    unique_faces, counts = np.unique(face_id_e, return_counts=True)
    face_charges_full[unique_faces] += counts * q_electron
    print(f"Binned {len(e_pos)} electrons to {len(unique_faces)} unique faces")

# --- Bin protons to closest face ---
if len(p_pos) > 0:
    _, face_id_p = tree.query(p_pos, k=1)
    unique_faces, counts = np.unique(face_id_p, return_counts=True)
    face_charges_full[unique_faces] += counts * q_proton
    print(f"Binned {len(p_pos)} protons to {len(unique_faces)} unique faces")

# --- Calculate surface charge density on full geometry ---
sigma_per_face_SI_full = face_charges_full / face_areas_m2_full
sigma_per_face_uC_full = sigma_per_face_SI_full * 1e6

# Store charge data on full geometry
pv_mesh.cell_data['charge_density_SI'] = sigma_per_face_SI_full
pv_mesh.cell_data['charge_density_uC'] = sigma_per_face_uC_full

## 3D slice of geometry

In [None]:
# --- 1. SETUP THICKNESS AND BOUNDS ---

THICKNESS = 0.003 # Increased thickness for better visibility
offsetX = 0.025 # Small buffer for span calculations
offsetZ = -0.006 # Small buffer for span calculations
# # add the offset directly to the center
# center = np.array([geometry.centroid[0]+offsetX,geometry.centroid[1], geometry.centroid[2]+offsetZ])

# # Calculate the dimension required for a diagonal cut in the XZ plane.
# x_span = bbox_bounds[1] - bbox_bounds[0]
# y_span = bbox_bounds[3] - bbox_bounds[2]
# z_span = bbox_bounds[5] - bbox_bounds[4]
# # The diagonal is now in the XZ plane
# diagonal_span_xz = np.sqrt(x_span**2 + z_span**2)

# # --- 2. CREATE AND ROTATE THE CLIPPING BOX (XZ Diagonal) ---

# # Create a bounding box that is:
# # - Thin along the X-axis (THICKNESS)
# # - Long along the Z-axis (diagonal_span_xz)
# # - Full span along the Y-axis (y_span)

# clipping_box = pv.Box(bounds=[
#     center[0] - THICKNESS/2, center[0] + THICKNESS/2, # X-span (thin edge)
#     center[1] - y_span/2, center[1] + y_span/2, # Y-span (full height)
#     center[2] - diagonal_span_xz/2, center[2] + diagonal_span_xz/2 # Z-span (long edge)
# ])

# # Move the center of the box to the mesh's centroid
# clipping_box.translate(center - clipping_box.center, inplace=True)

# # Rotate the box by 45 degrees around the Y-axis. 
# # This aligns the thin X-dimension along the Z=-X diagonal in the XZ plane.
# clipping_box.rotate_y(-45, inplace=True)

# # Clip the mesh with the rotated box
# geo_slice = pv_mesh.clip_surface(clipping_box)

# Method 1: Get clipped mesh and plane together (centered)
geo_slice, plane, clipping_box, _ = clip_mesh_with_plane(
    pv_mesh=pv_mesh,
    geometry=geometry,
    bbox_bounds=bbox_bounds,
    angle=45,
    axis='y',
    thickness=THICKNESS,
    offsetX=offsetX,
    offsetZ=offsetZ,
    plane_offset=0.0  # Adjust if plane needs to move along normal
)


# --- VISUALIZATION ---

pl = pv.Plotter()
pl.set_background('white')

# Add the original mesh
pl.add_mesh(
    pv_mesh,
    scalars='proton_illumination', 
    cmap='OrRd',
    clim=[0, 15],
    show_edges=False,
    opacity=1,  # More transparent to see the slice better
    scalar_bar_args={
        'title': 'Particle Hits',
        'vertical': True,
        'position_x': 0.85,
        'position_y': 0.15}
)

# Add the 45-degree slice
pl.add_mesh(
    geo_slice,
    scalars='proton_illumination',
    cmap='OrRd', 
    clim=[0, 15],
    show_edges=True,
    opacity=1.0,
    label='45 Degree Slice'
)


# 5. Add the intersection box (solid cyan)
pl.add_mesh(
    clipping_box, 
    color='cyan', 
    opacity=0.1, 
    show_edges=False,
    edge_color='black',
    line_width=0.1,
    label='Intersection Volume (Clipping Box ^ Bounding Box)'
)

# 5. Add the intersection box (solid cyan)
pl.add_mesh(
    plane, 
    color='k', 
    opacity=1, 
    show_edges=False,
    edge_color='black',
    line_width=0.1,
    label='Intersection Volume (Clipping Box ^ Bounding Box)'
)

# # 5. Add the intersection box (solid cyan)
# pl.add_mesh(
#     intersection, 
#     color='green', 
#     opacity=0.5, 
#     show_edges=False,
#     edge_color='black',
#     label='Intersection Volume (Clipping Box ^ Bounding Box)'
# )

# location closet to G4 orientation
# pl.camera_position = [
#     (0.019695016454329653, -0.3381202307807713, -0.006061249518755622),
#     (0.002166604469844069, 0.0019058251597530163, -0.014826609860567566),
#     (-0.046333590448935605, 0.02335526536874514, 0.9986529577263896)
# ]

# # desired new orientation to match G4 visualization
# pl.camera_position = [(0.11707360681992973, -0.3901610806871454, 0.1307194646773579),
#  (0.009330123662948608, 0.0010192572048619564, -0.008168031981064326),
#  (-0.018642837523090962, 0.3299623014407656, 0.9438100043107198)]

pl.view_xz()

pl.show()

In [None]:
def create_clipping_plane(geometry, bbox_bounds, angle=45, axis='y', thickness=0.025, 
                         offsetX=0.025, offsetZ=-0.006, plane_offset=0.0):
    """
    Create a plane centered and parallel to a clipped box for visualizing cross-sections.
    
    Args:
        geometry: Mesh geometry object with centroid attribute
        bbox_bounds: Bounding box bounds [xmin, xmax, ymin, ymax, zmin, zmax]
        angle: Rotation angle in degrees (default: 45)
        axis: Rotation axis - 'x', 'y', or 'z' (default: 'y')
        thickness: Thickness of the clipping plane (default: 0.025)
        offsetX: X-offset from centroid (default: 0.025)
        offsetZ: Z-offset from centroid (default: -0.006)
        plane_offset: Offset along the plane normal to center it in the box (default: 0.0)
    
    Returns:
        tuple: (clipping_box, plane, center) - the clipping box, plane polydata, and center point
    """
    
    # --- 1. SETUP THICKNESS AND BOUNDS ---
    
    # Calculate center with offsets
    center = np.array([
        geometry.centroid[0] + offsetX,
        geometry.centroid[1], 
        geometry.centroid[2] + offsetZ
    ])
    
    # Calculate spans
    x_span = bbox_bounds[1] - bbox_bounds[0]
    y_span = bbox_bounds[3] - bbox_bounds[2]
    z_span = bbox_bounds[5] - bbox_bounds[4]
    
    # Calculate diagonal span in XZ plane
    diagonal_span_xz = np.sqrt(x_span**2 + z_span**2)
    
    # --- 2. CREATE THE CLIPPING BOX (XZ Diagonal) ---
    
    clipping_box = pv.Box(bounds=[
        center[0] - thickness/2, center[0] + thickness/2,      # X-span (thin edge)
        center[1] - y_span/2, center[1] + y_span/2,            # Y-span (full height)
        center[2] - diagonal_span_xz/2, center[2] + diagonal_span_xz/2  # Z-span (long edge)
    ])
    
    # Move box to mesh centroid
    clipping_box.translate(center - clipping_box.center, inplace=True)
    
    # Rotate the box around specified axis
    if axis.lower() == 'y':
        clipping_box.rotate_y(-angle, inplace=True)
    elif axis.lower() == 'x':
        clipping_box.rotate_x(-angle, inplace=True)
    elif axis.lower() == 'z':
        clipping_box.rotate_z(-angle, inplace=True)
    
    # --- 3. CREATE A PLANE PARALLEL TO THE THIN FACE OF THE CLIPPING BOX ---
    
    # The plane should be parallel to the thin face (X-direction in the unrotated box)
    # So its normal should point along the thin dimension
    if axis.lower() == 'y':
        # For Y-axis rotation, normal points along rotated X-axis
        normal = np.array([
            np.cos(np.radians(-angle)),  # X component
            0,                            # Y component (no change)
            -np.sin(np.radians(-angle))  # Z component
        ])
        # Plane spans the YZ dimensions (in rotated frame)
        # Create in-plane vectors (perpendicular to normal)
        i_vec = np.array([0, 1, 0])  # Y direction
        j_vec = np.array([
            np.sin(np.radians(-angle)),
            0,
            np.cos(np.radians(-angle))
        ])  # Rotated Z direction
        i_size = y_span * 1.1
        j_size = diagonal_span_xz * 1.1
    elif axis.lower() == 'x':
        # For X-axis rotation, normal points along rotated X-axis
        normal = np.array([
            1,                            # X component (no change)
            0,                            # Y component
            0                             # Z component
        ])
        i_vec = np.array([0, 1, 0])
        j_vec = np.array([0, 0, 1])
        i_size = y_span * 1.1
        j_size = diagonal_span_xz * 1.1
    elif axis.lower() == 'z':
        # For Z-axis rotation, normal points along rotated X-axis
        normal = np.array([
            np.cos(np.radians(-angle)),  # X component
            -np.sin(np.radians(-angle)), # Y component
            0                             # Z component (no change)
        ])
        i_vec = np.array([
            np.sin(np.radians(-angle)),
            np.cos(np.radians(-angle)),
            0
        ])
        j_vec = np.array([0, 0, 1])
        i_size = diagonal_span_xz * 1.1
        j_size = y_span * 1.1
    
    # Create a plane lying parallel to the clipping box's thin face
    # Use i_vec and j_vec to properly orient the plane dimensions
    plane = pv.Plane(
        center=clipping_box.center, #center,
        direction=normal,
        i_size=j_size,  # Swapped to fix orientation
        j_size=i_size,  # Swapped to fix orientation
        i_resolution=10,
        j_resolution=10
    )
    
    # Apply offset along the normal direction to center the plane in the box
    if plane_offset != 0.0:
        plane.translate(normal * plane_offset, inplace=True)
    
    return clipping_box, plane, center


def clip_mesh_with_plane(pv_mesh, geometry, bbox_bounds, angle=45, axis='y', 
                        thickness=0.025, offsetX=0.025, offsetZ=-0.006, plane_offset=0.0):
    """
    Clip a mesh using a rotated box and return both the clipped mesh and visualization plane.
    
    Args:
        pv_mesh: PyVista mesh to clip
        geometry: Mesh geometry object with centroid attribute
        bbox_bounds: Bounding box bounds [xmin, xmax, ymin, ymax, zmin, zmax]
        angle: Rotation angle in degrees (default: 45)
        axis: Rotation axis - 'x', 'y', or 'z' (default: 'y')
        thickness: Thickness of the clipping plane (default: 0.025)
        offsetX: X-offset from centroid (default: 0.025)
        offsetZ: Z-offset from centroid (default: -0.006)
        plane_offset: Offset along plane normal to center it (default: 0.0)
    
    Returns:
        tuple: (geo_slice, plane, clipping_box, center)
    """
    
    # Create clipping box and plane
    clipping_box, plane, center = create_clipping_plane(
        geometry, bbox_bounds, angle, axis, thickness, offsetX, offsetZ, plane_offset
    )
    
    # Clip the mesh with the rotated box
    geo_slice = pv_mesh.clip_surface(clipping_box)
    
    return geo_slice, plane, clipping_box, center

In [None]:
clipping_box.center

In [None]:

# Example usage with your existing code

# Assume you have: pv_mesh, geometry, bbox_bounds defined

# Method 1: Get clipped mesh and plane together (centered)
geo_slice, plane, clipping_box, center = clip_mesh_with_plane(
    pv_mesh=pv_mesh,
    geometry=geometry,
    bbox_bounds=bbox_bounds,
    angle=45,
    axis='y',
    thickness=0.025,
    offsetX=0.025,
    offsetZ=-0.006,
    plane_offset=0.0  # Adjust if plane needs to move along normal
)

# Method 2: Just get the plane and clipping box (if you already have geo_slice)
clipping_box, plane, center = create_clipping_plane(
    geometry=geometry,
    bbox_bounds=bbox_bounds,
    angle=45,
    axis='y',
    plane_offset=0.0  # Try adjusting this if plane isn't centered
)

# Visualize the result
plotter = pv.Plotter()
plotter.add_mesh(geo_slice, color='lightblue', opacity=0.8, label='Clipped Mesh')
plotter.add_mesh(plane, color='red', opacity=0.3, label='Clipping Plane')
plotter.add_mesh(clipping_box, color='yellow', style='wireframe', 
                line_width=2, label='Clipping Box')
plotter.add_legend()
plotter.show_axes()
plotter.show()

# Alternative: Use with different rotation axes
# For XY diagonal (rotate around Z-axis)
geo_slice_z, plane_z, _, _ = clip_mesh_with_plane(
    pv_mesh, geometry, bbox_bounds, angle=45, axis='z'
)

# For YZ diagonal (rotate around X-axis)
geo_slice_x, plane_x, _, _ = clip_mesh_with_plane(
    pv_mesh, geometry, bbox_bounds, angle=45, axis='x'
)

## 2D Representation of the Electric Field

In [None]:
# ----------------------------------------------------
## ✅ Step 2: Return all points within the clipping_box
# ----------------------------------------------------

# Use PyVista's clip_box method on the point cloud
points_in_box_cloud = point_cloud.clip_box(clipping_box, invert=False)

# Extract the NumPy arrays for subsequent processing (like downsampling)
points_slice_full = points_in_box_cloud.points
vectors_slice_full = np.column_stack([
    points_in_box_cloud["Ex_val"], 
    points_in_box_cloud["Ey_val"], 
    points_in_box_cloud["Ez_val"]
])
magnitudes_slice_full = points_in_box_cloud["E_mag"]

print(f"Points within clipping box: {len(points_slice_full)}")

In [None]:
# # # ----------------------------------------------------
# # # Voxel Downsampling Helper Function (for XY plane)
# # # ----------------------------------------------------
# def voxel_downsample_points(points, spacing, num_components=2):
#     """
#     Downsamples points that lie on an arbitrary 2D plane embedded in 3D space 
#     by projecting them onto the plane's local coordinates (using PCA) 
#     and then applying 2D voxel binning.

#     Args:
#         points (np.ndarray): Nx3 array of points. Assumed to lie mostly on a 2D plane.
#         spacing (float): Voxel grid spacing for downsampling.
#         num_components (int): Number of principal components to use for projection (must be 2 for 2D slicing).

#     Returns:
#         np.ndarray: Indices of the unique points selected for downsampling.
#     """
#     if len(points) < 2:
#         return np.array([], dtype=int) if len(points) == 0 else np.array([0], dtype=int)

#     # 1. Center the data
#     mean_point = np.mean(points, axis=0)
#     centered_points = points - mean_point

#     # 2. Perform PCA to find the plane's local coordinate system (2 largest eigenvectors)
#     # The covariance matrix calculation finds the variance along all axes.
#     cov_matrix = np.cov(centered_points.T)
#     eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
    
#     # Sort eigenvectors by eigenvalues in descending order to get the two principal directions
#     sorted_indices = np.argsort(eigenvalues)[::-1]
    
#     # Get the 2 eigenvectors corresponding to the largest eigenvalues (these form the plane basis)
#     plane_basis = eigenvectors[:, sorted_indices[:num_components]]
    
#     # 3. Project the 3D points onto the 2D plane (N, 2 array)
#     projected_points = centered_points @ plane_basis
    
#     # 4. Perform 2D Voxel Binning on the projected coordinates
    
#     # Find min projected coordinates (for offset)
#     min_coords = projected_points.min(axis=0)
    
#     # Calculate voxel indices (2D)
#     x_indices = np.floor((projected_points[:, 0] - min_coords[0]) / spacing).astype(int)
#     y_indices = np.floor((projected_points[:, 1] - min_coords[1]) / spacing).astype(int)
    
#     # Create a unique key for each voxel (1D index)
#     max_x_index = x_indices.max() + 1
#     voxel_keys = y_indices * max_x_index + x_indices

#     # Find the index of the first point encountered in each unique voxel
#     _, unique_indices = np.unique(voxel_keys, return_index=True)
    
#     return unique_indices

# ----------------------------------------------------
# Voxel Downsampling Helper Function (3D)
# ----------------------------------------------------
def voxel_downsample_points(points, spacing):
    """
    Downsamples 3D points using voxel grid binning.

    Args:
        points (np.ndarray): Nx3 array of points
        spacing (float): Voxel grid spacing for downsampling

    Returns:
        np.ndarray: Indices of the unique points selected for downsampling.
    """
    if len(points) < 2:
        return np.array([], dtype=int) if len(points) == 0 else np.array([0], dtype=int)

    # Perform 3D Voxel Binning directly on the coordinates

    # Find min coordinates (for offset)
    min_coords = points.min(axis=0)

    # Calculate voxel indices (3D)
    x_indices = np.floor((points[:, 0] - min_coords[0]) / spacing).astype(int)
    y_indices = np.floor((points[:, 1] - min_coords[1]) / spacing).astype(int)
    z_indices = np.floor((points[:, 2] - min_coords[2]) / spacing).astype(int)

    # Create a unique key for each voxel (1D index)
    max_x_index = x_indices.max() + 1
    max_y_index = y_indices.max() + 1
    voxel_keys = z_indices * (max_x_index * max_y_index) + y_indices * max_x_index + x_indices

    # Find the index of the first point encountered in each unique voxel
    _, unique_indices = np.unique(voxel_keys, return_index=True)

    return unique_indices

ARROW_VOXEL_SPACING = 0.01 
VECTOR_SCALE_FACTOR = 1e-8

# ----------------------------------------------------
# Step 2: Vector Field Glyphs (Arrows) - Z component
# ----------------------------------------------------
start_time_vectors = time.time()

unique_indices = voxel_downsample_points(points_slice_full, ARROW_VOXEL_SPACING)

points_slice = points_slice_full[unique_indices]
vectors_slice = vectors_slice_full[unique_indices]
magnitudes_slice = magnitudes_slice_full[unique_indices]

MAGNITUDE_MAX_CLAMP = ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR / 2 
magnitudes_slice_clamped = np.clip(magnitudes_slice, a_min=None, a_max=MAGNITUDE_MAX_CLAMP)

slice_mesh_vectors = pv.PolyData(points_slice)
slice_mesh_vectors['vectors'] = vectors_slice
slice_mesh_vectors['magnitude'] = magnitudes_slice_clamped

arrow = pv.Arrow(tip_length=0.3, tip_radius=0.2, shaft_radius=0.04)
glyphs = slice_mesh_vectors.glyph(
    orient='vectors',
    scale='magnitude',
    factor=VECTOR_SCALE_FACTOR,
    geom=arrow
)

# Interpolate your point data onto this 2D slice
field_slice_interpolated = plane.interpolate(
    points_in_box_cloud,
    sharpness=1.0,
    radius=0.0001,
    null_value=0,
    strategy='closest_point'
)

In [None]:
plane.interpolate??

In [None]:
vmin, vmax = (-2,2)

# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# pl.add_mesh(
#     pv_mesh,
#     cmap="seismic",
#     opacity=0.6,
#     show_edges=False,
#     clim=[vmin, vmax],
#     interpolate_before_map=False,
#     preference="cell"
# )

# # Add the original mesh
# pl.add_mesh(
#     pv_mesh,
#     scalars='charge_density_uC', 
#     cmap='YlGnBu', #'seismic',
#     clim=[vmin, vmax],
#     show_edges=False,
#     opacity=1,  # More transparent to see the slice better
#     # scalar_bar_args={
#     #     'title': 'Particle Hits',
#     #     'vertical': True,
#     #     'position_x': 0.85,
#     #     'position_y': 0.15}
# )


# pl.add_mesh(
#     field_slice_interpolated,
#     scalars="Ez_val",  # Changed to Z component
#     cmap="YlGnBu",
#     opacity=1,
#     show_edges=False,
#     clim=[-2e5, 2e5],
#     scalar_bar_args={
#         'title': None, #'Ez (V/m)',
#         'vertical': False,
#         'position_x': 0.20,
#         'position_y': 0.12,
#         'width': 0.6,
#         'height': 0.05,
#     }
# )


# Add the original mesh
pl.add_mesh(
    geo_slice,
    scalars='electric_pressure', 
    cmap='seismic',
    clim=[vmin, vmax],
    show_edges=False,
    opacity=1,  # More transparent to see the slice better
    # scalar_bar_args={
    #     'title': 'Particle Hits',
    #     'vertical': True,
    #     'position_x': 0.85,
    #     'position_y': 0.15}
)

#pl.add_mesh(geo_slice, color="black", line_width=3, opacity=1)
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=5, opacity=1)

#pl.add_mesh(clipping_box, color='green', opacity=0.1, show_edges=True)

# # desired new orientation to match G4 visualization
# pl.camera_position = [(0.023504925144732475, 0.004486350775623978, 0.41194153528845395),
#  (0.002166604469844069, 0.0010035066774800594, -6.701385451800268e-06),
#  (-0.9984056206000426, 0.023060878700134766, 0.05152099210827938)]
pl.camera_position = [(0.2405486054033282, -0.014129531514704384, 0.33634817371357223),
 (0.002166604469844069, 0.0009990661314442928, -1.9069368376797846e-06),
 (-0.8157505803865249, 0.005166285560794769, 0.5783807570213173)]

pl.enable_parallel_projection()
#pl.enable_2d_style()
pl.view_xz()

# pl.screenshot(f'figures/wang_{configIN}#{iteration_SW}.jpeg', scale=4)

print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show()

In [None]:
pv_mesh

In [None]:
pl.camera_position

## Planar slice of geometry

In [None]:
# --- 1. SETUP THICKNESS AND BOUNDS ---

THICKNESS = 0.001 # Increased thickness for better visibility
offsetX = 0.025 # Small buffer for span calculations
offsetZ = -0.006 # Small buffer for span calculations
# add the offset directly to the center
center = np.array([geometry.centroid[0]+offsetX,geometry.centroid[1], geometry.centroid[2]+offsetZ])

# Calculate the dimension required for a diagonal cut in the XZ plane.
x_span = bbox_bounds[1] - bbox_bounds[0]
y_span = bbox_bounds[3] - bbox_bounds[2]
z_span = bbox_bounds[5] - bbox_bounds[4]
# The diagonal is now in the XZ plane
diagonal_span_xz = np.sqrt(x_span**2 + z_span**2)

# --- 2. CREATE AND ROTATE THE CLIPPING BOX (XZ Diagonal) ---

# Create a bounding box that is:
# - Thin along the X-axis (THICKNESS)
# - Long along the Z-axis (diagonal_span_xz)
# - Full span along the Y-axis (y_span)

clipping_box = pv.Box(bounds=[
    center[0] - THICKNESS/2, center[0] + THICKNESS/2, # X-span (thin edge)
    center[1] - y_span/2, center[1] + y_span/2, # Y-span (full height)
    center[2] - diagonal_span_xz/2, center[2] + diagonal_span_xz/2 # Z-span (long edge)
])

# Move the center of the box to the mesh's centroid
clipping_box.translate(center - clipping_box.center, inplace=True)

# Rotate the box by 45 degrees around the Y-axis. 
# This aligns the thin X-dimension along the Z=-X diagonal in the XZ plane.
clipping_box.rotate_y(-45, inplace=True)

# Clip the mesh with the rotated box
geo_slice = pv_mesh.clip_surface(clipping_box)

# --- VISUALIZATION ---

pl = pv.Plotter()
pl.set_background('white')

# Add the original mesh
pl.add_mesh(
    pv_mesh,
    scalars='proton_illumination', 
    cmap='OrRd',
    clim=[0, 15],
    show_edges=False,
    opacity=1,  # More transparent to see the slice better
    scalar_bar_args={
        'title': 'Particle Hits',
        'vertical': True,
        'position_x': 0.85,
        'position_y': 0.15}
)

# Add the 45-degree slice
pl.add_mesh(
    geo_slice,
    scalars='proton_illumination',
    cmap='OrRd', 
    clim=[0, 15],
    show_edges=True,
    opacity=1.0,
    label='45 Degree Slice'
)


# # 5. Create a new box object matching those bounds
bounding_world = pv.Box(bounds=bbox_bounds)

# triangulate meshes before boolean operation
bounding_world_tri = bounding_world.triangulate()
clipping_box_tri = clipping_box.triangulate()

# Now perform the boolean intersection
intersection = bounding_world_tri.boolean_intersection(clipping_box_tri).clean(tolerance=1e-6)

# 5. Add the intersection box (solid cyan)
pl.add_mesh(
    clipping_box, 
    color='cyan', 
    opacity=0.1, 
    show_edges=False,
    edge_color='black',
    line_width=0.1,
    label='Intersection Volume (Clipping Box ^ Bounding Box)'
)

# # 5. Add the intersection box (solid cyan)
# pl.add_mesh(
#     intersection, 
#     color='green', 
#     opacity=0.5, 
#     show_edges=False,
#     edge_color='black',
#     label='Intersection Volume (Clipping Box ^ Bounding Box)'
# )

# location closet to G4 orientation
# pl.camera_position = [
#     (0.019695016454329653, -0.3381202307807713, -0.006061249518755622),
#     (0.002166604469844069, 0.0019058251597530163, -0.014826609860567566),
#     (-0.046333590448935605, 0.02335526536874514, 0.9986529577263896)
# ]

# # desired new orientation to match G4 visualization
# pl.camera_position = [(0.11707360681992973, -0.3901610806871454, 0.1307194646773579),
#  (0.009330123662948608, 0.0010192572048619564, -0.008168031981064326),
#  (-0.018642837523090962, 0.3299623014407656, 0.9438100043107198)]

pl.view_xz()

pl.show()

## 2D Representation of the Electric Field

In [None]:
# ----------------------------------------------------
## ✅ Step 2: Return all points within the clipping_box
# ----------------------------------------------------

# Use PyVista's clip_box method on the point cloud
points_in_box_cloud = point_cloud.clip_box(clipping_box, invert=False)

# Extract the NumPy arrays for subsequent processing (like downsampling)
points_slice_full = points_in_box_cloud.points
vectors_slice_full = np.column_stack([
    points_in_box_cloud["Ex_val"], 
    points_in_box_cloud["Ey_val"], 
    points_in_box_cloud["Ez_val"]
])
magnitudes_slice_full = points_in_box_cloud["E_mag"]

print(f"Points within clipping box: {len(points_slice_full)}")

In [None]:
# # # ----------------------------------------------------
# # # Voxel Downsampling Helper Function (for XY plane)
# # # ----------------------------------------------------
# def voxel_downsample_points(points, spacing, num_components=2):
#     """
#     Downsamples points that lie on an arbitrary 2D plane embedded in 3D space 
#     by projecting them onto the plane's local coordinates (using PCA) 
#     and then applying 2D voxel binning.

#     Args:
#         points (np.ndarray): Nx3 array of points. Assumed to lie mostly on a 2D plane.
#         spacing (float): Voxel grid spacing for downsampling.
#         num_components (int): Number of principal components to use for projection (must be 2 for 2D slicing).

#     Returns:
#         np.ndarray: Indices of the unique points selected for downsampling.
#     """
#     if len(points) < 2:
#         return np.array([], dtype=int) if len(points) == 0 else np.array([0], dtype=int)

#     # 1. Center the data
#     mean_point = np.mean(points, axis=0)
#     centered_points = points - mean_point

#     # 2. Perform PCA to find the plane's local coordinate system (2 largest eigenvectors)
#     # The covariance matrix calculation finds the variance along all axes.
#     cov_matrix = np.cov(centered_points.T)
#     eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
    
#     # Sort eigenvectors by eigenvalues in descending order to get the two principal directions
#     sorted_indices = np.argsort(eigenvalues)[::-1]
    
#     # Get the 2 eigenvectors corresponding to the largest eigenvalues (these form the plane basis)
#     plane_basis = eigenvectors[:, sorted_indices[:num_components]]
    
#     # 3. Project the 3D points onto the 2D plane (N, 2 array)
#     projected_points = centered_points @ plane_basis
    
#     # 4. Perform 2D Voxel Binning on the projected coordinates
    
#     # Find min projected coordinates (for offset)
#     min_coords = projected_points.min(axis=0)
    
#     # Calculate voxel indices (2D)
#     x_indices = np.floor((projected_points[:, 0] - min_coords[0]) / spacing).astype(int)
#     y_indices = np.floor((projected_points[:, 1] - min_coords[1]) / spacing).astype(int)
    
#     # Create a unique key for each voxel (1D index)
#     max_x_index = x_indices.max() + 1
#     voxel_keys = y_indices * max_x_index + x_indices

#     # Find the index of the first point encountered in each unique voxel
#     _, unique_indices = np.unique(voxel_keys, return_index=True)
    
#     return unique_indices

# ----------------------------------------------------
# Voxel Downsampling Helper Function (3D)
# ----------------------------------------------------
def voxel_downsample_points(points, spacing):
    """
    Downsamples 3D points using voxel grid binning.

    Args:
        points (np.ndarray): Nx3 array of points
        spacing (float): Voxel grid spacing for downsampling

    Returns:
        np.ndarray: Indices of the unique points selected for downsampling.
    """
    if len(points) < 2:
        return np.array([], dtype=int) if len(points) == 0 else np.array([0], dtype=int)

    # Perform 3D Voxel Binning directly on the coordinates

    # Find min coordinates (for offset)
    min_coords = points.min(axis=0)

    # Calculate voxel indices (3D)
    x_indices = np.floor((points[:, 0] - min_coords[0]) / spacing).astype(int)
    y_indices = np.floor((points[:, 1] - min_coords[1]) / spacing).astype(int)
    z_indices = np.floor((points[:, 2] - min_coords[2]) / spacing).astype(int)

    # Create a unique key for each voxel (1D index)
    max_x_index = x_indices.max() + 1
    max_y_index = y_indices.max() + 1
    voxel_keys = z_indices * (max_x_index * max_y_index) + y_indices * max_x_index + x_indices

    # Find the index of the first point encountered in each unique voxel
    _, unique_indices = np.unique(voxel_keys, return_index=True)

    return unique_indices

ARROW_VOXEL_SPACING = 0.01 
VECTOR_SCALE_FACTOR = 1e-8

# ----------------------------------------------------
# Step 2: Vector Field Glyphs (Arrows) - Z component
# ----------------------------------------------------
start_time_vectors = time.time()

unique_indices = voxel_downsample_points(points_slice_full, ARROW_VOXEL_SPACING)

points_slice = points_slice_full[unique_indices]
vectors_slice = vectors_slice_full[unique_indices]
magnitudes_slice = magnitudes_slice_full[unique_indices]

MAGNITUDE_MAX_CLAMP = ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR / 2 
magnitudes_slice_clamped = np.clip(magnitudes_slice, a_min=None, a_max=MAGNITUDE_MAX_CLAMP)

slice_mesh_vectors = pv.PolyData(points_slice)
slice_mesh_vectors['vectors'] = vectors_slice
slice_mesh_vectors['magnitude'] = magnitudes_slice_clamped

arrow = pv.Arrow(tip_length=0.3, tip_radius=0.2, shaft_radius=0.04)
glyphs = slice_mesh_vectors.glyph(
    orient='vectors',
    scale='magnitude',
    factor=VECTOR_SCALE_FACTOR,
    geom=arrow
)

In [None]:
# # Create a slice through your original geometry using the box orientation
# bounds = clipping_box.bounds
# plane_center = [
#     (bounds[0] + bounds[1]) / 2,
#     (bounds[2] + bounds[3]) / 2,
#     (bounds[4] + bounds[5]) / 2
# ]

# normal = [0, 0, 1]

# field_slice_mesh = pv.Plane(
#     center=plane_center,  # <-- FIXED: Use explicit center at Y_SLICE
#     direction=normal,
#     j_size=bbox_bounds[1] - bbox_bounds[0],
#     i_size=bbox_bounds[5] - bbox_bounds[4],
#     i_resolution=250, 
#     j_resolution=250
# )

# print(f"2D slice points: {field_slice_mesh.n_points}")
# print(f"2D slice cells: {field_slice_mesh.n_cells}")

# # Interpolate your point data onto this 2D slice
# field_slice_interpolated = field_slice_mesh.interpolate(
#     point_cloud,
#     sharpness=3.0,
#     radius=0.001,
#     null_value=0,
#     strategy='closest_point'
# )

In [None]:
vmin, vmax = (-2,2)

# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# pl.add_mesh(
#     pv_mesh,
#     cmap="seismic",
#     opacity=0.6,
#     show_edges=False,
#     clim=[vmin, vmax],
#     interpolate_before_map=False,
#     preference="cell"
# )

# # Add the original mesh
# pl.add_mesh(
#     pv_mesh,
#     scalars='charge_density_uC', 
#     cmap='YlGnBu', #'seismic',
#     clim=[vmin, vmax],
#     show_edges=False,
#     opacity=1,  # More transparent to see the slice better
#     # scalar_bar_args={
#     #     'title': 'Particle Hits',
#     #     'vertical': True,
#     #     'position_x': 0.85,
#     #     'position_y': 0.15}
# )


# pl.add_mesh(
#     field_slice_interpolated,
#     scalars="Ex_val",  # Changed to Z component
#     cmap="YlGnBu",
#     opacity=1,
#     show_edges=False,
#     clim=[-2e5, 2e5],
#     scalar_bar_args={
#         'title': None, #'Ez (V/m)',
#         'vertical': False,
#         'position_x': 0.20,
#         'position_y': 0.12,
#         'width': 0.6,
#         'height': 0.05,
#     }
# )


# Add the original mesh
pl.add_mesh(
    geo_slice,
    scalars='electric_pressure', 
    cmap='seismic',
    clim=[vmin, vmax],
    show_edges=False,
    opacity=1,  # More transparent to see the slice better
    # scalar_bar_args={
    #     'title': 'Particle Hits',
    #     'vertical': True,
    #     'position_x': 0.85,
    #     'position_y': 0.15}
)

#pl.add_mesh(geo_slice, color="black", line_width=3, opacity=1)
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=5, opacity=1)

#pl.add_mesh(clipping_box, color='green', opacity=0.1, show_edges=True)

# # desired new orientation to match G4 visualization
# pl.camera_position = [(0.023504925144732475, 0.004486350775623978, 0.41194153528845395),
#  (0.002166604469844069, 0.0010035066774800594, -6.701385451800268e-06),
#  (-0.9984056206000426, 0.023060878700134766, 0.05152099210827938)]
pl.camera_position = [(0.2405486054033282, -0.014129531514704384, 0.33634817371357223),
 (0.002166604469844069, 0.0009990661314442928, -1.9069368376797846e-06),
 (-0.8157505803865249, 0.005166285560794769, 0.5783807570213173)]

pl.enable_parallel_projection()
#pl.enable_2d_style()
pl.view_xz()

# pl.screenshot(f'figures/wang_{configIN}#{iteration_SW}.jpeg', scale=4)

print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show()

In [None]:
pv_mesh

## Attempt to draw streamlines

In [None]:

# Final visualization
pl = pv.Plotter()
pl.set_background('white')

# Original geometry with illumination
pl.add_mesh(
    pv_mesh,
    scalars='illumination' if 'illumination' in pv_mesh.cell_data else None,
    cmap='OrRd',
    clim=[0, 15],
    show_edges=False,
    opacity=0.7,
    scalar_bar_args={'title': 'Particle Hits'}
)

# Seed points
pl.add_mesh(
    seed_points,
    color='red',
    point_size=8,
    render_points_as_spheres=True,
    label='Seed Points'
)

# Streamlines
if streamlines.n_points > 0:
    # Compute velocity magnitude for coloring
    if 'vectors' in streamlines.array_names:
        velocity = np.linalg.norm(streamlines['vectors'], axis=1)
        streamlines['velocity'] = velocity
    
    tube = streamlines.tube(radius=0.02, scalars='velocity')
    pl.add_mesh(
        tube,
        cmap='viridis',
        show_scalar_bar=True,
        scalar_bar_args={'title': 'Velocity'},
        label='Streamlines'
    )

pl.add_legend()
pl.show()

# Usage

vectors_slice_full = np.column_stack([
    point_cloud["Ex_val"], 
    point_cloud["Ey_val"], 
    point_cloud["Ez_val"]
])

point_cloud.cell_data["vectors"]=vectors_slice_full
streamlines, seed_points = create_streamlines_from_faces(
    pv_mesh,  # Your geometry mesh
    point_cloud,     # Your vector field dataset
    n_seeds_per_face=1,
    max_time=150.0
)


# Plot
pl = pv.Plotter()
# pl.add_mesh(pv_mesh, color='lightgray', opacity=0.3, show_edges=True)
# pl.add_mesh(seed_points, color='red', point_size=10, render_points_as_spheres=True)

# Add the original mesh
pl.add_mesh(
    pv_mesh,
    scalars='electric_pressure', 
    cmap='seismic',
    clim=[vmin, vmax],
    show_edges=False,
    opacity=1,  # More transparent to see the slice better
    # scalar_bar_args={
    #     'title': 'Particle Hits',
    #     'vertical': True,
    #     'position_x': 0.85,
    #     'position_y': 0.15}
)

if streamlines.n_points > 0:
    # Color streamlines by some scalar (e.g., velocity magnitude)
    if 'vectors' in streamlines.array_names:
        magnitudes = np.linalg.norm(streamlines['vectors'], axis=1)
        streamlines['velocity'] = magnitudes
    
    tube = streamlines.tube(radius=0.05, scalars='velocity')
    pl.add_mesh(tube, cmap='plasma', show_scalar_bar=True)

pl.show()

In [None]:
import numpy as np
import pyvista as pv

def create_proper_vector_field(points, vectors, geometry, method='structured'):
    """
    Create a proper vector field for streamline integration.
    """
    bounds = geometry.bounds
    padding = 0.1 * max(bounds[1]-bounds[0], bounds[3]-bounds[2], bounds[5]-bounds[4])
    
    if method == 'structured':
        # Create structured grid
        x, y, z = np.mgrid[
            bounds[0]-padding:bounds[1]+padding:20j,
            bounds[2]-padding:bounds[3]+padding:20j,
            bounds[4]-padding:bounds[5]+padding:15j
        ]
        grid = pv.StructuredGrid(x, y, z)
        
    elif method == 'rectilinear':
        # Create rectilinear grid (alternative to UniformGrid)
        x = np.linspace(bounds[0]-padding, bounds[1]+padding, 25)
        y = np.linspace(bounds[2]-padding, bounds[3]+padding, 25)
        z = np.linspace(bounds[4]-padding, bounds[5]+padding, 15)
        grid = pv.RectilinearGrid(x, y, z)
    
    # Create point cloud from your data
    cloud = pv.PolyData(points)
    cloud['vectors'] = vectors
    
    print(f"Cloud points: {cloud.n_points}")
    print(f"Cloud vectors shape: {cloud['vectors'].shape}")
    print(f"Grid points: {grid.n_points}")
    print(f"Vector magnitude range in cloud: [{np.linalg.norm(vectors, axis=1).min():.6f}, {np.linalg.norm(vectors, axis=1).max():.6f}]")
    
    # Interpolate onto grid
    try:
        interpolated = cloud.interpolate(grid, 'vectors')
        grid['vectors'] = interpolated['vectors']
        
        # Fill any remaining NaN values with zeros
        if np.any(np.isnan(grid['vectors'])):
            nan_mask = np.any(np.isnan(grid['vectors']), axis=1)
            grid['vectors'][nan_mask] = [0, 0, 0]
            print(f"Fixed {np.sum(nan_mask)} NaN vectors")
            
    except Exception as e:
        print(f"Interpolation failed: {e}")
        # Create a simple fallback vector field
        grid_points = grid.points
        center = np.mean(grid_points, axis=0)
        fallback_vectors = np.zeros_like(grid_points)
        for i in range(len(grid_points)):
            r = grid_points[i] - center
            # Create a vortex-like field
            fallback_vectors[i] = [-r[1], r[0], 0.1 * r[2]]
        grid['vectors'] = fallback_vectors
    
    grid.set_active_vectors('vectors')
    
    # Verify the vector field
    magnitudes = np.linalg.norm(grid['vectors'], axis=1)
    print(f"Final vector magnitude range: [{magnitudes.min():.6f}, {magnitudes.max():.6f}]")
    print(f"Non-zero vectors: {np.sum(magnitudes > 1e-10)} / {len(magnitudes)}")
    
    return grid

# Create proper vector field using structured grid
vector_grid = create_proper_vector_field(points_slice_full, vectors_slice_full, pv_mesh, method='structured')

In [None]:
import numpy as np
import pyvista as pv

def get_seeds_from_high_illumination(geometry, num_seeds=30, illumination_threshold=None, 
                                   use_top_faces=True, min_seeds_per_face=1, max_seeds_per_face=3):
    """
    Generate seed points from faces with high illumination values.
    
    Args:
        geometry: PyVista mesh with illumination data
        num_seeds: Total number of seed points to generate
        illumination_threshold: Minimum illumination value (if None, use top faces)
        use_top_faces: If True, select top N illuminated faces
        min_seeds_per_face: Minimum seeds per selected face
        max_seeds_per_face: Maximum seeds per selected face
    """
    
    if 'proton_illumination' not in geometry.cell_data:
        print("No illumination data found! Using all faces.")
        face_centers = geometry.cell_centers().points
        if len(face_centers) > num_seeds:
            indices = np.random.choice(len(face_centers), num_seeds, replace=False)
            seed_points = face_centers[indices]
        else:
            seed_points = face_centers
        return pv.PolyData(seed_points)
    
    illumination = geometry.cell_data['proton_illumination']
    face_centers = geometry.cell_centers().points
    
    print(f"Illumination range: [{illumination.min()}, {illumination.max()}]")
    print(f"Total faces: {len(illumination)}")
    
    if illumination_threshold is not None:
        # Use faces above threshold
        high_illum_mask = illumination >= illumination_threshold
        high_illum_faces = face_centers[high_illum_mask]
        high_illum_values = illumination[high_illum_mask]
        
        print(f"Faces above threshold {illumination_threshold}: {len(high_illum_faces)}")
        
    else:
        # Use top N illuminated faces
        if use_top_faces:
            # Select top faces by illumination value
            n_top_faces = min(num_seeds // 2, len(illumination))
            top_indices = np.argsort(illumination)[-n_top_faces:]
            high_illum_faces = face_centers[top_indices]
            high_illum_values = illumination[top_indices]
            
            print(f"Selected top {n_top_faces} illuminated faces")
            print(f"Top face illumination range: [{high_illum_values.min()}, {high_illum_values.max()}]")
        else:
            # Use probability weighted by illumination
            high_illum_faces = face_centers
            high_illum_values = illumination
    
    if len(high_illum_faces) == 0:
        print("No high illumination faces found! Using all faces.")
        high_illum_faces = face_centers
        high_illum_values = np.ones(len(face_centers))
    
    # Generate seeds proportional to illumination
    seed_points = []
    
    if use_top_faces and illumination_threshold is None:
        # For top faces, distribute seeds evenly among them
        seeds_per_face = max(min_seeds_per_face, 
                           min(max_seeds_per_face, num_seeds // len(high_illum_faces)))
        
        for face_center in high_illum_faces:
            for j in range(seeds_per_face):
                offset = np.random.normal(0, 0.005, 3)  # Small random offset
                seed_points.append(face_center + offset)
                
    else:
        # Weight by illumination value
        weights = high_illum_values / high_illum_values.sum()
        n_seeds_per_face = np.random.choice(
            [min_seeds_per_face, max_seeds_per_face], 
            size=len(high_illum_faces), 
            p=[0.3, 0.7]  # Prefer more seeds per face
        )
        
        # Adjust to match total seed count
        total_seeds = np.sum(n_seeds_per_face)
        if total_seeds > num_seeds * 1.5:
            # Scale down
            scale_factor = num_seeds / total_seeds
            n_seeds_per_face = np.maximum(min_seeds_per_face, 
                                        (n_seeds_per_face * scale_factor).astype(int))
        
        for i, (face_center, n_seeds) in enumerate(zip(high_illum_faces, n_seeds_per_face)):
            for j in range(n_seeds):
                offset = np.random.normal(0, 0.005, 3)
                seed_points.append(face_center + offset)
    
    seed_points = np.array(seed_points)
    
    # If we have too many seeds, randomly select
    if len(seed_points) > num_seeds:
        selected_indices = np.random.choice(len(seed_points), num_seeds, replace=False)
        seed_points = seed_points[selected_indices]
    
    print(f"Generated {len(seed_points)} seed points from high illumination faces")
    
    return pv.PolyData(seed_points)

# Usage examples:

# Method A: Top 20 illuminated faces
seeds_top = get_seeds_from_high_illumination(
    pv_mesh, 
    num_seeds=40,
    illumination_threshold=None,  # Use top faces
    use_top_faces=True,
    min_seeds_per_face=1,
    max_seeds_per_face=3
)


In [None]:

# Create proper vector field using structured grid
vector_grid = create_proper_vector_field(points_slice_full, vectors_slice_full, pv_mesh, method='structured')

seeds_top = get_seeds_from_high_illumination(
    pv_mesh, 
    num_seeds=40,
    illumination_threshold=None,  # Use top faces
    use_top_faces=True,
    min_seeds_per_face=1,
    max_seeds_per_face=3
)


streamlines = vector_grid.streamlines_from_source(
            seeds_top,
            vectors='vectors',
            max_time=200.0,
            integration_direction='forward',
            initial_step_length=0.05,
            max_step_length=0.5,
            terminal_speed=1e-8
        )


In [None]:

pl = pv.Plotter()
pl.set_background('white')

# Plot geometry with illumination
pl.add_mesh(
    pv_mesh,
    scalars='proton_illumination',
    cmap='OrRd',
    clim=[0, 15],
    show_edges=False,
    opacity=0.6
)

# Plot seed points
pl.add_mesh(
    seeds_top,
    color='red',
    point_size=8,
    render_points_as_spheres=True,
    label='Seed Points'
)

# Plot streamlines
if streamlines.n_points > 0:
    tube = streamlines.tube(radius=0.00015)
    pl.add_mesh(tube, color='blue', label='Streamlines')

pl.show()

In [None]:
fieldIN = df_SW[iteration_SW]
vmin, vmax = (-2, 2) 
 
geometry_center = geometry.centroid
 
# ----------------------------------------------------
# Step 0: Filter data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]
 
initial_mask = (magnitudes > 0) & (points[:,1]>=-0.1+geometry_center[1]) & (points[:,1]<=0.1+geometry_center[1])
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]
 
epsilon_0 = 8.854187817e-12
 
# Create field cloud
field_cloud = pv.PolyData(points)
field_cloud["E_x"] = vectors[:, 0]
field_cloud["E_y"] = vectors[:, 1]
field_cloud["E_z"] = vectors[:, 2]
 
print(f"Starting points (filtered): {len(points)}")
 
# ----------------------------------------------------
# Step 1: Geometry Setup
# ----------------------------------------------------
start_time_geo = time.time()
 
pv_spheres = pv.PolyData(
    geometry.vertices,
    np.hstack([np.full((len(geometry.faces), 1), 3), geometry.faces])
).compute_normals()
 
bbox_bounds = field_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)

pv_spheres_cropped = pv_spheres

# ============================================================
# Interpolate field to face centers
# ============================================================
face_centers = pv_spheres_cropped.cell_centers().points
 
face_field_cloud = pv.PolyData(face_centers)
face_field_interp = face_field_cloud.interpolate(
    field_cloud,
    radius=0.002,
    strategy='closest_point',
    sharpness=3.0,
    null_value=0.0
)
 
# Extract face-centered E-fields
E_x_faces = face_field_interp["E_x"]
E_y_faces = face_field_interp["E_y"]
E_z_faces = face_field_interp["E_z"]
E_vec_faces = np.stack([E_x_faces, E_y_faces, E_z_faces], axis=1)

# Face normals
face_normals = pv_spheres_cropped.cell_normals
nx = face_normals[:, 0]   # <-- x-direction component
ny = face_normals[:, 1]   # <-- y-direction component
nz = face_normals[:, 2]   # <-- z-direction component
 
# ============================================================
# Compute Maxwell electric pressure (normal)
# ============================================================
E_dot_n = np.einsum('ij,ij->i', E_vec_faces, face_normals)
E_mag_sq = np.einsum('ij,ij->i', E_vec_faces, E_vec_faces)
 
# Normal pressure (scalar)
P_normal_faces = epsilon_0 * (E_dot_n**2 - 0.5 * E_mag_sq)
# Compute X-directed electric pressure
P_x_faces = P_normal_faces * nx   # <-- THIS IS WHAT YOU WANTED
P_y_faces = P_normal_faces * ny   # <-- THIS IS WHAT YOU WANTED
P_z_faces = P_normal_faces * nz   # <-- THIS IS WHAT YOU WANTED

print(f"Computed electric pressure in {time.time() - start_time_geo:.2f}s")


In [None]:
# Save to cell data
pv_spheres_cropped.cell_data["electric_pressure"] = P_normal_faces
 
# ============================================================
# Plotting
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)

## Define Z_SLICE plane parameters
X_SLICE = 0.03 + geometry.centroid[0]
normal = [1, 0, 0]  # Z-normal for XY plane

# Center of the plane
plane_center = [X_SLICE, 
    (bbox_bounds[2] + bbox_bounds[3]) / 2,
    (bbox_bounds[4] + bbox_bounds[5]) / 2,
]

# Create a visual plane at Z_SLICE location (instead of geo_slice)
slice_plane = pv.Plane(
    center=plane_center,
    direction=normal,
    i_size=bbox_bounds[2] - bbox_bounds[3],  # X extent
    j_size=bbox_bounds[4] - bbox_bounds[5],  # Y extent
    i_resolution=2,  # Low resolution for simple plane
    j_resolution=2
)

# # Add the plane at Z_SLICE instead of geo_slice
# pl.add_mesh(
#     slice_plane, 
#     color="blue",  # or "black", "blue", etc.
#     opacity=0.3,        # Semi-transparent
#     show_edges=True,
#     edge_color=False,
#     line_width=2
# )

pl.camera_position = [
    (0.019695016454329653, -0.3381202307807713, -0.006061249518755622),
    (0.002166604469844069, 0.0019058251597530163, -0.014826609860567566),
    (-0.046333590448935605, 0.02335526536874514, 0.9986529577263896)
]
 
pl.show()

In [None]:
# Save to cell data
pv_spheres_cropped.cell_data["electric_pressure"] = P_normal_faces
 
# ============================================================
# Plotting
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)

for XIN in np.linspace(-0.04,0.04,10):
    # Define YZ plane parameters (slice perpendicular to X-axis)
    X_SLICE = XIN + geometry.centroid[0]
    normal = [1, 0, 0]  # X-normal for YZ plane

    # Center of the plane
    plane_center = [
        X_SLICE,
        (bbox_bounds[2] + bbox_bounds[3]) / 2,  # Y center
        (bbox_bounds[4] + bbox_bounds[5]) / 2,  # Z center
    ]

    # Create a visual plane at X_SLICE location
    slice_plane = pv.Plane(
        center=plane_center,
        direction=normal,
        i_size=abs(bbox_bounds[3] - bbox_bounds[2]),  # Y extent
        j_size=abs(bbox_bounds[5] - bbox_bounds[4]),  # Z extent
        i_resolution=2,
        j_resolution=2
    )

        # Optional: Create an actual slice of the data at this plane
    slice_data = pv_spheres_cropped.slice(normal=normal, origin=plane_center)
    # pl.add_mesh(
    #     slice_data,
    #     scalars="electric_pressure",
    #     cmap="seismic",
    #     clim=[vmin, vmax],
    #     show_edges=True,
    #     line_width=1
    # )


    # Add the YZ plane visualization
    pl.add_mesh(
        slice_plane, 
        color="blue",
        opacity=0.3,
        show_edges=True,
        line_width=2
    )

pl.camera_position = [
    (0.019695016454329653, -0.3381202307807713, -0.006061249518755622),
    (0.002166604469844069, 0.0019058251597530163, -0.014826609860567566),
    (-0.046333590448935605, 0.02335526536874514, 0.9986529577263896)
]
 
pl.show()

In [None]:
bbox_bounds[1]

In [None]:
vmin

In [None]:
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[-3,3],
    interpolate_before_map=False,
    preference="cell"
)

# Add the plane at Z_SLICE instead of geo_slice
# pl.add_mesh(
#     slice_plane, 
#     color="lightgray",  # or "black", "blue", etc.
#     opacity=0.3,        # Semi-transparent
#     show_edges=True,
#     edge_color="black",
#     line_width=2
# )

pl.camera_position = [
    (0.019695016454329653, -0.3381202307807713, -0.006061249518755622),
    (0.002166604469844069, 0.0019058251597530163, -0.014826609860567566),
    (-0.046333590448935605, 0.02335526536874514, 0.9986529577263896)
]
 
pl.show()

In [None]:
# ============================================================
# Plotting using Px
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure_x",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[-1, 1],
    interpolate_before_map=False,
    preference="cell"
)
 
pl.view_xy()
pl.show()

In [None]:
fieldIN = df_SW[iteration_SW]
vmin, vmax = (-1, 1) 
 
geometry_center = geometry.centroid
red_point = np.array([-0.1, 0., 0.1 + 0.037]) + geometry_center
 
# ----------------------------------------------------
# Step 0: Filter data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]
 
initial_mask = (magnitudes > 0) & (points[:,1]>=-0.1+geometry_center[1]) & (points[:,1]<=0.1+geometry_center[1])
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]
 
epsilon_0 = 8.854187817e-12
 
# Create field cloud
field_cloud = pv.PolyData(points)
field_cloud["E_x"] = vectors[:, 0]
field_cloud["E_y"] = vectors[:, 1]
field_cloud["E_z"] = vectors[:, 2]
 
print(f"Starting points (filtered): {len(points)}")
 
# ----------------------------------------------------
# Step 1: Geometry Setup
# ----------------------------------------------------
start_time_geo = time.time()
 
pv_spheres = pv.PolyData(
    geometry.vertices,
    np.hstack([np.full((len(geometry.faces), 1), 3), geometry.faces])
).compute_normals()
 
bbox_bounds = field_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = (
    pv_spheres
    .clip_box(bbox, invert=False)
    .extract_surface()
    .compute_normals(point_normals=True, cell_normals=True, inplace=False)
)
 
# ============================================================
# Interpolate field to face centers
# ============================================================
face_centers = pv_spheres_cropped.cell_centers().points
 
face_field_cloud = pv.PolyData(face_centers)
face_field_interp = face_field_cloud.interpolate(
    field_cloud,
    radius=0.002,
    strategy='closest_point',
    sharpness=3.0,
    null_value=0.0
)
 
# Extract face-centered E-fields
E_x_faces = face_field_interp["E_x"]
E_y_faces = face_field_interp["E_y"]
E_z_faces = face_field_interp["E_z"]
E_vec_faces = np.stack([E_x_faces, E_y_faces, E_z_faces], axis=1)
 
# Face normals
face_normals = pv_spheres_cropped.cell_normals
nx = face_normals[:, 0]   # <-- x-direction component
ny = face_normals[:, 1]   # <-- y-direction component
nz = face_normals[:, 2]   # <-- z-direction component
 
# ============================================================
# Compute Maxwell electric pressure (normal)
# ============================================================
E_dot_n = np.einsum('ij,ij->i', E_vec_faces, face_normals)
E_mag_sq = np.einsum('ij,ij->i', E_vec_faces, E_vec_faces)
 
# Normal pressure (scalar)
P_normal_faces = epsilon_0 * (E_dot_n**2 - 0.5 * E_mag_sq)
# Compute X-directed electric pressure
P_x_faces = P_normal_faces * nx   # <-- THIS IS WHAT YOU WANTED
P_y_faces = P_normal_faces * ny   # <-- THIS IS WHAT YOU WANTED
P_z_faces = P_normal_faces * nz   # <-- THIS IS WHAT YOU WANTED

In [None]:
# Save to cell data
pv_spheres_cropped.cell_data["electric_pressure"] = P_normal_faces #P_normal_faces
 
print(f"Computed x-directed electric pressure in {time.time() - start_time_geo:.2f}s")
 
# ============================================================
# Plotting using Px
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)
 
pl.view_xy()
pl.show() #jupyter_backend='static'
 

In [None]:
# Save to cell data
pv_spheres_cropped.cell_data["electric_pressure"] = P_normal_faces
 
print(f"Computed x-directed electric pressure in {time.time() - start_time_geo:.2f}s")
 
# ============================================================
# Plotting using Px
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)
 
pl.view_xy()
pl.show() #jupyter_backend='static'

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 14})

cmap = plt.cm.seismic 
#vmin, vmax = (-1, 1)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

fig, ax = plt.subplots(figsize=(5, 0.1))

# --- 1. Create and Configure the ScalarFormatter ---
formatter = ticker.ScalarFormatter(useMathText=True)

formatter.set_useOffset(False) 
formatter.set_powerlimits((0, 0)) 

# --- 2. Create the Colorbar and apply the Formatter ---
cb = mpl.colorbar.ColorbarBase(
    ax, 
    cmap=cmap, 
    norm=norm, 
    orientation='horizontal', label=r"Electric Pressure (Pa)"
)

# Apply the formatter to the colorbar's x-axis
cb.ax.xaxis.set_major_formatter(formatter)

# --- 3. Display the Plot ---
plt.show()

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 14})

cmap = plt.cm.seismic 
vmin, vmax = (-0.05, 0.05) 
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

fig, ax = plt.subplots(figsize=(5, 0.1))

# --- 1. Create and Configure the ScalarFormatter ---
formatter = ticker.ScalarFormatter(useMathText=True)

formatter.set_useOffset(False) 
formatter.set_powerlimits((0, 0)) 

# --- 2. Create the Colorbar and apply the Formatter ---
cb = mpl.colorbar.ColorbarBase(
    ax, 
    cmap=cmap, 
    norm=norm, 
    orientation='horizontal', label=r"Electric Pressure (Pa)"
)

# Apply the formatter to the colorbar's x-axis
cb.ax.xaxis.set_major_formatter(formatter)

# --- 3. Display the Plot ---
plt.show()

## 2D Representation of the Electric Field

In [None]:
# Save to cell data
pv_spheres_cropped.cell_data["electric_pressure"] = P_normal_faces
 
# ============================================================
# Plotting
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)

for XIN in np.linspace(-0.04,0.04,10):
    # Define YZ plane parameters (slice perpendicular to X-axis)
    X_SLICE = XIN + geometry.centroid[0]
    normal = [1, 0, 0]  # X-normal for YZ plane

    # Center of the plane
    plane_center = [
        X_SLICE,
        (bbox_bounds[2] + bbox_bounds[3]) / 2,  # Y center
        (bbox_bounds[4] + bbox_bounds[5]) / 2,  # Z center
    ]

    # Create a visual plane at X_SLICE location
    slice_plane = pv.Plane(
        center=plane_center,
        direction=normal,
        i_size=abs(bbox_bounds[3] - bbox_bounds[2]),  # Y extent
        j_size=abs(bbox_bounds[5] - bbox_bounds[4]),  # Z extent
        i_resolution=2,
        j_resolution=2
    )

        # Optional: Create an actual slice of the data at this plane
    slice_data = pv_spheres_cropped.slice(normal=normal, origin=plane_center)
    # pl.add_mesh(
    #     slice_data,
    #     scalars="electric_pressure",
    #     cmap="seismic",
    #     clim=[vmin, vmax],
    #     show_edges=True,
    #     line_width=1
    # )


    # Add the YZ plane visualization
    pl.add_mesh(
        slice_plane, 
        color="blue",
        opacity=0.3,
        show_edges=True,
        line_width=2
    )

pl.camera_position = [
    (0.019695016454329653, -0.3381202307807713, -0.006061249518755622),
    (0.002166604469844069, 0.0019058251597530163, -0.014826609860567566),
    (-0.046333590448935605, 0.02335526536874514, 0.9986529577263896)
]
 
pl.show()

In [None]:
## SETTINGS HERE ARE OPTIMIZED FOR ITERATION 86 ##

fieldIN = df_SW[iteration_SW]

N_DOWNSAMPLE_EMAG = 10 # Downsample factor for magnitude filtering
ARROW_VOXEL_SPACING = 0.01 
Z_SLICE = 0.01 + geometry.centroid[2]  # This correctly uses the centroid's Z
THICKNESS = 0.001
VECTOR_SCALE_FACTOR = 3e-8
FIELD_AVERAGE_RADIUS = 2.5e-3

vmin, vmax = (0, 2e5)

# ----------------------------------------------------
# Voxel Downsampling Helper Function (for XY plane)
# ----------------------------------------------------
def voxel_downsample_points(points, spacing):
    """
    Selects one point per voxel defined by the spacing.
    For Z-slice, uses X and Y for 2D density control.
    """
    min_x, min_y, _ = points.min(axis=0)
    
    x_indices = np.floor((points[:, 0] - min_x) / spacing).astype(int)
    y_indices = np.floor((points[:, 1] - min_y) / spacing).astype(int)
    
    max_x_index = x_indices.max() + 1
    voxel_keys = y_indices * max_x_index + x_indices

    _, unique_indices = np.unique(voxel_keys, return_index=True)
    
    return unique_indices

# ----------------------------------------------------
# Step 0: Load, Filter, and Downsample Data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]

initial_mask = (magnitudes > 0)
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]

points_ds = points[::N_DOWNSAMPLE_EMAG]
vectors_ds = vectors[::N_DOWNSAMPLE_EMAG]
magnitudes_ds = magnitudes[::N_DOWNSAMPLE_EMAG]

point_cloud = pv.PolyData(points_ds)
point_cloud["E_mag"] = magnitudes_ds
point_cloud["Ex_val"] = vectors_ds[:,0]
point_cloud["Ey_val"] = vectors_ds[:,1]
point_cloud["Ez_val"] = vectors_ds[:,2]

print(f"Starting points (filtered by mag > 0): {len(points)}")

# ----------------------------------------------------
# Step 1: Geometry Setup and Slicing
# ----------------------------------------------------
start_time_geo = time.time()

pv_spheres = pv.PolyData(
    geometry.vertices,
    np.hstack([np.full((len(geometry.faces), 1), 3), geometry.faces])
).compute_normals()

bbox_bounds = point_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = pv_spheres.clip_box(bbox, invert=False)

XIN = np.linspace(-0.04,0.04,10)[0]
# Define YZ plane parameters (slice perpendicular to X-axis)
X_SLICE = XIN + geometry.centroid[0]
normal = [1, 0, 0]  # X-normal for YZ plane

# Center of the plane
plane_center = [
    X_SLICE,
    (bbox_bounds[2] + bbox_bounds[3]) / 2,  # Y center
    (bbox_bounds[4] + bbox_bounds[5]) / 2,  # Z center
]

field_slice_mesh = pv.Plane(
    center=plane_center,
    direction=normal,
    i_size=bbox_bounds[1] - bbox_bounds[0],  # X size
    j_size=bbox_bounds[3] - bbox_bounds[2],  # Y size
    i_resolution=250, 
    j_resolution=250
)

field_slice_interpolated = field_slice_mesh.interpolate(
    point_cloud,
    sharpness=3.0,
    radius=0.001,
    null_value=1, 
    strategy='closest_point'
)

# Ensure all Z coordinates are exactly at Z_SLICE
field_slice_interpolated.points[:, 2] = Z_SLICE

# Slice geometry at the same location
geo_slice = pv_spheres_cropped.slice(normal=normal, origin=plane_center)
print(f"Geometry and slicing preparation complete in {time.time() - start_time_geo:.2f}s")

# ----------------------------------------------------
# Step 2: Vector Field Glyphs (Arrows) - Z component
# ----------------------------------------------------
start_time_vectors = time.time()

vector_mask = np.abs(points_ds[:, 2] - Z_SLICE) < THICKNESS
points_slice_full = points_ds[vector_mask]
vectors_slice_full = vectors_ds[vector_mask]
magnitudes_slice_full = magnitudes_ds[vector_mask]

unique_indices = voxel_downsample_points(points_slice_full, ARROW_VOXEL_SPACING)

points_slice = points_slice_full[unique_indices]
vectors_slice = vectors_slice_full[unique_indices]
magnitudes_slice = magnitudes_slice_full[unique_indices]

MAGNITUDE_MAX_CLAMP = ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR / 2
magnitudes_slice_clamped = np.clip(magnitudes_slice, a_min=None, a_max=MAGNITUDE_MAX_CLAMP)

points_slice[:,2] = Z_SLICE + 2*THICKNESS
vectors_slice[:,2] = 0.0 - 2*THICKNESS
slice_mesh_vectors = pv.PolyData(points_slice)
slice_mesh_vectors['vectors'] = vectors_slice
slice_mesh_vectors['magnitude'] = magnitudes_slice_clamped

print(f"Points in vector slice (after density control + extra column): {len(points_slice)}, old length: {len(points_slice_full)}...")

arrow = pv.Arrow(tip_length=0.3, tip_radius=0.2, shaft_radius=0.04)
glyphs = slice_mesh_vectors.glyph(
    orient='vectors',
    scale='magnitude',
    factor=VECTOR_SCALE_FACTOR,
    geom=arrow
)

In [None]:
## SETTINGS HERE ARE OPTIMIZED FOR ITERATION 86 ##

fieldIN = df_SW[iteration_SW]

N_DOWNSAMPLE_EMAG = 10 # Downsample factor for magnitude filtering
ARROW_VOXEL_SPACING = 0.01 

XIN = np.linspace(-0.04,0.04,10)[0]
# Define YZ plane parameters (slice perpendicular to X-axis)
#X_SLICE = XIN + geometry.centroid[0]
X_SLICE_OFFSET = XIN[0] # Offset from centroid for the YZ slice
THICKNESS = 0.001
VECTOR_SCALE_FACTOR = 3e-8
FIELD_AVERAGE_RADIUS = 2.5e-3

vmin, vmax = (0, 2e5)

# ----------------------------------------------------
# Voxel Downsampling Helper Function (for YZ plane)
# ----------------------------------------------------
def voxel_downsample_points(points, spacing):
    """
    Selects one point per voxel defined by the spacing.
    For YZ-slice, uses Y and Z for 2D density control.
    """
    min_y, min_z = points[:, 1].min(), points[:, 2].min()
    
    y_indices = np.floor((points[:, 1] - min_y) / spacing).astype(int)
    z_indices = np.floor((points[:, 2] - min_z) / spacing).astype(int)
    
    max_y_index = y_indices.max() + 1
    voxel_keys = z_indices * max_y_index + y_indices

    _, unique_indices = np.unique(voxel_keys, return_index=True)
    
    return unique_indices

# ----------------------------------------------------
# Step 0: Load, Filter, and Downsample Data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]

initial_mask = (magnitudes > 0)
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]

points_ds = points[::N_DOWNSAMPLE_EMAG]
vectors_ds = vectors[::N_DOWNSAMPLE_EMAG]
magnitudes_ds = magnitudes[::N_DOWNSAMPLE_EMAG]

point_cloud = pv.PolyData(points_ds)
point_cloud["E_mag"] = magnitudes_ds
point_cloud["Ex_val"] = vectors_ds[:,0]
point_cloud["Ey_val"] = vectors_ds[:,1]
point_cloud["Ez_val"] = vectors_ds[:,2]

print(f"Starting points (filtered by mag > 0): {len(points)}")

# ----------------------------------------------------
# Step 1: Geometry Setup and Slicing
# ----------------------------------------------------
start_time_geo = time.time()

pv_spheres = pv.PolyData(
    geometry.vertices,
    np.hstack([np.full((len(geometry.faces), 1), 3), geometry.faces])
).compute_normals()

bbox_bounds = point_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = pv_spheres.clip_box(bbox, invert=False)

# Define YZ plane parameters (slice perpendicular to X-axis)
X_SLICE = X_SLICE_OFFSET + geometry.centroid[0]
normal = [1, 0, 0]  # X-normal for YZ plane

# Center of the plane
plane_center = [
    X_SLICE,
    (bbox_bounds[2] + bbox_bounds[3]) / 2,  # Y center
    (bbox_bounds[4] + bbox_bounds[5]) / 2,  # Z center
]

# Create interpolation mesh for YZ plane
field_slice_mesh = pv.Plane(
    center=plane_center,
    direction=normal,
    i_size=abs(bbox_bounds[3] - bbox_bounds[2]),  # Y size
    j_size=abs(bbox_bounds[5] - bbox_bounds[4]),  # Z size
    i_resolution=250, 
    j_resolution=250
)

field_slice_interpolated = field_slice_mesh.interpolate(
    point_cloud,
    sharpness=3.0,
    radius=0.001,
    null_value=1, 
    strategy='closest_point'
)

# Ensure all X coordinates are exactly at X_SLICE
field_slice_interpolated.points[:, 0] = X_SLICE

# Slice geometry at the same location
geo_slice = pv_spheres_cropped.slice(normal=normal, origin=plane_center)
print(f"Geometry and slicing preparation complete in {time.time() - start_time_geo:.2f}s")

# ----------------------------------------------------
# Step 2: Vector Field Glyphs (Arrows) - YZ plane
# ----------------------------------------------------
start_time_vectors = time.time()

# Select vectors near the X_SLICE plane
vector_mask = np.abs(points_ds[:, 0] - X_SLICE) < THICKNESS
points_slice_full = points_ds[vector_mask]
vectors_slice_full = vectors_ds[vector_mask]
magnitudes_slice_full = magnitudes_ds[vector_mask]

unique_indices = voxel_downsample_points(points_slice_full, ARROW_VOXEL_SPACING)

points_slice = points_slice_full[unique_indices]
vectors_slice = vectors_slice_full[unique_indices]
magnitudes_slice = magnitudes_slice_full[unique_indices]

MAGNITUDE_MAX_CLAMP = ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR / 2
magnitudes_slice_clamped = np.clip(magnitudes_slice, a_min=None, a_max=MAGNITUDE_MAX_CLAMP)

# Offset arrows slightly in X direction and zero out X component of vectors
points_slice[:, 0] = X_SLICE + 2*THICKNESS
vectors_slice[:, 0] = 0.0  # Zero out X component for YZ plane visualization

slice_mesh_vectors = pv.PolyData(points_slice)
slice_mesh_vectors['vectors'] = vectors_slice
slice_mesh_vectors['magnitude'] = magnitudes_slice_clamped

print(f"Points in vector slice (after density control): {len(points_slice)}, old length: {len(points_slice_full)}...")

arrow = pv.Arrow(tip_length=0.3, tip_radius=0.2, shaft_radius=0.04)
glyphs = slice_mesh_vectors.glyph(
    orient='vectors',
    scale='magnitude',
    factor=VECTOR_SCALE_FACTOR,
    geom=arrow
)

In [None]:
# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

pl.add_mesh(
    pv_spheres_cropped,
    cmap="seismic",
    opacity=0.6,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)


pl.add_mesh(
    field_slice_interpolated,
    scalars="E_mag",  # Changed to Z component
    cmap="YlGnBu",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    scalar_bar_args={
        'title': None, #'Ez (V/m)',
        'vertical': False,
        'position_x': 0.20,
        'position_y': 0.12,
        'width': 0.6,
        'height': 0.05,
    }
)


pl.add_mesh(geo_slice, color="black", line_width=3, opacity=1)
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=5, opacity=1)

# pl.enable_parallel_projection()
# pl.enable_2d_style()
pl.view_yz()

pl.screenshot(f'figures/wang_{configIN}#{iteration_SW}.jpeg', scale=4)

print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show()

In [None]:

# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

pl.add_mesh(
    field_slice_interpolated,
    scalars="E_mag",  # Changed to Z component
    cmap="YlGnBu",
    opacity=1,
    show_edges=False,
    clim=[0, vmax],
    scalar_bar_args={
        'title': 'Ez (V/m)',
        'vertical': False,
        'position_x': 0.20,
        'position_y': 0.12,
        'width': 0.6,
        'height': 0.05,
    }
)

pl.add_mesh(
    pv_spheres_cropped,
    cmap="seismic",
    opacity=0.6,
    show_edges=False,
    clim=[0, vmax],
    interpolate_before_map=False,
    preference="cell"
)

pl.add_mesh(geo_slice, color="black", line_width=3, opacity=0.5)
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=4, opacity=1)

pl.enable_parallel_projection()
pl.enable_2d_style()
pl.view_xy()

# pl.screenshot(f'figures/fieldvectors_Z_{configIN}#{iteration_SW}.jpeg', scale=4)

print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show()

In [None]:

Y_SLICE = 0.03 + geometry.centroid[2]  # This correctly uses the centroid's Y

normal = [0, 0, 1]

# FIX: Use explicit center with Y_SLICE for the plane
plane_center = [(bbox_bounds[0] + bbox_bounds[1]) / 2,
    (bbox_bounds[2] + bbox_bounds[3]) / 2, Y_SLICE # X center
    #Y_SLICE,                                  # Y at slice location
    #(bbox_bounds[4] + bbox_bounds[5]) / 2   # Z center
]

# Slice geometry at the same location
geo_slice = pv_spheres_cropped.slice(normal=normal, origin=plane_center)  # <-- FIXED: Use plane_center

pl = pv.Plotter()
pl.set_background('white')


# pl.add_mesh(
#     pv_spheres_cropped,
#     cmap="seismic",
#     opacity=1,
#     show_edges=False,
#     clim=[vmin, vmax],
#     interpolate_before_map=False,
#     preference="cell"
# )

pl.add_mesh(geo_slice, color="black", line_width=3, opacity=0.5)

pl.show()

In [None]:
geometry.bounding_box.values

In [None]:
## 3D VISUALIZATION OF GEOMETRY AND ELECTRIC FIELD ##

fieldIN = df_SW[iteration_SW]

N_DOWNSAMPLE_EMAG = 10  # Increased downsampling for 3D performance
POINT_SIZE = 3
OPACITY = 0.6

vmin, vmax = (1e3, 1e6)  # Adjusted for log scale or regular magnitude

# ----------------------------------------------------
# Step 0: Load and Filter Data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]

# Filter by magnitude
initial_mask = (magnitudes > 0)
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]

# Downsample for performance
points_ds = points[::N_DOWNSAMPLE_EMAG]
vectors_ds = vectors[::N_DOWNSAMPLE_EMAG]
magnitudes_ds = magnitudes[::N_DOWNSAMPLE_EMAG]

# Create point cloud
point_cloud = pv.PolyData(points_ds)
point_cloud["E_mag"] = magnitudes_ds
point_cloud["E_vec"] = vectors_ds

print(f"Total points after filtering and downsampling: {len(points_ds)}")

# ----------------------------------------------------
# Step 1: Prepare Geometry
# ----------------------------------------------------
pv_spheres = pv.PolyData(
    geometry.vertices,
    np.hstack([np.full((len(geometry.faces), 1), 3), geometry.faces])
).compute_normals()

# Optional: Crop geometry to field bounds
bbox_bounds = point_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = pv_spheres.clip_box(bbox, invert=False)

print(f"Geometry preparation complete")

# ----------------------------------------------------
# Step 2: 3D Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# Add geometry (semi-transparent)
pl.add_mesh(
    pv_spheres_cropped,
    color="lightgray",
    opacity=0.3,
    show_edges=True,
    edge_color="black",
    line_width=1
)

# Add electric field magnitude as colored points
pl.add_mesh(
    point_cloud,
    scalars="E_mag",
    cmap="hot",  # or "viridis", "plasma", "turbo"
    opacity=OPACITY,
    point_size=POINT_SIZE,
    render_points_as_spheres=True,
    clim=[vmin, vmax],
    log_scale=False,  # Set to True if you want logarithmic color scale
    scalar_bar_args={
        'title': 'E_mag (V/m)',
        'vertical': True,
        'position_x': 0.85,
        'position_y': 0.15,
        'width': 0.05,
        'height': 0.7,
    }
)

# # Optional: Add vector glyphs (use sparingly, they're expensive)
# # Downsample even more for arrows in 3D
# ARROW_DOWNSAMPLE = 100
# points_arrows = points_ds[::ARROW_DOWNSAMPLE]
# vectors_arrows = vectors_ds[::ARROW_DOWNSAMPLE]
# magnitudes_arrows = magnitudes_ds[::ARROW_DOWNSAMPLE]

# arrow_mesh = pv.PolyData(points_arrows)
# arrow_mesh['vectors'] = vectors_arrows
# arrow_mesh['magnitude'] = magnitudes_arrows

# # Normalize arrows for visualization
# arrow_scale = 5e-6  # Adjust this for arrow length
# arrows_3d = arrow_mesh.glyph(
#     orient='vectors',
#     scale='magnitude',
#     factor=arrow_scale,
#     geom=pv.Arrow(tip_length=0.25, tip_radius=0.1, shaft_radius=0.03)
# )

# pl.add_mesh(arrows_3d, color='black', opacity=0.8)

# Set up 3D camera view
pl.camera_position = 'iso'  # isometric view
# Alternatively: pl.camera_position = [(x, y, z), (focal_x, focal_y, focal_z), (up_x, up_y, up_z)]

# Add axes for orientation
pl.add_axes()

# Enable interactive controls
pl.enable_trackball_style()

# Screenshot (optional)
# pl.screenshot(f'figures/fieldvectors_3D_{configIN}#{iteration_SW}.jpeg', scale=4)

print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show()

In [None]:
# ----------------------------------------------------
# Step 2: 3D Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# Add geometry (semi-transparent)
pl.add_mesh(
    pv_spheres_cropped,
    color="lightgray",
    opacity=0.3,
    show_edges=True,
    edge_color="black",
    line_width=1
)

# Add electric field magnitude as colored points
pl.add_mesh(
    point_cloud,
    scalars="E_mag",
    cmap="hot",  # or "viridis", "plasma", "turbo"
    opacity=OPACITY,
    point_size=POINT_SIZE,
    render_points_as_spheres=True,
    clim=[vmin, vmax],
    log_scale=False,  # Set to True if you want logarithmic color scale
    scalar_bar_args={
        'title': 'E_mag (V/m)',
        'vertical': True,
        'position_x': 0.85,
        'position_y': 0.15,
        'width': 0.05,
        'height': 0.7,
    }
)

In [None]:
## SETTINGS HERE ARE OPTIMIZED FOR ITERATION 86 ##

fieldIN = df_SW[iteration_SW]

center = geometry.centroid
N_DOWNSAMPLE_EMAG = 1
ARROW_VOXEL_SPACING = 0.02 
Y_SLICE = 0.0 + center[1]
THICKNESS = 0.001
VECTOR_SCALE_FACTOR = 1e-6 #2e-7 #2e-3 #5e-6 #2e-3 #5e-6 # Global scaling for glyphs
FIELD_AVERAGE_RADIUS = 2.5e-3 #2e-3

vmin, vmax = (-2e5, 2e5) # in log(E_mag) units
red_point = np.array([-0.1, 0, 0.1 - 0.015 + 0.037]) # 

# ----------------------------------------------------
# Voxel Downsampling Helper Function
# Ensures uniform spatial distribution of points in the slice
# ----------------------------------------------------
def voxel_downsample_points(points, spacing):
    """
    Selects one point per voxel defined by the spacing.
    Assumes points are 3D, but only uses X and Z for 2D density control.
    """
    # 1. Normalize coordinates to voxel indices (focus on X and Z for the 2D slice)
    min_x, _, min_z = points.min(axis=0)
    
    # Calculate bin indices for the points
    # We use X (column 0) and Z (column 2)
    x_indices = np.floor((points[:, 0] - min_x) / spacing).astype(int)
    z_indices = np.floor((points[:, 2] - min_z) / spacing).astype(int)
    
    # Combine X and Z indices into a unique hash/key
    max_x_index = x_indices.max() + 1
    voxel_keys = z_indices * max_x_index + x_indices

    # 2. Find the unique keys and their first occurrence
    # `return_index=True` gives the index of the first occurrence of each unique key
    unique_keys, unique_indices = np.unique(voxel_keys, return_index=True)
    
    return unique_indices

# ----------------------------------------------------
# Step 0: Load, Filter, and Downsample Data (Single Pass)
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]

# Apply initial filtering (z > 0 and magnitude > 0)
initial_mask = (points[:, 2] > 0) & (magnitudes > 0)
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]

# Aggressive Downsample (for point cloud, typically N_DOWNSAMPLE_EMAG=1 is best)
points_ds = points[::N_DOWNSAMPLE_EMAG]
vectors_ds = vectors[::N_DOWNSAMPLE_EMAG]
magnitudes_ds = magnitudes[::N_DOWNSAMPLE_EMAG]

# Create a PyVista Point Cloud (PolyData)
point_cloud = pv.PolyData(points_ds)
point_cloud["E_mag"] = magnitudes_ds   # Store log magnitude for visualization
point_cloud["Ex_val"] = vectors_ds[:,0] # Store vectors
point_cloud["Ez_val"] = vectors_ds[:,2] # Store vectors

print(f"Starting points (filtered by z > 0 & mag > 0): {len(points)}")

# ----------------------------------------------------
# Step 1: Geometry Setup and Slicing
# ----------------------------------------------------
start_time_geo = time.time()

# 1a. Load and Crop Geometry
pv_spheres = pv.PolyData(
    geometry.vertices,
    np.hstack([np.full((len(geometry.faces), 1), 3), geometry.faces])
).compute_normals()

# Define bounding box based on the downsampled field data
bbox_bounds = point_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = pv_spheres.clip_box(bbox, invert=False)

# 1b. Define the slice plane (ZX plane, normal along Y)
normal = [0, 1, 0] # ZX plane (normal along Y)

# Create a plane mesh for interpolation (this will be the magnitude slice)
plane_bounds = [
    point_cloud.bounds[0], point_cloud.bounds[1], # X bounds
    Y_SLICE, Y_SLICE,                             # Y (fixed)
    point_cloud.bounds[4], point_cloud.bounds[5]  # Z bounds
]

field_slice_mesh = pv.Plane(
    center=center, 
    direction=normal,
    j_size=bbox_bounds[1] - bbox_bounds[0], # X span
    i_size=bbox_bounds[5] - bbox_bounds[4], # Z span
    i_resolution=250, 
    j_resolution=250
)

# --- MODIFIED INTERPOLATION CALL FOR NEAREST NEIGHBOR ---
field_slice_interpolated = field_slice_mesh.interpolate(
    point_cloud,
    sharpness=3.0,      # High sharpness often helps with point data
    radius=0.001, #1e-12,       # Set radius to near-zero to minimize interpolation
    
    # 1. Provide a float placeholder to satisfy the TypeError
    null_value=1, 

    # 2. Force the strategy to use the nearest point (Nearest Neighbor)
    strategy='closest_point' # <--- This achieves the extrapolation you want
)
# --------------------------------------------------------

field_slice_interpolated.points[:, 1] = Y_SLICE

# Also update the geometry slice
geo_slice = pv_spheres_cropped.slice(normal=normal, origin=center)
print(f"Geometry and slicing preparation complete in {time.time() - start_time_geo:.2f}s")

# ----------------------------------------------------
# Step 2: Vector Field Glyphs (Arrows)
# ----------------------------------------------------
start_time_vectors = time.time()

# 2a. Filter the downsampled points again to extract only those in the slice volume
# We use NumPy masking directly on the downsampled data (points_ds)
vector_mask = np.abs(points_ds[:, 1] - Y_SLICE) < THICKNESS
points_slice_full = points_ds[vector_mask]
vectors_slice_full = vectors_ds[vector_mask]
magnitudes_slice_full = magnitudes_ds[vector_mask]

# 2b. Apply Voxel Downsampling to achieve uniform density
unique_indices = voxel_downsample_points(points_slice_full, ARROW_VOXEL_SPACING)

points_slice = points_slice_full[unique_indices]
vectors_slice = vectors_slice_full[unique_indices]
magnitudes_slice = magnitudes_slice_full[unique_indices]

# ----------------------------------------------------
# MODIFICATION: Calculate Clamping Limit and Apply Clamping
# ----------------------------------------------------
# The maximum allowed length of an arrow is ARROW_VOXEL_SPACING.
# The glyph length = magnitude * VECTOR_SCALE_FACTOR * arrow_length_in_geom (which is 1.0 for pv.Arrow).
# To ensure: glyph_length <= ARROW_VOXEL_SPACING
# We need: magnitude * VECTOR_SCALE_FACTOR <= ARROW_VOXEL_SPACING
# Therefore: magnitude_clamped <= ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR

# Define the maximum magnitude allowed
MAGNITUDE_MAX_CLAMP = ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR /2

# Apply the clamping (upper bound) to the magnitude array
magnitudes_slice_clamped = np.clip(magnitudes_slice, a_min=None, a_max=MAGNITUDE_MAX_CLAMP)
# ----------------------------------------------------


# 2c. Create a PolyData object for glyphs
points_slice[:,1] = Y_SLICE - 2*THICKNESS# Force y-coordinate to the slice plane for visualization
vectors_slice[:,1] = 0.0 - 2* THICKNESS# Zero out Y component for 2D slice visualization
slice_mesh_vectors = pv.PolyData(points_slice)
slice_mesh_vectors['vectors'] = vectors_slice
# Use the CLAMPED magnitude array for scaling
slice_mesh_vectors['magnitude'] = magnitudes_slice_clamped
#slice_mesh_vectors['magnitude'] = np.log10(magnitudes_slice)

# # 2c. Create a PolyData object for glyphs
# points_slice[:,1] = Y_SLICE - 2*THICKNESS# Force y-coordinate to the slice plane for visualization
# vectors_slice[:,1] = 0.0 - 2* THICKNESS# Zero out Y component for 2D slice visualization
# slice_mesh_vectors = pv.PolyData(points_slice)
# slice_mesh_vectors['vectors'] = vectors_slice
# #slice_mesh_vectors['magnitude'] = np.log10(magnitudes_slice)
# slice_mesh_vectors['magnitude'] = magnitudes_slice

print(f"Points in vector slice (after density control): {len(points_slice)}, old length: {len(points_slice_full)}...")

# 2d. Create the glyphs
arrow = pv.Arrow(tip_length=0.3, tip_radius=0.2, shaft_radius=0.04)
glyphs = slice_mesh_vectors.glyph(
    orient='vectors',
    scale='magnitude',
    factor=VECTOR_SCALE_FACTOR,
    geom=arrow
)

# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# Add interpolated magnitude slice
pl.add_mesh(
    field_slice_interpolated,
    scalars="Ex_val",
    cmap="YlGnBu",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],   # <-- set fixed color range here
    # --- COLORBAR POSITIONING ---
    scalar_bar_args={
        'title':None, # r'log$_{10}$(E$_{mag}$)', # Updated title format
        'vertical': False,            # Make it horizontal
        'position_x': 0.20,           # User-specified start position
        'position_y': 0.12,           # User-specified vertical position
        'width': 0.6,                 # User-specified width
        'height': 0.05,               # User-specified height
    }
)

# Add sliced geometry (outline only)
pl.add_mesh(geo_slice, color="black", line_width=5,opacity=0.5)

# Add vector glyphs
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=4,opacity=1)

# # Optional marker
sphere = pv.Sphere(radius=FIELD_AVERAGE_RADIUS, center=red_point)
pl.add_mesh(sphere, color="red", opacity=1)

# Force 2D (orthographic) projection and camera alignment for the ZX slice
pl.enable_parallel_projection()
pl.enable_2d_style()

# Align camera perpendicular to the slice
pl.view_xz() 

# # --- ADD THIS LINE BEFORE pl.show() ---
# pl.screenshot(f'figures/fieldvectors_{configIN}#{iteration}.jpeg', scale=4)

# # Show the plot
# print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show()

In [None]:
# -----------------------------------------
# Your existing setup code here...
# -----------------------------------------
fieldIN = df[iteration]
FIELD_AVERAGE_RADIUS = 2.5e-3
vmin, vmax = (-0.005, 0.005)
 
geometry_center = stacked_spheres_centroid  # replace with your centroid
red_point = np.array([-0.1, 0., 0.1 + 0.037]) + geometry_center
 
# ----------------------------------------------------
# Step 0: Filter data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]
 
initial_mask = (magnitudes > 0)
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]
 
epsilon_0 = 8.854187817e-12
 
# Create field cloud
field_cloud = pv.PolyData(points)
field_cloud["E_x"] = vectors[:, 0]
field_cloud["E_y"] = vectors[:, 1]
field_cloud["E_z"] = vectors[:, 2]
 
print(f"Starting points (filtered): {len(points)}")
 
# ----------------------------------------------------
# Step 1: Geometry Setup from vertices
# ----------------------------------------------------
start_time_geo = time.time()
 
# Convert vertex list to Nx3 numpy array
vertices = np.array(stacked_spheres)  # replace stacked_spheres with your vertex list
 
# Create a point cloud
cloud = pv.PolyData(vertices)
 
# Reconstruct a mesh from vertices
# Use delaunay_2d if roughly planar, otherwise use reconstruct_surface
print("before the construction")
pv_spheres = cloud.reconstruct_surface(nbr_neighbors=10)  # safer for 3D shapes
print("finished the construction")
pv_spheres.compute_normals(inplace=True)
 
# Crop mesh to field bounds
bbox_bounds = field_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = (
    pv_spheres
    .clip_box(bbox, invert=False)
    .extract_surface()
    .compute_normals(point_normals=True, cell_normals=True, inplace=False)
)
 
# ============================================================
# Interpolate field to face centers
# ============================================================
face_centers = pv_spheres_cropped.cell_centers().points
face_field_cloud = pv.PolyData(face_centers)
face_field_interp = face_field_cloud.interpolate(
    field_cloud,
    radius=0.002,
    strategy='closest_point',
    sharpness=3.0,
    null_value=0.0
)
 
# Extract face-centered E-fields
E_x_faces = face_field_interp["E_x"]
E_y_faces = face_field_interp["E_y"]
E_z_faces = face_field_interp["E_z"]
E_vec_faces = np.stack([E_x_faces, E_y_faces, E_z_faces], axis=1)
 
# Face normals
face_normals = pv_spheres_cropped.cell_normals
nx = face_normals[:, 0]   # x-direction component
 
# ============================================================
# Compute Maxwell electric pressure (normal)
# ============================================================
E_dot_n = np.einsum('ij,ij->i', E_vec_faces, face_normals)
E_mag_sq = np.einsum('ij,ij->i', E_vec_faces, E_vec_faces)
 
# Normal pressure (scalar)
P_normal_faces = epsilon_0 * (E_dot_n**2 - 0.5 * E_mag_sq)
 
# ============================================================
# Compute X-directed electric pressure
# ============================================================
P_x_faces = P_normal_faces * nx
pv_spheres_cropped.cell_data["electric_pressure_x"] = P_x_faces
 
print(f"Computed x-directed electric pressure in {time.time() - start_time_geo:.2f}s")
 
# ============================================================
# Plotting using Px
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure_x",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)
 
pl.view_xy()
pl.show(jupyter_backend='static')
 
# ============================================================
# Extract Line Plot Data (y=0 top surface)
# ============================================================
y_tolerance = 0.01
z_min = 0.08
 
x_centers = face_centers[:, 0]
y_centers = face_centers[:, 1]
z_centers = face_centers[:, 2]
 
Px_faces = P_x_faces  # shorthand
 
line_mask = (np.abs(y_centers) < y_tolerance) & (z_centers > z_min)
 
x_line = x_centers[line_mask]
z_line = z_centers[line_mask]
Px_line = Px_faces[line_mask]
 
# Bin + average
x_bin_width = 0.005
x_min, x_max = x_line.min(), x_line.max()
x_bins = np.arange(x_min, x_max + x_bin_width, x_bin_width)
x_bin_centers = (x_bins[:-1] + x_bins[1:]) / 2
 
bin_indices = np.digitize(x_line, x_bins)
 
x_line_avg, z_line_avg, Px_line_avg = [], [], []
 
for i in range(1, len(x_bins)):
    mask = (bin_indices == i)
    if mask.any():
        x_line_avg.append(x_line[mask].mean())
        z_line_avg.append(z_line[mask].mean())
        Px_line_avg.append(Px_line[mask].mean())
 
# Sort
x_line_sorted_PE = np.array(x_line_avg)
z_line_sorted_PE = np.array(z_line_avg)
pressure_line_sorted_PE = np.array(Px_line_avg)
 
sort_idx = np.argsort(x_line_sorted_PE)
x_line_sorted_PE = x_line_sorted_PE[sort_idx]
z_line_sorted_PE = z_line_sorted_PE[sort_idx]
pressure_line_sorted_PE = pressure_line_sorted_PE[sort_idx]
 
print(f"Extracted {len(x_line)} raw points, averaged into {len(x_line_sorted_PE)} bins along y=0 line")

In [None]:
# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# Add interpolated magnitude slice
pl.add_mesh(
    field_slice_interpolated,
    scalars="Ex_val",
    cmap="YlGnBu",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],   # <-- set fixed color range here
    # --- COLORBAR POSITIONING FIX ---
    scalar_bar_args={
        'title':None, # r'log$_{10}$(E$_{mag}$)', # Updated title format
        'vertical': False,            # Make it horizontal
        'position_x': 0.20,           # User-specified start position
        'position_y': 0.12,           # User-specified vertical position
        'width': 0.6,                 # User-specified width
        'height': 0.05,               # User-specified height
    }
    # -------------------------------
)

# Add sliced geometry (outline only)
pl.add_mesh(geo_slice, color="black", line_width=5,opacity=0.5)

# Add vector glyphs
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=4,opacity=1)



# # Optional marker
red_point = np.array([-0.1, 0, 0.1 - 0.015 + 0.036]) # 

# Define field averaging parameters
FIELD_AVERAGE_RADIUS = 2e-3

offsetLimitsX = 0.009
new_step_x = offsetLimitsX/3

offsetLimitsY = 0.011
new_step_y = offsetLimitsY/3

xoffset = np.round(np.arange(-offsetLimitsX, offsetLimitsX, new_step_x),4)
yoffset = np.round(np.arange(-offsetLimitsY, offsetLimitsY, new_step_y),4)

X, Y = np.meshgrid(xoffset, yoffset)
target_points_array = np.vstack([
    -0.1 - X.flatten(), 
    np.zeros(len(X.flatten())), 
    0.1 - 0.015 + 0.037 - Y.flatten()
]).T

print(f"Processing {len(target_points_array)} target points with radius {FIELD_AVERAGE_RADIUS} mm")
for red_point in target_points_array:

    sphere = pv.Sphere(radius=FIELD_AVERAGE_RADIUS, center=red_point)
    pl.add_mesh(sphere, color="red", opacity=1)

# # Combine into (N,3) array
# polyline = pv.PolyData(target_points_array)
# pl.add_mesh(polyline, color='r', point_size=1, opacity=0.8) # Add to the plot



# Parameters

# # Optional marker
red_point = np.array([-0.1, 0, 0.1 - 0.015 + 0.037]) # 
radius_mm = 40/1000/2   # 20 µm
center = red_point  # your np.array([-0.1-0.015, 0, 0.1 - 0.015 + 0.036])

# Create circle points manually (in XZ plane)
theta = np.linspace(0, 2*np.pi, 100)
x = center[0] + radius_mm * np.cos(theta)
y = np.full_like(theta, center[1])   # constant y value (so it's in XZ plane)
z = center[2] + radius_mm * np.sin(theta)

# Combine into (N,3) array
points = np.column_stack((x, y, z))
polyline = pv.PolyData(points)
pl.add_mesh(polyline, color='k', point_size=0.5, opacity=0.8) # Add to the plot


# # Create the line and add to the plot
# x_fixed,y_fixed = -0.1,0 
# zmin, zmax = 0, 0.2  # adjust to fit your plot domain
# points = np.array([[x_fixed, y_fixed, zmin],[x_fixed, y_fixed, zmax]])
# line = pv.Line(points[0], points[1])
# pl.add_mesh(line, color='black', line_width=2)

# Force 2D (orthographic) projection and camera alignment for the ZX slice
pl.enable_parallel_projection()
pl.enable_2d_style()

# Align camera perpendicular to the slice
pl.view_xz() 

# --- ADD THIS LINE BEFORE pl.show() ---
#pl.screenshot(f'figures/fieldvectors_{configIN}#{iteration}.jpeg', scale=4)

# Show the plot
print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show(jupyter_backend='static')

In [None]:
# Define field averaging parameters
FIELD_AVERAGE_RADIUS = 2e-3

offsetLimitsX = 0.009
new_step_x = offsetLimitsX/3

offsetLimitsY = 0.011
new_step_y = offsetLimitsY/3

xoffset = np.round(np.arange(-offsetLimitsX, offsetLimitsX, new_step_x),4)
yoffset = np.round(np.arange(-offsetLimitsY, offsetLimitsY, new_step_y),4)
X, Y = np.meshgrid(xoffset, yoffset)
target_points_array = np.vstack([
    -0.1 - X.flatten(), 
    np.zeros(len(X.flatten())), 
    0.1 - 0.015 + 0.037 - Y.flatten()
]).T

print(f"Processing {len(target_points_array)} target points with radius {FIELD_AVERAGE_RADIUS} mm")

In [None]:
offsetLimitsX = 0.009
new_step_x = offsetLimitsX/3

# np.arange(start, stop, step)
# We add a small epsilon (1e-9) to the stop value to guarantee
# that the last value, 0.009, is included due to floating-point arithmetic.
xoffset = np.round(np.arange(-offsetLimitsX, offsetLimitsX, new_step_x),4)

In [None]:
xoffset

In [None]:
yoffset

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 16})

cmap = plt.cm.YlGnBu 
vmin, vmax = (0, 2e5) # in log(E_mag) units
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

fig, ax = plt.subplots(figsize=(10, 0.1))

# --- 1. Create and Configure the ScalarFormatter ---
formatter = ticker.ScalarFormatter(useMathText=True)

formatter.set_useOffset(False) 
formatter.set_powerlimits((0, 0)) 

# --- 2. Create the Colorbar and apply the Formatter ---
cb = mpl.colorbar.ColorbarBase(
    ax, 
    cmap=cmap, 
    norm=norm, 
    orientation='horizontal', label=r"|E| (V/m)"
)

# Apply the formatter to the colorbar's x-axis
cb.ax.xaxis.set_major_formatter(formatter)

# --- 3. Display the Plot ---
plt.show()

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 14})

# --- 1. CONFIGURATION AND DATA LOADING ---
print("\n--- Starting Data Processing and Plot Generation ---")

# Define Data Keys for Fitting and Subtraction
KEY_FIT = 'PE_425K_initial8max0.8final12_noDissipation_sphere50um_pos-0.1'
KEY_TARGET_PE = 'PE_425K_initial8max0.8final12_RefinedGridDissipation_sphere40um_pos-0.1' # Renamed for clarity
KEY_TARGET_SW = 'SW_425K_initial8max0.8final12_RefinedGridDissipation_sphere20um_pos-0.1' # Key for SW comparison

# Load external literature data
zimmerman_SWdata = pd.read_csv("literature-data/Fig7a-SW.csv")
zimmerman_PEdata = pd.read_csv("literature-data/Fig7a-PE.csv")
zimmerman_PEandSWdata = pd.read_csv("literature-data/Fig7a-PE+SW.csv")

# Simulation Parameters (used for time conversion)
WORLD_XY_AREA_SQ_M = 300 * 300 / (1e6**2) # World area (m^2)

# PE Conversion Factor
PARTICLES_PER_ITERATION_PE = 81775
FLUX_PER_ITERATION_PE = PARTICLES_PER_ITERATION_PE / WORLD_XY_AREA_SQ_M
PE_ION_FLUX = 4e-6 * 6.241509e18
CONVERT_ITERATION_PE_TIME = FLUX_PER_ITERATION_PE / PE_ION_FLUX
print(f"PE Conversion Factor (s/iteration): {CONVERT_ITERATION_PE_TIME:.3e}")

# SW Conversion Factor
PARTICLES_PER_ITERATION_SW = 30601
FLUX_PER_ITERATION_SW = PARTICLES_PER_ITERATION_SW / WORLD_XY_AREA_SQ_M
SW_ION_FLUX = 3e-7 * 6.241509e18
CONVERT_ITERATION_SW_TIME = FLUX_PER_ITERATION_SW / SW_ION_FLUX
print(f"SW Conversion Factor (s/iteration): {CONVERT_ITERATION_SW_TIME:.3e}")

# Color Map Setup
CMAP_NAME = 'Dark2'
discrete_cmap = plt.get_cmap(CMAP_NAME, len(all_processed.keys()) + 1)
color_list_rgba = [discrete_cmap(i) for i in np.linspace(0, 1, len(all_processed.keys()) + 1)]

# --- 2. CURVE FITTING AND EXTRAPOLATION (UNCHANGED) ---

# Define the new fitting function (Polynomial of Order 3)
def poly_curve(t, a, b, c, d):
    """
    Function: a*t^3 + b*t^2 + c*t + d (Polynomial of Order 3)
    """
    return a*t**3 + b*t**2 + c*t + d

# Prepare data for fitting (PE_425K)
data_fit = all_processed[KEY_FIT]
x_fit = np.array(data_fit["iter"] - 1) * CONVERT_ITERATION_PE_TIME
y_fit = abs(data_fit["E_vecs"][:, 0])
y_fit_errors = data_fit["point_errors"][:, 0]

# Prepare Target Data (PE_425K Refined)
data_target_pe = all_processed[KEY_TARGET_PE]
x_extrapolate = np.array(data_target_pe["iter"] - 1) * CONVERT_ITERATION_PE_TIME
y_target_pe = abs(data_target_pe["E_vecs"][:, 0])
y_target_pe_errors = data_target_pe["point_errors"][:, 0]

# --- Fitting ---
method_label = f"Fit of {KEY_FIT.split('_')[1]} (Poly Order 3)"
try:
    popt, pcov = curve_fit(poly_curve, x_fit, y_fit, 
                           p0=[0, 0, 0, 1e4], sigma=y_fit_errors, absolute_sigma=True)
    
    A_fit, B_fit, C_fit, D_fit = popt
    print(f"\nFit Parameters for {KEY_FIT.split('_')[1]} (Poly): a={A_fit:.2e}, b={B_fit:.2e}, c={C_fit:.2e}, d={D_fit:.2e}")

    # Extrapolate and Calculate Subtraction
    y_extrapolated = poly_curve(x_extrapolate, *popt)
    y_subtraction = y_extrapolated - y_target_pe

    # Error estimation for the extrapolation
    perr = np.sqrt(np.diag(pcov))
    J = np.array([3*x_extrapolate**2, 2*x_extrapolate, np.ones_like(x_extrapolate), np.zeros_like(x_extrapolate)]).T
    y_extrapolated_errors = np.sqrt(np.diag(J @ pcov @ J.T))


except RuntimeError:
    print("\n⚠️ Warning: Curve fitting failed. Falling back to zeros.")
    y_extrapolated = np.zeros_like(x_fit)
    y_subtraction = np.zeros_like(x_fit)
    x_extrapolate = x_fit
    method_label = f"Fit of {KEY_FIT.split('_')[1]} (Failed, showing Zeros)"
    y_extrapolated_errors = np.zeros_like(x_extrapolate)

# --- Calculate Percent Difference for Fit vs. Target ---
y_denominator = np.where(y_extrapolated == 0, 1e-10, y_extrapolated)
y_percent_diff = (y_subtraction / y_denominator) * 100
y_subtraction_errors = np.sqrt(y_extrapolated_errors**2 + y_target_pe_errors**2)
y_percent_diff_errors = (y_subtraction_errors / np.abs(y_denominator)) * 100 

# --------------------------------------------------------------------------
# --- 3. ZIMMERMAN COMPARISON CALCULATIONS (UPDATED FOR ERROR) ---
# --------------------------------------------------------------------------

# --- PE Simulation vs. Zimmerman PE Literature ---
x_lit_pe = 10**zimmerman_PEdata["x"]
y_lit_pe = zimmerman_PEdata[" y"]

# Interpolate Zimmerman PE data to the simulation time points (x_extrapolate)
f_interp_pe = interp1d(x_lit_pe, y_lit_pe, kind='linear', fill_value="extrapolate")
y_lit_pe_interp = f_interp_pe(x_extrapolate)

# Calculate % Difference: (Sim - Lit) / Lit
y_zimmerman_pe_diff = ((y_target_pe - y_lit_pe_interp) / y_lit_pe_interp) * 100
# Calculate % Difference Error: (Sim_err / Lit_interp) * 100 (Assuming Lit error is negligible)
y_zimmerman_pe_diff_errors = (y_target_pe_errors / np.abs(y_lit_pe_interp)) * 100


# --- SW Simulation vs. Zimmerman SW Literature ---
# 1. Get SW simulation data
KEY_SIM_SW = KEY_TARGET_SW
data_target_sw = all_processed.get(KEY_SIM_SW, None)

if data_target_sw is not None:
    x_target_sw = np.array(data_target_sw["iter"] - 1) * CONVERT_ITERATION_SW_TIME
    y_target_sw = abs(data_target_sw["E_vecs"][:, 0])
    y_target_sw_errors = data_target_sw["point_errors"][:, 0] # EXTRACTED SW ERRORS

    # 2. Get SW literature data
    x_lit_sw = 10**zimmerman_SWdata["x"]
    y_lit_sw = zimmerman_SWdata[" y"]

    # Interpolate Zimmerman SW data to the SW simulation time points (x_target_sw)
    f_interp_sw = interp1d(x_lit_sw, y_lit_sw, kind='linear', fill_value="extrapolate")
    y_lit_sw_interp = f_interp_sw(x_target_sw)

    # Calculate % Difference: (Sim - Lit) / Lit
    y_zimmerman_sw_diff = ((y_target_sw - y_lit_sw_interp) / y_lit_sw_interp) * 100
    # Calculate % Difference Error: (Sim_err / Lit_interp) * 100 (Assuming Lit error is negligible)
    y_zimmerman_sw_diff_errors = (y_target_sw_errors / np.abs(y_lit_sw_interp)) * 100 
else:
    print(f"⚠️ Warning: SW target key '{KEY_SIM_SW}' not found in all_processed. Skipping SW comparison.")
    x_target_sw = np.array([])
    y_zimmerman_sw_diff = np.array([])
    y_zimmerman_sw_diff_errors = np.array([]) # Defined for consistent error calculation

# --------------------------------------------------------------------------

# --- 4. PLOTTING SETUP (MAIN + 2 SUBPLOTS) ---

# Set up figure and grid layout (5:1:1 height ratio for main plot vs. residual vs. benchmark plot)
fig = plt.figure(figsize=(8.01, 4.6))
gs = gridspec.GridSpec(3, 1, hspace=0.15, height_ratios=[4, 0.8, 0.8]) 

# Main Plot (Top)
ax_main = fig.add_subplot(gs[0])
# Fit Residual Plot (Middle), sharing the x-axis
ax_fit_res = fig.add_subplot(gs[1], sharex=ax_main)
# Zimmerman Benchmark Plot (Bottom), sharing the x-axis
ax_zimm_bench = fig.add_subplot(gs[2], sharex=ax_main)

# color_list is defined here for use in section 7
color_list = [color_list_rgba[2],color_list_rgba[3],color_list_rgba[6]]

# --- 5. MAIN PLOT GENERATION (ax_main) ---

# Plot reference data (Zimmerman)
ax_main.plot(x_lit_sw, y_lit_sw, '-', color="k", lw=3, alpha=0.3, label="Zimmerman SW/PE/PE+SW Ref.")
ax_main.plot(x_lit_pe, y_lit_pe, '-', color="k", lw=3, alpha=0.3)
#ax_main.plot(10**zimmerman_PEandSWdata["x"], zimmerman_PEandSWdata[" y"], '--', color="k", lw=3, alpha=0.3)

# Plot simulation data
color_list = [color_list_rgba[2],color_list_rgba[3],color_list_rgba[6]]
i=0

# # Plot the fitted/extrapolated curve
ax_main.plot(x_extrapolate, y_extrapolated, '-', color=color_list[-1], lw=2, 
             label=f"{method_label} (Extrapolated)")

for keyIN, colorIN in zip(all_processed.keys(), color_list_rgba):
    
    # Define plotting variables outside of loop to use them later
    case = all_processed[keyIN]["metadata"]["case"]
    factor = CONVERT_ITERATION_PE_TIME if case == "PE" else CONVERT_ITERATION_SW_TIME
    tempIN = all_processed[keyIN]["metadata"]["temperature"]
    targetIN = all_processed[keyIN]["metadata"]["target_point"]
    
    x_data = np.array(all_processed[keyIN]["iter"] - 1) * factor
    y_data = abs(all_processed[keyIN]["E_vecs"][:, 0])
    y_err = all_processed[keyIN]["point_errors"][:, 0]
    
    tempIN = all_processed[keyIN]["metadata"]["temperature"]
    targetIN = all_processed[keyIN]["metadata"]["target_point"]
    # noteIN = keyIN.split("_")[2] # Not used in label for brevity

    # Filter plotting to only the relevant cases (e.g., specific position and T=425)
    if (targetIN[0] < 0) & (tempIN == 425) & ("Total" not in keyIN.split("_")[3]):
        
        print(keyIN)

        plot_color = color_list[i]

        if keyIN == KEY_FIT:
            # # Plot the data line
            # ax_main.plot(x_data, y_data, '--', color=plot_color, lw=1)
            continue
        else:
            # Plot the data line
            ax_main.plot(x_data, y_data, '-', color=plot_color, lw=1.5)
            
            # Use fill_between for the error region (Replaces errorbars)
            ax_main.fill_between(x_data, y_data - y_err, y_data + y_err, 
                                color=plot_color, alpha=0.15, 
                                label=None) # Set label=None to avoid extra legend entry
        i+=1
# Clean up main plot
ax_main.set_ylabel(r"$|E_x|$ (V/m)")
ax_main.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
ax_main.set_ylim(0,2.2e5)
# ax_main.axvline(x=0.32754498952096356)
# ax_main.axvline(x=3.2390560074850843)
# ax_main.grid(True, linestyle=':', alpha=0.5) 
# ax_main.legend(loc='lower left', fontsize=8, ncol=2)
# Remove X-tick labels from the main plot
plt.setp(ax_main.get_xticklabels(), visible=False) 


# --- 6. FIT RESIDUAL PLOT GENERATION (ax_fit_res) ---

# PLOT PERCENT DIFFERENCE WITH SHADED ERROR REGION
ax_fit_res.plot(x_extrapolate, y_percent_diff, '-', color=color_list[-2], lw=2,
            label=r"% Diff: $\frac{|E_{Fit}| - |E_{Target}|}{|E_{Fit}|}$")
# Shaded region
# ax_fit_res.fill_between(x_extrapolate, y_percent_diff - y_percent_diff_errors, 
#                     y_percent_diff + y_percent_diff_errors, 
#                     color=color_list[2], alpha=0.2, label="Fit Error Region")
# ax_fit_res.set_ylabel(r"Fit % Diff")
# ax_fit_res.grid(True, linestyle=':', alpha=0.6)
# ax_fit_res.legend(loc='upper right', fontsize=8)
#ax_fit_res.set_ylim(0,10) 
#ax_fit_res.set_yticks([0,4,8])
# Remove X-tick labels from the fit residual plot
plt.setp(ax_fit_res.get_xticklabels(), visible=False) 

# --- 7. ZIMMERMAN BENCHMARK PLOT GENERATION (ax_zimm_bench) (UPDATED WITH SHADED REGION) ---

# PE Comparison
ax_zimm_bench.plot(x_extrapolate, y_zimmerman_pe_diff, '-', color=color_list[1], lw=2,
                   label=r"% Diff: $\frac{|E_{\text{PE Sim}}| - |E_{\text{PE Lit}}|}{|E_{\text{PE Lit}}|}$")

# # PE Shaded region (using same color as line, color_list[1])
# ax_zimm_bench.fill_between(x_extrapolate, y_zimmerman_pe_diff - y_zimmerman_pe_diff_errors, 
#                            y_zimmerman_pe_diff + y_zimmerman_pe_diff_errors, 
#                            color=color_list[1], alpha=0.2, label=None)

# SW Comparison
if x_target_sw.size > 0:
    ax_zimm_bench.plot(x_target_sw, y_zimmerman_sw_diff, '-', color=color_list[0], lw=2,
                       label=r"% Diff: $\frac{|E_{\text{SW Sim}}| - |E_{\text{SW Lit}}|}{|E_{\text{SW Lit}}|}$")
    
    # # SW Shaded region (using same color as line, color_list[0])
    # ax_zimm_bench.fill_between(x_target_sw, y_zimmerman_sw_diff - y_zimmerman_sw_diff_errors, 
    #                            y_zimmerman_sw_diff + y_zimmerman_sw_diff_errors, 
    #                            color=color_list[0], alpha=0.2, label="Error Region") # Labeled the SW error region

ax_zimm_bench.axhline(0, color='k', linestyle='-', lw=0.5, alpha=0.8) # Zero line
ax_zimm_bench.set_xlabel("Time [s]")
ax_zimm_bench.set_ylabel(r"% Diff",loc="top")
# ax_zimm_bench.grid(True, linestyle=':', alpha=0.6) 
# ax_zimm_bench.legend(loc='upper right', fontsize=8)
ax_zimm_bench.set_xlim(0,8) 
ax_zimm_bench.set_ylim(-40, 20) # A wider limit for benchmark comparison
ax_zimm_bench.set_yticks([-40,-20, 0, 20])


# --- 8. SAVE AND SHOW ---
plt.savefig("figures/zimmerman_benchmark_summary_with_lit_comparison.jpeg", bbox_inches="tight", dpi=300)
plt.show()