In [None]:
import matplotlib.pyplot as plt
plt.plot([1])

import numpy as np
import gymnasium as gym
from gymnasium import spaces

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax

import logging
import sys
from functools import partial
from pprint import pprint

import gdsfactory as gf
import jax
import jax.example_libraries.optimizers as opt
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sax
from gdsfactory.generic_tech import get_generic_pdk
from numpy.fft import fft2, fftfreq, fftshift, ifft2
from rich.logging import RichHandler
from scipy import constants
from sklearn.linear_model import LinearRegression
from tqdm.notebook import trange

import gplugins.sax as gs
import gplugins.tidy3d as gt
from gplugins.common.config import PATH

gf.config.rich_output()
PDK = get_generic_pdk()
PDK.activate()

logger = logging.getLogger()
logger.removeHandler(sys.stderr)
logging.basicConfig(level="WARNING", datefmt="[%X]", handlers=[RichHandler()])

gf.config.set_plot_options(show_subports=False)

from math import e

In [None]:
global circuit_parameter_input, punishment_parameter_1, punishment_parameter_2, loss_function, s_parameter_extraction
parameter_array_norm = np.zeros(3)
parameter_min_array = [19.5, 1, 0]
parameter_max_array = [20.5, 5, 4]
circuit_parameter_input = (parameter_array_norm, parameter_min_array, parameter_max_array)
punishment_parameter_1 = 1
punishment_parameter_2 = 1 

def loss_function(S):
    loss= abs(S["o1", "o2"]) ** 2
    return loss

def s_parameter_extraction(parameter_array): # This part will change
    
    def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
        return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})

    def mmi1x2():
        """Assumes a perfect 1x2 splitter"""
        return sax.reciprocal(
            {
                ("o1", "o2"): 0.5**0.5,
                ("o1", "o3"): 0.5**0.5,
            }
        )

    def bend_euler(wl=1.5, length=20.0):
        """ "Let's assume a reduced transmission for the euler bend compared to a straight"""
        return {k: 0.98 * v for k, v in straight(wl=wl, length=length).items()}

    def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0) -> sax.SDict:
        dwl = wl - wl0
        dneff_dwl = (ng - neff) / wl0
        neff = neff - dwl * dneff_dwl
        phase = 2 * jnp.pi * neff * length / wl
        transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
        return sax.reciprocal(
            {
                ("o1", "o2"): transmission,
            }
        )

    models = {
        "bend_euler": bend_euler,
        "mmi1x2": mmi1x2,
        "straight": waveguide,
    }

    mzi_component = gf.components.mzi(
        delta_length=parameter_array[0], length_x=parameter_array[2], length_y=parameter_array[1],
    )
        
    mzi_circuit, _ = sax.circuit(
        netlist=mzi_component.get_netlist(),
        models=models,
    )

    S = mzi_circuit(
        wl=1.53,
        syl={
            "length": parameter_array[0] / 2 + 2,
        },
        straight_9={
            "length": parameter_array[0] / 2 + 2,
        },
    )

    return S

