# Example-26: Module

In [1]:
# Given an objective function or a model, it can be wrapped with torch.nn.Module
# This allows to use different optimization methods (torch.optim, pytorch-optimizer, ...)

# In this example chromaticity is optimized by 
# 1) wrapping objective funtion (R^n x R^m x ... -> R) with a torch module (no data is passed to forward call)
# 2) wrapping chromaticity function with a torch module (plane index is passed as feature to forward call)

# In the first case, regular optimization is performed using all avaliable data

# For the second case, mini-batched optimization can be performed
# Planes (horozontal or vertical) are used as features
# Alternatively, location indices along the ring can be used as features (values of twiss parameters at location)
# Or location pairs (phase advance)

In [2]:
# Import

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
torch.set_printoptions(linewidth=128)

import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

from twiss import twiss

from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.pfp import parametric_fixed_point

from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
from model.library.sextupole import Sextupole
from model.library.dipole import Dipole
from model.library.line import Line

from model.command.wrapper import group
from model.command.wrapper import Wrapper

In [3]:
# Define simple FODO based lattice using nested lines

DR = Drift('DR', 0.25)
BM = Dipole('BM', 3.50, torch.pi/4.0)

QF_A = Quadrupole('QF_A', 0.5, +0.20)
QD_A = Quadrupole('QD_A', 0.5, -0.19)
QF_B = Quadrupole('QF_B', 0.5, +0.20)
QD_B = Quadrupole('QD_B', 0.5, -0.19)
QF_C = Quadrupole('QF_C', 0.5, +0.20)
QD_C = Quadrupole('QD_C', 0.5, -0.19)
QF_D = Quadrupole('QF_D', 0.5, +0.20)
QD_D = Quadrupole('QD_D', 0.5, -0.19)

SF_A = Sextupole('SF_A', 0.25, 0.00)
SD_A = Sextupole('SD_A', 0.25, 0.00)
SF_B = Sextupole('SF_B', 0.25, 0.00)
SD_B = Sextupole('SD_B', 0.25, 0.00)
SF_C = Sextupole('SF_C', 0.25, 0.00)
SD_C = Sextupole('SD_C', 0.25, 0.00)
SF_D = Sextupole('SF_D', 0.25, 0.00)
SD_D = Sextupole('SD_D', 0.25, 0.00)

