### First double multiplexer, then single phase tuner
Including even ports template and odd ports template, use odd port nets for test

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax import custom_vjp
from jax import nn

from functools import partial
import numpy as np

key = random.PRNGKey(0)
CTYPE = jnp.complex64
RTYPE = jnp.float32
FIGDPI = 100
FIGSIZE = (5,4)

scale_factor = 128
# double ring parameters
r1 = 0.553
r2 = 0.84
dx = 0.
# dx = 0.
a = 0.98
lr = 1e-3

# initiate around the phase value
around = np.pi/5



### Helper functions

In [2]:
# this block won't be changed.
# this functions is a helper function to get the location; won't be modified
def get_ABCD_index(num_ports, parity):
    if parity=="even":
        ind = jnp.arange(0, num_ports-1, 2)
    elif parity=="odd":
        ind = jnp.arange(1, num_ports-1, 2)
    else:
        raise ValueError("parity should be `even` or `odd`, but get:", parity)
    ind_A = (ind, ind)
    ind_D = (ind+1, ind+1)
    ind_B = (ind, ind+1)
    ind_C = (ind+1, ind)
    return ind_A, ind_B, ind_C, ind_D

# these two functions can be isolated; won't be modified
def double_transfer(theta, r1, r2, a, dx):
    theta = -(theta + dx)
    denom = 1+a**2*jnp.exp(2*1j*theta)*r1**2-2*a*jnp.exp(1j*theta)*r1*r2
    numA = r1+a**2*jnp.exp(2*1j*theta)*r1-a*jnp.exp(1j*theta)*r2-a*jnp.exp(1j*theta)*r1**2*r2
    numB = -1j*a*jnp.exp(1j*theta)*(-1+r1**2)*jnp.sqrt(1-r2**2)
    numC = -1j*a*jnp.exp(1j*theta)*(-1+r1**2)*jnp.sqrt(1-r2**2)
    numD = r1+a**2*jnp.exp(2*1j*theta)*r1-a*jnp.exp(1j*theta)*r2-a*jnp.exp(1j*theta)*r1**2*r2
    return numA/denom, numB/denom, numC/denom, numD/denom

def single_transfer(phi, r, a):
    return jnp.exp(1j*jnp.pi+1j*phi)*(a-r*jnp.exp(-1j*phi))/(1.-r*a*jnp.exp(1j*phi))

# activation can be separated as well
def ring_activation(g, phi_b, alpha, a_act, r_act, xout):
    # ring
    dphi = 0.5*g*jnp.abs(xout)**2 + 0.5*phi_b
    t_ring = jnp.exp(1j*(np.pi+dphi))*(a_act-r_act*jnp.exp(1j*dphi))/(1.0-r_act*a_act*jnp.exp(1j*dphi))
    xout = (1j * jnp.sqrt(1.0-alpha) * t_ring) * xout
    return xout

@partial(jit, static_argnums=(0,))
def get_transfer_even(num_ports, dA, dB, dC, dD):
    """
    Get the even depths transfer matrix.
    """
    ind_ABCD = get_ABCD_index(num_ports, "even")
    T = jnp.eye(num_ports, dtype=CTYPE)
    T = T.at[ind_ABCD[0]].set(dA)
    T = T.at[ind_ABCD[1]].set(dB)
    T = T.at[ind_ABCD[2]].set(dC)
    T = T.at[ind_ABCD[3]].set(dD)
    return T

@partial(jit, static_argnums=(0,))
def get_transfer_odd(num_ports, dA, dB, dC, dD):
    """
    Get the odd depths transfer matrix.
    """
    ind_ABCD = get_ABCD_index(num_ports, "odd")
    T = jnp.eye(num_ports, dtype=CTYPE)
    T = T.at[ind_ABCD[0]].set(dA)
    T = T.at[ind_ABCD[1]].set(dB)
    T = T.at[ind_ABCD[2]].set(dC)
    T = T.at[ind_ABCD[3]].set(dD)
    return T

double_transfer_even_jit = jit(double_transfer)
single_transfer_even_jit = jit(single_transfer)
double_transfer_odd_jit = jit(double_transfer)
single_transfer_odd_jit = jit(single_transfer)
ring_activation_jit = jit(ring_activation)

