In [None]:
import uproot
import glob
import numpy as np
import pandas as pd
import time
import os
import re

import pyvista as pv
pv.set_jupyter_backend('trame')  # or 'panel' if using panel

from scipy.constants import epsilon_0, e as q_e
from scipy.interpolate import griddata
from scipy.optimize import curve_fit
from scipy.spatial import cKDTree
from scipy.interpolate import interp1d

# Import interpolate for numerical method
from scipy.interpolate import CubicSpline
import matplotlib.gridspec as gridspec

from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib as mpl

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed

import trimesh
import h5py
from trimesh.points import PointCloud

from common_functions import *


## Quantitative Comparison with Zimmerman et al. 2016

In [None]:
# analyze one preprocessed h5 file

# directory = "/storage/scratch1/5/avira7/Grain-Charging-Simulation-Data/stacked-sphere/output111025/processed-fieldmaps"
# processedResults = load_h5_to_dict(f"{directory}/PE_425K_initial8max0.8final12_noDissipation_sphere50um-throughXX.h5")

In [None]:
# --- configuration ---
directory = "/storage/scratch1/5/avira7/Grain-Charging-Simulation-Data/stacked-sphere/output111025/processed-fieldmaps"
h5_filenames = glob.glob(f"{directory}/PE_425K*Refined*pos-0.1*.h5")

# --- helper function for one key ---
def process_key(args):

    keyIN, val, target_point = args

    points = val["pos"]
    vectors = val["E"]
    magnitudes = val["E_mag"]

    # radius = 6e-3  # spherical averaging radius
    radius = 5e-3  # spherical averaging radius
    
    mask = np.sum((points - target_point)**2, axis=1) <= radius**2
    if not np.any(mask):
        # Return placeholders if no points in sphere
        return -1, np.full(3, np.nan), np.nan, np.full(3, np.nan), 0

    avg_position = points[mask].mean(axis=0)
    E_vec = vectors[mask].mean(axis=0)
    E_mag = magnitudes[mask].mean(axis=0)
    E_vec_errors = vectors[mask].std(axis=0) /np.sqrt(len(magnitudes[mask]))
    #point_err = np.abs(avg_position - target_point)
    return int(keyIN.split("_")[1]), E_vec, E_mag, E_vec_errors,len(magnitudes[mask]) 

    # tree = cKDTree(points)
    # dist, idx = tree.query(target_point)

    # # Nearest neighbor field
    # E_vec_at_point = vectors[idx]
    # E_mag_at_point = magnitudes[idx]

    # return int(keyIN.split("_")[1]), E_vec_at_point, E_mag_at_point, dist 

# --- extract metadata from filename ---
def parse_filename_metadata(filename):
    """
    Only processes files with:
      - Temperature: 425 K
      - pos_value: -0.1
    Example:
    PE_425K_initial8max0.8final12_RefinedGridDissipation_500000particles_Sphere20um_pos-0.1_through26.h5
    
    """

    base = os.path.basename(filename)

    case = base.split("_")[0]

    # Temperature
    temp_match = re.search(r'_(\d+)K_', base)
    temperature = int(temp_match.group(1)) if temp_match else np.nan

    # Position
    pos_match = re.search(r'_pos([-+]?\d*\.?\d+)_through', base)
    pos_value = float(pos_match.group(1)) if pos_match else np.nan

    # Sphere size (if any)
    sphere_match = re.search(r'_sphere(\d+)um', base, re.IGNORECASE)
    sphere_um = int(sphere_match.group(1)) if sphere_match else np.nan

    # Octree parameters: initial, grad threshold, final
    octree_match = re.search(r'_initial(\d+)max([-+]?\d*\.?\d+)final(\d+)', base)
    if octree_match:
        octree_params = {
            "initial_depth": int(octree_match.group(1)),
            "percent_gradThreshold": float(octree_match.group(2)),
            "final_depth": int(octree_match.group(3))
        }
    else:
        octree_params = {"initial_depth": np.nan, "grad_threshold": np.nan, "final_depth": np.nan}

    # Target point
    #target_point = np.array([pos_value - 0.0073, 0, 0.1 - 0.015 + 0.037 - 0.00073])
    target_point = np.array([pos_value - 0.005, 0, 0.1 - 0.015 + 0.037])
    #target_point = np.array([pos_value, 0, 0.1 - 0.015 + 0.037])

    metadata = {
        "filename": base,
        "case": case,
        "temperature": temperature,
        "target_point": target_point,
        "sphere_um": sphere_um,
        "octree": octree_params,
       # "num_particles": num_particles
    }

    return metadata


# --- worker for one file ---
def process_file(fileIN):
    metadata = parse_filename_metadata(fileIN)
    
    # # Filter: skip files with 'through' files that have processed < 40 iterations
    # through_idx = int(re.search(r'_through(\d+)\.h5', fileIN).group(1))
    # if through_idx < 40:
    #     return None

    if metadata is None:
        return None

    print(f"→ Started {os.path.basename(fileIN)}\n", flush=True)
    processedResults = load_h5_to_dict(fileIN)
    key_prefix = os.path.basename(fileIN).split('_through')[0]

    keys = list(processedResults.keys())
    n_keys = len(keys)

    # --- inner parallelization across keys ---
    args_list = [(k, processedResults[k], metadata["target_point"]) for k in keys]

    with ThreadPoolExecutor(max_workers=8) as tpool:  # inner parallel threads
        results = list(tpool.map(process_key, args_list))

    ids, E_vecs, E_mags, standard_errors, N = zip(*results)
    ids = np.array(ids)
    E_vecs = np.array(E_vecs)
    E_mags = np.array(E_mags)
    standard_errors = np.array(standard_errors)
    num_points = np.array(N)

    print(f"✓ Finished {os.path.basename(fileIN)}", flush=True)
    return key_prefix, {
        "iter": ids,"E_vecs": E_vecs,"E_mags": E_mags,"point_errors": standard_errors, "N":num_points, "metadata": metadata
    }

# --- parallel execution across files ---
num_cores = 2
all_processed = {}

with ProcessPoolExecutor(max_workers=num_cores) as executor:
    futures = {executor.submit(process_file, f): f for f in h5_filenames}
    for fut in as_completed(futures):
        fileIN = futures[fut]
        try:
            result = fut.result()
            if result is not None:
                key_prefix, data = result
                all_processed[key_prefix] = data
                print(f"✔ Processed {os.path.basename(fileIN)}\n", flush=True)
        except Exception as e:
            print(f"❌ Error in {os.path.basename(fileIN)}: {e}\n", flush=True)


In [None]:
# --- configuration ---
#directory = "/storage/scratch1/5/avira7/Grain-Charging-Simulation-Data/stacked-sphere/output111025/processed-fieldmaps"
directory = "processed-fieldmaps/"
h5_filenames = glob.glob(f"{directory}/*425K*Refined*.h5")

# Define field averaging parameters
FIELD_AVERAGE_RADIUS = 2e-3
offsetLimits = 0.01

offsetLimitsX = 0.006
offsetLimitsY = 0.006
new_step_x = offsetLimitsX/1
new_step_y = offsetLimitsY/1

# Create grid of target points
xoffset = np.round(np.arange(-offsetLimitsX, offsetLimitsX+1e-9, new_step_x),5)
yoffset = np.round(np.arange(-offsetLimitsY, offsetLimitsY+1e-9, new_step_y),5)
X, Y = np.meshgrid(xoffset, yoffset)
target_points_array = np.vstack([
    -0.1 - X.flatten(), 
    np.zeros(len(X.flatten())), 
    0.1 - 0.015 + 0.037 - Y.flatten()
]).T

print(f"Processing {len(target_points_array)} target points with radius {FIELD_AVERAGE_RADIUS*1000} um")

# --- helper function for one key and one target point ---
def process_key_target(args):
    keyIN, val, target_point, radius = args

    points = val["pos"]
    vectors = val["E"]
    magnitudes = val["E_mag"]
    
    # Mask for spherical averaging around target point
    mask = np.sum((points - target_point)**2, axis=1) <= radius**2
    
    if not np.any(mask):
        # Return placeholders if no points in sphere
        return -1, np.full(3, np.nan), np.nan, np.full(3, np.nan), 0

    avg_position = points[mask].mean(axis=0)
    E_vec = vectors[mask].mean(axis=0)
    E_mag = magnitudes[mask].mean(axis=0)
    E_vec_errors = vectors[mask].std(axis=0) / np.sqrt(len(magnitudes[mask]))
    
    return int(keyIN.split("_")[1]), E_vec, E_mag, avg_position, len(magnitudes[mask])

# --- extract metadata from filename ---
def parse_filename_metadata(filename):
    """
    Processes files and extracts metadata.
    Example:
    PE_425K_initial8max0.8final12_RefinedGridDissipation_500000particles_Sphere20um_pos-0.1_through26.h5
    """

    base = os.path.basename(filename)
    case = base.split("_")[0]

    # Temperature
    temp_match = re.search(r'_(\d+)K_', base)
    temperature = int(temp_match.group(1)) if temp_match else np.nan

    # Position
    pos_match = re.search(r'_pos([-+]?\d*\.?\d+)_through', base)
    pos_value = float(pos_match.group(1)) if pos_match else np.nan

    # Sphere size (if any)
    sphere_match = re.search(r'_sphere(\d+)um', base, re.IGNORECASE)
    sphere_um = int(sphere_match.group(1)) if sphere_match else np.nan

    # Octree parameters: initial, grad threshold, final
    octree_match = re.search(r'_initial(\d+)max([-+]?\d*\.?\d+)final(\d+)', base)
    if octree_match:
        octree_params = {
            "initial_depth": int(octree_match.group(1)),
            "percent_gradThreshold": float(octree_match.group(2)),
            "final_depth": int(octree_match.group(3))
        }
    else:
        octree_params = {"initial_depth": np.nan, "grad_threshold": np.nan, "final_depth": np.nan}

    metadata = {
        "filename": base,
        "case": case,
        "temperature": temperature,
        "sphere_um": sphere_um,
        "octree": octree_params,
    }

    return metadata


