Tracking elements with autodifferentiable capabilities

In [1]:
import numpy as np
import torch
from collections import namedtuple
from torch.autograd.functional import jacobian

torch.__version__, np.__version__
torch.set_printoptions(precision=10)

# Constants

In [2]:
c_light = 2.99792458e8 #speed of light in m/s
m_e = 0.510998950e6 #electron mass in eV

# Tuples

In [3]:
Drift = namedtuple('Drift', 'L')
Quadrupole = namedtuple('Quadrupole', 'L K1 n_step')
Particle = namedtuple('Particle', 'x px y py z pz s t p0c mc2')
#Canonical phase space coordinates as defined in Bmad manual section 15.4.2. Added s t p0c and mc2 = 0

PyTorch tensor test Particle:

In [4]:
np.random.seed(0)
pvec = np.random.rand(10_000, 6) #test particles
s = 0.0
t = 0.0
p0c = 4.1996891E+07 #Reference particle momentum in eV
mc2 = 1*m_e # electron mass in eV
ts = torch.tensor(s, dtype=torch.float64)
tt = torch.tensor(t, dtype=torch.float64)
tp0c = torch.tensor(p0c, dtype=torch.float64)
tmc2 = torch.tensor(mc2, dtype=torch.float64) #necessary? 
tpvec = torch.tensor(pvec, requires_grad=True, dtype=torch.float64)
p_torch = Particle(*tpvec.T, ts, tt, tp0c, tmc2)
p_torch