### Define network parameters

In [3]:
# single ring parameters
rsingle = r1
asingle = a

# activation function parameters
# g = np.pi
# phi_b = 0
# alpha = 0.1
# a_act = 0.9
# r_act = 0.9

g = 0.314
phi_b = -0.1
alpha = 0.1
a_act = 0.9
r_act = 0.9


In [4]:
# depth = ports
num_ports = 16
num_depth = num_ports
num_layers = 2

# for even and odd depths (assume the first depth is even.)
num_depth_even = int(jnp.ceil(num_depth/2))
num_depth_odd = int(jnp.floor(num_depth/2))
num_ports_even = int(jnp.floor(num_ports/2))
num_ports_odd = int(jnp.ceil(num_ports/2) - 1)
num_total_even = num_layers*num_depth_even*num_ports_even 
num_total_odd = num_layers*num_depth_odd*num_ports_odd

# get the optimization parameters
# once the structure is setup, the shape of transfer matrix is setup.
key, subkey1, subkey2, subkey3, subkey4 = random.split(key, 5)

# our parameters
theta_even = random.uniform(subkey1, 
                            shape=(num_layers, num_depth_even, 
                                   num_ports_even),
                           )*around
theta_odd = random.uniform(subkey2, 
                           shape=(num_layers, num_depth_odd, 
                                  num_ports_odd),
                          )*around

phi_even = random.uniform(subkey3, 
                          shape=(num_layers, num_depth_even, 
                                 num_ports_even),
                         )*around
phi_odd = random.uniform(subkey4, 
                         shape=(num_layers, num_depth_odd, 
                                num_ports_odd),
                        )*around

params = {}
params['theta_even'] = theta_even
params['theta_odd'] = theta_odd
params['phi_even'] = phi_even
params['phi_odd'] = phi_odd
init_params = params

In [5]:
jit_matmul = jit(jnp.matmul)

# let's focus on the double ring only first
@jit
def clements_evenports(r1, r2, a, dx, rsingle, asingle, 
                       theta_even, theta_odd, phi_even, phi_odd, xin):
    """
    1-layer clements setup.
    theta_even: (num_depth, num_even_devices)
    theta_odd: (num_depth, num_odd_devices)
    phi_even: (num_depth, num_even_devices)
    phi_odd: (num_depth, num_even_devices)
    xin: (x_vec_shape, batch_size)
    """
    dAeven, dBeven, dCeven, dDeven = double_transfer_even_jit(theta_even, r1, r2, a, dx)
    dAodd, dBodd, dCodd, dDodd = double_transfer_odd_jit(theta_odd, r1, r2, a, dx)
    phase_even = single_transfer_even_jit(phi_even, rsingle, asingle)
    phase_odd = single_transfer_odd_jit(phi_odd, rsingle, asingle)
    # if the ports number is even
    for d in range(int(num_depth/2)):
        T_even = get_transfer_even(num_ports, 
                                   dAeven[d]*phase_even[d], dBeven[d]*phase_even[d], 
                                   dCeven[d], dDeven[d])
        T_odd = get_transfer_odd(num_ports, 
                                 dAodd[d]*phase_odd[d], dBodd[d]*phase_odd[d], 
                                 dCodd[d], dDodd[d])
        xin = jit_matmul(T_odd, jit_matmul(T_even, xin))
    return xin

@jit
def clements_oddports(r1, r2, a, dx, rsingle, asingle, 
                      theta_even, theta_odd, phi_even, phi_odd, xin):
    """
    1-layer clements setup.
    theta_even: (num_depth, num_even_devices)
    theta_odd: (num_depth, num_odd_devices)
    phi_even: (num_depth, num_even_devices)
    phi_odd: (num_depth, num_even_devices)
    xin: (x_vec_shape, batch_size)
    """
    dAeven, dBeven, dCeven, dDeven = double_transfer_even_jit(theta_even, r1, r2, a, dx)
    dAodd, dBodd, dCodd, dDodd = double_transfer_odd_jit(theta_odd, r1, r2, a, dx)
    phase_even = single_transfer_even_jit(phi_even, rsingle, asingle)
    phase_odd = single_transfer_odd_jit(phi_odd, rsingle, asingle)
    
    # if the ports number is odd.
    for d in range(int(num_depth/2)):
        T_even = get_transfer_even(num_ports, 
                                   dAeven[d]*phase_even[d], dBeven[d]*phase_even[d], 
                                   dCeven[d], dDeven[d])
        T_odd = get_transfer_odd(num_ports, 
                                 dAodd[d]*phase_odd[d], dBodd[d]*phase_odd[d], 
                                 dCodd[d], dDodd[d])
        xin = jit_matmul(T_odd, jit_matmul(T_even, xin))
        
    # note that jax won't jump error if we exeed the array sizes.
    T_even = get_transfer_even(num_ports, 
                               dAeven[d+1]*phase_even[d+1], dBeven[d+1]*phase_even[d+1], 
                               dCeven[d+1], dDeven[d+1])
    xin = jit_matmul(T_even, xin)
    return xin

