<a href="https://colab.research.google.com/github/denis-chakarov/tlc-project/blob/master/TLC_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

In [None]:
class TrafficLightNetworkEnv():
    """
    Gym style environment for RL. You may also inherit the class structure from OpenAI Gym. 
    Parameters:
        n_time_steps:   int
                        Total number of time steps within each episode
        seed:   int
                seed of the RNG (for reproducibility)
    """
    
    def __init__(self, n_time_steps, n_traffic_nodes, n_tl_queue_places, n_initial_cars, seed):
        """
        Initialize the environment.
        
        """
        
        self.n_time_steps = n_time_steps
        self.n_traffic_nodes = n_traffic_nodes
        self.n_traffic_lights = n_traffic_nodes*8
        self.n_initial_cars = n_initial_cars
        self.n_tl_queue_places = n_tl_queue_places
        
        if n_traffic_nodes == 1:
          self.n_destinations = 4
        elif n_traffic_nodes == 2:
          self.n_destinations = 6
        elif n_traffic_nodes == 4:
          self.n_destinations = 8
        elif n_traffic_nodes == 6:
          self.n_destinations = 10

        self.traffic_nodes = {} # a dictionary to hold array of traffic light ids, where the key is id
        self.traffic_lights = {} # a dictionary of tuples (tl_state, queue_occupied_places), where the key is id
        self.cars = {} # a dictionary of tuples (entrance, destination, current_tl_id, current_queue_place, current_route_step, [route]). where key is id

        ### define action space variables
        # car agent actions
        self.car_actions = np.array([0,1,2]) # [go straight, turn left, turn right]

        # traffic node (intersection) agent actions        
        self.intersection_actions = np.array([0,1,2,3,4,5]) # [south-north, east-west, south, north, east, west]
        
        ### define state space variables
        # state space is defined by each car with values [tl, place, des]
                
        self.set_seed(seed)
        self.set_route_map()
        self.reset()
        
    
    def step(self, action):
        """
        Interface between environment and agent. Performs one step in the environemnt.
        Parameters:
            action: int
                    the index of the respective action in the action array
        Returns:
            output: ( object, float, bool)
                    information provided by the environment about its current state:
                    (state, reward, done)
        """

        pass

        return self.state, reward, done

    
    
    def set_seed(self,seed=0):
        """
        Sets the seed of the RNG.
        
        """
        np.random.seed(seed)
    
    
    
    def reset(self):
        """
        Resets the environment to its initial values.
        Returns:
            state:  object
                    the initial state of the environment
        """
        self.current_step = 0

        ### re-create nodes
        self.traffic_nodes = {}
        for i in range(0, self.n_traffic_nodes):
          self.traffic_nodes[i+1] = [k for k in range(i*8+1, i*8+1 + 8)]

        ### initialize traffic lights - all red
        self.traffic_lights = {}
        for i in range(1, self.n_traffic_lights+1):
          self.traffic_lights[i] = ('red', 0)

        ### place randomly cars in the network
        self.cars = {}
        i=1
        s = set()
        while i<=self.n_initial_cars:
          # a car is defined by (entrance, destination, current_tl_id, current_queue_place, current_route_step, [route])
          ent_des = np.random.choice(range(1, self.n_destinations+1), 2, replace=False)
          route = self.get_route(ent_des[0], ent_des[1])
          tl = route[0]
          place = np.random.randint(1, self.n_tl_queue_places)          
          car = (ent_des[0],
                 ent_des[1],
                 tl,
                 place,
                 0,
                 route
                 )
          # discard cars at duplicate places
          if not (tl, place) in s:
            s.add((tl, place))
            self.cars[i] = car
            self.traffic_lights[tl] = (self.traffic_lights[tl][0], self.traffic_lights[tl][1] + 1)
            i=i+1
        
        self.state = self.extract_state()

        return self.state
    
    def render(self):
        """
        Plots the state of the environment. For visulization purposes only. 

        """
        pass
    
    def set_route_map(self):
        """
        Creates and stores all possible routes in the network. A route consists of a list of traffic light ids.
        """
        self.route_map = {}

        if(self.n_traffic_nodes == 2): # 6 network exits
          self.route_map[(1,2)] = [2]
          self.route_map[(1,3)] = [1,10]
          self.route_map[(1,4)] = [1,9]
          self.route_map[(1,5)] = [1,9]
          self.route_map[(1,6)] = [1]

          self.route_map[(2,1)] = [3]
          self.route_map[(2,3)] = [4,10]
          self.route_map[(2,4)] = [4,9]
          self.route_map[(2,5)] = [4,9]
          self.route_map[(2,6)] = [3]

          self.route_map[(3,1)] = [11,5]
          self.route_map[(3,2)] = [11,5]
          self.route_map[(3,4)] = [12]
          self.route_map[(3,5)] = [11]
          self.route_map[(3,6)] = [11,6]

          self.route_map[(4,1)] = [13,5]
          self.route_map[(4,2)] = [13,5]
          self.route_map[(4,3)] = [13]
          self.route_map[(4,5)] = [14]
          self.route_map[(4,6)] = [13,6]

          self.route_map[(5,1)] = [16,5]
          self.route_map[(5,2)] = [16,5]
          self.route_map[(5,3)] = [15]
          self.route_map[(5,4)] = [15]
          self.route_map[(5,6)] = [16,6]

          self.route_map[(6,1)] = [8]
          self.route_map[(6,2)] = [7]
          self.route_map[(6,3)] = [7,10]
          self.route_map[(6,4)] = [7,9]
          self.route_map[(6,5)] = [7,9]
    
    def get_route(self, entrance, destination):
        """
        Provides the route from entrance to destination in the network(city). A route consists of a list of traffic light ids.
        Returns:
          output: list of routes
        """
        return self.route_map[(entrance, destination)]


    def extract_state(self):
      """
      Extracts the state from self.cars
      Returns:
          state:  object
                  the state of the environment
      """      
      state = []
      for k in self.cars:
        car = self.cars[k]
        state.append((car[2], car[3], car[1]))

      return state
    

