# Tracking through elements with autodiff

In [1]:
import numpy as np
import torch
from collections import namedtuple
from torch.autograd.functional import jacobian
from pytao import Tao

np.set_printoptions(precision= 16, suppress=False)
torch.set_printoptions(precision= 16, sci_mode=True)
torch.__version__, np.__version__

('1.10.2', '1.22.2')

# Constants

In [2]:
c_light = 2.99792458e8 #speed of light in m/s
m_e = 0.510998950e6 #electron mass in eV

# Tuples

In [3]:
Drift = namedtuple('Drift', 'L')
Quadrupole = namedtuple('Quadrupole', 'L K1 NUM_STEPS')
Particle = namedtuple('Particle', 'x px y py z pz s t p0c mc2')
#Canonical phase space coordinates as defined in Bmad manual section 15.4.2. Added s t p0c and mc2 = 0

## PyTorch tensor test Particle:

In [4]:
np.random.seed(0)
pvec = np.random.rand(10_000, 6) #test particles
s = 0.0
t = 0.0
p0c = 4.0E+07 #Reference particle momentum in eV
mc2 = 1*m_e # electron mass in eV
ts = torch.tensor(s, dtype=torch.float64)
tt = torch.tensor(t, dtype=torch.float64)
tp0c = torch.tensor(p0c, dtype=torch.float64)
tmc2 = torch.tensor(mc2, dtype=torch.float64) #necessary? 
tpvec = torch.tensor(pvec, requires_grad=True, dtype=torch.float64)
p_torch = Particle(*tpvec.T, ts, tt, tp0c, tmc2)
p_torch