# --- worker for one file ---
def process_file(fileIN):
    metadata = parse_filename_metadata(fileIN)

    if metadata is None:
        return None

    print(f"→ Started {os.path.basename(fileIN)}\n", flush=True)
    processedResults = load_h5_to_dict(fileIN)
    key_prefix = os.path.basename(fileIN).split('_through')[0]

    keys = list(processedResults.keys())

    # Dictionary to hold results for each target point
    target_point_results = {}

    # Process each target point
    for tp_idx, target_point in enumerate(target_points_array):
        tp_key = f"target_{tp_idx:04d}"
        
        # --- inner parallelization across keys for this target point ---
        args_list = [(k, processedResults[k], target_point, FIELD_AVERAGE_RADIUS) for k in keys]

        with ThreadPoolExecutor(max_workers=8) as tpool:
            results = list(tpool.map(process_key_target, args_list))

        ids, E_vecs, E_mags, standard_errors, N = zip(*results)
        ids = np.array(ids)
        E_vecs = np.array(E_vecs)
        E_mags = np.array(E_mags)
        standard_errors = np.array(standard_errors)
        num_points = np.array(N)

        target_point_results[tp_key] = {
            "iter": ids,
            "E_vecs": E_vecs,
            "E_mags": E_mags,
            "point_errors": standard_errors,
            "N": num_points,
            "target_point": target_point,
            "radius": FIELD_AVERAGE_RADIUS
        }
        
        if (tp_idx + 1) % 10 == 0:
            print(f"  Processed {tp_idx + 1}/{len(target_points_array)} target points", flush=True)

    print(f"✓ Finished {os.path.basename(fileIN)}", flush=True)
    return key_prefix, {
        "target_points": target_point_results,
        "metadata": metadata
    }

# --- parallel execution across files ---
num_cores = 2
all_processed = {}

with ProcessPoolExecutor(max_workers=num_cores) as executor:
    futures = {executor.submit(process_file, f): f for f in h5_filenames}
    for fut in as_completed(futures):
        fileIN = futures[fut]
        try:
            result = fut.result()
            if result is not None:
                key_prefix, data = result
                all_processed[key_prefix] = data
                print(f"✔ Processed {os.path.basename(fileIN)}\n", flush=True)
        except Exception as e:
            print(f"❌ Error in {os.path.basename(fileIN)}: {e}\n", flush=True)

print(f"\n=== Processing Complete ===")
print(f"Total files processed: {len(all_processed)}")
print(f"Target points per file: {len(target_points_array)}")

# --- Access results example ---
# all_processed[key_prefix]["target_points"]["target_0000"]["E_vecs"]
# all_processed[key_prefix]["target_points"]["target_0000"]["target_point"]
# all_processed[key_prefix]["metadata"]

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 14})

# --- 1. CONFIGURATION AND DATA LOADING ---
print("\n--- Starting Data Processing and Plot Generation ---")

# Define Data Keys
#KEY_TARGET_PE = 'PE_425K_initial8max0.8final12_RefinedGridDissipation_sphere40um_pos-0.1'
KEY_TARGET_PE = 'PE_425K_initial8max0.8final12_RefinedGridDissipation_sphere40um_adjustedworld'

# Load external literature data
zimmerman_SWdata = pd.read_csv("literature-data/Fig7a-SW.csv")
zimmerman_PEdata = pd.read_csv("literature-data/Fig7a-PE.csv")

# Simulation Parameters (used for time conversion)
WORLD_XY_AREA_SQ_M = 300 * 300 / (1e6**2)  # World area (m^2)

# PE Conversion Factor
PARTICLES_PER_ITERATION_PE = 81775
FLUX_PER_ITERATION_PE = PARTICLES_PER_ITERATION_PE / WORLD_XY_AREA_SQ_M
PE_ION_FLUX = 4e-6 * 6.241509e18
CONVERT_ITERATION_PE_TIME = FLUX_PER_ITERATION_PE / PE_ION_FLUX
print(f"PE Conversion Factor (s/iteration): {CONVERT_ITERATION_PE_TIME:.3e}")

# SW Conversion Factor
PARTICLES_PER_ITERATION_SW = 30601
FLUX_PER_ITERATION_SW = PARTICLES_PER_ITERATION_SW / WORLD_XY_AREA_SQ_M
SW_ION_FLUX = 3e-7 * 6.241509e18
CONVERT_ITERATION_SW_TIME = FLUX_PER_ITERATION_SW / SW_ION_FLUX
print(f"SW Conversion Factor (s/iteration): {CONVERT_ITERATION_SW_TIME:.3e}")

# --- 2. PREPARE LITERATURE DATA ---
x_lit_sw = 10**zimmerman_SWdata["x"]
y_lit_sw = zimmerman_SWdata[" y"]
x_lit_pe = 10**zimmerman_PEdata["x"]
y_lit_pe = zimmerman_PEdata[" y"]

# --- 3. PLOTTING SETUP ---
fig, ax_main = plt.subplots(figsize=(8.01, 4.6))

# --- 4. PLOT LITERATURE DATA ---
ax_main.plot(x_lit_sw, y_lit_sw, '-', color="k", lw=3, alpha=0.3, label=None)
ax_main.plot(x_lit_pe, y_lit_pe, '-', color="k", lw=3, alpha=0.3, label=None)

# --- 5. PLOT SIMULATION DATA FOR EACH TARGET POINT ---
selectkey = KEY_TARGET_PE

# Color Map Setup
CMAP_NAME = 'jet'
n_targets = len(all_processed[selectkey]["target_points"].keys())
discrete_cmap = plt.get_cmap(CMAP_NAME, n_targets + 1)
color_list_rgba = [discrete_cmap(i) for i in np.linspace(0, 1, n_targets + 1)]

# Get case type for time conversion
case = all_processed[selectkey]["metadata"]["case"]
factor = CONVERT_ITERATION_PE_TIME if case == "PE" else CONVERT_ITERATION_SW_TIME

# Plot each target point
for targetIN, colorIN in zip(all_processed[selectkey]["target_points"].keys(), color_list_rgba):
    
    # Extract data
    x_data = np.array(all_processed[selectkey]["target_points"][targetIN]["iter"] - 1) * factor
    y_data = abs(all_processed[selectkey]["target_points"][targetIN]["E_vecs"][:, 0])
    y_err = all_processed[selectkey]["target_points"][targetIN]["point_errors"][:, 0]
    target_point = all_processed[selectkey]["target_points"][targetIN]["target_point"]
    
    # Plot the data line
    ax_main.plot(x_data, y_data, '-', color=colorIN, lw=1.5, 
                 label=f'x:{target_point[0]:.3f} z:{target_point[2]:.3f}')
    
    # Optional: Add error region
    # ax_main.fill_between(x_data, y_data - y_err, y_data + y_err, 
    #                      color=colorIN, alpha=0.15, label=None)

# --- 6. FORMAT MAIN PLOT ---
ax_main.set_xlabel("Time [s]")
ax_main.set_ylabel(r"$|E_x|$ (V/m)")
ax_main.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
# ax_main.set_ylim(0, 2.2e5)
ax_main.set_xlim(0, 1)
ax_main.legend(loc="upper right", fontsize=8, ncol=2)
ax_main.grid(True, linestyle=':', alpha=0.5)

# --- 7. SAVE AND SHOW ---
plt.tight_layout()
plt.savefig("figures/field_evolution_target_points.jpeg", bbox_inches="tight", dpi=300)
plt.show()

print("\n--- Plot Generation Complete ---")

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 14})

# --- 1. CONFIGURATION AND DATA LOADING ---
print("\n--- Starting Data Processing and Plot Generation ---")

# Define Data Keys for Fitting and Subtraction
KEY_FIT = 'PE_425K_initial8max0.8final12_noDissipation_sphere50um_pos-0.1'
KEY_TARGET_PE = 'PE_425K_initial8max0.8final12_RefinedGridDissipation_sphere40um_pos-0.1' # Renamed for clarity
KEY_TARGET_SW = 'SW_425K_initial8max0.8final12_RefinedGridDissipation_sphere20um_pos-0.1' # Key for SW comparison

# Load external literature data
zimmerman_SWdata = pd.read_csv("literature-data/Fig7a-SW.csv")
zimmerman_PEdata = pd.read_csv("literature-data/Fig7a-PE.csv")
zimmerman_PEandSWdata = pd.read_csv("literature-data/Fig7a-PE+SW.csv")

# Simulation Parameters (used for time conversion)
WORLD_XY_AREA_SQ_M = 300 * 300 / (1e6**2) # World area (m^2)

# PE Conversion Factor
PARTICLES_PER_ITERATION_PE = 81775
FLUX_PER_ITERATION_PE = PARTICLES_PER_ITERATION_PE / WORLD_XY_AREA_SQ_M
PE_ION_FLUX = 4e-6 * 6.241509e18
CONVERT_ITERATION_PE_TIME = FLUX_PER_ITERATION_PE / PE_ION_FLUX
print(f"PE Conversion Factor (s/iteration): {CONVERT_ITERATION_PE_TIME:.3e}")

# SW Conversion Factor
PARTICLES_PER_ITERATION_SW = 30601
FLUX_PER_ITERATION_SW = PARTICLES_PER_ITERATION_SW / WORLD_XY_AREA_SQ_M
SW_ION_FLUX = 3e-7 * 6.241509e18
CONVERT_ITERATION_SW_TIME = FLUX_PER_ITERATION_SW / SW_ION_FLUX
print(f"SW Conversion Factor (s/iteration): {CONVERT_ITERATION_SW_TIME:.3e}")

# Color Map Setup
CMAP_NAME = 'Dark2'
discrete_cmap = plt.get_cmap(CMAP_NAME, len(all_processed.keys()) + 1)
color_list_rgba = [discrete_cmap(i) for i in np.linspace(0, 1, len(all_processed.keys()) + 1)]

# --- 2. CURVE FITTING AND EXTRAPOLATION (UNCHANGED) ---

# Define the new fitting function (Polynomial of Order 3)
def poly_curve(t, a, b, c, d):
    """
    Function: a*t^3 + b*t^2 + c*t + d (Polynomial of Order 3)
    """
    return a*t**3 + b*t**2 + c*t + d

# Prepare data for fitting (PE_425K)
data_fit = all_processed[KEY_FIT]
x_fit = np.array(data_fit["iter"] - 1) * CONVERT_ITERATION_PE_TIME
y_fit = abs(data_fit["E_vecs"][:, 0])
y_fit_errors = data_fit["point_errors"][:, 0]

# Prepare Target Data (PE_425K Refined)
data_target_pe = all_processed[KEY_TARGET_PE]
x_extrapolate = np.array(data_target_pe["iter"] - 1) * CONVERT_ITERATION_PE_TIME
y_target_pe = abs(data_target_pe["E_vecs"][:, 0])
y_target_pe_errors = data_target_pe["point_errors"][:, 0]

