Tracking elements with autodifferentiable capabilities

In [1]:
import numpy as np
import torch
from collections import namedtuple
from torch.autograd.functional import jacobian

torch.set_printoptions(precision=10)
torch.__version__, np.__version__

('1.10.2', '1.22.2')

# Constants

In [2]:
c_light = 2.99792458e8 #speed of light in m/s
m_e = 0.510998950e6 #electron mass in eV

# Tuples

In [3]:
Drift = namedtuple('Drift', 'L')
Quadrupole = namedtuple('Quadrupole', 'L K1 NUM_STEPS')
Particle = namedtuple('Particle', 'x px y py z pz s t p0c mc2')
#Canonical phase space coordinates as defined in Bmad manual section 15.4.2. Added s t p0c and mc2 = 0

## PyTorch tensor test Particle:

In [4]:
np.random.seed(0)
pvec = np.random.rand(10_000, 6) #test particles
s = 0.0
t = 0.0
p0c = 4.1996891E+07 #Reference particle momentum in eV
mc2 = 1*m_e # electron mass in eV
ts = torch.tensor(s, dtype=torch.float64)
tt = torch.tensor(t, dtype=torch.float64)
tp0c = torch.tensor(p0c, dtype=torch.float64)
tmc2 = torch.tensor(mc2, dtype=torch.float64) #necessary? 
tpvec = torch.tensor(pvec, requires_grad=True, dtype=torch.float64)
p_torch = Particle(*tpvec.T, ts, tt, tp0c, tmc2)
p_torch

