## Solver development
This is a notebook for showcasing the development of Solvers. As an example a plasma oscillation is simulated with an electrostatic solver and an electromagnetic spectral solver with boris pusher. 

## Set up simulation for plasma oscillation 

In [None]:
import pipic
from pipic import consts, types
import numpy as np
from numba import cfunc, carray
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display, clear_output, Image

In [None]:
# simulation variables 
temperature = 1e-6 * consts.electron_mass * consts.light_velocity**2
density = 1e18
debye_length = np.sqrt(temperature / (4 * np.pi * density * consts.electron_charge**2))
plasma_period = np.sqrt(np.pi * consts.electron_mass / (density * consts.electron_charge**2))
l = 256 * debye_length
xmin, xmax = -l / 2, l / 2
nx = 256
dt = plasma_period / 256

In [None]:
# Define functions for initiating the simulation
@cfunc(types.add_particles_callback)
def density_profile(r, data_double, data_int):
    return density

momentum = 0.01 * consts.electron_mass * consts.light_velocity
@cfunc(types.particle_loop_callback)
def add_initial_momentum_m(r, p, w, id, data_double, data_int):
    p[0] -= momentum

@cfunc(types.particle_loop_callback)
def add_initial_momentum_p(r, p, w, id, data_double, data_int):
    p[0] += momentum

##### Define simulation without electrostatic solver


In [None]:
# initialize simulation 
sim=pipic.init(solver='electrostatic_1d', # using ecnergy-conserving (ec) solver
               xmin=xmin,xmax=xmax,
               nx=nx)

# add particles according to density_profile (moving left)
sim.add_particles(name='electron_l',
                  number= 10*nx, # total number of particles to add 
                  density=density_profile.address,
                  charge=-consts.electron_charge,
                  mass=consts.electron_mass,
                  temperature=temperature,)

# add initial momentum to particles
sim.particle_loop(name='electron_l', handler=add_initial_momentum_p.address,)

# add particles according to density_profile (moving right)
sim.add_particles(name='electron_r',
                  number= 10*nx, # total number of particles to add 
                  density=density_profile.address,
                  charge=-consts.electron_charge,
                  mass=consts.electron_mass,
                  temperature=temperature,)

# add initial momentum to particles
sim.particle_loop(name='electron_r', handler=add_initial_momentum_m.address,)

##### Define simulation without spectral em-solver and boris pusher solver

In [None]:

# initialize simulation 
sim_em_solver=pipic.init(solver='fourier_boris', # using ecnergy-conserving (ec) solver
               xmin=xmin,xmax=xmax,
               nx=nx)

# add particles according to density_profile
sim_em_solver.add_particles(name='electron_l',
                  number= 10*nx, # total number of particles to add 
                  density=density_profile.address,
                  charge=-consts.electron_charge,
                  mass=consts.electron_mass,
                  temperature=temperature,)

# add initial momentum to particles
sim_em_solver.particle_loop(name='electron_l', handler=add_initial_momentum_p.address,)

# add particles according to density_profile
sim_em_solver.add_particles(name='electron_r',
                  number= 10*nx, # total number of particles to add 
                  density=density_profile.address,
                  charge=-consts.electron_charge,
                  mass=consts.electron_mass,
                  temperature=temperature,)

# add initial momentum to particles
sim_em_solver.particle_loop(name='electron_r', handler=add_initial_momentum_m.address,)

In [None]:
# define functions and arrays for reading and saving field and particle phase space 
field_dd = np.zeros((nx,), dtype=np.double)  # array for saving Ez-field
@cfunc(types.field_loop_callback)
def field_callback(ind, r, E, B, data_double, data_int):
    # read Ez in the xz plane at y=0
    data = carray(data_double, field_dd.shape, dtype=np.double)
    data[ind[0]] = E[0]

particle_dd = np.zeros((64, nx), dtype=np.double)  # array for saving particle (integrated) phase-space

tt = np.zeros((2,), dtype=np.double)  # array for saving particle energy
@cfunc(types.particle_loop_callback)
def particle_callback_energy(r, p, w, id, data_double, data_int):
    # save particle momentum and position
    data = carray(data_double, tt.shape, dtype=np.double)
    p2 = p[0]**2
    te = w[0]*(np.sqrt(consts.electron_mass**2*consts.light_velocity**4 
                       + consts.light_velocity**2 * p2) - 
               consts.electron_mass*consts.light_velocity**2)

    tm = p[0] * w[0]
    data[0] += te
    data[1] += tm

