### Imports and Configuration

In [2]:
import gymnasium as gym

import time
import numpy as np
import matplotlib.pyplot as plt
import itertools
from tqdm.notebook import tqdm
import pandas as pd
import pickle

In [3]:
configuration = {

    # Parametrization bellow cannot be changed
    "lanes_count" : 10, # The environment must always have 10 lanes
    "vehicles_count": 50, # The environment must always have 50 other vehicles
    "duration": 120,  # [s] The environment must terminate never before 120 seconds
    "other_vehicles_type": "highway_env.vehicle.behavior.IDMVehicle", # This is the policy of the other vehicles
    "initial_spacing": 2, # Initial spacing between vehicles needs to be at most 2

    # Refer to refer to https://highway-env.farama.org/observations/ to change observation space type
    "observation": {
        "type": "Kinematics"
    },

    # Refer to refer to https://highway-env.farama.org/actions/ to change action space type
    "action": {
        "type": "DiscreteMetaAction",
    },

    # Parameterization bellow can be changed (as it refers mostly to the reward system)
    # "collision_reward": -10,  # The reward received when colliding with a vehicle. (Can be changed)
    # "reward_speed_range": [20, 30],  # [m/s] The reward for high speed is mapped linearly from this range to [0, HighwayEnv.HIGH_SPEED_REWARD]. (Can be changed)
    "simulation_frequency": 15, #15,  # [Hz] (Can be changed)
    "policy_frequency": 5, #5,  # [Hz] (Can be changed)

    "collision_reward": -1000,  # The reward received when colliding with a vehicle.
    "right_lane_reward": 0.1,  # The reward received when driving on the right-most lanes, linearly mapped to
    # zero for other lanes.
    "high_speed_reward": 5,  # The reward received when driving at full speed, linearly mapped to zero for
    # lower speeds according to config["reward_speed_range"].
    "lane_change_reward": 0,  # The reward received at each lane change action.
    "reward_speed_range": [20, 30],
    
    # Parameters defined bellow are purely for visualiztion purposes! You can alter them as you please
    "screen_width": 800,  # [px]
    "screen_height": 600,  # [px]
    "centering_position": [0.5, 0.5],
    "scaling": 5,
    "show_trajectories": True,
    "render_agent": True,
    "offscreen_rendering": False
}

In [4]:
default_config = configuration.copy()

### Checking the environment 

In [5]:
occupancyGrid = configuration.copy()
occupancyGrid["observation"] =  {
    "type": "OccupancyGrid",
    # "vehicles_count": 50,
    "features": [
                "presence",
                #"x", "y", 
                #"vx", "vy"
                ],
    # "features_range": {
    #      "x": [-500, 500],
    #      "y": [-500, 500],
    #     "vx": [-20, 20],
    #     "vy": [-20, 20]
    # },
    "grid_size": [[-100, 100], [-100, 100]],    # X controls how many lanes, Y controls how far ahead
    "grid_step": [1, 1],
    #"absolute": False,                     # Not implemented in the library
    #"as_image": True,
    # "align_to_vehicle_axes" : True
}

# The higher the number, the more frequent the policy and the simulation frequencies, the slower the simulation
occupancyGrid["simulation_frequency"] = 15
occupancyGrid["policy_frequency"] = 5

env = gym.make('highway-v0', render_mode='human', config=occupancyGrid)

obs, info = env.reset(seed = 30)

env.close()

obs

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