# --- Fitting ---
method_label = f"Fit of {KEY_FIT.split('_')[1]} (Poly Order 3)"
try:
    popt, pcov = curve_fit(poly_curve, x_fit, y_fit, 
                           p0=[0, 0, 0, 1e4], sigma=y_fit_errors, absolute_sigma=True)
    
    A_fit, B_fit, C_fit, D_fit = popt
    print(f"\nFit Parameters for {KEY_FIT.split('_')[1]} (Poly): a={A_fit:.2e}, b={B_fit:.2e}, c={C_fit:.2e}, d={D_fit:.2e}")

    # Extrapolate and Calculate Subtraction
    y_extrapolated = poly_curve(x_extrapolate, *popt)
    y_subtraction = y_extrapolated - y_target_pe

    # Error estimation for the extrapolation
    perr = np.sqrt(np.diag(pcov))
    J = np.array([3*x_extrapolate**2, 2*x_extrapolate, np.ones_like(x_extrapolate), np.zeros_like(x_extrapolate)]).T
    y_extrapolated_errors = np.sqrt(np.diag(J @ pcov @ J.T))


except RuntimeError:
    print("\n⚠️ Warning: Curve fitting failed. Falling back to zeros.")
    y_extrapolated = np.zeros_like(x_fit)
    y_subtraction = np.zeros_like(x_fit)
    x_extrapolate = x_fit
    method_label = f"Fit of {KEY_FIT.split('_')[1]} (Failed, showing Zeros)"
    y_extrapolated_errors = np.zeros_like(x_extrapolate)

# --- Calculate Percent Difference for Fit vs. Target ---
y_denominator = np.where(y_extrapolated == 0, 1e-10, y_extrapolated)
y_percent_diff = (y_subtraction / y_denominator) * 100
y_subtraction_errors = np.sqrt(y_extrapolated_errors**2 + y_target_pe_errors**2)
y_percent_diff_errors = (y_subtraction_errors / np.abs(y_denominator)) * 100 

# --------------------------------------------------------------------------
# --- 3. ZIMMERMAN COMPARISON CALCULATIONS (UPDATED FOR ERROR) ---
# --------------------------------------------------------------------------

# --- PE Simulation vs. Zimmerman PE Literature ---
x_lit_pe = 10**zimmerman_PEdata["x"]
y_lit_pe = zimmerman_PEdata[" y"]

# Interpolate Zimmerman PE data to the simulation time points (x_extrapolate)
f_interp_pe = interp1d(x_lit_pe, y_lit_pe, kind='linear', fill_value="extrapolate")
y_lit_pe_interp = f_interp_pe(x_extrapolate)

# Calculate % Difference: (Sim - Lit) / Lit
y_zimmerman_pe_diff = ((y_target_pe - y_lit_pe_interp) / y_lit_pe_interp) * 100
# Calculate % Difference Error: (Sim_err / Lit_interp) * 100 (Assuming Lit error is negligible)
y_zimmerman_pe_diff_errors = (y_target_pe_errors / np.abs(y_lit_pe_interp)) * 100


# --- SW Simulation vs. Zimmerman SW Literature ---
# 1. Get SW simulation data
KEY_SIM_SW = KEY_TARGET_SW
data_target_sw = all_processed.get(KEY_SIM_SW, None)

if data_target_sw is not None:
    x_target_sw = np.array(data_target_sw["iter"] - 1) * CONVERT_ITERATION_SW_TIME
    y_target_sw = abs(data_target_sw["E_vecs"][:, 0])
    y_target_sw_errors = data_target_sw["point_errors"][:, 0] # EXTRACTED SW ERRORS

    # 2. Get SW literature data
    x_lit_sw = 10**zimmerman_SWdata["x"]
    y_lit_sw = zimmerman_SWdata[" y"]

    # Interpolate Zimmerman SW data to the SW simulation time points (x_target_sw)
    f_interp_sw = interp1d(x_lit_sw, y_lit_sw, kind='linear', fill_value="extrapolate")
    y_lit_sw_interp = f_interp_sw(x_target_sw)

    # Calculate % Difference: (Sim - Lit) / Lit
    y_zimmerman_sw_diff = ((y_target_sw - y_lit_sw_interp) / y_lit_sw_interp) * 100
    # Calculate % Difference Error: (Sim_err / Lit_interp) * 100 (Assuming Lit error is negligible)
    y_zimmerman_sw_diff_errors = (y_target_sw_errors / np.abs(y_lit_sw_interp)) * 100 
else:
    print(f"⚠️ Warning: SW target key '{KEY_SIM_SW}' not found in all_processed. Skipping SW comparison.")
    x_target_sw = np.array([])
    y_zimmerman_sw_diff = np.array([])
    y_zimmerman_sw_diff_errors = np.array([]) # Defined for consistent error calculation

# --------------------------------------------------------------------------

# --- 4. PLOTTING SETUP (MAIN + 2 SUBPLOTS) ---

# Set up figure and grid layout (5:1:1 height ratio for main plot vs. residual vs. benchmark plot)
fig = plt.figure(figsize=(8.01, 4.6))
gs = gridspec.GridSpec(3, 1, hspace=0.15, height_ratios=[4, 0.8, 0.8]) 

# Main Plot (Top)
ax_main = fig.add_subplot(gs[0])
# Fit Residual Plot (Middle), sharing the x-axis
ax_fit_res = fig.add_subplot(gs[1], sharex=ax_main)
# Zimmerman Benchmark Plot (Bottom), sharing the x-axis
ax_zimm_bench = fig.add_subplot(gs[2], sharex=ax_main)

# color_list is defined here for use in section 7
color_list = [color_list_rgba[2],color_list_rgba[3],color_list_rgba[6]]

# --- 5. MAIN PLOT GENERATION (ax_main) ---

# Plot reference data (Zimmerman)
ax_main.plot(x_lit_sw, y_lit_sw, '-', color="k", lw=3, alpha=0.3, label="Zimmerman SW/PE/PE+SW Ref.")
ax_main.plot(x_lit_pe, y_lit_pe, '-', color="k", lw=3, alpha=0.3)
#ax_main.plot(10**zimmerman_PEandSWdata["x"], zimmerman_PEandSWdata[" y"], '--', color="k", lw=3, alpha=0.3)

# Plot simulation data
color_list = [color_list_rgba[2],color_list_rgba[3],color_list_rgba[6]]
i=0

# # Plot the fitted/extrapolated curve
ax_main.plot(x_extrapolate, y_extrapolated, '-', color=color_list[-1], lw=2, 
             label=f"{method_label} (Extrapolated)")

for keyIN, colorIN in zip(all_processed.keys(), color_list_rgba):
    
    # Define plotting variables outside of loop to use them later
    case = all_processed[keyIN]["metadata"]["case"]
    factor = CONVERT_ITERATION_PE_TIME if case == "PE" else CONVERT_ITERATION_SW_TIME
    tempIN = all_processed[keyIN]["metadata"]["temperature"]
    targetIN = all_processed[keyIN]["metadata"]["target_point"]
    
    x_data = np.array(all_processed[keyIN]["iter"] - 1) * factor
    y_data = abs(all_processed[keyIN]["E_vecs"][:, 0])
    y_err = all_processed[keyIN]["point_errors"][:, 0]
    
    tempIN = all_processed[keyIN]["metadata"]["temperature"]
    targetIN = all_processed[keyIN]["metadata"]["target_point"]
    # noteIN = keyIN.split("_")[2] # Not used in label for brevity

    # Filter plotting to only the relevant cases (e.g., specific position and T=425)
    if (targetIN[0] < 0) & (tempIN == 425) & ("Total" not in keyIN.split("_")[3]):
        
        print(keyIN)

        plot_color = color_list[i]

        if keyIN == KEY_FIT:
            # # Plot the data line
            # ax_main.plot(x_data, y_data, '--', color=plot_color, lw=1)
            continue
        else:
            # Plot the data line
            ax_main.plot(x_data, y_data, '-', color=plot_color, lw=1.5)
            
            # Use fill_between for the error region (Replaces errorbars)
            ax_main.fill_between(x_data, y_data - y_err, y_data + y_err, 
                                color=plot_color, alpha=0.15, 
                                label=None) # Set label=None to avoid extra legend entry
        i+=1
# Clean up main plot
ax_main.set_ylabel(r"$|E_x|$ (V/m)")
ax_main.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
ax_main.set_ylim(0,2.2e5)
# ax_main.axvline(x=0.32754498952096356)
# ax_main.axvline(x=3.2390560074850843)
# ax_main.grid(True, linestyle=':', alpha=0.5) 
# ax_main.legend(loc='lower left', fontsize=8, ncol=2)
# Remove X-tick labels from the main plot
plt.setp(ax_main.get_xticklabels(), visible=False) 


# --- 6. FIT RESIDUAL PLOT GENERATION (ax_fit_res) ---

# PLOT PERCENT DIFFERENCE WITH SHADED ERROR REGION
ax_fit_res.plot(x_extrapolate, y_percent_diff, '-', color=color_list[-2], lw=2,
            label=r"% Diff: $\frac{|E_{Fit}| - |E_{Target}|}{|E_{Fit}|}$")
# Shaded region
# ax_fit_res.fill_between(x_extrapolate, y_percent_diff - y_percent_diff_errors, 
#                     y_percent_diff + y_percent_diff_errors, 
#                     color=color_list[2], alpha=0.2, label="Fit Error Region")
# ax_fit_res.set_ylabel(r"Fit % Diff")
# ax_fit_res.grid(True, linestyle=':', alpha=0.6)
# ax_fit_res.legend(loc='upper right', fontsize=8)
#ax_fit_res.set_ylim(0,10) 
#ax_fit_res.set_yticks([0,4,8])
# Remove X-tick labels from the fit residual plot
plt.setp(ax_fit_res.get_xticklabels(), visible=False) 

# --- 7. ZIMMERMAN BENCHMARK PLOT GENERATION (ax_zimm_bench) (UPDATED WITH SHADED REGION) ---

# PE Comparison
ax_zimm_bench.plot(x_extrapolate, y_zimmerman_pe_diff, '-', color=color_list[1], lw=2,
                   label=r"% Diff: $\frac{|E_{\text{PE Sim}}| - |E_{\text{PE Lit}}|}{|E_{\text{PE Lit}}|}$")

# # PE Shaded region (using same color as line, color_list[1])
# ax_zimm_bench.fill_between(x_extrapolate, y_zimmerman_pe_diff - y_zimmerman_pe_diff_errors, 
#                            y_zimmerman_pe_diff + y_zimmerman_pe_diff_errors, 
#                            color=color_list[1], alpha=0.2, label=None)