In [None]:
env = TrafficLightNetworkEnv(n_time_steps=1, n_traffic_nodes=2, n_tl_queue_places=20, n_initial_cars=20, seed=1)
print(env.traffic_nodes)
print(env.state)

{1: [1, 2, 3, 4, 5, 6, 7, 8], 2: [9, 10, 11, 12, 13, 14, 15, 16]}
[(11, 10, 2), (11, 8, 6), (13, 6, 1), (3, 15, 6), (1, 1, 6), (11, 2, 1), (13, 14, 3), (7, 8, 5), (7, 18, 3), (3, 14, 6), (16, 16, 6), (14, 10, 5), (1, 7, 3), (7, 15, 5), (15, 14, 3), (7, 7, 3), (13, 8, 6), (11, 5, 2), (13, 19, 3), (16, 8, 2)]


In [None]:
env.traffic_lights

{1: ('red', 2),
 2: ('red', 0),
 3: ('red', 2),
 4: ('red', 0),
 5: ('red', 0),
 6: ('red', 0),
 7: ('red', 4),
 8: ('red', 0),
 9: ('red', 0),
 10: ('red', 0),
 11: ('red', 4),
 12: ('red', 0),
 13: ('red', 4),
 14: ('red', 1),
 15: ('red', 1),
 16: ('red', 2)}

In [None]:
env.cars

{1: (3, 2, 11, 10, 0, [11, 5]),
 2: (3, 6, 11, 8, 0, [11, 6]),
 3: (4, 1, 13, 6, 0, [13, 5]),
 4: (2, 6, 3, 15, 0, [3]),
 5: (1, 6, 1, 1, 0, [1]),
 6: (3, 1, 11, 2, 0, [11, 5]),
 7: (4, 3, 13, 14, 0, [13]),
 8: (6, 5, 7, 8, 0, [7, 9]),
 9: (6, 3, 7, 18, 0, [7, 10]),
 10: (2, 6, 3, 14, 0, [3]),
 11: (5, 6, 16, 16, 0, [16, 6]),
 12: (4, 5, 14, 10, 0, [14]),
 13: (1, 3, 1, 7, 0, [1, 10]),
 14: (6, 5, 7, 15, 0, [7, 9]),
 15: (5, 3, 15, 14, 0, [15]),
 16: (6, 3, 7, 7, 0, [7, 10]),
 17: (4, 6, 13, 8, 0, [13, 6]),
 18: (3, 2, 11, 5, 0, [11, 5]),
 19: (4, 3, 13, 19, 0, [13]),
 20: (5, 2, 16, 8, 0, [16, 5])}