In [207]:
class OccupancyGrid():
    def __init__(
            self,
            grid_size=[[-50, 50], [-50, 50]],  # X controls the lane-width, Y controls how far ahead
            grid_step=[1, 1],
            n_closest=3,
            ss_bins=[5,6],
            crop_dist=[[-10,10], [-10,25]],
            policy=None,
            sim_frequency=15,
            policy_frequency=5,
            render_mode = 'human',
            seed = 50,
    ):
        """
        Occupancy view class constructor
        Arguments:
            grid_size: list of lists, the size of the grid in the x and y direction, where x controls the lane-width and y controls how far ahead. Lanes are 5m wide, and the car position is (0,0)
            grid_step: list, the step size of the grid in the x and y direction, in meters
            n_closest: int, the number of closest cars to consider in the state space
            ss_bins: list, the number of bins to divide the x and y directions
            crop_dist: list of lists, the distance to crop the x and y directions, above which the values will be clipped
            policy: function, the policy to use in the simulation
            sim_frequency: int, the frequency of the simulation
            policy_frequency: int, the frequency of the policy
            render_mode: str, the mode to render the simulation
            seed: int, the seed to use in the simulation
        """

        self.grid_size = grid_size
        self.grid_step = grid_step
        self.config = default_config.copy()
        self.config["observation"] =  {
            "type": "OccupancyGrid",
            "features": ["presence"],
            "grid_size": grid_size,    # X controls how many lanes, Y controls how far ahead
            "grid_step": grid_step,
        }
        self.config["simulation_frequency"] = sim_frequency
        self.config["policy_frequency"] = policy_frequency
        self.render_mode = render_mode
        self.seed = seed
        self.policy = policy
        self.n_closest, self.ss_bins, self.crop_dist = n_closest, ss_bins, crop_dist
        self.initialize_states()

    def initialize_states(self):
        """
        Initialize the states of the occupancy grid
        """
        # States will be stored in a dictionary, with the key being ((x1,x2,...,xn), (y1,y2,...,yn)), and n is the number of neighbors
        # Make ss_bins[0] from the crop_dist[0] and ss_bins[1] from crop_dist[1]
        self.x_bins = np.linspace(self.crop_dist[0][0], self.crop_dist[0][1], self.ss_bins[0])
        self.y_bins = np.linspace(self.crop_dist[1][0], self.crop_dist[1][1], self.ss_bins[1])

        # Each of the nearest neighbors will have a state of the form (x,y). Create the first key of the dictionary in the form (x1,x2,...xn)
        x_keys = list(itertools.product(self.x_bins, repeat=self.n_closest))
        y_keys = list(itertools.product(self.y_bins, repeat=self.n_closest))
        self.states = list(itertools.product(x_keys, y_keys))

    def get_car_positions(self):
        positions = np.nonzero(self.current_obs[0])
        car_positions = np.array([positions[0]*self.grid_step[0] + self.grid_size[0][0], positions[1]*self.grid_step[1] + self.grid_size[1][0]]).T
        return car_positions

    def get_n_closest(self):
        car_positions = self.get_car_positions()
        distances = np.linalg.norm(car_positions, axis=1)

        # Remove the agent position
        closest = np.argsort(distances)[1:self.n_closest+1]
        closest_car_positions = car_positions[closest]

        # If there are less than n_closest cars, pad the array with the crop_dist values
        if len(closest_car_positions) < self.n_closest:
            n_missing = self.n_closest - len(closest_car_positions)
            closest_car_positions = np.pad(closest_car_positions, ((0, n_missing), (0,0)), 'constant', constant_values=(self.crop_dist[0][0], self.crop_dist[1][0]))

        # Values that are 
        return closest_car_positions
    
    def get_state(self):
        n_closest = self.get_n_closest()
        # For the closest cars, get the state
        state_x, state_y = [], []
        # Get the bin values for each of the x,y positions, and return a tuple with the values
        for car in n_closest:
            x = np.digitize(car[0], self.x_bins) - 1
            y = np.digitize(car[1], self.y_bins) - 1
            x_val, y_val = self.x_bins[x], self.y_bins[y]
            state_x.append(x_val)
            state_y.append(y_val)
        state = (tuple(state_x), tuple(state_y))
        return tuple(state)

    def test_env(self):
        """
        Function to test the environment with a random policy, or with a policy
        """
        env = gym.make('highway-v0', render_mode=self.render_mode, config=self.config)
        obs, info = env.reset(seed = self.seed)
        self.current_obs = obs
        done = False
        while not done:
            start = time.time()
            if self.policy is None:
                action = env.action_space.sample()
            else:
                action = self.policy()
            obs, reward, done, truncate, info = env.step(action)
            self.current_obs = obs
            # print(self.get_state())
            # time.sleep(2)
            end = time.time()
            print(f"Time taken: {end-start}")
        env.close()
        return info["score"]