Particle(x=tensor([0.5488135039, 0.4375872113, 0.5680445611,  ..., 0.1673917659,
        0.8283865988, 0.4274187576], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), px=tensor([0.7151893664, 0.8917730008, 0.9255966383,  ..., 0.3514514993,
        0.3534266484, 0.8832471531], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), y=tensor([0.6027633761, 0.9636627605, 0.0710360582,  ..., 0.2700843196,
        0.9201338773, 0.3800812269], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), py=tensor([0.5448831830, 0.3834415188, 0.0871292997,  ..., 0.0870263901,
        0.8137479865, 0.1690041298], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), z=tensor([0.4236547993, 0.7917250381, 0.0202183974,  ..., 0.3048581211,
        0.6832710885, 0.2639306656], dtype=torch.float64,
       grad_fn=<UnbindBackward0>), pz=tensor([0.6458941131, 0.5288949198, 0.8326198455,  ..., 0.1314423113,
        0.8690973434, 0.3011667628], dtype=torch.float64,
       grad_fn=<UnbindBackward0>)

# Drift

## track_a_drift

In [5]:
def make_track_a_drift(lib):
    """
    Makes track_a_drift given the library lib
    """
    
    sqrt = lib.sqrt
    
    def track_a_drift(p_in, drift):
        """
        Tracks the incoming Particle p_in though drift element
        and returns the outgoing particle. See eqs 24.58 in bmad manual
        """
        L = drift.L
        
        s = p_in.s
        t = p_in.t
        p0c = p_in.p0c
        mc2 = p_in.mc2
        
        x = p_in.x
        px = p_in.px
        y = p_in.y
        py = p_in.py
        z = p_in.z
        pz = p_in.pz
        
        P = 1 + pz #Particle's total momentum over p0
        Px = px / P #Particle's 'x' momentum over p0
        Py = py / P #Particle's 'y' momentum over p0
        Pxy2 = Px**2 + Py**2 #Particle's transverse mometum^2 over p0^2
        Pl = sqrt(1-Pxy2)  #Particle's longitudinal momentum over p0
        
        x = x + L * Px / Pl
        y = y + L * Py / Pl
        
        beta = sqrt( 1 - 1 /( 1 + (P*p0c/mc2)**2 ) )
        beta_ref = sqrt( 1 - 1 /( 1 + (p0c/mc2)**2 ) )
        z = z + L * ( beta/beta_ref - 1.0/Pl )
        s = s + L
        t = t + L / ( beta * Pl * c_light )
        
        return Particle(x, px, y, py, z, pz, s, t, p0c, mc2)
    
    return track_a_drift

# Specialized functions
track_a_drift_torch = make_track_a_drift(lib=torch)

## drift test

drift:

In [6]:
L=1.0 # Dipole length in m
d1 = Drift(torch.tensor(L, dtype=torch.float64))

test particle:

In [7]:
s = 0.0 #initial s
t = 0.0 #initial t
p0c = 4.1996891E+07 #Reference particle momentum in eV
mc2 = 1*m_e # electron mass in eV
pvec1 = [2e-3,3e-3,-3e-3,-1e-3,2e-3,-2e-3] 
#pvec1[0] = 1e-3
#pvec1[1] = 1e-3
#pvec1[5] = 1e-3
tvec1 = torch.tensor(pvec1, requires_grad=True, dtype=torch.float64)

track_a_drift test

In [8]:
track_a_drift_torch(Particle(*tvec1,ts, tt, tp0c, tmc2), d1)

Particle(x=tensor(0.0050060271, dtype=torch.float64, grad_fn=<AddBackward0>), px=tensor(0.0030000000, dtype=torch.float64, grad_fn=<UnbindBackward0>), y=tensor(-0.0040020090, dtype=torch.float64, grad_fn=<AddBackward0>), py=tensor(-0.0010000000, dtype=torch.float64, grad_fn=<UnbindBackward0>), z=tensor(0.0019946830, dtype=torch.float64, grad_fn=<AddBackward0>), pz=tensor(-0.0020000000, dtype=torch.float64, grad_fn=<UnbindBackward0>), s=tensor(1., dtype=torch.float64), t=tensor(3.3359055992e-09, dtype=torch.float64, grad_fn=<AddBackward0>), p0c=tensor(41996891., dtype=torch.float64), mc2=tensor(510998.9500000000, dtype=torch.float64))

In [9]:
f2 = lambda x: track_a_drift_torch(Particle(*x,ts, tt, tp0c, tmc2), d1)[:6]
with torch.autograd.detect_anomaly():
    J = jacobian(f2, tvec1)
J

  with torch.autograd.detect_anomaly():


(tensor([ 1.0000000000e+00,  1.0020180925e+00,  0.0000000000e+00,
         -3.0181176940e-06,  0.0000000000e+00, -3.0120814586e-03],
        dtype=torch.float64),
 tensor([0., 1., 0., 0., 0., 0.], dtype=torch.float64),
 tensor([ 0.0000000000e+00, -3.0181176940e-06,  1.0000000000e+00,
          1.0020100442e+00,  0.0000000000e+00,  1.0040271529e-03],
        dtype=torch.float64),
 tensor([0., 0., 0., 1., 0., 0.], dtype=torch.float64),
 tensor([ 0.0000000000e+00, -3.0120814586e-03,  0.0000000000e+00,
          1.0040271529e-03,  1.0000000000e+00,  1.5897915887e-04],
        dtype=torch.float64),
 tensor([0., 0., 0., 0., 0., 1.], dtype=torch.float64))

In [10]:
J

(tensor([ 1.0000000000e+00,  1.0020180925e+00,  0.0000000000e+00,
         -3.0181176940e-06,  0.0000000000e+00, -3.0120814586e-03],
        dtype=torch.float64),
 tensor([0., 1., 0., 0., 0., 0.], dtype=torch.float64),
 tensor([ 0.0000000000e+00, -3.0181176940e-06,  1.0000000000e+00,
          1.0020100442e+00,  0.0000000000e+00,  1.0040271529e-03],
        dtype=torch.float64),
 tensor([0., 0., 0., 1., 0., 0.], dtype=torch.float64),
 tensor([ 0.0000000000e+00, -3.0120814586e-03,  0.0000000000e+00,
          1.0040271529e-03,  1.0000000000e+00,  1.5897915887e-04],
        dtype=torch.float64),
 tensor([0., 0., 0., 0., 0., 1.], dtype=torch.float64))

# track_a_quadrupole

In [11]:
Quadrupole = namedtuple('Drift', 'k1 l n_step')
k1 = 2.000
l = 0.1 
n_step = 50
d1 = Quadrupole(k1,l, n_step)
d1

Drift(k1=2.0, l=0.1, n_step=50)

In [12]:
def make_track_a_quadrupole(lib):
    """
    Makes track_a_drift given the library lib
    """
    
    sqrt = lib.sqrt
    absolute = lib.abs
    sin = lib.sin
    cos = lib.cos
    sinh = lib.sinh
    cosh = lib.cosh
    hstack = lib.hstack
    matmul = lib.matmul
    if lib==np:
        matrix = np.array
    elif lib==torch:
        matrix = torch.tensor
    
    def track_a_quadrupole(p_in, quad):
        """
        Tracks the incoming Particle p_in though drift element
        and returns the outgoing particle. See eqs 24.58 in bmad manual
        """
        K1 = quad.K1
        L = quad.L
        n_step = quad.n_step
        l = L/n_step
        
        s = p_in.s
        t = p_in.t
        p0c = p_in.p0c
        mc2 = p_in.mc2
        
        x = p_in.x
        px = p_in.px
        xpx = hstack(x,px)
        
        y = p_in.y
        py = p_in.py
        z = p_in.z
        pz = p_in.pz
        
        P = 1 + pz #Particle's total momentum over p0
        Px = px / P #Particle's 'x' momentum over p0
        Py = py / P #Particle's 'y' momentum over p0
        Pxy2 = Px**2 + Py**2 #Particle's transverse mometum^2 over p0^2
        Pl = sqrt(1-Pxy2)  #Particle's longitudinal momentum over p0
        
        sqrt_k = sqrt(absolute(K1))
        sk_l = sqrt_k * l
        
        for i in range(n_step):
            if (abs(sk_l) < 1e-10):
                k_l2 = k1 * l**2
                cx = 1 + k_l2 / 2
                sx = (1 + k_l2 / 6) * length
            elif (k1 < 0):# focus
                cx = cos(sk_l)
                sx = sin(sk_l) / sqrt_k
            else:# defocus
                cx = cosh(sk_l)
                sx = sinh(sk_l) / sqrt_k
            
            kmatx = matrix([[cx, sx/p0c],[K1*sx*p0c,cx]])
            
            dcx = -k1 * sx * length / (2 * rel_p)
            dsx = (sx - length * cx) / (2 * rel_p)
            dz_dx1 =   -zc(1)/rel_p - k1 * (cx * dsx + dcx * sx) / 4
            dz_dx2 = -2*zc(2)/rel_p - k1 * sx * dsx/rel_p
            dz_dx3 = -2*zc(3)/rel_p - (cx * dsx + dcx * sx) / (4 * rel_p**2)
            
            dz_x1 = 0 #Derivative. 
            dz_x2 = 0
            dz_x3 = 0
            dz_y1 = 0 
            dz_y2 = 0
            dz_y3 = 0
            
            z = z + dz_x1 * x**2 + dz_dx2 * x * px + dz_dx3 * px**2 + \
            dz_y1 * y**2 + dz_dy2 * y * py + dz_dy3 * py**2
            
            x = 
            px = 
            y = 
            py =
            
        
        return Particle(x, px, y, py, z, pz, s, t, p0c, mc2)
    
    return track_a_quadrupole

# Specialized functions
track_a_quadrupole_torch = make_track_a_quadrupole(lib=torch)