# Implementing linear optics

In [1]:
import numpy as np
from scipy.linalg import block_diag

from discopy import cat, monoidal
from discopy.monoidal import PRO
from discopy.tensor import Dim

In [2]:
# code for computing the permanent from https://github.com/scipy/scipy/issues/7151
def npperm(M):
    n = M.shape[0]
    d = np.ones(n)
    j =  0
    s = 1
    f = np.arange(n)
    v = M.sum(axis=0)
    p = np.prod(v)
    while (j < n-1):
        v -= 2*d[j]*M[j]
        d[j] = -d[j]
        s = -s
        prod = np.prod(v)
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]
    return p/2**(n-1) 

In [3]:
@monoidal.Diagram.subclass
class OpticsDiagram(monoidal.Diagram):
    """
    Diagram with beam splitters and phases.
    """
    def __repr__(self):
        return super().__repr__().replace('Diagram', 'OpticsDiagram')
    
    @property
    def array(self):
        """ 
        The array corresponding to the diagram.
        Builds a block diagonal matrix for each layer and then multiplies them in sequence.
        """
        scan, array = self.dom, np.identity(len(self.dom))
        for box, off in zip(self.boxes, self.offsets):
            left, right = len(scan[:off]), len(scan[off + len(box.dom):])
            array = np.matmul(array, block_diag(np.identity(left), box.array, np.identity(right)))
        return array
    
    def amp(self, n_photons, x, y, permanent=npperm):
        "Evaluates the amplitude of x >> self >> y where x, y are lists of natural numbers summing to n_photons"
        if sum(x) != sum(y):
            return np.array(0)
        n_modes = len(self.dom)
        assert len(x) == len(y) == n_modes
        unitary = self.array
        matrix = np.stack([unitary[:, i] for i in range(n_modes) for j in range(y[i])], axis=1)
        matrix = np.stack([matrix[i] for i in range(n_modes) for j in range(x[i])], axis=0)
        divisor = np.sqrt(np.prod([np.math.factorial(n) for n in x + y]))
        amp = permanent(matrix) / divisor
        return amp    
    
class OpticsBox(OpticsDiagram, monoidal.Box):
    def __init__(self, name, dom, cod, data, **params):
        if not isinstance(dom, PRO):
            raise TypeError(messages.type_err(PRO, dom))
        if not isinstance(cod, PRO):
            raise TypeError(messages.type_err(PRO, cod))
        monoidal.Box.__init__(self, name, dom, cod, data=data, **params)
        OpticsDiagram.__init__(self, dom, cod, [self], [0], layers=self.layers)

    def __repr__(self):
        return super().__repr__().replace('Box', 'OpticsBox')
    
    @property
    def array(self):
        """ The array inside the box. """
        if isinstance(self, PhaseShift):
            return np.array(np.exp(self.data[0] * 1j))
        if isinstance(self, BeamSplitter):
            cos, sin = np.cos(self.data[0] / 2), np.sin(self.data[0] / 2)
            return np.array([sin, cos, cos, -sin]).reshape((2, 2))
        if isinstance(self, MZI):
            exp, cos, sin = np.exp(1j *self.data[0]), np.cos(self.data[1] / 2), np.sin(self.data[1] / 2)
            return np.array([exp * sin, exp * cos, cos, -sin]).reshape((2, 2))
        return np.array(self.data).reshape(Dim(len(self.dom)) @ Dim(len(self.cod)) or (1, ))

class Id(monoidal.Id, OpticsDiagram):
    """ Identity tensor.Diagram """
    def __init__(self, dom=PRO()):
        monoidal.Id.__init__(self, dom)
        OpticsDiagram.__init__(self, dom, dom, [], [], layers=cat.Id(dom))
        
class PhaseShift(OpticsBox):
    def __init__(self, phase):
        super().__init__('Phase shift', PRO(1), PRO(1), [phase])  

        
class BeamSplitter(OpticsBox):
    def __init__(self, angle):
        super().__init__('Beam splitter', PRO(2), PRO(2), [angle])
        
class MZI(OpticsBox):
    def __init__(self, phase, angle):
        super().__init__('MZI', PRO(2), PRO(2), [phase, angle])
    

In [4]:
amps = []
for i in [0, 1, 2, 3, 4]:
    amplitude = MZI(0.3, 0.5).amp(4, [2, 2], [i, 4-i])
    amps += [np.absolute(amplitude) **2]
    print(np.absolute(amplitude) **2)
sum(amps)

0.019811434686576455
0.265527531852589
0.4293220669216679
0.265527531852589
0.01981143468657651


0.9999999999999989

In [5]:
mach = BeamSplitter(0.5) >> Id(PRO(1)) @ PhaseShift(0.3)
amps = []
for i in [0, 1, 2, 3, 4]:
    amplitude = mach.amp(4, [2, 2], [i, 4-i])
    amps += [np.absolute(amplitude) **2]
    print(np.absolute(amplitude) **2)
sum(amps)

0.01981143468657656
0.2655275318525892
0.4293220669216681
0.2655275318525891
0.019811434686576517


0.9999999999999996

In [6]:
# alternatively, we can compute the permanent using Xanadu's thewalrus
import thewalrus
from thewalrus import perm
amps = []
for i in [0, 1, 2, 3, 4]:
    amplitude = MZI(0.3, 0.5).amp(4, [2, 2], [i, 4-i], permanent=perm)
    amps += [np.absolute(amplitude) **2]
    print(np.absolute(amplitude) **2)
sum(amps)

0.019811434686576503
0.2655275318525891
0.4293220669216682
0.2655275318525891
0.01981143468657651


0.9999999999999993