# SW Comparison
if x_target_sw.size > 0:
    ax_zimm_bench.plot(x_target_sw, y_zimmerman_sw_diff, '-', color=color_list[0], lw=2,
                       label=r"% Diff: $\frac{|E_{\text{SW Sim}}| - |E_{\text{SW Lit}}|}{|E_{\text{SW Lit}}|}$")
    
    # # SW Shaded region (using same color as line, color_list[0])
    # ax_zimm_bench.fill_between(x_target_sw, y_zimmerman_sw_diff - y_zimmerman_sw_diff_errors, 
    #                            y_zimmerman_sw_diff + y_zimmerman_sw_diff_errors, 
    #                            color=color_list[0], alpha=0.2, label="Error Region") # Labeled the SW error region

ax_zimm_bench.axhline(0, color='k', linestyle='-', lw=0.5, alpha=0.8) # Zero line
ax_zimm_bench.set_xlabel("Time [s]")
ax_zimm_bench.set_ylabel(r"% Diff",loc="top")
# ax_zimm_bench.grid(True, linestyle=':', alpha=0.6) 
# ax_zimm_bench.legend(loc='upper right', fontsize=8)
ax_zimm_bench.set_xlim(0,8) 
ax_zimm_bench.set_ylim(-40, 20) # A wider limit for benchmark comparison
ax_zimm_bench.set_yticks([-40,-20, 0, 20])


# --- 8. SAVE AND SHOW ---
plt.savefig("figures/zimmerman_benchmark_summary_with_lit_comparison.jpeg", bbox_inches="tight", dpi=300)
plt.show()

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 12})

# --- 1. CONFIGURATION AND DATA LOADING ---
print("\n--- Starting Data Processing and Plot Generation ---")

# Define Data Keys for Fitting and Subtraction
KEY_FIT = 'PE_425K_initial8max0.8final12_noDissipation_sphere50um_pos-0.1'
KEY_TARGET = 'PE_425K_initial8max0.8final12_RefinedGridDissipation_sphere20um_pos-0.1'

# Load external literature data
zimmerman_SWdata = pd.read_csv("literature-data/Fig7a-SW.csv")
zimmerman_PEdata = pd.read_csv("literature-data/Fig7a-PE.csv")
zimmerman_PEandSWdata = pd.read_csv("literature-data/Fig7a-PE+SW.csv")

# Simulation Parameters (used for time conversion)
WORLD_XY_AREA_SQ_M = 300 * 300 / (1e6**2) # World area (m^2)

# PE Conversion Factor
PARTICLES_PER_ITERATION_PE = 81775
FLUX_PER_ITERATION_PE = PARTICLES_PER_ITERATION_PE / WORLD_XY_AREA_SQ_M
PE_ION_FLUX = 4e-6 * 6.241509e18
CONVERT_ITERATION_PE_TIME = FLUX_PER_ITERATION_PE / PE_ION_FLUX
print(f"PE Conversion Factor (s/iteration): {CONVERT_ITERATION_PE_TIME:.3e}")

# SW Conversion Factor
PARTICLES_PER_ITERATION_SW = 30601
FLUX_PER_ITERATION_SW = PARTICLES_PER_ITERATION_SW / WORLD_XY_AREA_SQ_M
SW_ION_FLUX = 3e-7 * 6.241509e18
CONVERT_ITERATION_SW_TIME = FLUX_PER_ITERATION_SW / SW_ION_FLUX
print(f"SW Conversion Factor (s/iteration): {CONVERT_ITERATION_SW_TIME:.3e}")

# Color Map Setup
CMAP_NAME = 'Dark2'
discrete_cmap = plt.get_cmap(CMAP_NAME, len(all_processed.keys()) + 1)
color_list_rgba = [discrete_cmap(i) for i in np.linspace(0, 1, len(all_processed.keys()) + 1)]

# --- 2. CURVE FITTING AND EXTRAPOLATION ---

# Define the new fitting function (Polynomial of Order 3)
def poly_curve(t, a, b, c, d):
    """
    Function: a*t^3 + b*t^2 + c*t + d (Polynomial of Order 3)
    """
    return a*t**3 + b*t**2 + c*t + d


# Prepare data for fitting (PE_425K)
data_fit = all_processed[KEY_FIT]
x_fit = np.array(data_fit["iter"] - 1) * CONVERT_ITERATION_PE_TIME
y_fit = abs(data_fit["E_vecs"][:, 0])
y_fit_errors = data_fit["point_errors"][:, 0] # Get errors for KEY_FIT (50um)

# Prepare Target Data (PE_600K range)
data_target = all_processed[KEY_TARGET]
x_extrapolate = np.array(data_target["iter"] - 1) * CONVERT_ITERATION_PE_TIME
y_target = abs(data_target["E_vecs"][:, 0])
y_target_errors = data_target["point_errors"][:, 0] # Get errors for KEY_TARGET (20um)

# --- Choose Fitting/Extrapolation Method ---

# Method 1: Basic Polynomial Fit (Order 3) - CURRENT DEFAULT
method_label = f"Fit of {KEY_FIT.split('_')[1]} (Poly Order 3)"
try:
    # Using y_fit_errors as sigma for weighted fit
    popt, pcov = curve_fit(poly_curve, x_fit, y_fit, 
                           p0=[0, 0, 0, 1e4], sigma=y_fit_errors, absolute_sigma=True) # Added sigma
    
    A_fit, B_fit, C_fit, D_fit = popt
    print(f"\nFit Parameters for {KEY_FIT.split('_')[1]} (Poly): a={A_fit:.2e}, b={B_fit:.2e}, c={C_fit:.2e}, d={D_fit:.2e}")

    # Extrapolate and Calculate Subtraction
    y_extrapolated = poly_curve(x_extrapolate, *popt)
    y_subtraction = y_extrapolated - y_target

    # Estimate error in extrapolated fit for error propagation
    perr = np.sqrt(np.diag(pcov))
    # NOTE: J calculation assumes the x_extrapolate points are the basis for the estimated error
    J = np.array([3*x_extrapolate**2, 2*x_extrapolate, np.ones_like(x_extrapolate), np.zeros_like(x_extrapolate)]).T
    y_extrapolated_errors = np.sqrt(np.diag(J @ pcov @ J.T))


except RuntimeError:
    print("\n⚠️ Warning: Curve fitting failed. Check initial guess (p0) or fitting range.")
    y_extrapolated = np.zeros_like(x_fit)
    y_subtraction = np.zeros_like(x_fit)
    x_extrapolate = x_fit
    method_label = f"Fit of {KEY_FIT.split('_')[1]} (Failed, showing Zeros)"
    y_extrapolated_errors = np.zeros_like(x_extrapolate) # Set errors to zero if fit fails


# --- Calculate Percent Difference and its Error ---
y_denominator = np.where(y_extrapolated == 0, 1e-10, y_extrapolated)
y_percent_diff = (y_subtraction / y_denominator) * 100

# Error propagation for the difference: sqrt(error_fit^2 + error_target^2)
# Direct assignment works because y_target_errors has the same length as x_extrapolate
y_target_errors_interp = y_target_errors 

y_subtraction_errors = np.sqrt(y_extrapolated_errors**2 + y_target_errors_interp**2)

# Error propagation for the percentage
y_percent_diff_errors = (y_subtraction_errors / np.abs(y_denominator)) * 100 
# ---------------------------------------------------

# --- 3. PLOTTING SETUP (MAIN + SUBPLOT) ---

# Set up figure and grid layout (5:1 height ratio for main plot vs. residual plot)
fig = plt.figure(figsize=(8.01, 3.22))
gs = gridspec.GridSpec(2, 1, hspace=0.1, height_ratios=[5, 1])

# Main Plot (Top)
ax_main = fig.add_subplot(gs[0])
# Subtraction Plot (Bottom), sharing the x-axis
ax_sub = fig.add_subplot(gs[1], sharex=ax_main)

# --- 4. MAIN PLOT GENERATION (ax_main) ---

# Plot reference data (Zimmerman)
ax_main.plot(10**zimmerman_SWdata["x"], zimmerman_SWdata[" y"], '-', color="k", lw=3, alpha=0.3, label="Zimmerman SW/PE/PE+SW Ref.")
ax_main.plot(10**zimmerman_PEdata["x"], zimmerman_PEdata[" y"], '-', color="k", lw=3, alpha=0.3)
#ax_main.plot(10**zimmerman_PEandSWdata["x"], zimmerman_PEandSWdata[" y"], '--', color="k", lw=3, alpha=0.3)

# Plot simulation data
color_list = [color_list_rgba[2],color_list_rgba[3],color_list_rgba[3]]
i=0

# # Plot the fitted/extrapolated curve
ax_main.plot(x_extrapolate, y_extrapolated, ':', color=color_list[-1], lw=2, 
             label=f"{method_label} (Extrapolated)")

for keyIN, colorIN in zip(all_processed.keys(), color_list_rgba):
    
    # Define plotting variables outside of loop to use them later
    case = all_processed[keyIN]["metadata"]["case"]
    factor = CONVERT_ITERATION_PE_TIME if case == "PE" else CONVERT_ITERATION_SW_TIME
    tempIN = all_processed[keyIN]["metadata"]["temperature"]
    targetIN = all_processed[keyIN]["metadata"]["target_point"]
    
    x_data = np.array(all_processed[keyIN]["iter"] - 1) * factor
    y_data = abs(all_processed[keyIN]["E_vecs"][:, 0])
    y_err = all_processed[keyIN]["point_errors"][:, 0]
    
    tempIN = all_processed[keyIN]["metadata"]["temperature"]
    targetIN = all_processed[keyIN]["metadata"]["target_point"]
    # noteIN = keyIN.split("_")[2] # Not used in label for brevity

    # Filter plotting to only the relevant cases (e.g., specific position and T=425)
    if (targetIN[0] < 0) & (tempIN == 425) & ("Total" not in keyIN.split("_")[3]):
        
        print(keyIN)

        plot_color = color_list[i]

        if keyIN == KEY_FIT:
            # # Plot the data line
            # ax_main.plot(x_data, y_data, '--', color=plot_color, lw=1)
            continue
        else:
            # Plot the data line
            ax_main.plot(x_data, y_data, '-', color=plot_color, lw=1.5)
            
            # Use fill_between for the error region (Replaces errorbars)
            ax_main.fill_between(x_data, y_data - y_err, y_data + y_err, 
                                color=plot_color, alpha=0.15, 
                                label=None) # Set label=None to avoid extra legend entry
        i+=1

# Add original uncommented features back to ax_main
#ax_main.axvline(x=65 * CONVERT_ITERATION_PE_TIME, color='gray', linestyle='-.', lw=1, alpha=0.7, label="Vertical Marker")
ax_main.set_ylabel(r"$|E_x|$ (V/m)")
ax_main.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
#ax_main.set_yscale('log') # Use log scale for better visualization of power-law decay
ax_main.set_ylim(0,2.4e5) # Uncommented limits
# ax_main.set_xlim(7.4e-1, 1.25e1) # Uncommented limits

