In [1]:
import sys
sys.path.append("../../")

import torch

from cpymad.madx import Madx
import pysixtrack
import sixtracklib as stl

from GroundThinLens.SampleBeam import Beam
from ThinLens.Models import SIS18_Lattice, F0D0Model

In [2]:
dim = 6
slices = 10
quadSliceMultiplicity = 1
dtype = torch.double
device = torch.device("cuda")
outputPerElement = False  # exceeds outputAtBPM
outputAtBPM = False

# prepare models
Lattice = SIS18_Lattice

model = Lattice(dim=dim, slices=slices, quadSliceMultiplicity=quadSliceMultiplicity,
                dtype=dtype, cellsIdentical=True).to(device)

# model = F0D0Model(k1=0.3, slices=slices, dim=dim, dtype=dtype)

In [3]:
# train set
if dim == 6:
    beam = Beam(mass=18.798, energy=19.0, exn=1.258e-6, eyn=2.005e-6, sigt=0.00, sige=0.000, particles=int(1e5))
    bunch = beam.bunch.to(device)

turns = 100

In [4]:
import time 

t0 = time.time()

trackBunch = bunch
with torch.no_grad():
    for i in range(turns):
        trackBunch = model(trackBunch)

print("tracking completed within {:.2f}s".format(time.time() - t0))

bunch = bunch.cpu()
trackBunch = trackBunch.cpu()

tracking completed within 81.96s


set up STL tracking

In [5]:
output = True
madx = Madx(stdout=False)
madx.options.echo = output 
madx.options.warn = output 
madx.options.info = output 

# specify beam
assert madx.command.beam(mass=18.798, charge=7, exn=1.258e-6, eyn=2.005e-6, gamma=19/18.798)  

# activate sequence
madx.input(model.thinMultipoleMadX())
madx.command.use(sequence="SIS18")

# load into STL
pysixtrack_elements = pysixtrack.Line.from_madx_sequence(
        madx.sequence.sis18,
        exact_drift=False, install_apertures=False,
    )

elements = stl.Elements.from_line(pysixtrack_elements)

elements.BeamMonitor(num_stores=turns)

<BeamMonitor at 61784
  num_stores:100
  start:0
  skip:1
  out_address:0
  max_particle_id:0
  min_particle_id:0
  is_rolling:0
  is_turn_ordered:1
>

set up bunch

In [6]:
# set up bunch
beam = madx.sequence["sis18"].beam
particles = stl.Particles.from_ref(len(bunch), p0c=beam.pc, mass0=beam.mass)

particles.x = bunch[:,0]
particles.px = bunch[:,1]
particles.y = bunch[:,2]
particles.py = bunch[:,3]


do track

In [7]:
t0 = time.time()

# track
job = stl.TrackJob(elements, particles, device=None)

job.track_until(turns)
job.collect()

print("STL tracking completed within {:.2f}s".format(time.time() - t0))

x = job.output.particles[0].x.reshape(-1, len(bunch)).transpose()
y = job.output.particles[0].y.reshape(-1, len(bunch)).transpose()

STL tracking completed within 157.11s


In [8]:
xDiff = x[:,-1] - trackBunch[:,0].numpy()
print("max deviation in x: {:.2e}".format(xDiff.max()))

yDiff = y[:,-1] - trackBunch[:,2].numpy()
print("max deviation in y: {:.2e}".format(yDiff.max()))

max deviation in x: 3.10e-05
max deviation in y: 9.81e-05
