In [6]:
import os
import sys
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import CheckpointCallback
from gymnasium import spaces
import torch

In [7]:
# Add the `src` folder to the Python path
sys.path.append(os.path.join(os.path.abspath('..'), 'src'))
import environment

In [14]:
# Wrap the TrafficEnvironment into a gym-like environment
class TrafficEnvWrapper(gym.Env):
    def __init__(self, lanes=5, initial_distance=4000):
        """
        Args:
        - lanes (int): Number of lanes (default is 5).
        - initial_distance (int): The distance from destination.
        """
        super().__init__()
        self.lanes = lanes
        self.initial_distance = initial_distance
        self.internal_env = environment.TrafficEnvironment(lanes=self.lanes, initial_distance=self.initial_distance)

        # Action space: 3 actions (0: left, 1: stay, 2: right)
        self.action_space = spaces.Discrete(3)
        
        # Observations space: 3 most recent states, each with distance, current lane, and clearance rates
        # Observation space: Flattened 1D vector with shape (3 * (2 + lanes),)
        obs_dim = 3 * (2 + self.lanes)
        self.observation_space = spaces.Box(low=0, high=4000, shape=(obs_dim,), dtype=np.float32)

    def reset(self, seed=None, options=None):
        # Set the seed for reproducibility
        super().reset(seed=seed)

        # if seed is not None:
        #     np.random.seed(seed)  # Set numpy's random seed if provided
        
        # Reset internal environment and get the initial state
        initial_state = self.internal_env.reset()

        # Ensure initial_state contains exactly 3 states by padding if necessary
        if len(initial_state) < 3:
            padding = [initial_state[0]] * (3 - len(initial_state))
            initial_state = padding + initial_state

        # Get the environment reset observation
        obs = np.array([value for state in initial_state for value in state], dtype=np.float32)
        return obs, {}

    def step(self, action):
        # Subtract 1 to map the action: 0 -> -1 (left), 1 -> 0 (stay), 2 -> 1 (right)
        mapped_action = action - 1
    
        # Take a step in the internal environment
        next_state, reward, done = self.internal_env.step(mapped_action)

        # Ensure next_state contains exactly 3 states by padding if necessary
        if len(next_state) < 3:
            padding = [next_state[0]] * (3 - len(next_state))
            next_state = padding + next_state

        # Flatten the next state history to a single array to match the observation space shape
        obs = np.array([value for state in next_state for value in state], dtype=np.float32)

        # Set truncated to False, as we are not using time limits here
        truncated = False
    
        return obs, reward, done, truncated, {}

    def render(self, mode="human"):
        pass

In [15]:
# Initialize the custom environment
wrapped_env  = TrafficEnvWrapper()

# Check the environment
check_env(wrapped_env, warn=True)

In [17]:
# Define the PPO model using your single environment
ppo_model = PPO("MlpPolicy", wrapped_env, verbose=1)

# Set up a checkpoint callback to save the model every 1000 steps
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./ppo_checkpoints/', name_prefix='ppo_traffic_model')

# Train the PPO model
total_timesteps = 50000
ppo_model.learn(total_timesteps=total_timesteps, callback=checkpoint_callback)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
-----------------------------
| time/              |      |
|    fps             | 1201 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 2.57e+03    |
|    ep_rew_mean          | -3.01e+04   |
| time/                   |             |
|    fps                  | 838         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.013220936 |
|    clip_fraction        | 0.202       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | -0.00121    |
|    learning_rate        | 0.0003      |
|    loss               

<stable_baselines3.ppo.ppo.PPO at 0x1ee906d3fd0>