In [1]:
import numpy as np
import torch
from torch.nn import Module

In [2]:
class TorchSectorDipole(Module):
    def __init__(self, name, rho, theta):
        super().__init__()
        self.name = name
        self.rho = rho
        self.theta = theta
        
    def get_transport_matrix(self):
        M = torch.eye(6)
        
        C = torch.cos(torch.tensor(self.theta))
        S = torch.sin(torch.tensor(self.theta))
        h = 1/self.rho
        L = self.rho*self.theta
        
        M[0,0] = C
        M[0,1] = S/h
        M[0,5] = (1-C)/h
        M[1,0] = -h*S
        M[1,1] = C
        M[1,5] = S
        M[2,3] = L
        M[4,0] = S
        M[4,1] = (1-C)/h
        M[4,5] = (self.theta - S)/h
        
        return M

    def transport_particle(self, particle_phasespace_vector):
        return self.get_transport_matrix() @ particle_phasespace_vector
    
    def transport_sigma_matrix(self, sigma_matrix):
        M = self.get_transport_matrix()
        return M @ sigma_matrix @ torch.transpose(M,0,1)

In [3]:
D1 = TorchSectorDipole('D1',3,np.radians(30))

In [4]:
D1.get_transport_matrix()

tensor([[ 0.8660,  1.5000,  0.0000,  0.0000,  0.0000,  0.4019],
        [-0.1667,  0.8660,  0.0000,  0.0000,  0.0000,  0.5000],
        [ 0.0000,  0.0000,  1.0000,  1.5708,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000],
        [ 0.5000,  0.4019,  0.0000,  0.0000,  1.0000,  0.0708],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000]])