# Clean up main plot
# ax_main.grid(True, linestyle=':', alpha=0.5)
# ax_main.legend(loc='lower left', fontsize=8, ncol=2)
# Remove X-tick labels from the main plot
plt.setp(ax_main.get_xticklabels(), visible=False) 

# --- 5. SUBTRACTION PLOT GENERATION (ax_sub) ---

# PLOT PERCENT DIFFERENCE WITH SHADED ERROR REGION
ax_sub.plot(x_extrapolate, y_percent_diff, '-', color=color_list_rgba[3], lw=2,
            label=r"Relative Error: $\frac{|E_{Fit}| - |E_{Target}|}{|E_{Fit}|}$")
# Shaded region
# ax_sub.fill_between(x_extrapolate, y_percent_diff - y_percent_diff_errors, 
#                     y_percent_diff + y_percent_diff_errors, 
#                     color=color_list_rgba[2], alpha=0.2, label="Error Region")
ax_sub.set_xlabel("Time [s]")
#ax_sub.set_ylabel(r"% Diff")
# ax_sub.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) # Removed, % difference is typically not sci notation
# ax_sub.grid(True, linestyle=':', alpha=0.6)
# ax_sub.legend(loc='upper right', fontsize=8)
ax_sub.set_xlim(0,6) # Uncommented limit check (if sharing xlim works)
ax_sub.set_ylim(0,10) # Uncommented limit check (if sharing xlim works)
ax_sub.set_yticks([0,4,8])

# --- 6. SAVE AND SHOW ---
plt.savefig("figures/zimmerman_benchmark_summary.jpeg", bbox_inches="tight", dpi=300)
plt.show()

In [None]:
discrete_cmap

In [None]:
iteration = 90
print(f"PE equivalent time for iteration#{iteration}: {(iteration-1)*CONVERT_ITERATION_PE_TIME}")

## 2D Representation of the Electric Field

In [None]:
## READ IN STACKED SPHERES GEOMETRY ## 

stacked_spheres = trimesh.load_mesh('../sphere-charging/geometry/isolated_grain_interpolated.stl') 

# # Visualize with Trimesh
# stacked_spheres.show()

In [None]:
stacked_spheres.centroid

In [None]:
iteration = 8

configIN = "onlyphotoemission"
#directory = "/storage/scratch1/5/avira7/Grain-Charging-Simulation-Data/build-dissipationRefinedGrid-initial8max0.8final12/"
directory = "../build-wang-comparison"

if iteration <10 :
    filenames = sorted(glob.glob(f"{directory}/fieldmaps/*00{iteration}*{configIN}*.txt")) #{iteration}
else:
    filenames = sorted(glob.glob(f"{directory}/fieldmaps/*{iteration}*{configIN}*.txt")) #{iteration}
print(filenames)

df  = read_data_format_efficient(filenames,scaling=True)

# check to make sure this matches the total nodes in outputlogs
#len(df[iteration]["E_mag"]) 

In [None]:
## SETTINGS HERE ARE OPTIMIZED FOR ITERATION 86 ##

fieldIN = df[iteration]

center = stacked_spheres.centroid
N_DOWNSAMPLE_EMAG = 1
ARROW_VOXEL_SPACING = 0.02 
Y_SLICE = 0.0 + center[1]
THICKNESS = 0.001
VECTOR_SCALE_FACTOR = 1e-6 #2e-7 #2e-3 #5e-6 #2e-3 #5e-6 # Global scaling for glyphs
FIELD_AVERAGE_RADIUS = 2.5e-3 #2e-3

vmin, vmax = (-2e5, 2e5) # in log(E_mag) units
red_point = np.array([-0.1, 0, 0.1 - 0.015 + 0.037]) # 

# ----------------------------------------------------
# Voxel Downsampling Helper Function
# Ensures uniform spatial distribution of points in the slice
# ----------------------------------------------------
def voxel_downsample_points(points, spacing):
    """
    Selects one point per voxel defined by the spacing.
    Assumes points are 3D, but only uses X and Z for 2D density control.
    """
    # 1. Normalize coordinates to voxel indices (focus on X and Z for the 2D slice)
    min_x, _, min_z = points.min(axis=0)
    
    # Calculate bin indices for the points
    # We use X (column 0) and Z (column 2)
    x_indices = np.floor((points[:, 0] - min_x) / spacing).astype(int)
    z_indices = np.floor((points[:, 2] - min_z) / spacing).astype(int)
    
    # Combine X and Z indices into a unique hash/key
    max_x_index = x_indices.max() + 1
    voxel_keys = z_indices * max_x_index + x_indices

    # 2. Find the unique keys and their first occurrence
    # `return_index=True` gives the index of the first occurrence of each unique key
    unique_keys, unique_indices = np.unique(voxel_keys, return_index=True)
    
    return unique_indices

# ----------------------------------------------------
# Step 0: Load, Filter, and Downsample Data (Single Pass)
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]

# Apply initial filtering (z > 0 and magnitude > 0)
initial_mask = (points[:, 2] > 0) & (magnitudes > 0)
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]

# Aggressive Downsample (for point cloud, typically N_DOWNSAMPLE_EMAG=1 is best)
points_ds = points[::N_DOWNSAMPLE_EMAG]
vectors_ds = vectors[::N_DOWNSAMPLE_EMAG]
magnitudes_ds = magnitudes[::N_DOWNSAMPLE_EMAG]

# Create a PyVista Point Cloud (PolyData)
point_cloud = pv.PolyData(points_ds)
point_cloud["E_mag"] = magnitudes_ds   # Store log magnitude for visualization
point_cloud["Ex_val"] = vectors_ds[:,0] # Store vectors
point_cloud["Ez_val"] = vectors_ds[:,2] # Store vectors

print(f"Starting points (filtered by z > 0 & mag > 0): {len(points)}")

# ----------------------------------------------------
# Step 1: Geometry Setup and Slicing
# ----------------------------------------------------
start_time_geo = time.time()

# 1a. Load and Crop Geometry
pv_spheres = pv.PolyData(
    stacked_spheres.vertices,
    np.hstack([np.full((len(stacked_spheres.faces), 1), 3), stacked_spheres.faces])
).compute_normals()

# Define bounding box based on the downsampled field data
bbox_bounds = point_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = pv_spheres.clip_box(bbox, invert=False)

# 1b. Define the slice plane (ZX plane, normal along Y)
normal = [0, 1, 0] # ZX plane (normal along Y)

# Create a plane mesh for interpolation (this will be the magnitude slice)
plane_bounds = [
    point_cloud.bounds[0], point_cloud.bounds[1], # X bounds
    Y_SLICE, Y_SLICE,                             # Y (fixed)
    point_cloud.bounds[4], point_cloud.bounds[5]  # Z bounds
]

field_slice_mesh = pv.Plane(
    center=center, 
    direction=normal,
    j_size=bbox_bounds[1] - bbox_bounds[0], # X span
    i_size=bbox_bounds[5] - bbox_bounds[4], # Z span
    i_resolution=250, 
    j_resolution=250
)

# --- MODIFIED INTERPOLATION CALL FOR NEAREST NEIGHBOR ---
field_slice_interpolated = field_slice_mesh.interpolate(
    point_cloud,
    sharpness=3.0,      # High sharpness often helps with point data
    radius=0.001, #1e-12,       # Set radius to near-zero to minimize interpolation
    
    # 1. Provide a float placeholder to satisfy the TypeError
    null_value=1, 

    # 2. Force the strategy to use the nearest point (Nearest Neighbor)
    strategy='closest_point' # <--- This achieves the extrapolation you want
)
# --------------------------------------------------------

field_slice_interpolated.points[:, 1] = Y_SLICE

# Also update the geometry slice
geo_slice = pv_spheres_cropped.slice(normal=normal, origin=center)
print(f"Geometry and slicing preparation complete in {time.time() - start_time_geo:.2f}s")

# ----------------------------------------------------
# Step 2: Vector Field Glyphs (Arrows)
# ----------------------------------------------------
start_time_vectors = time.time()

# 2a. Filter the downsampled points again to extract only those in the slice volume
# We use NumPy masking directly on the downsampled data (points_ds)
vector_mask = np.abs(points_ds[:, 1] - Y_SLICE) < THICKNESS
points_slice_full = points_ds[vector_mask]
vectors_slice_full = vectors_ds[vector_mask]
magnitudes_slice_full = magnitudes_ds[vector_mask]

# 2b. Apply Voxel Downsampling to achieve uniform density
unique_indices = voxel_downsample_points(points_slice_full, ARROW_VOXEL_SPACING)

points_slice = points_slice_full[unique_indices]
vectors_slice = vectors_slice_full[unique_indices]
magnitudes_slice = magnitudes_slice_full[unique_indices]

# ----------------------------------------------------
# MODIFICATION: Calculate Clamping Limit and Apply Clamping
# ----------------------------------------------------
# The maximum allowed length of an arrow is ARROW_VOXEL_SPACING.
# The glyph length = magnitude * VECTOR_SCALE_FACTOR * arrow_length_in_geom (which is 1.0 for pv.Arrow).
# To ensure: glyph_length <= ARROW_VOXEL_SPACING
# We need: magnitude * VECTOR_SCALE_FACTOR <= ARROW_VOXEL_SPACING
# Therefore: magnitude_clamped <= ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR

# Define the maximum magnitude allowed
MAGNITUDE_MAX_CLAMP = ARROW_VOXEL_SPACING / VECTOR_SCALE_FACTOR /2

# Apply the clamping (upper bound) to the magnitude array
magnitudes_slice_clamped = np.clip(magnitudes_slice, a_min=None, a_max=MAGNITUDE_MAX_CLAMP)
# ----------------------------------------------------


# 2c. Create a PolyData object for glyphs
points_slice[:,1] = Y_SLICE - 2*THICKNESS# Force y-coordinate to the slice plane for visualization
vectors_slice[:,1] = 0.0 - 2* THICKNESS# Zero out Y component for 2D slice visualization
slice_mesh_vectors = pv.PolyData(points_slice)
slice_mesh_vectors['vectors'] = vectors_slice
# Use the CLAMPED magnitude array for scaling
slice_mesh_vectors['magnitude'] = magnitudes_slice_clamped
#slice_mesh_vectors['magnitude'] = np.log10(magnitudes_slice)