In [6]:
@jit
def network_even(theta_even, theta_odd, phi_even, phi_odd, xin):
    xin = clements_evenports(r1, r2, a, dx, rsingle, asingle, 
                              theta_even[0], theta_odd[0], 
                              phi_even[0], phi_odd[0], xin)
    xin = ring_activation(g, phi_b, alpha, a_act, r_act, xin)
    for l in range(1, num_layers):
        xin = clements_evenports(r1, r2, a, dx, rsingle, asingle, 
                                  theta_even[l], theta_odd[l], 
                                  phi_even[l], phi_odd[l], xin)
        xin = ring_activation(g, phi_b, alpha, a_act, r_act, xin)
    
    xin = jnp.abs(xin)**2
    return xin

@jit
def network_odd(theta_even, theta_odd, phi_even, phi_odd, xin):
    xin = clements_oddports(r1, r2, a, dx, rsingle, asingle, 
                             theta_even[0], theta_odd[0], 
                             phi_even[0], phi_odd[0], xin)
    xin = ring_activation(g, phi_b, alpha, a_act, r_act, xin)
    for l in range(1, num_layers):
        xin = clements_oddports(r1, r2, a, dx, rsingle, asingle, 
                                 theta_even[l], theta_odd[l], 
                                 phi_even[l], phi_odd[l], xin)
        xin = ring_activation(g, phi_b, alpha, a_act, r_act, xin)
    
    xin = jnp.abs(xin)**2
    return xin

### MNIST dataset preparation

In [7]:
from mnist_data import *

mnist_dp = MNISTDataProcessor()
data_N16 = mnist_dp.fourier(2)
x_train, y_train, x_test, y_test = data_N16
x_train = jnp.array(x_train.T, dtype=CTYPE)*scale_factor
y_train = jnp.array(y_train.T, dtype=RTYPE)
x_test = jnp.array(x_test.T, dtype=CTYPE)*scale_factor
y_test = jnp.array(y_test.T, dtype=RTYPE)

### Cross-Entropy Loss + Optimizer + train

In [8]:
from jax.experimental import optimizers
import time
from tqdm import trange, tqdm



In [9]:
@jit
def loss(params, xin, Y):
    xin = network_even(params['theta_even'], params['theta_odd'], 
                       params['phi_even'], params['phi_odd'], xin)
    y_pred = xin[:10, :]
    log_softmax_y_pred = nn.log_softmax(y_pred, axis=0)
    l = - log_softmax_y_pred*Y[:10, :]
    return l.sum(axis=0).mean()

def accuracy(params, x, y, batch_size):
    num_batches = int(x.shape[1]/batch_size)
    y_pred = jnp.empty((10,0), dtype=RTYPE)

    for batch in range(num_batches):
        x_sub = x[:, batch*batch_size:(batch+1)*batch_size]
        x_pred = network_even(params['theta_even'], params['theta_odd'], 
                              params['phi_even'], params['phi_odd'], x_sub)
        y_pred = jnp.hstack((y_pred, x_pred[:10,:]))
    
    y_pred = jnp.argmax(y_pred, axis=0)
    y_true = y[:, :y_pred.shape[0]]
    y_true = jnp.argmax(y_true, axis=0)
    acc = jnp.sum(jnp.equal(y_pred, y_true))/float(y_pred.shape[0])
    return acc, y_pred, y_true

