In [None]:
!pip install stable_baselines3

Collecting stable_baselines3
  Downloading stable_baselines3-2.4.0-py3-none-any.whl.metadata (4.5 kB)
Collecting gymnasium<1.1.0,>=0.29.1 (from stable_baselines3)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium<1.1.0,>=0.29.1->stable_baselines3)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading stable_baselines3-2.4.0-py3-none-any.whl (183 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium, stable_baselines3
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0 stable_baselines3-2.

In [None]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from gym.spaces import Box, MultiDiscrete

  from jax import xla_computation as _xla_computation


In [None]:
class SolventSwitchEnv(gym.Env):
    def __init__(self):
        super(SolventSwitchEnv, self).__init__()

        # Action space: Jacket Temperature and Solvent Flowrate increments (discrete changes)
        self.action_space = MultiDiscrete([7])

        # Observation space: Continuous states as described (14 variables)
        self.observation_space = Box(low=np.array([-np.inf]*14), high=np.array([np.inf]*14), dtype=np.float32)

        # Initial state and constraints
        self.state = None
        self.done = False
        self.constraints = {
            'volume': (500, 1200),
            'jacket_temp': (20, 80),
            'flowrate': (0, 90),
            'condenser_load': (0, 160)
        }

    def reset(self):
        # Initialize the state based on Table 1
        self.state = np.array([1139.48, 0.0, 69.55, 13.45, 17.0, 22.72, 23.86] + [0.0]*7)
        # 8: condenser load, 7: Flow rate, 6: Jacket temperature, 0: Volume, 1: x_crys, 2: x_reac, 3: x_imp, 9: Inert Gas, 4: x_api ,5: T_Reb
        self.done = False
        return self.state

    def step(self, action):
        # Simulate state transition logic based on the document's dynamics
        next_state = self.state.copy()
        if action == 0:
            next_state[7] += 0
            next_state[6] += 0
        elif action == 1:
            next_state[7] +=  0.05
            next_state[6]+=1
        elif action == 2:
            next_state[7] -= 0.05
            next_state[6]-=1
        elif action == 3:
            next_state[7] += 0.1
            next_state[6]+=5
        elif action == 4:
            next_state[7] -= 0.1
            next_state[6]-=5
        elif action == 5:
            next_state[6]+=10
        elif action == 6:
            next_state[6]-=10


        # Reward Calculation: Reward is exponential decay based on reaction solvent amount
        reaction_solvent_amount = next_state[2] * next_state[0]
        reward = np.exp(-reaction_solvent_amount)

        # Logarithmic constraint penalties
        for key, (min_val, max_val) in self.constraints.items():
            if key == 'volume':
                value = next_state[0]
            elif key == 'jacket_temp':
                value = next_state[6]
            elif key == 'flowrate':
                value = next_state[7]
            elif key == 'condenser_load':
                value = next_state[8]

            if value < min_val or value > max_val:
                reward -= 100 * np.log(abs(value - min_val if value < min_val else value - max_val) + 1e-6)

        # Check termination conditions
        if next_state[2] <= 10.62 and next_state[3] <= 6.09 and next_state[1] >= 48.07:
            self.done = True

        self.state = next_state
        return next_state, reward, self.done, {}

In [None]:
!pip install shimmy

Collecting shimmy
  Downloading Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Downloading Shimmy-2.0.0-py3-none-any.whl (30 kB)
Installing collected packages: shimmy
Successfully installed shimmy-2.0.0


In [None]:
env = DummyVecEnv([lambda: SolventSwitchEnv()])



In [None]:
model = PPO(
    "MlpPolicy", env, verbose=1, learning_rate=0.001, gamma=0.998,
    n_steps=32, batch_size=64, clip_range=0.2, policy_kwargs={"net_arch": [16, 16]}
)

# Train PPO Model
model.learn(total_timesteps=100000)
model.save("ppo_solvent")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
-----------------------------
| time/              |      |
|    fps             | 1146 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 32   |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 584          |
|    iterations           | 2            |
|    time_elapsed         | 0            |
|    total_timesteps      | 64           |
| train/                  |              |
|    approx_kl            | 0.0073076505 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.94        |
|    explained_variance   | 1.19e-07     |
|    learning_rate        | 0.001        |
|    loss                 | -0.0389      |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0192      |
|    val



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
|    time_elapsed         | 252           |
|    total_timesteps      | 91136         |
| train/                  |               |
|    approx_kl            | 3.2726675e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.94         |
|    explained_variance   | -1.19e-07     |
|    learning_rate        | 0.001         |
|    loss                 | 1.56e+08      |
|    n_updates            | 28470         |
|    policy_gradient_loss | -0.000351     |
|    value_loss           | 3.13e+08      |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 360          |
|    iterations           | 2849         |
|    time_elapsed         | 252          |
|    total_timesteps      | 91168        |
| train/                  |              |
|    approx_kl    

In [None]:
model =PPO.load('ppo_solvent')
obs = env.reset()

time_steps = 0
done = False

while not done:
    action, _states = model.predict(obs)  # Get the agent's action
    obs, reward, done, info = env.step(action) # Apply the action
    time_steps += 1

time_in_minutes = time_steps * 5  # each step corresponds to 5 minutes
print(f"Time required to reach x_reac <= 10.62: {time_in_minutes} minutes")