# # 2c. Create a PolyData object for glyphs
# points_slice[:,1] = Y_SLICE - 2*THICKNESS# Force y-coordinate to the slice plane for visualization
# vectors_slice[:,1] = 0.0 - 2* THICKNESS# Zero out Y component for 2D slice visualization
# slice_mesh_vectors = pv.PolyData(points_slice)
# slice_mesh_vectors['vectors'] = vectors_slice
# #slice_mesh_vectors['magnitude'] = np.log10(magnitudes_slice)
# slice_mesh_vectors['magnitude'] = magnitudes_slice

print(f"Points in vector slice (after density control): {len(points_slice)}, old length: {len(points_slice_full)}...")

# 2d. Create the glyphs
arrow = pv.Arrow(tip_length=0.3, tip_radius=0.2, shaft_radius=0.04)
glyphs = slice_mesh_vectors.glyph(
    orient='vectors',
    scale='magnitude',
    factor=VECTOR_SCALE_FACTOR,
    geom=arrow
)

# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# Add interpolated magnitude slice
pl.add_mesh(
    field_slice_interpolated,
    scalars="Ex_val",
    cmap="YlGnBu",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],   # <-- set fixed color range here
    # --- COLORBAR POSITIONING ---
    scalar_bar_args={
        'title':None, # r'log$_{10}$(E$_{mag}$)', # Updated title format
        'vertical': False,            # Make it horizontal
        'position_x': 0.20,           # User-specified start position
        'position_y': 0.12,           # User-specified vertical position
        'width': 0.6,                 # User-specified width
        'height': 0.05,               # User-specified height
    }
)

# Add sliced geometry (outline only)
pl.add_mesh(geo_slice, color="black", line_width=5,opacity=0.5)

# Add vector glyphs
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=4,opacity=1)

# # Optional marker
sphere = pv.Sphere(radius=FIELD_AVERAGE_RADIUS, center=red_point)
pl.add_mesh(sphere, color="red", opacity=1)

# Force 2D (orthographic) projection and camera alignment for the ZX slice
pl.enable_parallel_projection()
pl.enable_2d_style()

# Align camera perpendicular to the slice
pl.view_xz() 

# --- ADD THIS LINE BEFORE pl.show() ---
pl.screenshot(f'figures/fieldvectors_{configIN}#{iteration}.jpeg', scale=4)

# Show the plot
print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show(jupyter_backend='static')

In [None]:
# -----------------------------------------
# Your existing setup code here...
# -----------------------------------------
fieldIN = df[iteration]
FIELD_AVERAGE_RADIUS = 2.5e-3
vmin, vmax = (-0.005, 0.005)
 
geometry_center = stacked_spheres_centroid  # replace with your centroid
red_point = np.array([-0.1, 0., 0.1 + 0.037]) + geometry_center
 
# ----------------------------------------------------
# Step 0: Filter data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]
 
initial_mask = (magnitudes > 0)
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]
 
epsilon_0 = 8.854187817e-12
 
# Create field cloud
field_cloud = pv.PolyData(points)
field_cloud["E_x"] = vectors[:, 0]
field_cloud["E_y"] = vectors[:, 1]
field_cloud["E_z"] = vectors[:, 2]
 
print(f"Starting points (filtered): {len(points)}")
 
# ----------------------------------------------------
# Step 1: Geometry Setup from vertices
# ----------------------------------------------------
start_time_geo = time.time()
 
# Convert vertex list to Nx3 numpy array
vertices = np.array(stacked_spheres)  # replace stacked_spheres with your vertex list
 
# Create a point cloud
cloud = pv.PolyData(vertices)
 
# Reconstruct a mesh from vertices
# Use delaunay_2d if roughly planar, otherwise use reconstruct_surface
print("before the construction")
pv_spheres = cloud.reconstruct_surface(nbr_neighbors=10)  # safer for 3D shapes
print("finished the construction")
pv_spheres.compute_normals(inplace=True)
 
# Crop mesh to field bounds
bbox_bounds = field_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = (
    pv_spheres
    .clip_box(bbox, invert=False)
    .extract_surface()
    .compute_normals(point_normals=True, cell_normals=True, inplace=False)
)
 
# ============================================================
# Interpolate field to face centers
# ============================================================
face_centers = pv_spheres_cropped.cell_centers().points
face_field_cloud = pv.PolyData(face_centers)
face_field_interp = face_field_cloud.interpolate(
    field_cloud,
    radius=0.002,
    strategy='closest_point',
    sharpness=3.0,
    null_value=0.0
)
 
# Extract face-centered E-fields
E_x_faces = face_field_interp["E_x"]
E_y_faces = face_field_interp["E_y"]
E_z_faces = face_field_interp["E_z"]
E_vec_faces = np.stack([E_x_faces, E_y_faces, E_z_faces], axis=1)
 
# Face normals
face_normals = pv_spheres_cropped.cell_normals
nx = face_normals[:, 0]   # x-direction component
 
# ============================================================
# Compute Maxwell electric pressure (normal)
# ============================================================
E_dot_n = np.einsum('ij,ij->i', E_vec_faces, face_normals)
E_mag_sq = np.einsum('ij,ij->i', E_vec_faces, E_vec_faces)
 
# Normal pressure (scalar)
P_normal_faces = epsilon_0 * (E_dot_n**2 - 0.5 * E_mag_sq)
 
# ============================================================
# Compute X-directed electric pressure
# ============================================================
P_x_faces = P_normal_faces * nx
pv_spheres_cropped.cell_data["electric_pressure_x"] = P_x_faces
 
print(f"Computed x-directed electric pressure in {time.time() - start_time_geo:.2f}s")
 
# ============================================================
# Plotting using Px
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure_x",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)
 
pl.view_xy()
pl.show(jupyter_backend='static')
 
# ============================================================
# Extract Line Plot Data (y=0 top surface)
# ============================================================
y_tolerance = 0.01
z_min = 0.08
 
x_centers = face_centers[:, 0]
y_centers = face_centers[:, 1]
z_centers = face_centers[:, 2]
 
Px_faces = P_x_faces  # shorthand
 
line_mask = (np.abs(y_centers) < y_tolerance) & (z_centers > z_min)
 
x_line = x_centers[line_mask]
z_line = z_centers[line_mask]
Px_line = Px_faces[line_mask]
 
# Bin + average
x_bin_width = 0.005
x_min, x_max = x_line.min(), x_line.max()
x_bins = np.arange(x_min, x_max + x_bin_width, x_bin_width)
x_bin_centers = (x_bins[:-1] + x_bins[1:]) / 2
 
bin_indices = np.digitize(x_line, x_bins)
 
x_line_avg, z_line_avg, Px_line_avg = [], [], []
 
for i in range(1, len(x_bins)):
    mask = (bin_indices == i)
    if mask.any():
        x_line_avg.append(x_line[mask].mean())
        z_line_avg.append(z_line[mask].mean())
        Px_line_avg.append(Px_line[mask].mean())
 
# Sort
x_line_sorted_PE = np.array(x_line_avg)
z_line_sorted_PE = np.array(z_line_avg)
pressure_line_sorted_PE = np.array(Px_line_avg)
 
sort_idx = np.argsort(x_line_sorted_PE)
x_line_sorted_PE = x_line_sorted_PE[sort_idx]
z_line_sorted_PE = z_line_sorted_PE[sort_idx]
pressure_line_sorted_PE = pressure_line_sorted_PE[sort_idx]
 
print(f"Extracted {len(x_line)} raw points, averaged into {len(x_line_sorted_PE)} bins along y=0 line")

In [None]:
# ----------------------------------------------------
# Step 3: Visualization
# ----------------------------------------------------
pl = pv.Plotter()
pl.set_background('white')

# Add interpolated magnitude slice
pl.add_mesh(
    field_slice_interpolated,
    scalars="Ex_val",
    cmap="YlGnBu",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],   # <-- set fixed color range here
    # --- COLORBAR POSITIONING FIX ---
    scalar_bar_args={
        'title':None, # r'log$_{10}$(E$_{mag}$)', # Updated title format
        'vertical': False,            # Make it horizontal
        'position_x': 0.20,           # User-specified start position
        'position_y': 0.12,           # User-specified vertical position
        'width': 0.6,                 # User-specified width
        'height': 0.05,               # User-specified height
    }
    # -------------------------------
)

# Add sliced geometry (outline only)
pl.add_mesh(geo_slice, color="black", line_width=5,opacity=0.5)

# Add vector glyphs
pl.add_mesh(glyphs, color='black', show_scalar_bar=False, line_width=4,opacity=1)



# # Optional marker
red_point = np.array([-0.1, 0, 0.1 - 0.015 + 0.036]) # 

# Define field averaging parameters
FIELD_AVERAGE_RADIUS = 2e-3

offsetLimitsX = 0.009
new_step_x = offsetLimitsX/3

offsetLimitsY = 0.011
new_step_y = offsetLimitsY/3

xoffset = np.round(np.arange(-offsetLimitsX, offsetLimitsX, new_step_x),4)
yoffset = np.round(np.arange(-offsetLimitsY, offsetLimitsY, new_step_y),4)

X, Y = np.meshgrid(xoffset, yoffset)
target_points_array = np.vstack([
    -0.1 - X.flatten(), 
    np.zeros(len(X.flatten())), 
    0.1 - 0.015 + 0.037 - Y.flatten()
]).T

print(f"Processing {len(target_points_array)} target points with radius {FIELD_AVERAGE_RADIUS} mm")
for red_point in target_points_array:

    sphere = pv.Sphere(radius=FIELD_AVERAGE_RADIUS, center=red_point)
    pl.add_mesh(sphere, color="red", opacity=1)

# # Combine into (N,3) array
# polyline = pv.PolyData(target_points_array)
# pl.add_mesh(polyline, color='r', point_size=1, opacity=0.8) # Add to the plot



# Parameters

# # Optional marker
red_point = np.array([-0.1, 0, 0.1 - 0.015 + 0.037]) # 
radius_mm = 40/1000/2   # 20 µm
center = red_point  # your np.array([-0.1-0.015, 0, 0.1 - 0.015 + 0.036])

# Create circle points manually (in XZ plane)
theta = np.linspace(0, 2*np.pi, 100)
x = center[0] + radius_mm * np.cos(theta)
y = np.full_like(theta, center[1])   # constant y value (so it's in XZ plane)
z = center[2] + radius_mm * np.sin(theta)

# Combine into (N,3) array
points = np.column_stack((x, y, z))
polyline = pv.PolyData(points)
pl.add_mesh(polyline, color='k', point_size=0.5, opacity=0.8) # Add to the plot


# # Create the line and add to the plot
# x_fixed,y_fixed = -0.1,0 
# zmin, zmax = 0, 0.2  # adjust to fit your plot domain
# points = np.array([[x_fixed, y_fixed, zmin],[x_fixed, y_fixed, zmax]])
# line = pv.Line(points[0], points[1])
# pl.add_mesh(line, color='black', line_width=2)

