In [1]:
import os
from datetime import datetime
import numpy as np
import pandas as pd
import random
import xarray as xr
from numba import njit
import plotly.graph_objects as go

# --------------------------
# SETTINGS
# --------------------------
case = "RPN"
# input_folder = f"/Volumes/data_backup/extreme_base/{case}_Base/plane_product/object/"
# ncfile = os.path.join(input_folder, f"Amitis_{case}_Base_115000_xz_comp.nc")
input_folder = f"/Volumes/data_backup/extreme_base/{case}_Base/05/out/"
ncfile = os.path.join(input_folder, f"Amitis_{case}_Base_115000.nc")

output_folder = f"/Users/danywaller/Projects/mercury/extreme/jfield_topology/{case}_Base/"
os.makedirs(output_folder, exist_ok=True)

# Planet parameters
RM = 2440.0          # Mercury radius [km]
RC = 2400.0          # depth within conductive layer [km]

if case in ["RPS", "RPN"]:
    plot_depth = RM
elif case in ["CPS", "CPN"]:
    plot_depth = RC
else:
    raise ValueError("Invalid case ID, pick one of RPS, RPN, CPS, CPN")

# Seed settings
n_lat = 60
n_lon = n_lat*2
max_steps = 5000
h_step = 50.0
surface_tol = 75.0

max_lines = 100  # downsample trajectory points for plotting


In [2]:
# --------------------------
# NUMBA FUNCTIONS
# --------------------------
@njit
def trilinear_interp(x_grid, y_grid, z_grid, B, xi, yi, zi):
    i = np.searchsorted(x_grid, xi) - 1
    j = np.searchsorted(y_grid, yi) - 1
    k = np.searchsorted(z_grid, zi) - 1
    i = max(0, min(i, len(x_grid)-2))
    j = max(0, min(j, len(y_grid)-2))
    k = max(0, min(k, len(z_grid)-2))
    xd = (xi - x_grid[i]) / (x_grid[i+1]-x_grid[i])
    yd = (yi - y_grid[j]) / (y_grid[j+1]-y_grid[j])
    zd = (zi - z_grid[k]) / (z_grid[k+1]-z_grid[k])
    c000 = B[i,j,k]; c100 = B[i+1,j,k]; c010 = B[i,j+1,k]; c001 = B[i,j,k+1]
    c101 = B[i+1,j,k+1]; c011 = B[i,j+1,k+1]; c110 = B[i+1,j+1,k]; c111 = B[i+1,j+1,k+1]
    c00 = c000*(1-xd)+c100*xd
    c01 = c001*(1-xd)+c101*xd
    c10 = c010*(1-xd)+c110*xd
    c11 = c011*(1-xd)+c111*xd
    c0 = c00*(1-yd)+c10*yd
    c1 = c01*(1-yd)+c11*yd
    return c0*(1-zd)+c1*zd

@njit
def get_B(r, Bx, By, Bz, x_grid, y_grid, z_grid):
    bx = trilinear_interp(x_grid, y_grid, z_grid, Bx, r[0], r[1], r[2])
    by = trilinear_interp(x_grid, y_grid, z_grid, By, r[0], r[1], r[2])
    bz = trilinear_interp(x_grid, y_grid, z_grid, Bz, r[0], r[1], r[2])
    B = np.array([bx, by, bz])
    norm = np.linalg.norm(B)
    if norm == 0.0:
        return np.zeros(3)
    return B / norm

@njit
def rk45_step(f, r, h, Bx, By, Bz, x_grid, y_grid, z_grid):
    k1 = f(r, Bx, By, Bz, x_grid, y_grid, z_grid)
    k2 = f(r + h*k1*0.25, Bx, By, Bz, x_grid, y_grid, z_grid)
    k3 = f(r + h*(3*k1+9*k2)/32, Bx, By, Bz, x_grid, y_grid, z_grid)
    k4 = f(r + h*(1932*k1 - 7200*k2 + 7296*k3)/2197, Bx, By, Bz, x_grid, y_grid, z_grid)
    k5 = f(r + h*(439*k1/216 - 8*k2 + 3680*k3/513 - 845*k4/4104), Bx, By, Bz, x_grid, y_grid, z_grid)
    k6 = f(r + h*(-8*k1/27 + 2*k2 - 3544*k3/2565 + 1859*k4/4104 - 11*k5/40), Bx, By, Bz, x_grid, y_grid, z_grid)
    r_next = r + h*(16*k1/135 + 6656*k3/12825 + 28561*k4/56430 - 9*k5/50 + 2*k6/55)
    return r_next

@njit
def trace_field_line_rk(seed, Bx, By, Bz, x_grid, y_grid, z_grid, RM, max_steps=5000, h=50.0, surface_tol=-1.0):
    traj = np.empty((max_steps, 3), dtype=np.float64)
    traj[0] = seed
    r = seed.copy()
    exit_y_boundary = False
    for i in range(1, max_steps):
        B = get_B(r, Bx, By, Bz, x_grid, y_grid, z_grid)
        if np.all(B == 0.0):
            return traj[:i], exit_y_boundary
        r_next = rk45_step(get_B, r, h, Bx, By, Bz, x_grid, y_grid, z_grid)
        traj[i] = r_next
        r = r_next
        if np.linalg.norm(r) <= RM + surface_tol:
            return traj[:i+1], exit_y_boundary
        if (r[0]<x_grid[0] or r[0]>x_grid[-1] or r[2]<z_grid[0] or r[2]>z_grid[-1]):
            return traj[:i+1], exit_y_boundary
        if r[1]<y_grid[0] or r[1]>y_grid[-1]:
            exit_y_boundary = True
            return traj[:i+1], exit_y_boundary
    return traj, exit_y_boundary