FODO_A = Line('FODO_A', [QF_A, DR, SF_A, DR, BM, DR, SD_A, DR, QD_A, QD_A, DR, SD_A, DR, BM, DR, SF_A, DR, QF_A], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_B = Line('FODO_B', [QF_B, DR, SF_B, DR, BM, DR, SD_B, DR, QD_B, QD_B, DR, SD_B, DR, BM, DR, SF_B, DR, QF_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_C = Line('FODO_C', [QF_C, DR, SF_C, DR, BM, DR, SD_C, DR, QD_C, QD_C, DR, SD_C, DR, BM, DR, SF_C, DR, QF_C], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_D = Line('FODO_D', [QF_D, DR, SF_D, DR, BM, DR, SD_D, DR, QD_D, QD_D, DR, SD_D, DR, BM, DR, SF_D, DR, QF_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)

RING = Line('RING', [FODO_A, FODO_B, FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)

In [4]:
# Set parametric mapping

ring, *_ = group(RING, 'FODO_A', 'FODO_D', ('ms', ['Sextupole'], None, None), ('dp', None, None, None), root=True)

# Set deviation parameters

fp = torch.tensor(4*[0.0], dtype=torch.float64)
ms = torch.tensor(8*[0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

In [5]:
# Define parametric chomaticity function

# Compute parametric fixed point (first order dispersion)

pfp, *_ = parametric_fixed_point((0, 1), fp, [ms, dp], ring)
chop(pfp)

# Define ring around parametric fixed point

def mapping(state, ms, dp):
    return ring(state + evaluate(pfp, [ms, dp]), ms, dp) - evaluate(pfp, [ms, dp])

# Define tunes

def tune(ms, dp):
    matrix = torch.func.jacrev(mapping)(fp, ms, dp)
    tunes, *_ = twiss(matrix)
    return tunes

# Define chromaticity

def chromaticity(ms):
    return torch.func.jacrev(tune, 1)(ms, dp).squeeze()

# Compute natural chromaticity

print(chromaticity(ms))

tensor([-2.0649, -0.8260], dtype=torch.float64)


In [6]:
# Chromaticity can be corrected in a single step

# Compute starting values

psix, psiy = chromaticity(ms)

# Set target values

psix_target = torch.tensor(5.0, dtype=torch.float64)
psiy_target = torch.tensor(5.0, dtype=torch.float64)

# Perform correction

dpsix = psix - psix_target
dpsiy = psiy - psiy_target

solution = - torch.linalg.pinv((torch.func.jacrev(chromaticity)(ms)).squeeze()) @ torch.stack([dpsix, dpsiy])
print(solution)

# Test solution

print(chromaticity(solution))

tensor([ 0.7439, -1.2084,  0.7439, -1.2084,  0.7439, -1.2084,  0.7439, -1.2084], dtype=torch.float64)
tensor([5.0000, 5.0000], dtype=torch.float64)


In [7]:
# Optimization (wrapping objective funtion)

# Set model parameters
# Parameters are not cloned inside the module on initialization, values will change during optimization!

ms = torch.tensor(8*[0.0], dtype=torch.float64)

# Define scalar objective function

def objective(ms):
    psix, psiy = chromaticity(ms)
    return ((psix - psix_target)**2 + (psiy - psiy_target)**2).sqrt()

# Set model (forward returns evaluated objective)

model = Wrapper(objective, ms)

# Set optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=1.0E-2)

# Perfom optimization

epochs = 256
for epoch in range(epochs):

    # Evaluate model
    error = model()
    
    # Compute derivatives
    error.backward()

    # Perform optimization step
    optimizer.step()

    # Set gradient to zero
    optimizer.zero_grad()

    # Verbose
    knobs, *_ = [*model.parameters()]
    print(error.detach(), (knobs.detach() - solution).norm())

tensor(9.1573, dtype=torch.float64) tensor(2.8105, dtype=torch.float64)
tensor(9.0611, dtype=torch.float64) tensor(2.7830, dtype=torch.float64)
tensor(8.9651, dtype=torch.float64) tensor(2.7555, dtype=torch.float64)
tensor(8.8693, dtype=torch.float64) tensor(2.7280, dtype=torch.float64)
tensor(8.7737, dtype=torch.float64) tensor(2.7006, dtype=torch.float64)
tensor(8.6784, dtype=torch.float64) tensor(2.6732, dtype=torch.float64)
tensor(8.5833, dtype=torch.float64) tensor(2.6458, dtype=torch.float64)
tensor(8.4884, dtype=torch.float64) tensor(2.6184, dtype=torch.float64)
tensor(8.3938, dtype=torch.float64) tensor(2.5910, dtype=torch.float64)
tensor(8.2995, dtype=torch.float64) tensor(2.5636, dtype=torch.float64)
tensor(8.2054, dtype=torch.float64) tensor(2.5363, dtype=torch.float64)
tensor(8.1116, dtype=torch.float64) tensor(2.5090, dtype=torch.float64)
tensor(8.0181, dtype=torch.float64) tensor(2.4817, dtype=torch.float64)
tensor(7.9249, dtype=torch.float64) tensor(2.4544, dtype=torch.f

In [8]:
# Optimization (wrapping chromaticity function)

# Set model parameters
# Parameters are not cloned inside the module on initialization, values will change during optimization!

ms = torch.tensor(8*[0.0], dtype=torch.float64)

# Set features and labels
# X selects the plane (horizontal or vertical chomaticity)
# y is corresponding target chromaticity value for selected plane

X = torch.tensor([[0], [1]])
y = torch.stack([psix_target, psiy_target])

# Set dataset
# Note, batch size is one, technicaly this is not a mini-batch optimization

batch_size = 1
dataset = TensorDataset(X.clone(), y.clone())

# Set data loader

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Set objective (return horizontal or vertical chomaticity)

def objective(x, ms):
    return chromaticity(ms)[x].squeeze()

# Set model (forward returns evaluated objective)

model = Wrapper(objective, ms)

# Set optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=1.0E-2)

# Set loss funtion

lf = torch.nn.MSELoss()

# Perfom optimization

epochs = 256
for epoch in range(epochs):

    # Loop over batches of data
    for batch, (X, y) in enumerate(dataloader):

        # Evaluate model
        y_hat = model(X)

        # Evaluate loss function
        error = lf(y_hat, y.squeeze())

        # Compute derivatives
        error.backward()
    
        # Perform optimization step
        optimizer.step()
    
        # Set gradient to zero
        optimizer.zero_grad()
    
    # Verbose
    knobs, *_ = [*model.parameters()]
    print(error.detach(), (knobs.detach() - solution).norm())

tensor(35.4295, dtype=torch.float64) tensor(2.8297, dtype=torch.float64)
tensor(47.7980, dtype=torch.float64) tensor(2.8006, dtype=torch.float64)
tensor(33.9269, dtype=torch.float64) tensor(2.7718, dtype=torch.float64)
tensor(45.6336, dtype=torch.float64) tensor(2.7423, dtype=torch.float64)
tensor(32.6024, dtype=torch.float64) tensor(2.7134, dtype=torch.float64)
tensor(43.5373, dtype=torch.float64) tensor(2.6840, dtype=torch.float64)
tensor(42.6609, dtype=torch.float64) tensor(2.6547, dtype=torch.float64)
tensor(41.7582, dtype=torch.float64) tensor(2.6255, dtype=torch.float64)
tensor(29.9431, dtype=torch.float64) tensor(2.5970, dtype=torch.float64)
tensor(29.4738, dtype=torch.float64) tensor(2.5684, dtype=torch.float64)
tensor(28.9800, dtype=torch.float64) tensor(2.5399, dtype=torch.float64)
tensor(37.5585, dtype=torch.float64) tensor(2.5109, dtype=torch.float64)
tensor(36.7554, dtype=torch.float64) tensor(2.4820, dtype=torch.float64)
tensor(27.1260, dtype=torch.float64) tensor(2.4540,