# Force 2D (orthographic) projection and camera alignment for the ZX slice
pl.enable_parallel_projection()
pl.enable_2d_style()

# Align camera perpendicular to the slice
pl.view_xz() 

# --- ADD THIS LINE BEFORE pl.show() ---
#pl.screenshot(f'figures/fieldvectors_{configIN}#{iteration}.jpeg', scale=4)

# Show the plot
print(f"Total execution time: {time.time() - start_time:.2f}s")
pl.show(jupyter_backend='static')

In [None]:
# Define field averaging parameters
FIELD_AVERAGE_RADIUS = 2e-3

offsetLimitsX = 0.009
new_step_x = offsetLimitsX/3

offsetLimitsY = 0.011
new_step_y = offsetLimitsY/3

xoffset = np.round(np.arange(-offsetLimitsX, offsetLimitsX, new_step_x),4)
yoffset = np.round(np.arange(-offsetLimitsY, offsetLimitsY, new_step_y),4)
X, Y = np.meshgrid(xoffset, yoffset)
target_points_array = np.vstack([
    -0.1 - X.flatten(), 
    np.zeros(len(X.flatten())), 
    0.1 - 0.015 + 0.037 - Y.flatten()
]).T

print(f"Processing {len(target_points_array)} target points with radius {FIELD_AVERAGE_RADIUS} mm")

In [None]:
offsetLimitsX = 0.009
new_step_x = offsetLimitsX/3

# np.arange(start, stop, step)
# We add a small epsilon (1e-9) to the stop value to guarantee
# that the last value, 0.009, is included due to floating-point arithmetic.
xoffset = np.round(np.arange(-offsetLimitsX, offsetLimitsX, new_step_x),4)

In [None]:
xoffset

In [None]:
yoffset

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 16})

cmap = plt.cm.YlGnBu 
vmin, vmax = (-2e5, 2e5) # in log(E_mag) units
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

fig, ax = plt.subplots(figsize=(10, 0.1))

# --- 1. Create and Configure the ScalarFormatter ---
formatter = ticker.ScalarFormatter(useMathText=True)

formatter.set_useOffset(False) 
formatter.set_powerlimits((0, 0)) 

# --- 2. Create the Colorbar and apply the Formatter ---
cb = mpl.colorbar.ColorbarBase(
    ax, 
    cmap=cmap, 
    norm=norm, 
    orientation='horizontal', label=r"E$_x$ (V/m)"
)

# Apply the formatter to the colorbar's x-axis
cb.ax.xaxis.set_major_formatter(formatter)

# --- 3. Display the Plot ---
plt.show()

## 3D Representation of the Electric Pressure

In [None]:
iteration_SW, iteration_PE = 10, 13
print(f"SW equivalent time for iteration#{iteration_SW}: {(iteration_SW-1)*CONVERT_ITERATION_SW_TIME}")
print(f"PE equivalent time for iteration#{iteration_PE}: {(iteration_PE-1)*CONVERT_ITERATION_PE_TIME}")

In [None]:
configIN = "onlyphotoemission"
directory = "../build-zimmerman-comparison"
#directory = "/storage/scratch1/5/avira7/Grain-Charging-Simulation-Data/build-dissipationRefinedGrid-initial8max0.8final12"

if iteration_PE <10 :
    filenames = sorted(glob.glob(f"{directory}/fieldmaps/*00{iteration_PE}*{configIN}*.txt")) #{iteration}
else:
    filenames = sorted(glob.glob(f"{directory}/fieldmaps/*{iteration_PE}*{configIN}*.txt")) #{iteration}
print(filenames)

df_PE  = read_data_format_efficient(filenames,scaling=True)

In [None]:
configIN = "onlysolarwind"
directory = "../build-zimmerman-SWonly-comparison"
#directory = "/storage/scratch1/5/avira7/Grain-Charging-Simulation-Data/build-dissipationRefinedGrid-initial8max0.8final12"

if iteration_SW <10 :
    filenames = sorted(glob.glob(f"{directory}/fieldmaps/*00{iteration_SW}*{configIN}*.txt")) #{iteration}
else:
    filenames = sorted(glob.glob(f"{directory}/fieldmaps/*{iteration_SW}*{configIN}*.txt")) #{iteration}
print(filenames)

df_SW = read_data_format_efficient(filenames,scaling=True)

In [None]:
# -----------------------------------------
# Your existing setup code here...
# -----------------------------------------
fieldIN = df_PE[iteration_PE]
FIELD_AVERAGE_RADIUS = 2.5e-3
vmin, vmax = (-0.005, 0.005) 
 
geometry_center = stacked_spheres.centroid
red_point = np.array([-0.1, 0., 0.1 + 0.037]) + geometry_center
 
# ----------------------------------------------------
# Step 0: Filter data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]
 
initial_mask = (magnitudes > 0) & (points[:,1]>=-0.1+geometry_center[1]) & (points[:,1]<=0.1+geometry_center[1])
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]
 
epsilon_0 = 8.854187817e-12
 
# Create field cloud
field_cloud = pv.PolyData(points)
field_cloud["E_x"] = vectors[:, 0]
field_cloud["E_y"] = vectors[:, 1]
field_cloud["E_z"] = vectors[:, 2]
 
print(f"Starting points (filtered): {len(points)}")
 
# ----------------------------------------------------
# Step 1: Geometry Setup
# ----------------------------------------------------
start_time_geo = time.time()
 
pv_spheres = pv.PolyData(
    stacked_spheres.vertices,
    np.hstack([np.full((len(stacked_spheres.faces), 1), 3), stacked_spheres.faces])
).compute_normals()
 
bbox_bounds = field_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = (
    pv_spheres
    .clip_box(bbox, invert=False)
    .extract_surface()
    .compute_normals(point_normals=True, cell_normals=True, inplace=False)
)
 
# ============================================================
# Interpolate field to face centers
# ============================================================
face_centers = pv_spheres_cropped.cell_centers().points
 
face_field_cloud = pv.PolyData(face_centers)
face_field_interp = face_field_cloud.interpolate(
    field_cloud,
    radius=0.002,
    strategy='closest_point',
    sharpness=3.0,
    null_value=0.0
)
 
# Extract face-centered E-fields
E_x_faces = face_field_interp["E_x"]
E_y_faces = face_field_interp["E_y"]
E_z_faces = face_field_interp["E_z"]
E_vec_faces = np.stack([E_x_faces, E_y_faces, E_z_faces], axis=1)
 
# Face normals
face_normals = pv_spheres_cropped.cell_normals
nx = face_normals[:, 0]   # <-- x-direction component
 
# ============================================================
# Compute Maxwell electric pressure (normal)
# ============================================================
E_dot_n = np.einsum('ij,ij->i', E_vec_faces, face_normals)
E_mag_sq = np.einsum('ij,ij->i', E_vec_faces, E_vec_faces)
 
# Normal pressure (scalar)
P_normal_faces = epsilon_0 * (E_dot_n**2 - 0.5 * E_mag_sq)
 
# ============================================================
# Compute X-directed electric pressure
# ============================================================
P_x_faces = P_normal_faces * nx   # <-- THIS IS WHAT YOU WANTED
 
# Save to cell data
pv_spheres_cropped.cell_data["electric_pressure_x"] = P_x_faces
 
print(f"Computed x-directed electric pressure in {time.time() - start_time_geo:.2f}s")
 
# ============================================================
# Plotting using Px
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure_x",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)
 
pl.view_xy()
pl.show(jupyter_backend='static')
 
# ============================================================
# Extract Line Plot Data (y=0 top surface)
# ============================================================
y_tolerance = 0.01
z_min = 0.08
 
x_centers = face_centers[:, 0]
y_centers = face_centers[:, 1]
z_centers = face_centers[:, 2]
 
Px_faces = P_x_faces  # shorthand
 
line_mask = (np.abs(y_centers) < y_tolerance) & (z_centers > z_min)
 
x_line = x_centers[line_mask]
z_line = z_centers[line_mask]
Px_line = Px_faces[line_mask]
 
# Bin + average
x_bin_width = 0.005
x_min, x_max = x_line.min(), x_line.max()
x_bins = np.arange(x_min, x_max + x_bin_width, x_bin_width)
x_bin_centers = (x_bins[:-1] + x_bins[1:]) / 2
 
bin_indices = np.digitize(x_line, x_bins)
 
x_line_avg, z_line_avg, Px_line_avg = [], [], []
 
for i in range(1, len(x_bins)):
    mask = (bin_indices == i)
    if mask.any():
        x_line_avg.append(x_line[mask].mean())
        z_line_avg.append(z_line[mask].mean())
        Px_line_avg.append(Px_line[mask].mean())
 
# Sort
x_line_sorted_PE = np.array(x_line_avg)
z_line_sorted_PE = np.array(z_line_avg)
pressure_line_sorted_PE = np.array(Px_line_avg)
 
sort_idx = np.argsort(x_line_sorted_PE)
x_line_sorted_PE = x_line_sorted_PE[sort_idx]
z_line_sorted_PE = z_line_sorted_PE[sort_idx]
pressure_line_sorted_PE = pressure_line_sorted_PE[sort_idx]
 
print(f"Extracted {len(x_line)} raw points, averaged into {len(x_line_sorted_PE)} bins along y=0 line")

In [None]:
geometry_center

In [None]:
# -----------------------------------------
# Your existing setup code here...
# -----------------------------------------
fieldIN = df_SW[iteration_SW]
FIELD_AVERAGE_RADIUS = 2.5e-3
vmin, vmax = (-0.01, 0.01) 
 
geometry_center = stacked_spheres.centroid
red_point = np.array([-0.1, 0., 0.1 + 0.037]) + geometry_center
 
# ----------------------------------------------------
# Step 0: Filter data
# ----------------------------------------------------
start_time = time.time()
points = fieldIN["pos"]
vectors = fieldIN["E"]
magnitudes = fieldIN["E_mag"]
 
initial_mask = (magnitudes > 0) & (points[:,1]>=-0.1+geometry_center[1]) & (points[:,1]<=0.1+geometry_center[1])
points = points[initial_mask]
vectors = vectors[initial_mask]
magnitudes = magnitudes[initial_mask]
 
epsilon_0 = 8.854187817e-12
 
# Create field cloud
field_cloud = pv.PolyData(points)
field_cloud["E_x"] = vectors[:, 0]
field_cloud["E_y"] = vectors[:, 1]
field_cloud["E_z"] = vectors[:, 2]
 