In [10]:
def run_xor(x_train, y_train, x_test, y_test, params, num_epochs, batch_size=32, lr=1e-3, optimizer=optimizers.adam):
    jit_value_grad  = jit(value_and_grad(loss, (0, )))
    _, num_train_samples = x_train.shape
    num_train_batches = int(num_train_samples/batch_size)
    
    @jit
    def update(params, x, y, opt_state):
        # we record the loss parameters before the updated one
        value, grads = jit_value_grad(params, x, y)
        opt_state = opt_update(0, grads[0], opt_state)
        return get_params(opt_state), opt_state, value

    # Defining an optimizer in Jax
    opt_init, opt_update, get_params = optimizer(lr)
    opt_state = opt_init(params)
    
    # record by iteration steps (record loss for every batch)
    loss_list = []
    # record by epoch
    acc_train_list = []
    acc_test_list = []
    
    start_time = time.time()
    acc_test, _, _ = accuracy(params, x_test, y_test, batch_size)
    acc_train, _, _ = accuracy(params, x_train, y_train, batch_size)
    acc_train_list.append(acc_train)
    acc_test_list.append(acc_test)
    
    t = trange(num_epochs, desc='MNIST', position=0, leave=True)
    for epoch in t:
        for batch in range(num_train_batches):
            params, opt_state, l = update(params, 
                                          x_train[:, batch*batch_size:(batch+1)*batch_size], 
                                          y_train[:, batch*batch_size:(batch+1)*batch_size], 
                                          opt_state)
            loss_list.append(l)

        # evaluation process
        acc_test, _, _ = accuracy(params, x_test, y_test, batch_size)
        acc_train, _, _ = accuracy(params, x_train, y_train, batch_size)
        acc_train_list.append(acc_train)
        acc_test_list.append(acc_test)
#         t.set_description('Train accuracy=%g, ' % acc_train + 'Test accuracy=%g, ' % acc_test + 'last batch loss=%g'% l)
        t.set_description('Test accuracy=%g' % acc_test)
#         print("epoch: ", epoch)
#         print("time elapse: ", time.time()-start_time)
#         print("train accuracy: ", acc_train)
#         print("test accuracy: ", acc_test)
    
    return loss_list, acc_train_list, acc_test_list, params

### Start the training

In [11]:
num_epochs = 200
batch_size = 128

# num_epochs = 200
# lr = 2e-5
# batch_size = 256

In [12]:
start_time = time.time()
loss_list, acc_train_list, acc_test_list, trained_params = run_xor(x_train, y_train, 
                                                                   x_test, y_test, 
                                                                   params, num_epochs, 
                                                                   batch_size=batch_size,
                                                                   lr=lr)
print("Total training time: ", time.time()-start_time)

MNIST:   0%|          | 0/200 [00:00<?, ?it/s]2023-06-13 10:57:49.355860: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Slow compile?  XLA was built without compiler optimizations, which can be slow.  Try rebuilding with -c opt.
Compiling module jit_update.27288
********************************
Test accuracy=0.878906:  49%|████▉     | 98/200 [09:04<08:11,  4.82s/it] 

### Analyze results

In [None]:
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd
from sklearn.metrics import confusion_matrix

In [None]:
acc, y_pred, y_true = accuracy(trained_params, x_test, y_test, batch_size)
y_pred = np.array(y_pred)
y_true = np.array(y_true)
conf_m = confusion_matrix(y_true, y_pred)
conf_m = conf_m / conf_m.astype(np.float).sum(axis=1)
df_cm = pd.DataFrame(conf_m)

plt.figure(figsize=FIGSIZE, dpi=FIGDPI)
sn.heatmap(np.round(df_cm*100, 1), annot=True, annot_kws={"size": 10})
# plt.savefig(filepath+prename+posname+'_confusion.svg', dpi=FIGDPI, format="svg")

In [None]:
fig, ax = plt.subplots(figsize=FIGSIZE, dpi=FIGDPI) 
ax.plot(acc_test_list, label='test')
ax.plot(acc_train_list, label='train')
ax.legend()
ax.set_xlabel('epochs')
ax.set_ylabel('accuracy')
ax.set_ylim([0.0,1])
# plt.savefig(filepath+prename+posname+'_accuracy.svg', dpi=FIGDPI, format="svg")