In [1]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax

class filter_env(gym.Env):

    # Define constants for clearer code
    UP = 1
    DOWN = -1

    def __init__(self, render_mode="console"):
        super(filter_env, self).__init__()

        self.ring_length = 5

        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions, we have two: left and right

        # For continuous action space, use the following three lines:

        # self.action_space = spaces.Box(
        #     low=-1, high=1, shape=(1,), dtype=np.float32
        # )

        # For discrete action space, use the  following two lines: 
        n_actions = 2
        self.action_space = spaces.Discrete(n_actions)
        
        # The observation will be the coordinate of the agent
        # this can be described both by Discrete and Box space
        self.observation_space = spaces.Box(
            low=1, high=30, shape=(1,), dtype=np.float32
        )

    def reset(self, seed=None, options=None):
        """
        Important: the observation must be a numpy array
        :return: (np.array)
        """
        super().reset(seed=seed, options=options)
        # Initialize the agent at the right of the grid
        self.ring_length = 5
        # here we convert to float32 to make it more general (in case we want to use continuous actions)
        return np.array([self.ring_length]).astype(np.float32), {}  # empty info dict

    def all_pass_analytical(self, ring_length, wl, coupling, wl0, ng, neff, loss):
        """Analytic Frequency Domain Response of an all pass filter"""
        detected = jnp.zeros_like(wl)
        transmission = 1 - coupling
        neff_wl = (
            neff + (wl0 - wl) * (ng - neff) / wl0
        )  # we expect a linear behavior with respect to wavelength
        out = jnp.sqrt(transmission) - 10 ** (-loss * ring_length / 20.0) * jnp.exp(
            2j * jnp.pi * neff_wl * ring_length / wl
        )
        out /= 1 - jnp.sqrt(transmission) * 10 ** (-loss * ring_length / 20.0) * jnp.exp(
            2j * jnp.pi * neff_wl * ring_length / wl
        )
        detected = abs(out) ** 2
        return detected

    def filter_simulation(self, ring_length):

        loss = 0.1  # [dB/μm] (alpha) waveguide loss
        neff = 2.34  # Effective index of the waveguides
        ng = 3.4  # Group index of the waveguides
        wl0 = 1.55  # [μm] the wavelength at which neff and ng are defined
        # ring_length = 10.0  # [μm] Length of the ring
        coupling = 0.5  # [] coupling of the coupler
        wl = jnp.linspace(1.5, 1.6, 1000)  # [μm] Wavelengths to sweep over
        
        detected = self.all_pass_analytical(ring_length, wl, coupling, wl0, ng, neff, loss) # non-jitted evaluation time

        T_min = min(detected)
        T_max = max(detected)
        wl_min = wl[(detected==T_min)]
        # display(T_min, T_max, wl_min)
        return T_min, T_max, wl_min
        
    def loss_function(self, T_min, T_max, wl_filter):
        T_filter_min_target = 0
        T_filter_max_target = 1
        wl_filter_norm_min = 1.5
        wl_filter_norm_max = 1.6
        wl_filter_target = 1.55
        wl_filter_norm = (wl_filter-wl_filter_norm_min)/(wl_filter_norm_max-wl_filter_norm_min)
        wl_filter_target_norm = (wl_filter_target-wl_filter_norm_min)/(wl_filter_norm_max-wl_filter_norm_min)
        loss = (pow(T_min-T_filter_min_target,2)+pow(T_max-T_filter_max_target,2)+pow(wl_filter_norm-wl_filter_target_norm,2))/3
        return loss

    def step(self, action):
        if action == self.DOWN:
            self.ring_length -= 1
        elif action == self.UP:
            self.ring_length += 1


        terminated = False
        truncated = False  # we do not limit the number of steps here

        T_filter_min, T_filter_max, wl_filter = self.filter_simulation(self.ring_length)
        reward = -self.loss_function(T_filter_min, T_filter_max, wl_filter)
        reward = float(reward)

        info = {}

        return (
            np.array([self.ring_length]).astype(np.float32),
            reward,
            terminated,
            truncated,
            info,
        )

    def close(self):
        pass

