In [1]:
from collections import namedtuple

from gym.wrappers import RescaleAction, TimeLimit
import numpy as np
import pandas as pd
from stable_baselines3 import TD3
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from tqdm import tqdm

from environments import ARESEASequential

In [2]:
def load_sequential(model_name, max_episode_steps=50, measure_beam="us"):

    ModelSetup = namedtuple("ModelSetup", ["name","env","model","max_episode_steps","measure_beam"])

    log_dir = f"models/{model_name}"

    def make_env():
        env = ARESEASequential(
            backend="machine",
            backendargs={"measure_beam": measure_beam}
        )
        env = TimeLimit(env, max_episode_steps=max_episode_steps)
        env = RescaleAction(env, -1, 1)
        return env

    env = DummyVecEnv([make_env])
    env = VecNormalize.load(f"{log_dir}/vec_normalize.pkl", env)
    env.training = False
    env.norm_reward = False

    model = TD3.load(f"{log_dir}/model")

    return ModelSetup(model_name, env, model, max_episode_steps, measure_beam)

In [10]:
def pack_dataframe(fn):
    def wrapper(setup, desired):
        observations, rewards, beam_images = fn(setup, desired)
        observations = np.array(observations)

        df = pd.DataFrame(np.arange(len(observations)), columns=["step"])
        df["q1"] = observations[:,0]
        df["q2"] = observations[:,1]
        df["cv"] = observations[:,2]
        df["q3"] = observations[:,3]
        df["ch"] = observations[:,4]
        df["mup_x"] = observations[:,5]
        df["mup_y"] = observations[:,6]
        df["sigmap_x"] = observations[:,7]
        df["sigmap_y"] = observations[:,8]
        df["mu_x"] = observations[:,9]
        df["mu_y"] = observations[:,10]
        df["sigma_x"] = observations[:,11]
        df["sigma_y"] = observations[:,12]
        df["reward"] = [np.nan] + rewards
        df["beam_image"] = beam_images

        df["model_name"] = setup.name
        df["max_episode_steps"] = setup.max_episode_steps
        df["measure_beam"] = setup.measure_beam

        return df
    
    return wrapper

In [11]:
@pack_dataframe
def run(setup, desired):
    env, model = setup.env, setup.model

    env.get_attr("unwrapped")[0].next_initial = np.zeros(5)
    env.get_attr("unwrapped")[0].next_desired = desired

    observations = []
    rewards = []
    beam_images = []

    observation = env.reset()
    observations.append(env.unnormalize_obs(observation).squeeze())
    beam_images.append(env.get_attr("backend")[0].last_beam_image)

    with tqdm(total=setup.max_episode_steps) as pbar:
        done = False
        while not done:
            action, _ = model.predict(observation, deterministic=True)
            observation, reward, done, info = env.step(action)

            observations.append(env.unnormalize_obs(observation).squeeze())
            rewards.append(reward.squeeze())
            beam_images.append(env.get_attr("backend")[0].last_beam_image)

            pbar.update(1)

    observations[-1] = env.unnormalize_obs(info[0]["terminal_observation"].squeeze())

    return observations, rewards, beam_images

In [12]:
setup = load_sequential("amber-mountain-976")

In [13]:
df = run(setup, np.zeros(4))

100%|██████████| 50/50 [07:14<00:00,  8.68s/it]


Unnamed: 0,step,q1,q2,cv,q3,ch,mup_x,mup_y,sigmap_x,sigmap_y,mu_x,mu_y,sigma_x,sigma_y,reward,beam_image,model_name,max_episode_steps,measure_beam
0,0,-8.423128e-08,8.02262e-09,1.95862e-08,1.964785e-14,6.518805e-13,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-0.002191068,0.000318097,0.00035,0.000411,,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
1,1,-1.278085,2.920598,3.0,-0.0001214446,0.0005999974,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-0.0009494628,0.0003034156,0.000186,0.000582,0.36849228,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
2,2,-4.205281,5.516942,5.994088,-0.0003000994,0.001076514,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-0.0002058276,4.8938e-05,9.6e-05,0.00049,0.6697913,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
3,3,-6.357828,7.190644,8.364227,-0.0003508193,0.0006845113,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,0.0002257464,-0.0001125574,2.3e-05,0.000353,0.26966745,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
4,4,-7.643978,7.363797,9.814283,-0.0002847884,0.000111344,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,0.0001726296,-0.000146814,2.3e-05,0.000266,0.19535112,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
5,5,-7.708496,5.925903,11.04495,-0.0002015532,-0.000209068,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-0.0001062336,-9.29822e-05,2.3e-05,0.000254,0.17663863,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
6,6,-7.243416,4.27239,12.3737,-0.0001444104,-9.251322e-05,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-0.0001062336,-2.4469e-05,2.3e-05,0.000254,0.09561317,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
7,7,-6.763373,2.689986,13.6969,-0.0001300711,1.631229e-05,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-2.65584e-05,-9.7876e-06,2.3e-05,0.000249,0.16293322,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
8,8,-6.342036,1.260316,14.64613,-0.0001276819,2.470362e-05,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-1.32792e-05,-1.46814e-05,2.3e-05,0.000245,0.02920192,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
9,9,-6.08275,0.2698187,15.21941,-0.0001205976,5.052481e-06,6.019531e-14,-2.626839e-13,4.67142e-12,-4.424122e-12,-2.65584e-05,-1.46814e-05,2.3e-05,0.000237,0.005954224,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",amber-mountain-976,50,us
