In [None]:
import numpy as np
import torch
from collections import namedtuple

torch.__version__, np.__version__

# Particle named tuple

In [None]:
Particle = namedtuple('Particle', 'x px y py z pz')
#Canonical phase space coordinates as defined in Bmad manual section 15.4.2

In [None]:
np.random.seed(0)
pvec = np.random.rand(10_000, 6) #test particles

## NumPy array test Particle

In [None]:
p_np = Particle(*pvec.T)
p_np

## PyTorch tenstor test Particle

In [None]:
p_torch = Particle(*torch.tensor(pvec.T))
p_torch

# Reference particle

## Option 1: 

Create a Coord tuple with the first entry being a Particle tuple and the other entries being the reference particle parameters (similar to the coord_struct in Bmad)

In [None]:
Coord = namedtuple('Coord', 'particle s t p0c mc2')
mc2 = 0.51099895000e6 #electron mass energy in eV
p0c = 10e6 #reference particle momentum in eV

### NumPy array Coord

In [None]:
orb_np = Coord(p_np, 0, 0, p0c, mc2)
orb_np

### PyTorch tensor Coord

In [None]:
orb_torch = Coord(p_torch, 0, 0, p0c, mc2)
orb_torch

## Option 2: 

Create a separate tuple for the reference particle's parameters:

In [None]:
ReferenceParticle = namedtuple('ReferenceParticle', 's t p0c mc2')
ref_par = ReferenceParticle(0,0,p0c,mc2)
ref_par

# track_a_drift

In [None]:
def make_track_a_drift(lib):
    """
    Makes track_a_drift given the library lib
    """
    
    sqrt = lib.sqrt
    
    def track_a_drift(orb, L):
        """
        Tracks the incoming Particle tuple p_in though a drift of length params.L
        and returns the outgoing Particle. See eqs 24.58 in bmad manual
        """
        L = params.L
        p0c = params.p0c
        
        x = p_in.x
        px = p_in.px #Particle's 'x' momentum over reference momentum P0
        y = p_in.y
        py = p_in.py #Particle's 'y' momentum over P0
        z = p_in.z
        pz = p_in.pz #Delta over P0
        
        P = 1 + pz #Particle's total momentum over P0
        Px = px / P #Particle's 'x' momentum over total momentum
        Py = py / P #Particle's 'y' momentum over total momentum
        Pxy2 = Px**2 + Py**2 #Particle's transverse mometun over total momentum (squared)
        Pl = sqrt(1-Pxy2)  #Particle's longitudinal momentum over total momentum
        
        x = x + L * px / ((1+pz)*Pl)
        y = y + L * Py / Pl
        z = z - L *( - 1.0/Pl) #should fix by taking the effect of include_ref_motion (see eq 24.58)
        
        return Particle(x, px, y, py, z, pz)
    
    return track_a_drift

# Specialized functions
track_a_drift_torch = make_track_a_drift(lib=torch)