In [2]:
from stable_baselines3.common.env_checker import check_env

In [3]:
env = filter_env()
# If the environment don't follow the interface, an error will be thrown
check_env(env, warn=True)

  reward = float(reward)


In [4]:
env = filter_env()

obs, _ = env.reset()

print(env.observation_space)
print(env.action_space)
print(env.action_space.sample())

GO_DOWN = -1
# Hardcoded best agent: always go left!
n_steps = 3
for step in range(n_steps):
    print(f"Step {step + 1}")
    obs, reward, terminated, truncated, info = env.step(GO_DOWN)
    done = terminated or truncated
    print("obs=", obs, "reward=", reward, "done=", done)
    if done:
        print("Goal reached!", "reward=", reward)
        break

Box(1.0, 30.0, (1,), float32)
Discrete(2)
1
Step 1
obs= [4.] reward= -0.11491471108418626 done= False
Step 2
obs= [3.] reward= -0.4018542897279109 done= False
Step 3
obs= [2.] reward= -0.19656161952983242 done= False


  reward = float(reward)


In [5]:
from stable_baselines3 import PPO, A2C, DQN, DDPG
from stable_baselines3.common.env_util import make_vec_env

# Instantiate the env
vec_env = make_vec_env(filter_env, n_envs=1)

In [6]:
# Train the agent
model = A2C("MlpPolicy", env, verbose=1).learn(100)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


  reward = float(reward)


In [7]:
# Test the trained agent
# using the vecenv
obs = vec_env.reset()
n_steps = 40
for step in range(n_steps):
    action, _ = model.predict(obs, deterministic=True)
    print(f"Step {step + 1}")
    print("Action: ", action)
    obs, reward, done, info = vec_env.step(action)
    print("obs=", obs, "reward=", reward, "done=", done)
    vec_env.render()
    if done:
        # Note that the VecEnv resets automatically
        # when a done signal is encountered
        print("Goal reached!", "reward=", reward)
        break

Step 1
Action:  [1]
obs= [[6.]] reward= [-0.06692415] done= [False]
Step 2
Action:  [1]
obs= [[7.]] reward= [-0.10934637] done= [False]


  reward = float(reward)


Step 3
Action:  [1]
obs= [[8.]] reward= [-0.0388408] done= [False]
Step 4
Action:  [1]
obs= [[9.]] reward= [-0.10321518] done= [False]
Step 5
Action:  [1]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 6
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 7
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 8
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 9
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 10
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 11
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 12
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 13
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 14
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 15
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 16
Action:  [0]
obs= [[10.]] reward= [-0.02256382] done= [False]
Step 17
Action:  [0]
obs= [[10

In [8]:
# Train the agent 
# Only works with continuous action spaces.

model_2 = DDPG("MlpPolicy", env, verbose=1).learn(100)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


AssertionError: The algorithm only supports (<class 'gymnasium.spaces.box.Box'>,) as action spaces but Discrete(2) was provided

In [None]:
# Test the trained agent
# using the vecenv
obs = vec_env.reset()
n_steps = 20
for step in range(n_steps):
    action, _ = model_2.predict(obs, deterministic=True)
    print(f"Step {step + 1}")
    print("Action: ", action)
    obs, reward, done, info = vec_env.step(action)
    print("obs=", obs, "reward=", reward, "done=", done)
    vec_env.render()
    if done:
        # Note that the VecEnv resets automatically
        # when a done signal is encountered
        print("Goal reached!", "reward=", reward)
        break

Step 1
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 2
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 3
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 4
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 5
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 6
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 7
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 8
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 9
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 10
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 11
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 12
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 13
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]
..........x
Step 14
Action:  [1]
obs= [[10.]] reward= [0.] done= [False]