tte = np.zeros((2,), dtype=np.double)  # array for saving total field energy
@cfunc(types.field_loop_callback)
def field_callback_energy(ind, r, E, B, data_double, data_int):
    # save total field energy
    data = carray(data_double, tte.shape, dtype=np.double)
    data[0] += (E[0]**2 + E[1]**2 + E[2]**2 + 
                B[0]**2 + B[1]**2 + B[2]**2) / (8*np.pi) * (xmax - xmin) / nx
    data[1] += (E[1]*B[2]-E[2]*B[1]) * (xmax - xmin) / nx


pmin = -momentum*2 # minimum momentum
pmax = momentum*2 # maximum momentum
dp = (pmax - pmin) / particle_dd.shape[0] # momentum step
dx = (xmax - xmin) / particle_dd.shape[1] # position step
@cfunc(types.particle_loop_callback)
def particle_callback(r, p, w, id, data_double, data_int):
    data = carray(data_double, particle_dd.shape, dtype=np.double)
    ip = int(particle_dd.shape[0] * (p[0] - pmin) / (pmax - pmin))
    ix = int(particle_dd.shape[1] * (r[0] - xmin) / (xmax - xmin))
    if ip >= 0 and ip < particle_dd.shape[0] and ix < particle_dd.shape[1] and ix >= 0:
        data[ip, ix] += w[0] / (dx * dp) / density #/ (3*density/pmax) / (xmax - xmin) / (ymax - ymin)  # normalize by dz, dp and density
        # save total energy and momentum


#### Create figures and run simulation

In [None]:
# create custom colormaps
import matplotlib.pylab as pl
from matplotlib.colors import ListedColormap

# modify existing Reds colormap with a linearly fading alpha
red = pl.cm.Reds  # original colormap
fading_red = red(np.arange(red.N)) # extract colors
fading_red[:, -1] = np.linspace(0, 1, red.N) # modify alpha
fading_red = ListedColormap(fading_red) # convert to colormap

# modify existing Blues colormap with a linearly fading alpha
blue = pl.cm.Blues  # original colormap
fading_blue = blue(np.arange(blue.N)) # extract colors
fading_blue[:, -1] = np.linspace(0, 1, blue.N) # modify alpha
fading_blue = ListedColormap(fading_blue) # convert to colormap

In [None]:
# initialize plot
fig = plt.figure(figsize=(10, 8))
gs = fig.add_gridspec(3, 2)  # 3 rows, 2 columns

# First two rows: 4 subplots
axs = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1]), 
       fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[1, 1]),
       fig.add_subplot(gs[2, :])]

field_amplitude = 5e4

# with ES-solver
x_axis = np.linspace(xmin, xmax, nx)
Ez_plot = axs[2].plot(x_axis,field_dd[:])[0]
axs[2].set_ylim(-field_amplitude, field_amplitude)
zpz_plot_r = axs[0].imshow(particle_dd / (3/pmax),  #!!!!!
             extent=[xmin, xmax,pmin, pmax], 
             aspect='auto', origin='lower', 
             cmap=fading_blue,vmin=0, vmax=1,
             interpolation = 'none')

zpz_plot_l = axs[0].imshow(particle_dd / (3/pmax),  #!!!!!
             extent=[xmin, xmax,pmin, pmax], 
             aspect='auto', origin='lower', 
             cmap=fading_red,vmin=0, vmax=1,
             interpolation = 'none')

# with EM-solver
Ez_plot_em = axs[3].plot(x_axis[:],field_dd[:])[0]
axs[3].set_ylim(-field_amplitude, field_amplitude)
zpz_plot_em_r = axs[1].imshow(particle_dd / (3/pmax), 
             extent=[xmin, xmax,pmin, pmax], 
             aspect='auto', origin='lower', 
             cmap=fading_blue,vmin=0, vmax=1,
             interpolation = 'none')

zpz_plot_em_l = axs[1].imshow(particle_dd / (3/pmax), 
             extent=[xmin, xmax,pmin, pmax], 
             aspect='auto', origin='lower', 
             cmap=fading_red,vmin=0, vmax=1,
             interpolation = 'none')


