In [None]:
import disco
import numpy as np
import h5py
from matplotlib import pyplot as plt
from astropy import units as u
from astropy import constants
from scipy.constants import elementary_charge
import ai.cs
from mpl_toolkits.mplot3d import Axes3D


# Load Regridded File

In [None]:
hdf = h5py.File('regrid.h5')
regrid_data = {}

for key in hdf.keys():
    regrid_data[key] = hdf[key][:]
hdf.close()

regrid_data['Bx'] *= u.nT
regrid_data['By'] *= u.nT
regrid_data['Bz'] *= u.nT
regrid_data['Ex'] *= u.mV/u.m
regrid_data['Ey'] *= u.mV/u.m
regrid_data['Ez'] *= u.mV/u.m
regrid_data['n'] *= u.cm**(-3)
regrid_data['T'] *= u.eV
regrid_data['xaxis'] *= u.R_earth
regrid_data['yaxis'] *= u.R_earth
regrid_data['zaxis'] *= u.R_earth

# Make Axes Instance

In [None]:
taxis = np.array([-1, 1]) * u.day

axes = disco.Axes(regrid_data['xaxis'],
                  regrid_data['yaxis'],
                  regrid_data['zaxis'],
                  taxis,
                  r_inner=2.5 * u.R_earth)
axes

In [None]:
regrid_data['Bx'].shape

# Make FieldModel Instance

In [None]:
mass = constants.m_p
charge = elementary_charge * u.coulomb

Bx = np.array([regrid_data['Bx'].T]*taxis.size).T * regrid_data['Bx'].unit
By = np.array([regrid_data['By'].T]*taxis.size).T * regrid_data['By'].unit
Bz = np.array([regrid_data['Bz'].T]*taxis.size).T * regrid_data['Bz'].unit
Ex = np.array([regrid_data['Ex'].T]*taxis.size).T * regrid_data['Ex'].unit
Ey = np.array([regrid_data['Ey'].T]*taxis.size).T * regrid_data['Ey'].unit
Ez = np.array([regrid_data['Ez'].T]*taxis.size).T * regrid_data['Ez'].unit

m = (Bx.value==0) & (By.value==0) & (Bz.value==0) & (Ex.value==0) & (Ey.value==0) & (Ez.value==0)
print('Cells to NaN:', m.sum() / m.size)
Bx[m] = np.nan
By[m] = np.nan
Bz[m] = np.nan
Ex[m] = np.nan
Ey[m] = np.nan
Ez[m] = np.nan

field_model = disco.FieldModel(Bx, By, Bz, Ex, Ey, Ez, mass, charge, axes)
field_model

# Make ParticleState

In [None]:

particle_height = 3 * u.R_earth

particle_long_axis = np.arange(0, 2*np.pi, .1) # magnetic longitude
particle_invlat_axis = np.arange(65, 85, .1) # invariant magnetic latitude
particle_lat_axis = np.arcsin(np.sqrt((u.R_earth / particle_height) * np.sin(particle_invlat_axis)**2))

particle_lat, particle_long = np.meshgrid(particle_long_axis, particle_lat_axis, indexing='ij')

pos_x, pos_y, pos_z = ai.cs.sp2cart(particle_height, particle_long, particle_lat, )

In [None]:

ax = plt.figure(figsize=(8,8)).add_subplot(111, projection='3d')

uu = np.linspace(0, 2 * np.pi, 100)
vv = np.linspace(0, np.pi, 100)
x = np.outer(np.cos(uu), np.sin(vv))
y = np.outer(np.sin(uu), np.sin(vv))
z = np.outer(np.ones(np.size(uu)), np.cos(vv))
ax.plot_surface(x, y, z)

ax.plot(pos_x, pos_y, pos_z, '.', zorder=100)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
ax.set_zlim(-5, 5)
ax.set_xlabel('X (SM)')
ax.set_ylabel('Y (SM)')
ax.set_zlabel('Z (SM)')
ax.set_title('MARBLE Grid')

In [None]:
energy_axis = 10**np.linspace(np.log10(5), np.log10(20e3), 30) * u.eV
energy_axis

In [None]:
particle_lat, particle_long, particle_energy = np.meshgrid(particle_lat_axis, particle_long_axis, energy_axis, indexing='ij')
pos_x, pos_y, pos_z = ai.cs.sp2cart(particle_height, particle_long, particle_lat, )
pos_x *= u.R_earth
pos_y *= u.R_earth
pos_z *= u.R_earth
particle_vel = np.sqrt(2 * particle_energy / constants.m_p)
gamma = 1 / np.sqrt(1 - (particle_vel / constants.c) ** 2)
ppar = gamma * constants.m_p * particle_vel
magnetic_moment = np.zeros(ppar.shape) * u.MeV/u.nT

In [None]:
particle_state = disco.ParticleState(pos_x.flatten(), pos_y.flatten(), pos_z.flatten(),
                                     ppar.flatten(), magnetic_moment.flatten(), constants.m_p, charge)

# Make a TraceConfig

In [None]:
config = disco.TraceConfig(
    t_final=-1 * u.hr,
    h_initial=5 * u.ms,
    rtol=1e-2,
    output_freq=None,
    integrate_backwards=True
)

# Perform Traces

In [None]:
import time
start_time = time.time()
hist = disco.trace_trajectory(config, particle_state, field_model)
print('took', time.time() - start_time, 's')

In [None]:
xx = hist.x.flatten()
yy = hist.y.flatten()
zz = hist.z.flatten()

out_bd = (
    (np.abs(xx -  regrid_data['xaxis'].value[0]) < .1) |
    (np.abs(xx -  regrid_data['xaxis'].value[-1]) < .1) |
    
    (np.abs(yy -  regrid_data['yaxis'].value[0]) < .1) |
    (np.abs(yy -  regrid_data['yaxis'].value[-1]) < .1) |
    
    (np.abs(zz -  regrid_data['zaxis'].value[0]) < .1) |
    (np.abs(zz -  regrid_data['zaxis'].value[-1]) < .1) 
)
inner_bd = (np.abs(np.sqrt(xx**2 + yy**2 + zz**2) - field_model.axes.r_inner) < 0.1)

plt.bar([0, 1, 2], [xx.size - out_bd.sum() - inner_bd.sum(), inner_bd.sum(), out_bd.sum()])
plt.gca().set_xticks([0, 1, 2])
plt.gca().set_xticklabels(['Time Limit', 'Inner Boundary', 'Outer Boundary'])
plt.title('Particle Trace Endpoint')

In [None]:
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')

b = inner_bd | out_bd
im = ax.scatter(hist.x.flatten()[b], hist.y.flatten()[b], hist.z.flatten()[b], c=(particle_energy/1e3).flatten()[b], s=.1)
ax.set_xlim(-16, 16)
ax.set_ylim(-16, 16)
ax.set_zlim(-16, 16)
ax.set_xlabel('X (SM)')
ax.set_ylabel('Y (SM)')
ax.set_zlabel('Z (SM)')
