# Optimizing sBs under loss channel and factoring the recovery matrix with unitaries
9/20/24

In [4]:
import dynamiqs as dq
import qutip as qt
import jax
from jaxtyping import Array
import ruff
import diffrax as dx
import jax.numpy as jnp
import jax.scipy.linalg as jla
import equinox as eqx
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from IPython.display import display, Latex, Math
import strawberryfields as sf
import os
from jaxpulse.controllers import *
from jaxpulse.optimizers import *
from gkp_utils.utils import *
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".5"

In [5]:
# optimizer of square GKP
class SBS(eqx.Module):
    small_1: Array
    rot_1: Array
    big: Array
    rot_2: Array
    small_2: Array

    def __init__(self, small_1, rot_1, big, rot_2, small_2):
        self.small_1 = small_1
        self.rot_1 = rot_1
        self.big = big
        self.rot_2 = rot_2
        self.small_2 = small_2

    def generate_U(self):

        return jnp.array(
            [
                CD(self.small_1[i]) @ (
                dq.tensor(dq.eye(GKP_N),dq.dag(R_x(self.rot_1[i])))@(
                CD(self.big[i])@(
                dq.tensor(dq.eye(N),R_x(self.rot_2[i]))@(
                CD(self.small_2[i])))))
                for i in [0,1]
            ]
        )
    
    
    def __call__(self, N_t: int, loss_rate: float = 0.): 
        plus = 1/jnp.sqrt(2.0)*(dq.fock(2,0) + dq.fock(2,1))
        plusdm = dq.todm(plus)
        U_sbs = self.generate_U()
        U_sbs_dag = dq.dag(U_sBs)
        rho_sbs = dq.tensor(dq.fock_dm(N,0),plusdm)
        rho_t = jnp.zeros((N_t,GKP_N*2,GKP_N*2))
        rho_t = rho_t.at[0,:,:].set(rho_sbs)
        K0 = jnp.sqrt(loss_rate)*dq.tensor(dq.destroy(GKP_N),I2)
        if 1 - loss_rate*GKP_N < 0:
            raise ValueError(f"Loss rate must not exceed {1/GKP_N}")
        K0_dag = dq.dag(K0)
        K1 = dq.tensor(jnp.sqrt(IN - loss_rate*dq.number(GKP_N)), I2)
        K1_dag = dq.dag(K1)
        def loop_body(i,rho):
            traced_out = dq.ptrace(rho,0,(GKP_N,2))
            recombo = dq.tensor(traced_out, plusdm)
            # unitary evolution
            xevolve = U_sbs[0]@recombo@U_sbs_dag[0]
            xpevolve = U_sbs[1]@xevolve@U_sbs_dag[1]
            # kraus map
            mapped = K0 @ xpevolve @ K0_dag + K1 @ xpevolve @ K1_dag

            return rho.at[i,:,:].set(mapped.reshape((2*GKP_N,2*GKP_N)))
        rho_t = jax.lax.fori_loop(0,N_t,loop_body,rho_t)

        n_t = dq.expect(NI, rho_t)
        return n_t, rho_t[-1,:,:]



        