print(f"Starting points (filtered): {len(points)}")
 
# ----------------------------------------------------
# Step 1: Geometry Setup
# ----------------------------------------------------
start_time_geo = time.time()
 
pv_spheres = pv.PolyData(
    stacked_spheres.vertices,
    np.hstack([np.full((len(stacked_spheres.faces), 1), 3), stacked_spheres.faces])
).compute_normals()
 
bbox_bounds = field_cloud.bounds
bbox = pv.Box(bounds=bbox_bounds)
pv_spheres_cropped = (
    pv_spheres
    .clip_box(bbox, invert=False)
    .extract_surface()
    .compute_normals(point_normals=True, cell_normals=True, inplace=False)
)
 
# ============================================================
# Interpolate field to face centers
# ============================================================
face_centers = pv_spheres_cropped.cell_centers().points
 
face_field_cloud = pv.PolyData(face_centers)
face_field_interp = face_field_cloud.interpolate(
    field_cloud,
    radius=0.002,
    strategy='closest_point',
    sharpness=3.0,
    null_value=0.0
)
 
# Extract face-centered E-fields
E_x_faces = face_field_interp["E_x"]
E_y_faces = face_field_interp["E_y"]
E_z_faces = face_field_interp["E_z"]
E_vec_faces = np.stack([E_x_faces, E_y_faces, E_z_faces], axis=1)
 
# Face normals
face_normals = pv_spheres_cropped.cell_normals
nx = face_normals[:, 0]   # <-- x-direction component
 
# ============================================================
# Compute Maxwell electric pressure (normal)
# ============================================================
E_dot_n = np.einsum('ij,ij->i', E_vec_faces, face_normals)
E_mag_sq = np.einsum('ij,ij->i', E_vec_faces, E_vec_faces)
 
# Normal pressure (scalar)
P_normal_faces = epsilon_0 * (E_dot_n**2 - 0.5 * E_mag_sq)
 
# ============================================================
# Compute X-directed electric pressure
# ============================================================
P_x_faces = P_normal_faces * nx   # <-- THIS IS WHAT YOU WANTED
 
# Save to cell data
pv_spheres_cropped.cell_data["electric_pressure_x"] = P_x_faces
 
print(f"Computed x-directed electric pressure in {time.time() - start_time_geo:.2f}s")
 
# ============================================================
# Plotting using Px
# ============================================================
pl = pv.Plotter()
pl.set_background('white')
 
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure_x",
    cmap="seismic",
    opacity=1,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)
 
pl.view_xy()
pl.show(jupyter_backend='static')
 
# ============================================================
# Extract Line Plot Data (y=0 top surface)
# ============================================================
y_tolerance = 0.01
z_min = 0.08
 
x_centers = face_centers[:, 0]
y_centers = face_centers[:, 1]
z_centers = face_centers[:, 2]
 
Px_faces = P_x_faces  # shorthand
 
line_mask = (np.abs(y_centers) < y_tolerance) & (z_centers > z_min)
 
x_line = x_centers[line_mask]
z_line = z_centers[line_mask]
Px_line = Px_faces[line_mask]
 
# Bin + average
x_bin_width = 0.005
x_min, x_max = x_line.min(), x_line.max()
x_bins = np.arange(x_min, x_max + x_bin_width, x_bin_width)
x_bin_centers = (x_bins[:-1] + x_bins[1:]) / 2
 
bin_indices = np.digitize(x_line, x_bins)
 
x_line_avg, z_line_avg, Px_line_avg = [], [], []
 
for i in range(1, len(x_bins)):
    mask = (bin_indices == i)
    if mask.any():
        x_line_avg.append(x_line[mask].mean())
        z_line_avg.append(z_line[mask].mean())
        Px_line_avg.append(Px_line[mask].mean())
 
# Sort
x_line_sorted_PE = np.array(x_line_avg)
z_line_sorted_PE = np.array(z_line_avg)
pressure_line_sorted_PE = np.array(Px_line_avg)
 
sort_idx = np.argsort(x_line_sorted_PE)
x_line_sorted_PE = x_line_sorted_PE[sort_idx]
z_line_sorted_PE = z_line_sorted_PE[sort_idx]
pressure_line_sorted_PE = pressure_line_sorted_PE[sort_idx]
 
print(f"Extracted {len(x_line)} raw points, averaged into {len(x_line_sorted_PE)} bins along y=0 line")

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 14})

cmap = plt.cm.seismic 
vmin, vmax = (-1, 1)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

fig, ax = plt.subplots(figsize=(5, 0.1))

# --- 1. Create and Configure the ScalarFormatter ---
formatter = ticker.ScalarFormatter(useMathText=True)

formatter.set_useOffset(False) 
formatter.set_powerlimits((0, 0)) 

# --- 2. Create the Colorbar and apply the Formatter ---
cb = mpl.colorbar.ColorbarBase(
    ax, 
    cmap=cmap, 
    norm=norm, 
    orientation='horizontal', label=r"Electric Pressure (Pa)"
)

# Apply the formatter to the colorbar's x-axis
cb.ax.xaxis.set_major_formatter(formatter)

# --- 3. Display the Plot ---
plt.show()

In [None]:
# Enable LaTeX rendering
mpl.rcParams['text.usetex'] = False
# Set the global font size
mpl.rcParams.update({'font.size': 14})

cmap = plt.cm.seismic 
vmin, vmax = (-0.05, 0.05) 
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

fig, ax = plt.subplots(figsize=(5, 0.1))

# --- 1. Create and Configure the ScalarFormatter ---
formatter = ticker.ScalarFormatter(useMathText=True)

formatter.set_useOffset(False) 
formatter.set_powerlimits((0, 0)) 

# --- 2. Create the Colorbar and apply the Formatter ---
cb = mpl.colorbar.ColorbarBase(
    ax, 
    cmap=cmap, 
    norm=norm, 
    orientation='horizontal', label=r"Electric Pressure (Pa)"
)

# Apply the formatter to the colorbar's x-axis
cb.ax.xaxis.set_major_formatter(formatter)

# --- 3. Display the Plot ---
plt.show()

In [None]:

# ============================================================
# Create Line Plot
# ============================================================

# Create a figure with the line plot overlaid on a slice view
fig, ax = plt.subplots(figsize=(8, 3))

# Plot the pressure line
# ax.plot(x_line_sorted_PE * 1000, pressure_line_sorted_PE, '-', linewidth=2, color=color_list[1],
#         label='Photoemission', zorder=5)
# ax.axhline(0, color='k', linestyle='-', lw=0.5, alpha=0.8)
# ax.tick_params(axis='y', labelcolor=color_list[1])
# ax.set_xlabel(r'X Position ($\mu$m)', fontsize=12)
# ax.set_ylabel('Electric Pressure (Pa)', fontsize=12)
# ax.set_ylim([-0.6,0.3])
plt.axis('off')

# Add a background showing the geometry profile (top surface outline)
# Extract all top surface points for context
top_mask = (np.abs(y_centers) < y_tolerance) & (z_centers > z_min - 0.02)
x_top = x_centers[top_mask]
z_top = z_centers[top_mask]

# Sort and plot as scatter to show geometry extent
sort_top = np.argsort(x_top)
ax2 = ax.twinx()
# ax2.plot(x_line_sorted_SW * 1000, pressure_line_sorted_SW, '-', linewidth=2, color=color_list[2],
#         label='Solar Wind', zorder=5)
ax2.scatter(x_top * 1000, z_top * 1000 , c='gray', alpha=0.2, s=3, label=None)
# ax2.set_ylabel(r'Z Position ($\mu$m)', fontsize=12, color='gray')
# ax2.set_ylim([-0.02,0.01])
ax2.set_xlim([-150,150])
ax2.tick_params(axis='y', labelcolor=color_list[2])
plt.axis('off')
plt.savefig("figures/sphere_trace.svg",transparent=True,dpi=300)
plt.show()

In [None]:

# ============================================================
# Create Line Plot
# ============================================================

# Create a figure with the line plot overlaid on a slice view
fig, ax = plt.subplots(figsize=(8, 3))

# Plot the pressure line
ax.semilogy(x_line_sorted_PE * 1000, pressure_line_sorted_PE, '-', linewidth=2, color='g',
        label='Photoemission', zorder=5)
ax.axhline(0, color='k', linestyle='-', lw=0.5, alpha=0.8)
#ax.tick_params(axis='y', labelcolor=color_list[1])
ax.set_xlabel(r'X Position ($\mu$m)', fontsize=12)
ax.set_ylabel('Electric Pressure (Pa)', fontsize=12)
ax.set_ylim([-0.6,0.3])

# # Add a background showing the geometry profile (top surface outline)
# # Extract all top surface points for context
# top_mask = (np.abs(y_centers) < y_tolerance) & (z_centers > z_min - 0.02)
# x_top = x_centers[top_mask]
# z_top = z_centers[top_mask]

# # Sort and plot as scatter to show geometry extent
# sort_top = np.argsort(x_top)
# ax2 = ax.twinx()
# ax2.plot(x_line_sorted_SW * 1000, pressure_line_sorted_SW, '-', linewidth=2, color='b',
#         label='Solar Wind', zorder=5)
ax2.set_ylim([-0.02,0.01])
ax2.set_xlim([-150,150])
#ax2.tick_params(axis='y', labelcolor=color_list[0])
fig.savefig("figures/electric_pressure_linesplot.jpeg",transparent=True,dpi=300,bbox_inches="tight")
plt.show()

In [None]:
# ============================================================
# Optional: Show 3D visualization with line highlighted
# ============================================================

pl = pv.Plotter()
pl.set_background('white')

# Add full mesh with face colors
pl.add_mesh(
    pv_spheres_cropped,
    scalars="electric_pressure",
    cmap="seismic",
    opacity=0.7,
    show_edges=False,
    clim=[vmin, vmax],
    interpolate_before_map=False,
    preference="cell"
)

# Highlight the extracted line points
line_points_3d = np.column_stack([x_line_sorted, np.zeros_like(x_line_sorted), z_line_sorted])
line_polydata = pv.PolyData(line_points_3d)
pl.add_mesh(line_polydata, color='yellow', point_size=2, render_points_as_spheres=True)

# Add connecting line
if len(line_points_3d) > 1:
    line_cells = np.column_stack([
        np.full(len(line_points_3d)-1, 2),
        np.arange(len(line_points_3d)-1),
        np.arange(1, len(line_points_3d))
    ]).flatten()
    line_mesh = pv.PolyData(line_points_3d, lines=line_cells)
    pl.add_mesh(line_mesh, color='yellow', line_width=5)

pl.view_xy()
pl.show(jupyter_backend='static')