# ------------------- SARSA -------------------    
class Sarsa(OccupancyGrid):
    def __init__(
        self,
        alpha=0.75,
        gamma=0.99,
        m=1000,
        epsilon=0.5,
        **kwargs,
        ):
        """
        SARSA class constructor
        Arguments:
            alpha: float, the learning rate
            gamma: float, the discount factor
            m: int, the number of episodes to train the agent for
            epsilon: float, the epsilon value for the epsilon-greedy policy
        """
        
        super().__init__(**kwargs)
        self.initialize_Q()
        self.alpha, self.gamma, self.m, self.epsilon = alpha, gamma, m, epsilon

    def policy_Q(self, state):
        values = [self.Q[(state, action)] for action in range(5)]
        return np.argmax(values)   

    def initialize_Q(self):
        # Combine the possible states with the possible actions
        keys = list(itertools.product(self.states, range(5)))       # 5 possible actions, 0-4: left, idle, right, accelerate, decelerate
        if len(keys) > 150000:
            print("Warning: The number of states is too large, consider reducing the number of states")
        self.Q = {key: 0 for key in keys}

    def epsilon_greedy(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.randint(5)
        else:
            values = [self.Q[(state, action)] for action in range(5)]
            return np.argmax(values)
        
    def train(self): 
        env = gym.make('highway-v0', render_mode=None, config=self.config)
        obs, info = env.reset(seed = self.seed)
        self.current_obs = obs
        done = False
        state = self.get_state()
        action = self.epsilon_greedy(state)
        for i in tqdm(range(self.m)):
            env.reset()
            done = False
            while not done:
                next_obs, reward, done, truncate, info = env.step(action)
                next_state = self.get_state()
                next_action = self.epsilon_greedy(next_state)
                self.Q[(state, action)] += self.alpha*(reward + self.gamma*self.Q[(next_state, next_action)] - self.Q[(state, action)])
                state, action = next_state, next_action
                self.current_obs = next_obs
        env.close()

    def test(self):
        env = gym.make('highway-v0', render_mode=self.render_mode, config=self.config)
        obs, info = env.reset(seed = self.seed)
        self.current_obs = obs
        done = False
        state = self.get_state()
        action = self.policy_Q(state)
        while not done:
            next_obs, reward, done, truncate, info = env.step(action)
            next_state = self.get_state()
            next_action = self.policy_Q(next_state)
            state, action = next_state, next_action
            self.current_obs = next_obs
            print(next_state)
        env.close()
        return info["score"]

    def run(self, n_episodes):
        pass


In [None]:
a.Q = s.Q.copy()

In [208]:
a = Sarsa()

In [211]:
count = 0
for key, value in a.Q.items():
    if value != 0:
        count += 1
print(100*count/len(a.Q))

3.071851851851852


In [212]:
a.train()

  0%|          | 0/1000 [00:00<?, ?it/s]

In [132]:
new_occupancy = OcupancyGrid(render_mode=None)
# print(new_occupancy.x_bins, new_occupancy.y_bins, len(new_occupancy.states))
# new_occupancy.test_env()

___________________________
### Old stuff

In [203]:
def return_car_positions(obs, grid_step=1, grid_size=50):
    # car_positions = []
    # for i in range(int(grid_size*2 / grid_step)):
    #     for j in range(int(grid_size*2 / grid_step)):
    #         if obs[0,i,j] == 1:
    #             car_positions.append([i*grid_step - grid_size, j*grid_step - grid_size])
    positions = np.nonzero(obs[0])
    car_positions = list(zip(positions[0]*grid_step - grid_size, positions[1]*grid_step - grid_size))
    return car_positions

def get_n_closest(obs, grid_step=1, grid_size=50, n=3):
    car_positions = return_car_positions(obs, grid_step, grid_size)
    distances = [np.linalg.norm(car) for car in car_positions]
    closest = np.argsort(distances)[1:n+1]
    return [car_positions[i] for i in closest]

def get_distances(car_positions):
    return [np.linalg.norm(car) for car in car_positions]

In [45]:
env.action_space.sample()

0

The action space is as follows
</br></br>
ACTIONS_ALL = {
        0: 'LANE_LEFT',
        1: 'IDLE',
        2: 'LANE_RIGHT',
        3: 'FASTER',
        4: 'SLOWER'
    }

In [204]:
# Render the environment slow motion and print the observations
env = gym.make('highway-v0', render_mode='human', config=occupancyGrid)
env.reset(seed = 500)
env.render()
for _ in range(100):
    print(env.action_space.sample())
    obs, reward, done, truncate, info = env.step(env.action_space.sample())
    env.render()
    time.sleep(0.1)
    #print(obs)
    if done:
        break
env.close()

[9.848857801796104, 22.47220505424423]
[10.295630140987, 22.47220505424423]
[10.63014581273465, 21.095023109728988]
[11.313708498984761, 21.095023109728988]
[12.041594578792296, 20.591260281974]
[13.601470508735444, 19.697715603592208]
[15.264337522473747, 18.384776310850235]
[16.1245154965971, 18.027756377319946]
[16.492422502470642, 17.88854381999832]
[16.278820596099706, 18.788294228055936]
[15.132745950421556, 18.788294228055936]
[14.560219778561036, 17.88854381999832]
[14.560219778561036, 17.88854381999832]
[13.601470508735444, 17.88854381999832]
[12.649110640673518, 17.88854381999832]
[11.40175425099138, 18.788294228055936]
[9.486832980505138, 18.788294228055936]
[7.615773105863909, 18.384776310850235]
[5.385164807134504, 18.973665961010276]
[3.605551275463989, 18.439088914585774]
