In [None]:
import disco
from disco.readers import GenericHdf5FieldModel

import numpy as np
import h5py
from matplotlib import pyplot as plt
from astropy import units as u
from astropy import constants
from scipy.constants import elementary_charge
import ai.cs
from mpl_toolkits.mplot3d import Axes3D
import time
import pandas as pd

import cupy 
cupy.cuda.runtime.setDevice(3)

# Load Regridded File

In [None]:
field_model = GenericHdf5FieldModel('regrid2.h5')

# Make ParticleState

In [None]:
particle_height = 3 # Re
energy_min = 5 # eV
energy_max = 20e3 # eV
energy_count = 30 
inv_lat_min = 75 # deg
inv_lat_max = 85 # deg
inv_lat_step = 1 # deg
long_step = 1 # deg

In [None]:
particle_long_axis = np.deg2rad(np.arange(0, 360, long_step)) # magnetic longitude
particle_invlat_axis = np.deg2rad(np.arange(inv_lat_min, inv_lat_max, inv_lat_step)) # invariant magnetic latitude
particle_lat_axis = np.arcsin(np.sqrt((1 / particle_height) * np.sin(particle_invlat_axis)**2))
particle_lat, particle_long = np.meshgrid(particle_long_axis, particle_lat_axis, indexing='ij')
pos_x, pos_y, pos_z = ai.cs.sp2cart(particle_height, particle_long, particle_lat, )

In [None]:
energy_axis = 10**np.linspace(np.log10(energy_min), np.log10(energy_max), energy_count) * u.eV

In [None]:
particle_lat, particle_long, particle_energy = np.meshgrid(particle_lat_axis, particle_long_axis, energy_axis, indexing='ij')
pos_x, pos_y, pos_z = ai.cs.sp2cart(particle_height, particle_lat, particle_long )
pos_x *= u.R_earth
pos_y *= u.R_earth
pos_z *= u.R_earth
particle_vel = np.sqrt(2 * particle_energy / constants.m_p)
gamma = 1 / np.sqrt(1 - (particle_vel / constants.c) ** 2)
ppar = gamma * constants.m_p * particle_vel
magnetic_moment = np.zeros(ppar.shape) * u.MeV/u.nT
charge = - elementary_charge * u.C

In [None]:
particle_state = disco.ParticleState(pos_x.flatten(), pos_y.flatten(), pos_z.flatten(),
                                     ppar.flatten(), magnetic_moment.flatten(), constants.m_p, charge)

In [None]:
plt.plot(pos_x.flatten(), pos_y.flatten(), ',')
plt.gca().set_aspect('equal')

# Make a TraceConfig

In [None]:
config = disco.TraceConfig(
    t_final=-np.inf * u.hr,
    rtol=1e-2,
    integrate_backwards=True,
    iters_max=2500,
)

# Perform Traces

In [None]:
start_time = time.time()
hist = disco.trace_trajectory(config, particle_state, field_model)
print('took', time.time() - start_time, 's')

# Plotting

In [None]:
plt.hist(-disco._undim_time(hist.t[-1, :].flatten()).to(u.min).value, bins=np.arange(60))
plt.yscale('log')
plt.xlabel('Time Backwards Integrated (minutes)')
plt.ylabel('Particles Finished (Bin Count)')
plt.title('Integration Time Required')
None

In [None]:
step = 1
df = pd.DataFrame(dict(
    x=hist.x[step, :],
    y=hist.y[step, :], 
    z=hist.z[step, :],
    r=np.sqrt(hist.x[step, :]**2 + hist.x[step, :]**2 + hist.z[step, :]**2)
))
df.head()

In [None]:
from matplotlib.colors import LogNorm
df_small = df.sample(n=int(1e4)).sample(frac=1)

plt.figure()
plt.scatter(df_small.x, df_small.y, c=df_small.z, s=.5)
plt.colorbar().set_label('Z (SM)')
plt.xlabel('X (SM)')
plt.ylabel('Y (SM)')
plt.title('Backwards Tracing Endpoints - Equitorial')
#plt.xlim(-15, 15)
#plt.ylim(-15, 15)

plt.figure()
plt.scatter(df_small.x, df_small.z, c=df_small.y, s=.5)
plt.colorbar().set_label('Y (SM)')
plt.xlabel('X (SM)')
plt.ylabel('Z (SM)')
plt.title('Backwards Tracing Endpoints - Meridonal')
#plt.xlim(-15, 15)
#plt.ylim(-15, 15)