@njit
def classify(traj_fwd, traj_bwd, RM, exit_fwd_y=False, exit_bwd_y=False):
    hit_fwd = np.linalg.norm(traj_fwd[-1]) <= RM
    hit_bwd = np.linalg.norm(traj_bwd[-1]) <= RM
    if exit_fwd_y or exit_bwd_y:
        return "TBD"
    if hit_fwd and hit_bwd:
        return "closed"
    elif hit_fwd or hit_bwd:
        return "open"
    else:
        return "TBD"


In [3]:
# --------------------------
# CREATE SEEDS ON SPHERE
# --------------------------
lats_surface = np.linspace(-90, 90, n_lat)
lons_surface = np.linspace(-180, 180, n_lon)
seeds = []
for lat in lats_surface:
    for lon in lons_surface:
        phi = np.radians(lat)
        theta = np.radians(lon)
        x_s = plot_depth*np.cos(phi)*np.cos(theta)
        y_s = plot_depth*np.cos(phi)*np.sin(theta)
        z_s = plot_depth*np.sin(phi)
        seeds.append(np.array([x_s, y_s, z_s]))
seeds = np.array(seeds)


In [4]:
# --------------------------
# LOAD VECTOR FIELD FROM NETCDF
# --------------------------
def load_field(ncfile):
    ds = xr.open_dataset(ncfile)
    x = ds["Nx"].values
    y = ds["Ny"].values
    z = ds["Nz"].values
    # pick first time step
    Jx = np.transpose(ds["Jx"].isel(time=0).values, (2,1,0))
    Jy = np.transpose(ds["Jy"].isel(time=0).values, (2,1,0))
    Jz = np.transpose(ds["Jz"].isel(time=0).values, (2,1,0))
    ds.close()
    return x, y, z, Jx, Jy, Jz

x, y, z, Jx, Jy, Jz = load_field(ncfile)

start = datetime.now()
print(f"Loaded {ncfile} at {str(start)}")


Loaded /Volumes/data_backup/extreme_base/RPN_Base/05/out/Amitis_RPN_Base_115000.nc at 2026-01-15 16:40:25.866624


In [5]:
# --------------------------
# TRACE FIELD LINES
# --------------------------
lines_by_topo = {"closed": [], "open": []}

for seed in seeds:
    traj_fwd, exit_fwd_y = trace_field_line_rk(seed, Jx, Jy, Jz, x, y, z, plot_depth, max_steps=max_steps, h=h_step)
    traj_bwd, exit_bwd_y = trace_field_line_rk(seed, Jx, Jy, Jz, x, y, z, plot_depth, max_steps=max_steps, h=-h_step)
    topo = classify(traj_fwd, traj_bwd, plot_depth, exit_fwd_y, exit_bwd_y)
    if topo in ["closed", "open"]:
        lines_by_topo[topo].append(traj_fwd)
        lines_by_topo[topo].append(traj_bwd)

classtime = datetime.now()
print(f"Classified all lines at {str(classtime)}")


Classified all lines at 2026-01-15 16:40:59.858275


In [6]:
# --------------------------
# PLOT 3D FIELD LINES
# --------------------------
colors = {"closed": "blue", "open": "red"}
fig = go.Figure()

# add planet sphere
theta = np.linspace(0, np.pi, 100)        # colatitude
phi   = np.linspace(0, 2*np.pi, 200)      # longitude
theta, phi = np.meshgrid(theta, phi)

xs = RM * np.sin(theta) * np.cos(phi)
ys = RM * np.sin(theta) * np.sin(phi)
zs = RM * np.cos(theta)

mask_pos = xs > 0
mask_neg = xs <= 0

# light grey hemisphere (X > 0)
fig.add_trace(go.Surface(
    x=np.where(mask_pos, xs, np.nan),
    y=np.where(mask_pos, ys, np.nan),
    z=np.where(mask_pos, zs, np.nan),
    surfacecolor=np.ones_like(xs),
    colorscale=[[0, 'lightgrey'], [1, 'lightgrey']],
    cmin=0,
    cmax=1,
    showscale=False
))

# black hemisphere (X <= 0)
fig.add_trace(go.Surface(
    x=np.where(mask_neg, xs, np.nan),
    y=np.where(mask_neg, ys, np.nan),
    z=np.where(mask_neg, zs, np.nan),
    surfacecolor=np.zeros_like(xs),
    colorscale=[[0, 'black'], [1, 'black']],
    cmin=0,
    cmax=1,
    showscale=False
))

# add field lines
for topo, lines in lines_by_topo.items():
    first = True  # flag to show legend only once per topo

    # Downsample lines if there are too many
    if len(lines) > max_lines:
        lines_to_plot = random.sample(lines, max_lines)
    else:
        lines_to_plot = lines
        
    for traj in lines_to_plot:
        fig.add_trace(go.Scatter3d(
            x=traj[:,0], y=traj[:,1], z=traj[:,2],
            mode='lines',
            line=dict(color=colors[topo], width=2),
            name=topo,
            showlegend=first
        ))
        first = False  # only first trace per topo shows in legend

fig.update_layout(
    width=1000,
    height=800,
    scene=dict(
        xaxis=dict(title='X [km]', range=[-8*RM, 8*RM]),
        yaxis=dict(title='Y [km]', range=[-8*RM, 8*RM]),
        zaxis=dict(title='Z [km]', range=[-8*RM, 8*RM]),
        aspectmode='cube' 
    ),
    title=f"{case} Current Field Line Topology"
)

out_html = f"{case}_J_vector_topology.html"
fig.write_html(os.path.join(output_folder, out_html), include_plotlyjs="cdn")
fig.write_image(os.path.join(output_folder, out_html.replace(".html", ".png")), scale=2)