In [None]:
class circuit_env(gym.Env):

    def __init__(self):
        super(circuit_env, self).__init__()

        self.parameter_array_norm = circuit_parameter_input[0]
        self.parameter_min_array = circuit_parameter_input[1]
        self.parameter_max_array = circuit_parameter_input[2]

        self.action_space = spaces.Box(
            low=-1, high=1, shape=(len(self.parameter_array_norm),), dtype=np.float32
        )
        self.observation_space = spaces.Box(
            low=-1, high=1, shape=(len(self.parameter_array_norm),), dtype=np.float32
        )

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        self.parameter_array_norm = np.zeros(3)
        return np.array(self.parameter_array_norm).astype(np.float32), {} 

    def step(self, action):
        i=0
        punishment_total = 0
        parameter_array = self.reverse_normalization()
        S = s_parameter_extraction(parameter_array)
        while i < len(action):
            if ((action[i] <= 0) & (self.parameter_array_norm[i]+action[i]>-1)) | ((action[i] >= 0) & (self.parameter_array_norm[i]+action[i]<1)):
                self.parameter_array_norm[i] += float(action[i])
                punishment_total = punishment_total
            else:
                punishment_total = punishment_total + punishment_parameter_1
            i=i+1
        punishment_total = punishment_total + sum(abs(action))*punishment_parameter_2
        reward = float(1-loss_function(S)-punishment_total)

        terminated = False
        truncated = False  # we do not limit the number of steps here
        info = {}
        return (
            np.array(self.parameter_array_norm).astype(np.float32),
            reward,
            terminated,
            truncated,
            info,
        )

    def close(self):
        pass

    def reverse_normalization(self):
        parameter_array = np.zeros(len(self.parameter_array_norm))
        i = 0 
        while i < len(self.parameter_array_norm):
            parameter_array[i] = self.parameter_min_array[i] + (self.parameter_max_array[i] - self.parameter_min_array[i]) * (self.parameter_array_norm[i]+1)/2
            i = i + 1
        return parameter_array


In [None]:
from stable_baselines3.common.env_checker import check_env
env = circuit_env()
# If the environment don't follow the interface, an error will be thrown
check_env(env, warn=True)

from stable_baselines3 import PPO, A2C, DQN, DDPG
from stable_baselines3.common.env_util import make_vec_env

# Instantiate the env
vec_env = make_vec_env(circuit_env, n_envs=1)

In [16]:
# Train the agent
model = PPO("MlpPolicy", env, verbose=0, n_steps=20, gamma=0.95).learn(4000, progress_bar=False)



In [None]:
# Test the trained agent
# using the vecenv


obs = vec_env.reset()
obs_array = []
reward_array = []
action_array = []
n_steps = 20
steps_array = np.linspace(1,n_steps,n_steps)
for step in steps_array:
    action, _ = model.predict(obs, deterministic=True)
    #print(f"Step {step + 1}")
    # print("Action: ", action)
    obs, reward, done, info = vec_env.step(action)
    # print("obs=", obs, "reward=", reward, "done=", done)
    vec_env.render()
    obs_array = obs_array + [obs]
    reward_array = reward_array + [reward]
    action_array = action_array + [action*0.5]
    # if reward>1.95:
    #     # Note that the VecEnv resets automatically
    #     # when a done signal is encountered
    #     print("Goal reached!", "reward=", reward)
    #     break

display("Final results: delta_length = " + str(obs[0][0]*0.5+20) + ", length_y = " + str(obs[0][1]*2+3) + ", length_x = " + str(obs[0][2]*2+2) +  ", reward = "+str(reward) )

plt.figure()
plt.plot(reward_array)
plt.xlabel("steps")
plt.ylabel("reward")
plt.grid()

obs_array_ = np.array(obs_array)
action_array_ = np.array(action_array)

plt.figure()
line1,=plt.plot(obs_array_[:,:,0]*0.5+20, label="delta_length")
line2,=plt.plot(obs_array_[:,:,1]*2+3, label="length_y")
line3,=plt.plot(obs_array_[:,:,2]*2+2, label=("length_x"))
plt.xlabel("steps")
plt.legend(handles=[line1, line2, line3])
plt.grid()

plt.figure()
line1,=plt.plot(action_array_[:,:,0]*0.5, label="delta_length-action")
line2,=plt.plot(action_array_[:,:,1]*2, label="length_y-action")
line3,=plt.plot(action_array_[:,:,2]*2, label="length_x-action")
plt.xlabel("steps")
plt.legend(handles=[line1, line2, line3])
plt.grid()

In [None]:
# Train the agent
model_2 = PPO("MlpPolicy", env, verbose=0, n_steps=10, gamma=0.90).learn(500, progress_bar=True)
# Test the trained agent
# using the vecenv