Particle(x=tensor([0.5488135039, 0.4375872113, 0.5680445611,  ..., 0.1673917659,
        0.8283865988, 0.4274187576], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), px=tensor([0.7151893664, 0.8917730008, 0.9255966383,  ..., 0.3514514993,
        0.3534266484, 0.8832471531], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), y=tensor([0.6027633761, 0.9636627605, 0.0710360582,  ..., 0.2700843196,
        0.9201338773, 0.3800812269], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), py=tensor([0.5448831830, 0.3834415188, 0.0871292997,  ..., 0.0870263901,
        0.8137479865, 0.1690041298], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), z=tensor([0.4236547993, 0.7917250381, 0.0202183974,  ..., 0.3048581211,
        0.6832710885, 0.2639306656], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), pz=tensor([0.6458941131, 0.5288949198, 0.8326198455,  ..., 0.1314423113,
        0.8690973434, 0.3011667628], dtype=torch.float64,
       grad_fn=<UnbindBackward0>)

# Drift

## track_a_drift

In [9]:
def make_track_a_drift(lib):
    """
    Makes track_a_drift given the library lib
    """
    
    sqrt = lib.sqrt
    
    def track_a_drift(p_in, drift):
        """
        Tracks the incoming Particle p_in though drift element
        and returns the outgoing particle. See eqs 24.58 in bmad manual
        """
        L = drift.L
        
        s = p_in.s
        t = p_in.t
        p0c = p_in.p0c
        mc2 = p_in.mc2
        
        x = p_in.x
        px = p_in.px
        y = p_in.y
        py = p_in.py
        z = p_in.z
        pz = p_in.pz
        
        P = 1 + pz #Particle's total momentum over p0
        Px = px / P #Particle's 'x' momentum over p0
        Py = py / P #Particle's 'y' momentum over p0
        Pxy2 = Px**2 + Py**2 #Particle's transverse mometum^2 over p0^2
        Pl = sqrt(1-Pxy2)  #Particle's longitudinal momentum over p0
        
        x = x + L * Px / Pl
        y = y + L * Py / Pl
        
        #beta = sqrt( 1 - 1 /( 1 + (P*p0c/mc2)**2 ) )
        #beta_ref = sqrt( 1 - 1 /( 1 + (p0c/mc2)**2 ) )
        beta = P * p0c / sqrt( (P*p0c)**2 + mc2**2)
        beta_ref = p0c / sqrt( p0c**2 + mc2**2)
        z = z + L * ( beta/beta_ref - 1.0/Pl )
        s = s + L
        t = t + L / ( beta * Pl * c_light )
        
        return Particle(x, px, y, py, z, pz, s, t, p0c, mc2)
    
    return track_a_drift

# Specialized functions
track_a_drift_torch = make_track_a_drift(lib=torch)

## drift test

drift:

In [10]:
L=1.0 # Drift length in m
d1 = Drift(torch.tensor(L, dtype=torch.float64))
d1

Drift(L=tensor(1., dtype=torch.float64))

test particle:

In [21]:
s = 0.0 #initial s
t = 0.0 #initial t
p0c = 4.1996891E+07 #Reference particle momentum in eV
mc2 = 1*m_e # electron mass in eV
ts = torch.tensor(s, dtype=torch.float64)
tt = torch.tensor(t, dtype=torch.float64)
tp0c = torch.tensor(p0c, dtype=torch.float64)
tmc2 = torch.tensor(mc2, dtype=torch.float64) #necessary? 
pvec1 = [2e-3,3e-3,-3e-3,-1e-3,2e-3,-2e-3] 
tvec1 = torch.tensor(pvec1, requires_grad=True, dtype=torch.float64)
tvec1

tensor([ 0.0020000000,  0.0030000000, -0.0030000000, -0.0010000000,
         0.0020000000, -0.0020000000], dtype=torch.float64, requires_grad=True)

track_a_drift test:

In [22]:
track_a_drift_torch(Particle(*tvec1,ts, tt, tp0c, tmc2), d1)

Particle(x=tensor(0.0050060271, dtype=torch.float64, grad_fn=<AddBackward0>), px=tensor(0.0030000000, dtype=torch.float64, grad_fn=<UnbindBackward0>), y=tensor(-0.0040020090, dtype=torch.float64, grad_fn=<AddBackward0>), py=tensor(-0.0010000000, dtype=torch.float64, grad_fn=<UnbindBackward0>), z=tensor(8.3583408871e-06, dtype=torch.float64, grad_fn=<AddBackward0>), pz=tensor(-0.0020000000, dtype=torch.float64, grad_fn=<UnbindBackward0>), s=tensor(1., dtype=torch.float64), t=tensor(4.0804661003e-08, dtype=torch.float64, grad_fn=<AddBackward0>), p0c=tensor(41996.8910000000, dtype=torch.float64), mc2=tensor(510998.9500000000, dtype=torch.float64))

Jacobian:

In [None]:
f2 = lambda x: track_a_drift_torch(Particle(*x,ts, tt, tp0c, tmc2), d1)[:6]
with torch.autograd.detect_anomaly():
    J = jacobian(f2, tvec1)

In [None]:
J

# Quadrupole

## track_a_quadrupole

In [None]:
def make_track_a_quadrupole(lib):
    """
    Makes track_a_quadrupole given the library lib
    """
    sqrt = lib.sqrt
    absolute = lib.abs
    sin = lib.sin
    cos = lib.cos
    sinh = lib.sinh
    cosh = lib.cosh
    
    def quad_mat2_calc(k1, length, rel_p):
        """
        Returns 2x2 transfer matrix elements aij and the coefficients to calculate 
        the change in z position.
        Input: 
            k1 -- Quad strength: k1 > 0 ==> defocus
            length -- Quad length
            rel_p -- Relative momentum P/P0
        Output:
            a11, a12, a21, a22 : transfer matrix elements
            c1, c2, c3: second order derivatives of z such that 
                        z = c1 * x_0^2 + c2 * x_0 * px_0 + c3* px_0^2
        """  
        sqrt_k = sqrt(absolute(k1))
        sk_l = sqrt_k * length
        if (absolute(sk_l) < 1e-10):
            k_l2 = k1 * length**2
            cx = 1 + k_l2 / 2
            sx = (1 + k_l2 / 6) * length
            dcx = -k_l2 / (2 * rel_p)
            dsx = -k_l2 * length / (6 * rel_p)
        elif (k1 < 0):
            cx = cos(sk_l)
            sx = sin(sk_l) / sqrt_k
            dcx = -k1 * sx * length / (2 * rel_p)
            dsx = (sx - length * cx) / (2 * rel_p)
        else:
            cx = cosh(sk_l)
            sx = sinh(sk_l) / sqrt_k
            dcx = -k1 * sx * length / (2 * rel_p)
            dsx = (sx - length * cx) / (2 * rel_p)
        
        a11 = cx
        a12 = sx / rel_p
        a21 = k1 * sx * rel_p
        a22 = cx
            
        c1 = k1 * (-cx * sx + length) / 4
        c2 = -k1 * sx**2 / (2 * rel_p)
        c3 = -(cx * sx + length) / (4 * rel_p**2)

        return a11, a12, a21, a22, c1, c2, c3
    
    
    def low_energy_z_correction(pz, p0c, mass, ds):
        """
        Corrects the change in z-coordinate due to speed < c_light. 
        Input:
            p0c -- reference particle momentum in eV
            mass -- particle mass in eV
            e_tot -- total energy
        """
        beta = sqrt( 1 - 1 /( 1 + ((1+pz)*p0c/mass)**2 ) )
        beta0 = sqrt( 1 - 1 /( 1 + (p0c/mass)**2 ) )
        e_tot = sqrt(p0c**2+mass**2)
        beta0 = p0c/e_tot

        if (mass * (beta0*pz)**2 < 3e-7 * e_tot):
            f = beta0**2 * (2 * beta0**2 - (mass / e_tot)**2 / 2)
            dz = ds * pz * (1 - 3 * (pz * beta0**2) / 2 + pz**2 * f) * (mass / e_tot)**2
        else
          dz = ds * (beta - beta0) / beta0
    
    def track_a_quadrupole(p_in, quad):
        """
        Tracks the incoming Particle p_in though quad element
        and returns the outgoing particle.
        """
        l = quad.L
        k1 = quad.K1
        n_step = quad.NUM_STEPS #number of divisions
        step_len = l/n_step #length of division
        
        s = p_in.s
        t = p_in.t
        p0c = p_in.p0c
        mc2 = p_in.mc2
        
        x = p_in.x
        px = p_in.px
        y = p_in.y
        py = p_in.py
        z = p_in.z
        pz = p_in.pz
        
        rel_p = 1 + pz #Particle's relative momentum (P/P0)
        
        for i in range(n_step):
            tx11, tx12, tx21, tx22, dz_x1, dz_x2, dz_x3 = quad_mat2_calc(-k1, step_len, rel_p)
            ty11, ty12, ty21, ty22, dz_y1, dz_y2, dz_y3 = quad_mat2_calc( k1, step_len, rel_p)
            
            z = ( z +
                dz_x1 * x**2 + dz_x2 * x * px + dz_x3 * px**2 +
                dz_y1 * y**2 + dz_y2 * y * py + dz_y3 * py**2 )
            x = tx11 * x + tx12 * px
            px = tx21 * x + tx22 * px
            y = ty11 * y + ty12 * py
            py = ty21 * y + ty22 * py
            
            z = low_energy_z_correction()

        
        return Particle(x, px, y, py, z, pz, s, t, p0c, mc2)
    
    return track_a_quadrupole

# Specialized functions
track_a_quadrupole_torch = make_track_a_quadrupole(lib=torch)

## quadrupole test

In [None]:
#Create quad 1
L = 1.0 #Length in m
K1 = 10.0 #Quad focusing strength. Positive is focusing in x subspace
NUM_STEPS = 50 #number of divisions for tracking. 50 is bmad default apparently (?)
q1 = Quadrupole(torch.tensor(L, dtype=torch.float64), torch.tensor(K1, dtype=torch.float64), torch.tensor(NUM_STEPS, dtype=torch.float64))
q1

In [None]:
a=torch.tensor(1.0,requires_grad=True,dtype=torch.float64)
b=torch.tensor([1.0,1.0],requires_grad=True,dtype=torch.float64)
b

In [None]:
a*b

In [None]:
torch.cat([a.unsqueeze(0),a.unsqueeze(0)])

In [None]:
out=torch.tensor([1.0,0],dtype=torch.float64)*a+torch.tensor([0,1.0],dtype=torch.float64)*a #check if double precision

In [None]:
out.dtype

In [None]:
out[0].backward()

In [None]:
print(a.grad)