# Unitary Matrix Networks in the Frequency domain

## Imports

In [None]:
# standard library
from collections import OrderedDict
import sys
sys.path.append('../')

# photontorch
import torch
import photontorch as pt

# other
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import trange

## Settings

In [None]:
%matplotlib inline
DEVICE = 'cpu'
np.random.seed(0)
torch.manual_seed(0)
np.set_printoptions(precision=2, suppress=True)
env = pt.Environment(freqdomain=True, num_t=1, grad=True)
pt.set_environment(env);
pt.current_environment()

## Unitary Matrices

A unitary matrix is a matrix $U$ for which
\begin{align*}
U\cdot U^\dagger = U^\dagger \cdot U = I
\end{align*}

A unitary matrix is easily implemented in photonics. Indeed, according to the paper *"[Experimental Realization of Any Discrete Unitary Matrix](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.73.58)"* by Reck et. al., Any unitary matrix can be written as a combination of phase shifters and directional couplers with variable coupling (or MZI's) (Figure (a))

However, there exists an alternative approach to achieve any unitary operation, first proposed by Clements et. al. : [Optimal design for universal multiport interferometers](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-3-12-1460) (Figure (b))

![Unitary Matrix Paper](images/clements.jpeg)


## 2x2 Unitary matrix (Reck)

### Functions

In [None]:
def array(tensor):
    arr = tensor.data.cpu().numpy()
    if arr.shape[0] == 2:
        arr = arr[0] + 1j * arr[1]
    return arr

def tensor(array):
    if array.dtype == np.complex64 or array.dtype == np.complex128:
        array = np.stack([np.real(array), np.imag(array)])
    return torch.tensor(array, dtype=torch.get_default_dtype(), device=DEVICE)

def rand_phase():
    return float(2*np.pi*np.random.rand())

class Network(pt.Network):
    def _handle_source(self, matrix, **kwargs):
        expanded_matrix = matrix[:,None,None,:,:]
        a,b,c,d,e = expanded_matrix.shape
        expanded_matrix = torch.cat([
            expanded_matrix,
            torch.zeros((a,b,c,self.num_mc-d,e), device=expanded_matrix.device),
        ], -2)
        return expanded_matrix
    def forward(self, matrix):
        ''' matrix shape = (2, num_sources, num_sources)'''
        result = super(Network, self).forward(matrix, power=False)
        return result[:,0,0,:,:]
    def count_params(self):
        num_params = 0
        for p in self.parameters():
            num_params += int(np.prod(p.shape))
        return num_params

def unitary_matrix(m,n):
    real_part = np.random.rand(m,n)
    imag_part = np.random.rand(m,n)
    complex_matrix = real_part + 1j*imag_part
    if m >= n:
        unitary_matrix, _, _ = np.linalg.svd(complex_matrix, full_matrices = False)
    else:
        _, _, unitary_matrix = np.linalg.svd(complex_matrix, full_matrices = False)
    return unitary_matrix

### Define Network

In [None]:
class Network2x2(Network):
    def __init__(self):
        super(Network2x2, self).__init__()
        self.s1 = pt.Source()
        self.s2 = pt.Source()
        self.d1 = pt.Detector()
        self.d2 = pt.Detector()
        self.mzi = pt.Mzi(length=0, phi=rand_phase(), theta=rand_phase(), trainable=True)
        self.wg1 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.wg2 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.link('s1:0', '0:mzi:1', '0:wg1:1', '0:d1')
        self.link('s2:0', '3:mzi:2', '0:wg2:1', '0:d2')
        
nw2x2 = Network2x2().to(DEVICE).initialize()

### Check unitarity

To see which unitary matrix the network represents, we search for the result of the propagation of an identity matrix through the network. The power flag was set to false, as we are interested in the full complex output of the system. To show that this matrix is indeed unitary, we multiply with its conjugate transpose:

In [None]:
def check_unitarity(nw):
    matrix = tensor(np.eye(nw.num_sources) + 0j)
    result = array(nw(matrix))
    print(result@result.T.conj())

check_unitarity(nw2x2)

### Check Universality

However, it would be more interesting if we can show that this network can act like *any* unitary matrix. We will now train the network to be equal to another unitary matrix by using the unitary property $U\cdot U^\dagger=I$: we will train the network to obtain $I$ with $U_0^\dagger$ as input.

In [None]:
def check_universality(nw, num_epochs=500, lr=0.1):
    matrix_to_approximate = unitary_matrix(nw.num_sources, nw.num_sources)
    matrix_input = tensor(matrix_to_approximate.T.conj().copy())
    eye = tensor(np.eye(nw.num_sources) + 0j)
    optimizer = torch.optim.Adam(nw.parameters(), lr=lr)
    lossfunc = torch.nn.MSELoss()
    epochs = trange(num_epochs)
    for i in epochs:
        optimizer.zero_grad()
        result = nw(matrix_input)
        loss = lossfunc(result, eye)
        loss.backward()
        optimizer.step()
        epochs.set_postfix(loss=f'{loss.item():.7f}')
        if loss.item() < 1e-6:
            break

    matrix_approximated = array(nw(eye))
    print(matrix_approximated)
    print(matrix_to_approximate)

In [None]:
check_universality(nw2x2)

## 3x3 Unitary Matrix (Reck)

In [None]:
class Reck3x3(Network):
    def __init__(self):
        super(Reck3x3, self).__init__()
        self.s1 = pt.Source()
        self.s2 = pt.Source()
        self.s3 = pt.Source()
        self.d1 = pt.Detector()
        self.d2 = pt.Detector()
        self.d3 = pt.Detector()
        self.mzi1 = pt.Mzi(length=0, phi=rand_phase(), theta=rand_phase(), trainable=True)
        self.mzi2 = pt.Mzi(length=0, phi=rand_phase(), theta=rand_phase(), trainable=True)
        self.mzi3 = pt.Mzi(length=0, phi=rand_phase(), theta=rand_phase(), trainable=True)
        self.wg1 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.wg2 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.wg3 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.link("s1:0",                         "0:mzi1:1",                        "0:d1")
        self.link("s2:0",             "0:mzi2:1", "3:mzi1:2", "0:mzi3:1",            "0:d2")
        self.link("s3:0", "0:wg1:1",  "3:mzi2:2", "0:wg2:1",  "3:mzi3:2", "0:wg3:1", "0:d3")
reck3x3 = Reck3x3().to(DEVICE).initialize()

### Check Unitarity

In [None]:
check_unitarity(reck3x3)

### Check Universality

In [None]:
check_universality(reck3x3)

## 3x3 Unitary Matrix (Clements)

In [None]:
class Clements3x3(Network):
    def __init__(self):
        super(Clements3x3, self).__init__()
        self.s1 = pt.Source()
        self.s2 = pt.Source()
        self.s3 = pt.Source()
        self.d1 = pt.Detector()
        self.d2 = pt.Detector()
        self.d3 = pt.Detector()
        self.mzi1 = pt.Mzi(length=0, phi=rand_phase(), theta=rand_phase(), trainable=True)
        self.mzi2 = pt.Mzi(length=0, phi=rand_phase(), theta=rand_phase(), trainable=True)
        self.mzi3 = pt.Mzi(length=0, phi=rand_phase(), theta=rand_phase(), trainable=True)
        self.wg1 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.wg2 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.wg3 = pt.Waveguide(length=0, phase=rand_phase(), trainable=True)
        self.link("s1:0", "0:mzi1:1",             "0:mzi3:1", "0:wg1:1", "0:d1")
        self.link("s2:0", "3:mzi1:2", "0:mzi2:1", "3:mzi3:2", "0:wg2:1", "0:d2")
        self.link("s3:0",             "3:mzi2:2",             "0:wg3:1", "0:d3")
clem3x3 = Clements3x3().to(DEVICE).initialize()

### Check Unitarity

In [None]:
check_unitarity(clem3x3)

### Check Universality

In [None]:
check_universality(clem3x3, num_epochs=1000)

## NxN Unitary Matrix (Reck)

Creating those networks is quite cumbersome. However they are also implemented by photontorch, which then handles the creation of the networks:

In [None]:
reck2x2 = pt.ReckNxN(N=2).to(DEVICE).terminate().initialize()
reck5x5 = pt.ReckNxN(N=5).to(DEVICE).terminate().initialize()
# quick monkeypatch to have the same behavior as the classes defined above
reck5x5.__class__ = Network

### Check Unitarity

In [None]:
check_unitarity(reck5x5)

### Check Universality

In [None]:
check_universality(reck5x5)

## NxN Unitary Matrix (Clements)

In [None]:
clem5x5 = pt.ClementsNxN(N=5).to(DEVICE).terminate().initialize()
clem6x6 = pt.ClementsNxN(N=6).to(DEVICE).terminate().initialize()
# quick monkeypatch to have the same behavior as the classes defined above
clem5x5.__class__ = clem6x6.__class__ = Network

### Check Unitarity

In [None]:
check_unitarity(clem5x5)
check_unitarity(clem6x6)

### Check Universality

In [None]:
check_universality(clem5x5, num_epochs=1000)
check_universality(clem6x6, num_epochs=1000)