Particle(x=tensor([5.4881350392732475e-01, 4.3758721126269251e-01, 5.6804456109393231e-01,
         ..., 1.6739176591167459e-01, 8.2838659877820775e-01,
        4.2741875758117731e-01], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), px=tensor([7.1518936637241948e-01, 8.9177300078207977e-01, 9.2559663829266103e-01,
         ..., 3.5145149927123487e-01, 3.5342664838276150e-01,
        8.8324715311888191e-01], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), y=tensor([6.0276337607164387e-01, 9.6366276050102928e-01, 7.1036058197886942e-02,
         ..., 2.7008431955854806e-01, 9.2013387734489693e-01,
        3.8008122694143009e-01], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), py=tensor([5.4488318299689686e-01, 3.8344151882577771e-01, 8.7129299701540708e-02,
         ..., 8.7026390127975972e-02, 8.1374798653759683e-01,
        1.6900412983681179e-01], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), z=tensor([4.2365479933890471e-01, 7.9172503808266459e-01,

# Drift

## `track_a_drift`

In [5]:
def make_track_a_drift(lib):
    """
    Makes track_a_drift given the library lib
    """
    
    sqrt = lib.sqrt
    
    def track_a_drift(p_in, drift):
        """
        Tracks the incoming Particle p_in though drift element
        and returns the outgoing particle. See eqs 24.58 in bmad manual
        """
        L = drift.L
        
        s = p_in.s
        t = p_in.t
        p0c = p_in.p0c
        mc2 = p_in.mc2
        
        x = p_in.x
        px = p_in.px
        y = p_in.y
        py = p_in.py
        z = p_in.z
        pz = p_in.pz
        
        P = 1 + pz #Particle's total momentum over p0
        Px = px / P #Particle's 'x' momentum over p0
        Py = py / P #Particle's 'y' momentum over p0
        Pxy2 = Px**2 + Py**2 #Particle's transverse mometum^2 over p0^2
        Pl = sqrt(1-Pxy2)  #Particle's longitudinal momentum over p0
        
        x = x + L * Px / Pl
        y = y + L * Py / Pl
        
        #beta = sqrt( 1 - 1 /( 1 + (P*p0c/mc2)**2 ) )
        #beta_ref = sqrt( 1 - 1 /( 1 + (p0c/mc2)**2 ) )
        #better alternative:
        beta = P * p0c / sqrt( (P*p0c)**2 + mc2**2) # hypot
        beta_ref = p0c / sqrt( p0c**2 + mc2**2)
        z = z + L * ( beta/beta_ref - 1.0/Pl )
        s = s + L
        t = t + L / ( beta * Pl * c_light )
        
        return Particle(x, px, y, py, z, pz, s, t, p0c, mc2)
    
    return track_a_drift

# Specialized functions
track_a_drift_torch = make_track_a_drift(lib=torch)

## drift test

drift:

In [6]:
L=1.0 # Drift length in m
d1 = Drift(torch.tensor(L, dtype=torch.float64))
d1

Drift(L=tensor(1.0000000000000000e+00, dtype=torch.float64))

test particle:

In [7]:
s = 0.0 #initial s
t = 0.0 #initial t
p0c = 4.0E+07 #Reference particle momentum in eV
mc2 = 1*m_e # electron mass in eV
ts = torch.tensor(s, dtype=torch.float64)
tt = torch.tensor(t, dtype=torch.float64)
tp0c = torch.tensor(p0c, dtype=torch.float64)
tmc2 = torch.tensor(mc2, dtype=torch.float64)
pvec1 = [2e-3,3e-3,-3e-3,-1e-3,2e-3,-2e-3] 
tvec1 = torch.tensor(pvec1, requires_grad=True, dtype=torch.float64)
tvec1

tensor([2.0000000000000000e-03, 3.0000000000000001e-03, -3.0000000000000001e-03,
        -1.0000000000000000e-03, 2.0000000000000000e-03, -2.0000000000000000e-03],
       dtype=torch.float64, requires_grad=True)

`track_a_drift` test:

In [8]:
p_out=track_a_drift_torch(Particle(*tvec1,ts, tt, tp0c, tmc2), d1)

In [9]:
x_me=torch.hstack([p_out.x,p_out.px,p_out.y,p_out.py,p_out.z,p_out.pz]).detach()
x_me

tensor([5.0060271145229325e-03, 3.0000000000000001e-03, -4.0020090381743109e-03,
        -1.0000000000000000e-03, 1.9946525738924175e-03, -2.0000000000000000e-03],
       dtype=torch.float64)

In [10]:
tao = Tao('-lat test_drift.bmad -noplot')
tao.cmd('set particle_start x='+str(pvec1[0]))
tao.cmd('set particle_start px='+str(pvec1[1]))
tao.cmd('set particle_start y='+str(pvec1[2]))
tao.cmd('set particle_start py='+str(pvec1[3]))
tao.cmd('set particle_start z='+str(pvec1[4]))
tao.cmd('set particle_start pz='+str(pvec1[5]))
orbit_out=tao.orbit_at_s(ele=1)
orbit_out

{'x': 0.00500602711452293,
 'px': 0.003,
 'y': -0.00400200903817431,
 'py': -0.001,
 'z': 0.00199465257389236,
 'pz': -0.002,
 'spin': array([0., 0., 0.]),
 'field': array([0., 0.]),
 'phase': array([0., 0.]),
 's': 1.0,
 't': 3.32925913921535e-09,
 'charge': 0.0,
 'path_len': -1.64365049348802e-16,
 'p0c': 40000000.0,
 'beta': 0.999918082707881,
 'ix_ele': 1,
 'state': 'Alive',
 'direction': 1,
 'species': 'Electron',
 'location': 'Downstream_End'}

In [11]:
x_tao=torch.tensor([orbit_out['x'],orbit_out['px'],orbit_out['y'],orbit_out['py'],orbit_out['z'],orbit_out['pz']],dtype=torch.float64)
x_tao

tensor([5.0060271145229299e-03, 3.0000000000000001e-03, -4.0020090381743100e-03,
        -1.0000000000000000e-03, 1.9946525738923598e-03, -2.0000000000000000e-03],
       dtype=torch.float64)

In [12]:
torch.allclose(x_me, x_tao)

True

In [13]:
torch.eq(x_me, x_tao)

tensor([False,  True, False,  True, False,  True])

Jacobian:

In [14]:
f2 = lambda x: track_a_drift_torch(Particle(*x,ts, tt, tp0c, tmc2), d1)[:6]
J = jacobian(f2, tvec1)

In [15]:
mat_me = torch.vstack(J)
mat_me

tensor([[ 1.0000000000000000e+00,  1.0020180925273929e+00,
          0.0000000000000000e+00, -3.0181176940051169e-06,
          0.0000000000000000e+00, -3.0120814586171068e-03],
        [ 0.0000000000000000e+00,  1.0000000000000000e+00,
          0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  0.0000000000000000e+00],
        [ 0.0000000000000000e+00, -3.0181176940051169e-06,
          1.0000000000000000e+00,  1.0020100442135422e+00,
          0.0000000000000000e+00,  1.0040271528723688e-03],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  1.0000000000000000e+00,
          0.0000000000000000e+00,  0.0000000000000000e+00],
        [ 0.0000000000000000e+00, -3.0120814586171063e-03,
          0.0000000000000000e+00,  1.0040271528723688e-03,
          1.0000000000000000e+00,  1.7421652474771806e-04],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  0.00000000000000

In [16]:
drift_tao = tao.matrix(0,1)
mat_tao = torch.tensor(drift_tao['mat6'], dtype=torch.float64)
mat_tao

tensor([[ 1.0000000000000000e+00,  1.0020180925273900e+00,
          0.0000000000000000e+00, -3.0181176940051199e-06,
          0.0000000000000000e+00, -3.0120814586171098e-03],
        [ 0.0000000000000000e+00,  1.0000000000000000e+00,
          0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  0.0000000000000000e+00],
        [ 0.0000000000000000e+00, -3.0181176940051199e-06,
          1.0000000000000000e+00,  1.0020100442135400e+00,
          0.0000000000000000e+00,  1.0040271528723699e-03],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  1.0000000000000000e+00,
          0.0000000000000000e+00,  0.0000000000000000e+00],
        [ 0.0000000000000000e+00, -3.0120814586171098e-03,
          0.0000000000000000e+00,  1.0040271528723699e-03,
          1.0000000000000000e+00,  1.7421652474810300e-04],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  0.00000000000000

In [17]:
torch.eq(mat_me, mat_tao)

tensor([[ True, False,  True, False,  True, False],
        [ True,  True,  True,  True,  True,  True],
        [ True, False,  True, False,  True, False],
        [ True,  True,  True,  True,  True,  True],
        [ True, False,  True, False,  True, False],
        [ True,  True,  True,  True,  True,  True]])

In [18]:
torch.allclose(mat_me, mat_tao)

True

# Quadrupole

## `track_a_quadrupole`

In [19]:
def make_track_a_quadrupole(lib):
    """
    Makes track_a_quadrupole given the library lib
    """
    sqrt = lib.sqrt
    absolute = lib.abs
    sin = lib.sin
    cos = lib.cos
    sinh = lib.sinh
    cosh = lib.cosh
    
    def quad_mat2_calc(k1, length, rel_p):
        """
        Returns 2x2 transfer matrix elements aij and the coefficients to calculate 
        the change in z position.
        Input: 
            k1 -- Quad strength: k1 > 0 ==> defocus
            length -- Quad length
            rel_p -- Relative momentum P/P0
        Output:
            a11, a12, a21, a22 -- transfer matrix elements
            c1, c2, c3 -- second order derivatives of z such that 
                        z = c1 * x_0^2 + c2 * x_0 * px_0 + c3* px_0^2
        """  
        sqrt_k = sqrt(absolute(k1))
        sk_l = sqrt_k * length
        if (absolute(sk_l) < 1e-10):
            k_l2 = k1 * length**2
            cx = 1 + k_l2 / 2
            sx = (1 + k_l2 / 6) * length
        elif (k1 < 0):
            cx = cos(sk_l)
            sx = sin(sk_l) / sqrt_k
        else:
            cx = cosh(sk_l)
            sx = sinh(sk_l) / sqrt_k
        
        a11 = cx
        a12 = sx / rel_p
        a21 = k1 * sx * rel_p
        a22 = cx
            
        c1 = k1 * (-cx * sx + length) / 4
        c2 = -k1 * sx**2 / (2 * rel_p)
        c3 = -(cx * sx + length) / (4 * rel_p**2)

        return a11, a12, a21, a22, c1, c2, c3
    
    
    def low_energy_z_correction(pz, p0c, mass, ds):
        """
        Corrects the change in z-coordinate due to speed < c_light. 
        Input:
            p0c -- reference particle momentum in eV
            mass -- particle mass in eV
        Output: 
            dz -- dz = (ds - d_particle) + ds * (beta - beta_ref) / beta_ref
        """
        beta = (1+pz) * p0c / sqrt( ((1+pz)*p0c)**2 + mc2**2)
        beta0 = p0c / sqrt( p0c**2 + mc2**2)
        e_tot = sqrt(p0c**2+mass**2)

        if (mass * (beta0*pz)**2 < 3e-7 * e_tot):
            f = beta0**2 * (2 * beta0**2 - (mass / e_tot)**2 / 2)
            dz = ds * pz * (1 - 3 * (pz * beta0**2) / 2 + pz**2 * f) * (mass / e_tot)**2
        else:
            dz = ds * (beta - beta0) / beta0
        
        return dz
    
    
    def track_a_quadrupole(p_in, quad):
        """
        Tracks the incoming Particle p_in though quad element
        and returns the outgoing particle.
        """
        l = quad.L
        k1 = quad.K1 #pobably the mistake is here. see lines 79 and 46 in track_a_quadrupole.f90
        n_step = quad.NUM_STEPS #number of divisions
        step_len = l/n_step #length of division
        
        b1=k1*l
        
        s = p_in.s
        t = p_in.t
        p0c = p_in.p0c
        mc2 = p_in.mc2
        
        x = p_in.x
        px = p_in.px
        y = p_in.y
        py = p_in.py
        z = p_in.z
        pz = p_in.pz
        
        for i in range(n_step):
            rel_p = 1 + pz #Particle's relative momentum (P/P0)
            k1=b1/(l*rel_p)
            
            tx11, tx12, tx21, tx22, dz_x1, dz_x2, dz_x3 = quad_mat2_calc(-k1, step_len, rel_p)
            ty11, ty12, ty21, ty22, dz_y1, dz_y2, dz_y3 = quad_mat2_calc( k1, step_len, rel_p)
            
            z = ( z +
                dz_x1 * x**2 + dz_x2 * x * px + dz_x3 * px**2 +
                dz_y1 * y**2 + dz_y2 * y * py + dz_y3 * py**2 )
            x_next = tx11 * x + tx12 * px
            px_next = tx21 * p_in.x + tx22 * px
            y_next = ty11 * y + ty12 * py
            py_next = ty21 * y + ty22 * py
            x=x_next
            px=px_next
            y=y_next
            py=py_next
            
            z = z + low_energy_z_correction(pz, p0c, mc2, step_len)

        
        return Particle(x, px, y, py, z, pz, s, t, p0c, mc2)
    
    return track_a_quadrupole

# Specialized functions
track_a_quadrupole_torch = make_track_a_quadrupole(lib=torch)

## quadrupole test

In [20]:
#Create quad 1
L = 0.1 #Length in m
K1 = 10 #Quad focusing strength. Positive is focusing in x
NUM_STEPS = 1 #number of divisions for tracking. 1 is bmad default when there are no other multipoles
q1 = Quadrupole(torch.tensor(L, dtype=torch.float64), torch.tensor(K1, dtype=torch.float64), NUM_STEPS)
q1

Quadrupole(L=tensor(1.0000000000000001e-01, dtype=torch.float64), K1=tensor(1.0000000000000000e+01, dtype=torch.float64), NUM_STEPS=1)

Test particle:

In [21]:
s = 0.0 #initial s
t = 0.0 #initial t
p0c = 4.0E+07 #Reference particle momentum in eV
mc2 = 1*m_e # electron mass in eV
ts = torch.tensor(s, dtype=torch.float64)
tt = torch.tensor(t, dtype=torch.float64)
tp0c = torch.tensor(p0c, dtype=torch.float64)
tmc2 = torch.tensor(mc2, dtype=torch.float64) 
pvec1 = [2e-3,3e-3,-3e-3,-1e-3,2e-3,-2e-3] 
tvec1 = torch.tensor(pvec1, requires_grad=True, dtype=torch.float64)
tvec1

tensor([2.0000000000000000e-03, 3.0000000000000001e-03, -3.0000000000000001e-03,
        -1.0000000000000000e-03, 2.0000000000000000e-03, -2.0000000000000000e-03],
       dtype=torch.float64, requires_grad=True)

`track_a_quadrupole` test

In [22]:
p_out=track_a_quadrupole_torch(Particle(*tvec1,ts, tt, tp0c, tmc2), q1)

In [23]:
x_me=torch.hstack([p_out.x,p_out.px,p_out.y,p_out.py,p_out.z,p_out.pz]).detach()
x_me

tensor([2.1962397193025516e-03, 8.8418342648533682e-04, -3.2534419732692245e-03,
        -4.1008717415728594e-03, 1.9993946642253308e-03, -2.0000000000000000e-03],
       dtype=torch.float64)

In [24]:
tao = Tao('-lat test_quad.bmad -noplot')
tao.cmd('set particle_start x='+str(pvec1[0]))
tao.cmd('set particle_start px='+str(pvec1[1]))
tao.cmd('set particle_start y='+str(pvec1[2]))
tao.cmd('set particle_start py='+str(pvec1[3]))
tao.cmd('set particle_start z='+str(pvec1[4]))
tao.cmd('set particle_start pz='+str(pvec1[5]))
orbit_out = tao.orbit_at_s(ele=1)

In [25]:
%%tao
show ele 1

-------------------------
Tao> show ele 1
Element # 1
Element Name: Q1
Key: Quadrupole
S_start, S:      0.000000,      0.100000
Ref_time:  3.335913E-10

Attribute values [Only non-zero values shown]:
    1  L                           =  1.0000000E-01 m
    4  K1                          =  1.0000000E+01 1/m^2    45  B1_GRADIENT                 = -1.3342564E+00 T/m
   10  FRINGE_TYPE                 =  None (1)               11  FRINGE_AT                   =  No_End (4)
   13  SPIN_FRINGE_ON              =  F (0)
   47  PTC_CANONICAL_COORDS        =  T (1)
   50  DELTA_REF_TIME              =  3.3359131E-10 sec
   53  P0C                         =  4.0000000E+07 eV           BETA                        =  9.9991841E-01
   54  E_TOT                       =  4.0003264E+07 eV           GAMMA                       =  7.8284435E+01
   65  INTEGRATOR_ORDER            = 0
   67  DS_STEP                     =  1.0000000E-01 m        66  NUM_STEPS                   = 1
   68  CSR_DS_STEP       

In [26]:
x_tao=torch.tensor([orbit_out['x'],orbit_out['px'],orbit_out['y'],orbit_out['py'],orbit_out['z'],orbit_out['pz']],dtype=torch.float64)
x_tao

tensor([2.1962397193025498e-03, 8.8418342648533704e-04, -3.2534419732692201e-03,
        -4.1008717415728603e-03, 1.9993946642253299e-03, -2.0000000000000000e-03],
       dtype=torch.float64)

In [27]:
x_tao=torch.tensor([orbit_out['x'],orbit_out['px'],orbit_out['y'],orbit_out['py'],orbit_out['z'],orbit_out['pz']],dtype=torch.float64)
x_tao

tensor([2.1962397193025498e-03, 8.8418342648533704e-04, -3.2534419732692201e-03,
        -4.1008717415728603e-03, 1.9993946642253299e-03, -2.0000000000000000e-03],
       dtype=torch.float64)

In [28]:
torch.allclose(x_me, x_tao)

True

In [29]:
torch.isclose(x_me, x_tao)

tensor([True, True, True, True, True, True])

In [30]:
f3 = lambda x: track_a_quadrupole_torch(Particle(*x,ts, tt, tp0c, tmc2), q1)[:6]
J = jacobian(f3, tvec1)

In [31]:
mat_me = torch.vstack(J)
mat_me

tensor([[ 9.5031674318754977e-01,  9.8535410975817278e-02,
          0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00, -1.9248585503177227e-04],
        [-9.8338340153865633e-01,  9.5031674318754977e-01,
          0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  1.1496639089440819e-04],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          1.0505199385060540e+00,  1.0188215775106230e-01,
          0.0000000000000000e+00,  2.5690939373378331e-04],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          1.0167839343556018e+00,  1.0505199385060540e+00,
          0.0000000000000000e+00,  1.0174858226574043e-04],
        [ 8.0032908698420230e-05, -1.9425079143865157e-04,
          1.5433242974866441e-04,  2.5952207539749645e-04,
          1.0000000000000000e+00,  1.7567092021426940e-05],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  0.00000000000000

In [32]:
drift_tao = tao.matrix(0,1)
mat_tao = torch.tensor(drift_tao['mat6'], dtype=torch.float64)
mat_tao

tensor([[ 9.5031674318754999e-01,  9.8535410975817306e-02,
          0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00, -1.9248585503177200e-04],
        [-9.8338340153865600e-01,  9.5031674318754999e-01,
          0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  1.1496639089440800e-04],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          1.0505199385060500e+00,  1.0188215775106201e-01,
          0.0000000000000000e+00,  2.5690939373378298e-04],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          1.0167839343556000e+00,  1.0505199385060500e+00,
          0.0000000000000000e+00,  1.0174858226574100e-04],
        [ 8.0032908698420203e-05, -1.9425079143865201e-04,
          1.5433242974866401e-04,  2.5952207539749601e-04,
          1.0000000000000000e+00,  1.7567092021426899e-05],
        [ 0.0000000000000000e+00,  0.0000000000000000e+00,
          0.0000000000000000e+00,  0.00000000000000

In [33]:
torch.eq(mat_me, mat_tao)

tensor([[False, False,  True,  True,  True, False],
        [False, False,  True,  True,  True, False],
        [ True,  True, False, False,  True, False],
        [ True,  True, False, False,  True, False],
        [False, False, False, False,  True, False],
        [ True,  True,  True,  True,  True,  True]])

In [34]:
torch.allclose(mat_me, mat_tao)

True

# misc