# set titles
axs[0].set_title('Electrostatic solver')
axs[1].set_title('Spectral solver with  boris pusher')
axs[2].set_xlabel('x (cm)')
axs[3].set_xlabel('x (cm)')
axs[0].set_ylabel('$p_x$ (cm g/s)')
axs[3].set_ylabel('$E_x$ (StatV/cm)')

fig.tight_layout()

In [None]:
# ===============================SIMULATION======================================
simulation_steps = int(8 * plasma_period / dt)
frames = simulation_steps // 8 # number of frames to show in the animation
counter = 0

# arrays for saving total energy (particle and field)
te_es = []
te_em = []
tef_es = []
tef_em = []

def animate(i):
    sim.advance(time_step=dt, number_of_iterations=8,use_omp=True)
    sim_em_solver.advance(time_step=dt, number_of_iterations=8,use_omp=True)

    
    # read diagnostics for ES solver
    tt.fill(0)
    tte.fill(0)
    particle_dd.fill(0)
    # plot field
    sim.field_loop(handler=field_callback.address, 
                   data_double=pipic.addressof(field_dd),
                   use_omp=True)
    Ez_plot.set_ydata(field_dd[:])
    # plot particle phase space
    sim.particle_loop(name='electron_l', 
                      handler=particle_callback.address, 
                      data_double=pipic.addressof(particle_dd))
    zpz_plot_l.set_data(particle_dd / (5/pmax))
    particle_dd.fill(0)
    sim.particle_loop(name='electron_r', 
                      handler=particle_callback.address, 
                      data_double=pipic.addressof(particle_dd))
    zpz_plot_r.set_data(particle_dd / (5/pmax))
    # read total energy and momentum
    sim.particle_loop(name='electron_l',
                        handler=particle_callback_energy.address,
                        data_double=pipic.addressof(tt))
    sim.particle_loop(name='electron_r',
                        handler=particle_callback_energy.address,
                        data_double=pipic.addressof(tt))
    sim.field_loop(handler=field_callback_energy.address,
                     data_double=pipic.addressof(tte))
    tef_es.append(tte[0])
    te_es.append(tt[0])

    # read diagnostics for EM solver
    tt.fill(0)
    tte.fill(0)
    particle_dd.fill(0)
    # plot field
    sim_em_solver.field_loop(handler=field_callback.address, 
                   data_double=pipic.addressof(field_dd),
                   use_omp=True)
    Ez_plot_em.set_ydata(field_dd)
    # plot particle phase space
    sim_em_solver.particle_loop(name='electron_l', 
                      handler=particle_callback.address, 
                      data_double=pipic.addressof(particle_dd))
    zpz_plot_em_l.set_data(particle_dd / (5/pmax))
    particle_dd.fill(0)
    sim_em_solver.particle_loop(name='electron_r', 
                      handler=particle_callback.address, 
                      data_double=pipic.addressof(particle_dd))
    zpz_plot_em_r.set_data(particle_dd / (5/pmax))
    # read total energy and momentum
    sim_em_solver.particle_loop(name='electron_l',
                        handler=particle_callback_energy.address,
                        data_double=pipic.addressof(tt))
    sim_em_solver.particle_loop(name='electron_r',
                        handler=particle_callback_energy.address,
                        data_double=pipic.addressof(tt))
    sim_em_solver.field_loop(handler=field_callback_energy.address,
                        data_double=pipic.addressof(tte))
    tef_em.append(tte[0])
    te_em.append(tt[0])


    axs[4].cla()
    l, = axs[4].plot(np.arange(len(te_es))*8*dt, np.array(te_es)+np.array(tef_es), label='Total energy (ES)', color='tab:blue')
    ll, = axs[4].plot(np.arange(len(te_em))*8*dt, np.array(te_em)+np.array(tef_em), label='Total energy (EM)', color='tab:red')
    axs[4].set_ylabel('Total energy (erg)')
    axs[4].set_xlabel('Time (s)')
    axs[4].legend(loc='center right',frameon=False)
    global counter
    clear_output()
    if counter <= frames:
        display(HTML('<pre> Progress: ' + "{:.2f}".format(100*counter/frames) + '</pre>'), display_id = True)
    counter += 1
    return 
    
ani = animation.FuncAnimation(fig, animate, frames=frames, interval = 40)


html = HTML(ani.to_jshtml())
display(html)
plt.close()