obs = vec_env.reset()
obs_array = []
reward_array = []
action_array = []
n_steps = 50
steps_array = np.linspace(1,n_steps,n_steps)
for step in steps_array:
    action, _ = model_2.predict(obs, deterministic=True)
    #print(f"Step {step + 1}")
    # print("Action: ", action)
    obs, reward, done, info = vec_env.step(action)
    # print("obs=", obs, "reward=", reward, "done=", done)
    vec_env.render()
    obs_array = obs_array + [obs]
    reward_array = reward_array + [reward]
    action_array = action_array + [action*0.5]
    # if reward>1.95:
    #     # Note that the VecEnv resets automatically
    #     # when a done signal is encountered
    #     print("Goal reached!", "reward=", reward)
    #     break

display("Final results: delta_length = " + str(obs[0][0]*0.5+20) + ", length_y = " + str(obs[0][1]*2+3) + ", length_x = " + str(obs[0][2]*2+2) +  ", reward = "+str(reward) )

plt.figure()
plt.plot(reward_array)
plt.xlabel("steps")
plt.ylabel("reward")
plt.grid()

obs_array_ = np.array(obs_array)
action_array_ = np.array(action_array)

plt.figure()
line1,=plt.plot(obs_array_[:,:,0]*0.5+20, label="delta_length")
line2,=plt.plot(obs_array_[:,:,1]*2+3, label="length_y")
line3,=plt.plot(obs_array_[:,:,2]*2+2, label=("length_x"))
plt.xlabel("steps")
plt.legend(handles=[line1, line2, line3])
plt.grid()

plt.figure()
line1,=plt.plot(action_array_[:,:,0]*0.5, label="delta_length-action")
line2,=plt.plot(action_array_[:,:,1]*2, label="length_y-action")
line3,=plt.plot(action_array_[:,:,2]*2, label="length_x-action")
plt.xlabel("steps")
plt.legend(handles=[line1, line2, line3])
plt.grid()

In [None]:
obs = vec_env.reset()
obs_array = []
reward_array = []
action_array = []
n_steps = 40
steps_array = np.linspace(1,n_steps,n_steps)
for step in steps_array:
    action, _ = model_2.predict(obs, deterministic=True)
    #print(f"Step {step + 1}")
    # print("Action: ", action)
    obs, reward, done, info = vec_env.step(action)
    # print("obs=", obs, "reward=", reward, "done=", done)
    vec_env.render()
    obs_array = obs_array + [obs]
    reward_array = reward_array + [reward]
    action_array = action_array + [action*0.5]
    # if reward>1.95:
    #     # Note that the VecEnv resets automatically
    #     # when a done signal is encountered
    #     print("Goal reached!", "reward=", reward)
    #     break

display("Final results: delta_length = " + str(obs[0][0]*0.5+20) + ", length_y = " + str(obs[0][1]*2+3) + ", length_x = " + str(obs[0][2]*2+2) +  ", reward = "+str(reward) )

plt.figure()
plt.plot(reward_array)
plt.xlabel("steps")
plt.ylabel("reward")
plt.grid()

obs_array_ = np.array(obs_array)
action_array_ = np.array(action_array)

plt.figure()
line1,=plt.plot(obs_array_[:,:,0]*0.5+20, label="delta_length")
line2,=plt.plot(obs_array_[:,:,1]*2+3, label="length_y")
line3,=plt.plot(obs_array_[:,:,2]*2+2, label=("length_x"))
plt.xlabel("steps")
plt.legend(handles=[line1, line2, line3])
plt.grid()

plt.figure()
line1,=plt.plot(action_array_[:,:,0]*0.5, label="delta_length-action")
line2,=plt.plot(action_array_[:,:,1]*2, label="length_y-action")
line3,=plt.plot(action_array_[:,:,2]*2, label="length_x-action")
plt.xlabel("steps")
plt.legend(handles=[line1, line2, line3])
plt.grid()