In [None]:
import numpy as np

from flow.core.params import InitialConfig
from flow.core.params import TrafficLightParams
from flow.networks.base import Network

ADDITIONAL_NET_PARAMS = {
    # radius of the intersection
    "radius_intersection": 15,
    # number of lanes
    "lanes": 3,
    # speed limit for all edges
    "speed_limit": 13.9,
    # resolution of the curved portions
    "resolution": 40
}

class IntersectionNetwork(Network):
    """ Requires from net_params:

    * **radius_intersection** : radius of the intersection
    * **resolution** : number of nodes resolution in the circular portions
    * **lanes** : number of lanes in the network
    * **speed** : max speed of vehicles in the network

    Usage
    -----
    >>> from flow.core.params import NetParams
    >>> from flow.core.params import VehicleParams
    >>> from flow.core.params import InitialConfig
    >>> from road_network import IntersectionNetwork
    >>>
    >>> network = IntersectionNetwork(
    >>>     name='intersection',
    >>>     vehicles=VehicleParams(),
    >>>     net_params=NetParams(
    >>>         additional_params={
    >>>             'radius_intersection': 15,
    >>>             'lanes': 3,
    >>>             'speed_limit': 13.9,
    >>>             'resolution': 40
    >>>         },
    >>>     )
    >>> )
    """

    def __init__(self,
                 name,
                 vehicles,
                 net_params,
                 initial_config=InitialConfig(),
                 traffic_lights=TrafficLightParams()):

        for p in ADDITIONAL_NET_PARAMS.keys():
            if p not in net_params.additional_params:
                raise KeyError('Network parameter "{}" not supplied'.format(p))

        self.intersection_len = 100

        super().__init__(name, vehicles, net_params, initial_config,
                         traffic_lights)

    def specify_nodes(self, net_params):
        """See parent class."""
        r = net_params.additional_params["radius_intersection"]

        nodes = [{
            "id": "center",
            "x": 0,
            "y": 0,
            "radius": r,
            "type": "priority"
        }, {
            "id": "right",
            "x": self.intersection_len,
            "y": 0,
            "type": "priority"
        }, {
            "id": "top",
            "x": 0,
            "y": self.intersection_len,
            "type": "priority"
        }, {
            "id": "left",
            "x": -self.intersection_len,
            "y": 0,
            "type": "priority"
        }, {
            "id": "bottom",
            "x": 0,
            "y": -self.intersection_len,
            "type": "priority"
        }]

        return nodes

    def specify_edges(self, net_params):
        """See parent class."""

        # intersection edges
        edges = [{
            "id": "b_c",
            "type": "edgeType",
            #"priority": "78",
            "from": "bottom",
            "to": "center",
            "length": self.intersection_len
        }, {
            "id": "c_t",
            "type": "edgeType",
            #"priority": 78,
            "from": "center",
            "to": "top",
            "length": self.intersection_len
        }, {
            "id": "r_c",
            "type": "edgeType",
            #"priority": 78,
            "from": "right",
            "to": "center",
            "length": self.intersection_len
        }, {
            "id": "c_l",
            "type": "edgeType",
            #"priority": 46,
            "from": "center",
            "to": "left",
            "length": self.intersection_len
        }, {
            "id": "t_c",
            "type": "edgeType",
            #"priority": 78,
            "from": "top",
            "to": "center",
            "length": self.intersection_len
        }, {
            "id": "c_r",
            "type": "edgeType",
            #"priority": 46,
            "from": "center",
            "to": "right",
            "length": self.intersection_len
        }, {
            "id": "l_c",
            "type": "edgeType",
            #"priority": 78,
            "from": "left",
            "to": "center",
            "length": self.intersection_len
        }, {
            "id": "c_b",
            "type": "edgeType",
            #"priority": "78",
            "from": "center",
            "to": "bottom",
            "length": self.intersection_len
        }]

        return edges

    def specify_types(self, net_params):
        """See parent class."""
        lanes = net_params.additional_params["lanes"]
        speed_limit = net_params.additional_params["speed_limit"]
        types = [{
            "id": "edgeType",
            "numLanes": lanes,
            "speed": speed_limit
        }]

        return types

    def specify_connections(self, net_params):
        """See parent class."""
        lanes = net_params.additional_params["lanes"]
        conn_dict = {}
        conn = []
        conn += [{"from": "b_c",
                  "to": "c_t",
                  "fromLane": str(1),
                  "toLane": str(1)}]
        conn += [{"from": "b_c",
                  "to": "c_r",
                  "fromLane": str(0),
                  "toLane": str(0)}]
        conn += [{"from": "b_c",
                  "to": "c_l",
                  "fromLane": str(2),
                  "toLane": str(2)}]
        conn += [{"from": "t_c",
                  "to": "c_b",
                  "fromLane": str(1),
                  "toLane": str(1)}]
        conn += [{"from": "t_c",
                  "to": "c_l",
                  "fromLane": str(0),
                  "toLane": str(0)}]
        conn += [{"from": "t_c",
                  "to": "c_r",
                  "fromLane": str(2),
                  "toLane": str(2)}]
        conn += [{"from": "r_c",
                  "to": "c_l",
                  "fromLane": str(1),
                  "toLane": str(1)}]
        conn += [{"from": "r_c",
                  "to": "c_t",
                  "fromLane": str(0),
                  "toLane": str(0)}]
        conn += [{"from": "r_c",
                  "to": "c_b",
                  "fromLane": str(2),
                  "toLane": str(2)}]
        conn += [{"from": "l_c",
                  "to": "c_r",
                  "fromLane": str(1),
                  "toLane": str(1)}]
        conn += [{"from": "l_c",
                  "to": "c_b",
                  "fromLane": str(0),
                  "toLane": str(0)}]
        conn += [{"from": "l_c",
                  "to": "c_t",
                  "fromLane": str(2),
                  "toLane": str(2)}]

        conn_dict["center"] = conn
        return conn_dict
    
    def specify_routes(self, net_params):
        """See parent class."""
        rts = {
            "r_c":
                [(["r_c", "c_l"], 1/3), (["r_c", "c_t"], 1/3),
                    (["r_c", "c_b"], 1/3)],
            "b_c":
                [(["b_c", "c_l"], 1/3), (["b_c", "c_t"], 1/3),
                    (["b_c", "c_r"], 1/3)],
            "t_c":
                [(["t_c", "c_b"], 1/3), (["t_c", "c_l"], 1/3),
                    (["t_c", "c_r"], 1/3)],
            "l_c":
                [(["l_c", "c_r"], 1/3), (["l_c", "c_t"], 1/3),
                    (["l_c", "c_b"], 1/3)],
            "c_r":
                ["c_r"],
            "c_l":
                ["c_l"],
            "c_t":
                ["c_t"],
            "c_b":
                ["c_b"],
            "human_0":
                ["r_c", "c_t"],
            "human_1":
                ["t_c", "c_b"],
            "human_2":
                ["l_c", "c_t"]
            }

        return rts

In [None]:
def order_vehicles(state):
    distances = {}
    ordered_vehicles = []
    
    for veh in list(state.keys()):
        perturbation = 1e-10*np.random.randn()
        dist = np.sqrt(state[veh][0]**2 + state[veh][1]**2) + perturbation
        distances[dist] = veh
    
    for _ in list(state.keys()):
        min_dist = min(list(distances.keys()))
        ordered_vehicles.append(distances[min_dist])
        distances.pop(min_dist)
        
    return ordered_vehicles

In [None]:
from flow.envs.base_gpt import Env
import torch
from gym.spaces.box import Box
from gym.spaces import MultiBinary

#import numpy as np

lanes = { 0: [0.0, 0.0, 1.0],
          1: [0.0, 1.0, 0.0],
          2: [1.0, 0.0, 0.0]
        }
ways = { ('t_c', 'c_l'): [1.0, 0.0, 0.0], ('t_c', 'c_b'): [0.0, 1.0, 0.0], ('t_c', 'c_r'): [0.0, 0.0, 1.0],
         ('r_c', 'c_t'): [1.0, 0.0, 0.0], ('r_c', 'c_l'): [0.0, 1.0, 0.0], ('r_c', 'c_b'): [0.0, 0.0, 1.0],
         ('b_c', 'c_r'): [1.0, 0.0, 0.0], ('b_c', 'c_t'): [0.0, 1.0, 0.0], ('b_c', 'c_l'): [0.0, 0.0, 1.0],
         ('l_c', 'c_b'): [1.0, 0.0, 0.0], ('l_c', 'c_r'): [0.0, 1.0, 0.0], ('l_c', 'c_t'): [0.0, 0.0, 1.0]
       }
queues = { 't_c': [1.0, 0.0, 0.0, 0.0],
           'b_c': [0.0, 1.0, 0.0, 0.0],
           'r_c': [0.0, 0.0, 1.0, 0.0],
           'l_c': [0.0, 0.0, 0.0, 1.0],
         }

ADDITIONAL_ENV_PARAMS = {
    # maximum velocity for autonomous vehicles, in m/s
    'max_speed': 13.9,
}


class SpeedEnv(Env):
    """Fully observed velocity environment.

    This environment used to train autonomous vehicles to improve traffic flows
    when velocity actions are permitted by the rl agent.

    Required from env_params:

    * max_speed: maximum speed for autonomous vehicles, in m/s^2
    * sort_vehicles: specifies whether vehicles are to be sorted by position
      during a simulation step. If set to True, the environment parameter
      self.sorted_ids will return a list of all vehicles sorted in accordance
      with the environment

    States
        The state consists of (for each vehicle in the network):
        - relative position to the center of the intersection on the x-axis
        - relative position to the center of the intersection on the y-axis
        - vehicle speed
        - vehicle orientation angle
        - lane of approach (one-hot)
        - way the vehicle will follow (one-hot)
        - intersection branch through which the vehicle is approaching (one-hot)

    Actions
        Actions are a list of speeds for each rl vehicle
        
    Rewards
        The reward function is a summation of three terms (for each vehicle):
        - -100 if there was a collision
        - +100 if the intersection was crossed
        - -timestep to encourage crossing as fast as possible
        
    Termination
        A rollout is terminated if the time horizon is reached or if two
        vehicles collide into one another.
    """

    def __init__(self, env_params, sim_params, network, simulator='traci'):
        for p in ADDITIONAL_ENV_PARAMS.keys():
            if p not in env_params.additional_params:
                raise KeyError(
                    'Environment parameter \'{}\' not supplied'.format(p))
        
        super().__init__(env_params, sim_params, network, simulator)

    @property
    def action_space(self):
        """See class definition."""
        num_vehicles = len(self.k.vehicle.get_ids())
        return Box(
            low=0,
            high=self.env_params.additional_params['max_speed'],
            shape=(num_vehicles, ),
            dtype=np.float32)

    @property
    def observation_space(self):
        """See class definition."""
        vehs = len(self.k.vehicle.get_ids())
        obs_space = Box(low=-1, high=1, shape=(vehs,14*vehs))
            
        return obs_space

    def _apply_rl_actions(self, rl_actions, vehs):
        """See class definition."""
        self.k.vehicle.apply_velocity(vehs, rl_actions)

    def compute_reward(self, vehs, **kwargs):
        """See class definition."""
        ids = self.k.vehicle.get_ids()
        # collided_vehicles
        coll_veh = self.k.simulation.collided_vehicles()
        # successful_vehicles
        succ_veh = self.k.simulation.successful_vehicles()
        
        rewards = torch.tensor([])
        dones = torch.tensor([])
        
        for i in vehs:
            if i in ids:
                if i in coll_veh:
                    reward = torch.tensor([-100.0])
                    done = torch.tensor([1.0])
                elif i in succ_veh:
                    reward = torch.tensor([100.0])
                    done = torch.tensor([1.0])
                else:
                    reward = torch.tensor([-0.25])
                    done = torch.tensor([0.0])
            else:
                reward = torch.tensor([100.0])
                done = torch.tensor([1.0])
            
            rewards = torch.cat((rewards, reward))
            dones = torch.cat((dones, done))

        return rewards, dones

    def get_state(self):
        """See class definition."""
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        ids = self.k.vehicle.get_ids()
        state_dict = {}
        
        for q in ids:
            obs = []
            
            # POSITION
            pos = self.k.vehicle.get_2d_position(q)
            obs.append(np.clip((pos[0]-100)/100, -1, 1))
            obs.append(np.clip((pos[1]-100)/100, -1, 1))
            
            # VELOCITY
            vel = np.clip((self.k.vehicle.get_speed(q) - 13.9/2)/(13.9/2), -1, 1)
            obs.append(vel)
            
            # HEADING ANGLE
            angle = np.clip((self.k.vehicle.get_orientation(q)[2]-180)/180, -1, 1)
            obs.append(angle)
            
            # LANE, WAY AND QUEUE
            if self.k.vehicle.get_route(q) == '': # just to fix a simulator bug
                lane = [0.0, 0.0, 0.0]
                way = [0.0, 0.0, 0.0]
                queue = [0.0, 0.0, 0.0, 0.0]
            else:
                way = ways[self.k.vehicle.get_route(q)]
                lane = [way[2], way[1], way[0]]
                queue = queues[self.k.vehicle.get_route(q)[0]]
            
            obs = obs + lane + way + queue
            
            state_dict[q] = obs
            
        ord_vehs = order_vehicles(state_dict)
        state = torch.zeros((len(ord_vehs), 14*len(ord_vehs)))
        for k in range(len(ord_vehs)):
            ego_state = torch.as_tensor([state_dict[ord_vehs[k]]])
            for q in range(len(ord_vehs)):
                if k != q:
                    other_state = torch.as_tensor([state_dict[ord_vehs[q]]])
                    ego_state = torch.cat((ego_state, other_state), dim=1)
            state[k] = ego_state
        
        num_arrived = env.k.vehicle.get_num_arrived()
        if num_arrived > 0:
            if len(ids) > 0:
                aug_col = torch.zeros((len(ids), 14*num_arrived))
                aug_row = torch.zeros((num_arrived, 14*(state.shape[0]+num_arrived)))
                state = torch.cat((state, aug_col), dim=1)
                state = torch.cat((state, aug_row), dim=0)
            else:
                state = torch.zeros((num_arrived, 14*num_arrived))
                
        state = state.to(torch.float32)

        return state.to(device), ord_vehs

In [None]:
from flow.envs.ring.accel import ADDITIONAL_ENV_PARAMS
from flow.core.params import EnvParams
from flow.envs.ring.accel import AccelEnv

env_params = EnvParams(additional_params=ADDITIONAL_ENV_PARAMS)

from flow.core.params import SumoParams

random_seed = np.random.choice(1000)
sim_params = SumoParams(sim_step=0.25, render=False, seed=random_seed)

from flow.core.params import TrafficLightParams

traffic_lights = TrafficLightParams()

from flow.core.params import InitialConfig

initial_config = InitialConfig()

from flow.core.params import VehicleParams

vehicles = VehicleParams()

from flow.controllers.rlcontroller import RLController
from flow.controllers.routing_controllers import ContinuousRouter
from flow.core.params import SumoCarFollowingParams

vehicles.add("rl",
             acceleration_controller=(RLController, {}),
             routing_controller=(ContinuousRouter, {}),
             car_following_params=SumoCarFollowingParams(
                speed_mode="aggressive"),
             num_vehicles=0)

from flow.core.params import InFlows

inflow = InFlows()

inflow.add(veh_type="rl",
           edge="t_c",
           depart_lane="best",
           #vehs_per_hour=200,
           #period=18,
           probability=1/18
          )
inflow.add(veh_type="rl",
           edge="b_c",
           depart_lane="best",
           #vehs_per_hour=200,
           #period=18,
           probability=1/18
          )
inflow.add(veh_type="rl",
           edge="r_c",
           depart_lane="best",
           #vehs_per_hour=200
           #period=18,
           probability=1/18
          )
inflow.add(veh_type="rl",
           edge="l_c",
           depart_lane="best",
           #vehs_per_hour=200
           #period=18,
           probability=1/18
          )

In [None]:
from flow.core.params import NetParams

net_params = NetParams(inflows=inflow, additional_params=ADDITIONAL_NET_PARAMS)

flow_params = dict(
    exp_tag='test',
    env_name=SpeedEnv,
    network=IntersectionNetwork,
    simulator='traci',
    sim=sim_params,
    env=env_params,
    net=net_params,
    veh=vehicles,
    initial=initial_config,
    tls=traffic_lights,
)

# number of time steps
flow_params['env'].horizon = 1200

In [None]:
import random
import numpy as np

class PrioritizedReplayBuffer(object):
    
    def __init__(self, capacity, alpha=0.6, beta=0.4):
        self.capacity = capacity # we use a power of 2 for capacity because it simplifies the code
        self.alpha = alpha
        self.beta = beta  # importance-sampling, from initial value increasing to 1, often 0.4
        self.epsilon = 0.01  # small amount to avoid zero priority
        self.beta_increment_per_sampling = 1e-4  # annealing the bias, often 1e-3
        
        # maintain segment binary trees to take sum and find minimum over a range
        self.priority_sum = [0 for _ in range(2*self.capacity)]
        self.priority_min = [float('inf') for _ in range(2*self.capacity)]
        
        self.max_priority = 1. # current max priority to be assigned to new transitions
        
        self.data = {
            'obs': [],
            'action': np.zeros((capacity, 1)),
            'reward': np.zeros((capacity, 1)),
            'next_obs': [],
            'not_done': np.zeros((capacity, 1)),
        }
        
        self.next_idx = 0
        self.size = 0
        
    def add(self, obs, action, reward, next_obs, done):
        
        idx = self.next_idx
        
        self.data['obs'].append(obs.detach().cpu().tolist())
        self.data['action'][idx] = action.numpy()
        self.data['reward'][idx] = reward.numpy()
        self.data['next_obs'].append(next_obs.detach().cpu().tolist())
        self.data['not_done'][idx] = 1. - done.numpy()
        
        self.next_idx = (idx + 1) % self.capacity
        self.size = min(self.capacity, self.size + 1)
        
        priority_alpha = self.max_priority ** self.alpha # new samples get max_priority
        
        self._set_priority_min(idx, priority_alpha)
        self._set_priority_sum(idx, priority_alpha)
        
    # set priority in binary segment tree for minimum
    def _set_priority_min(self, idx, priority_alpha):
        
        idx += self.capacity # leaf of the binary tree
        self.priority_min[idx] = priority_alpha
        
        while idx >= 2: # update tree by traversing along ancestors, continue until the root of the tree
            idx //= 2 # get the index of the parent node
            self.priority_min[idx] = min(self.priority_min[2*idx], self.priority_min[2*idx + 1]) # value of the
            # parent node is the minimum of its two children
            
    # set priority in binary segment tree for sum
    def _set_priority_sum(self, idx, priority):
        
        idx += self.capacity # leaf of the binary tree
        self.priority_sum[idx] = priority
        
        while idx >= 2: # update tree by traversing along ancestors, continue until the root of the tree
            idx //= 2 # get the index of the parent node
            self.priority_sum[idx] = self.priority_sum[2*idx] + self.priority_sum[2*idx + 1] # value of the
            # parent node is the sum of its two children
            
    def _sum(self):
        return self.priority_sum[1] # the root node keeps the sum of all values
    
    def _min(self):
        return self.priority_min[1] # the root node keeps the min of all values
    
    def find_prefix_sum_idx(self, prefix_sum):
        
        # start from the root
        idx = 1
        while idx < self.capacity:
            if self.priority_sum[idx*2] >= prefix_sum: # if the sum of the left branch is higher than the required sum
                idx = 2*idx # go to the left branch of the tree
            else: # otherwise go to the right branch
                prefix_sum -= self.priority_sum[idx*2] # and reduce the sum of left branch from required sum
                idx = 2*idx + 1
                
        return idx - self.capacity # we are at the leaf node
    
    def sample(self, batch_size):
        
        # initialize samples
        samples = {
            'weights': np.zeros(shape=(batch_size), dtype=np.float32),
            'indexes': np.zeros(shape=(batch_size), dtype=np.int32)
        }
        
        self.beta = np.amin([1., self.beta + self.beta_increment_per_sampling])  # max = 1
        
        # get sample indexes
        for i in range(batch_size):
            p = random.random() * self._sum()
            idx = self.find_prefix_sum_idx(p)
            samples['indexes'][i] = idx
            
        prob_min = self._min() / self._sum()
        max_weight = (prob_min * self.size) ** (-self.beta)
        
        for i in range(batch_size):
            idx = samples['indexes'][i]
            
            prob = self.priority_sum[idx + self.capacity] / self._sum()
            weight = (prob * self.size) ** (-self.beta)
            
            samples['weights'][i] = weight / max_weight
            
        for k, v in self.data.items():
            if k == 'obs':
                samples['obs'] = []
                for i in samples['indexes']:
                    samples['obs'].append(self.data['obs'][i])
            elif k == 'next_obs':
                samples['next_obs'] = []
                for i in samples['indexes']:
                    samples['next_obs'].append(self.data['next_obs'][i])
            else:
                samples[k] = v[samples['indexes']]
            
        return samples
    
    def update_priorities(self, indexes, priorities):
        
        for idx, priority in zip(indexes, priorities):
            
            self.max_priority = max(self.max_priority, priority)
            
            priority_alpha = priority ** self.alpha
            
            self._set_priority_min(idx, priority_alpha)
            self._set_priority_sum(idx, priority_alpha)
            
    def is_full(self):
        return self.capacity == self.size

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Actor(nn.Module):
    def __init__(self, ego_state_dim, action_dim, max_action):
        super(Actor, self).__init__()

        self.lstm = nn.LSTMCell(ego_state_dim, 256)
        self.fc1 = nn.Linear(256 + ego_state_dim, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 256)
        self.pi = nn.Linear(256, action_dim)

        self.max_action = max_action
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.to(self.device)
        
    def forward_train(self, state):
        """
        state is a list containing the elements in the batch: each element is in turn a list,
        corresponding to the state dimension in that timestep (14 * num_veh)
        """
        rnn_out = [] # this list will contain all the batch elements after the recursive layer (B, 270)
        for i in range(len(state)): # for every element in the batch
            current_state = torch.tensor(state[i], device=self.device, dtype=torch.float32) # take the corresponding state
            num_vehs = current_state.shape[0] // 14 # compute the number of vehicles in that instant
            current_state = current_state.view(num_vehs, 14) # reshape (Num_Veh, 14)
            hx = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            cx = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            for j in range(num_vehs): # for every vehicle present in that instant
                hx, cx = self.lstm(current_state[j], (hx, cx)) # iterate in the LSTM cell
            hx = torch.cat((current_state[0], hx)) # concatenate the ego_state with the final LSTM output
            rnn_out.append(hx) # compose the final list
        x = torch.stack(rnn_out, dim=0) # transform it into a tensor
        
        # MLP part
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        
        #out = self.max_action * torch.sigmoid(self.pi(x))
        out = torch.tanh(self.pi(x))
        
        return out

    def forward(self, state):
        """
        state is a (V, F*V) tensor
        """
        num_vehs = state.shape[1] // 14 # compute the number of vehicles in that instant
        if num_vehs == 0:
            return torch.tensor([], device=self.device)
        
        rnn_out = []
        for i in range(state.shape[0]): # for every vehicle
            current_state = state[i].view(num_vehs, 14) # reshape (Num_Veh, 14)
            hx = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            cx = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            for j in range(num_vehs): # for every vehicle present in that instant
                hx, cx = self.lstm(current_state[j], (hx, cx)) # iterate in the LSTM cell
            hx = torch.cat((current_state[0], hx)) # concatenate the ego_state with the final LSTM output
            rnn_out.append(hx) # compose the final list
        x = torch.stack(rnn_out, dim=0) # transform it into a tensor

        # MLP part
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        
        #out = self.max_action * torch.sigmoid(self.pi(x))
        out = torch.tanh(self.pi(x))
        
        return out
    
class Critic(nn.Module):
    def __init__(self, ego_state_dim, action_dim):
        super(Critic, self).__init__()

        self.lstm_1 = nn.LSTMCell(ego_state_dim, 256)
        self.fc1_1 = nn.Linear(256 + ego_state_dim + action_dim, 1024)
        self.fc2_1 = nn.Linear(1024, 1024)
        self.fc3_1 = nn.Linear(1024, 512)
        self.fc4_1 = nn.Linear(512, 256)
        self.q_1 = nn.Linear(256, 1)

        self.lstm_2 = nn.LSTMCell(ego_state_dim, 256)
        self.fc1_2 = nn.Linear(256 + ego_state_dim + action_dim, 1024)
        self.fc2_2 = nn.Linear(1024, 1024)
        self.fc3_2 = nn.Linear(1024, 512)
        self.fc4_2 = nn.Linear(512, 256)
        self.q_2 = nn.Linear(256, 1)

        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.to(self.device)
        
    def forward(self, state, action):
        """
        - state is a list containing the elements in the batch: each element is in turn a list,
        corresponding to the state dimension in that timestep (14 * num_veh)
        - action is a tensor of dimension (B) containing the actions of all the vehicles in the batch
        """
        rnn_out_1 = [] # this list will contain all the batch elements after the recursive layer (B, 271)
        rnn_out_2 = [] # this list will contain all the batch elements after the recursive layer (B, 271)
        
        for i in range(len(state)): # for every element in the batch
            current_state = torch.tensor(state[i], device=self.device, dtype=torch.float32) # take the corresponding state
            num_vehs = current_state.shape[0] // 14 # compute the number of vehicles in that instant
            current_state = current_state.view(num_vehs, 14) # reshape (Num_Veh, 14)

            hx_1 = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            cx_1 = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            hx_2 = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            cx_2 = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize

            for j in range(num_vehs): # for every vehicle present in that instant

                hx_1, cx_1 = self.lstm_1(current_state[j], (hx_1, cx_1)) # iterate in the LSTM cell
                hx_2, cx_2 = self.lstm_2(current_state[j], (hx_2, cx_2)) # iterate in the LSTM cell
            
            aug_state = torch.cat((current_state[0], action[i]))
            
            hx_1 = torch.cat((aug_state, hx_1)) # concatenate the ego_state with the final LSTM output
            hx_2 = torch.cat((aug_state, hx_2)) # concatenate the ego_state with the final LSTM output
            
            rnn_out_1.append(hx_1) # compose the final list
            rnn_out_2.append(hx_2) # compose the final list

        x_1 = torch.stack(rnn_out_1, dim=0) # transform it into a tensor
        x_2 = torch.stack(rnn_out_2, dim=0) # transform it into a tensor
        
        # MLP part 1
        x_1 = F.relu(self.fc1_1(x_1))
        x_1 = F.relu(self.fc2_1(x_1))
        x_1 = F.relu(self.fc3_1(x_1))
        x_1 = F.relu(self.fc4_1(x_1))

        # MLP part 2
        x_2 = F.relu(self.fc1_2(x_2))
        x_2 = F.relu(self.fc2_2(x_2))
        x_2 = F.relu(self.fc3_2(x_2))
        x_2 = F.relu(self.fc4_2(x_2))
        
        out_1 = self.q_1(x_1)
        out_2 = self.q_2(x_2)
        
        return out_1, out_2
    
    def Q1(self, state, action):
        
        rnn_out_1 = [] # this list will contain all the batch elements after the recursive layer (B, 271)
        
        for i in range(len(state)): # for every element in the batch
            current_state = torch.tensor(state[i], device=self.device, dtype=torch.float32) # take the corresponding state
            num_vehs = current_state.shape[0] // 14 # compute the number of vehicles in that instant
            current_state = current_state.view(num_vehs, 14) # reshape (Num_Veh, 14)

            hx_1 = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize
            cx_1 = torch.zeros(256, device=self.device, dtype=torch.float32) # initialize

            for j in range(num_vehs): # for every vehicle present in that instant

                hx_1, cx_1 = self.lstm_1(current_state[j], (hx_1, cx_1)) # iterate in the LSTM cell
            
            aug_state = torch.cat((current_state[0], action[i]))
            
            hx_1 = torch.cat((aug_state, hx_1)) # concatenate the ego_state with the final LSTM output
            
            rnn_out_1.append(hx_1) # compose the final list

        x_1 = torch.stack(rnn_out_1, dim=0) # transform it into a tensor
        
        # MLP part
        x_1 = F.relu(self.fc1_1(x_1))
        x_1 = F.relu(self.fc2_1(x_1))
        x_1 = F.relu(self.fc3_1(x_1))
        x_1 = F.relu(self.fc4_1(x_1))
        
        out_1 = self.q_1(x_1)
        
        return out_1

In [None]:
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class TD3(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        discount=0.99,
        tau=4e-3,
        policy_noise=0.2,
        noise_clip=0.3,
        policy_freq=2,
        filename='LSTM_AIM'
    ):

        self.actor = Actor(state_dim, action_dim, max_action)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-5)

        self.critic = Critic(state_dim, action_dim)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-4, weight_decay=1e-6)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.filename = filename

        self.total_it = 0


    def select_action(self, state):
        return self.actor.forward(state).detach().cpu()
    
    def mse(self, expected, targets, weights):
        """Custom loss function that takes into account the importance-sampling weights."""
        td_error = expected - targets
        weighted_squared_error = weights * td_error * td_error
        return torch.sum(weighted_squared_error) / torch.numel(weighted_squared_error)

    def train(self, replay_buffer, batch_size=128):
        self.total_it += 1
        
        for i in range(200):
            # Sample replay buffer 
            batch = replay_buffer.sample(batch_size)
            state, action, next_state, reward, not_done = batch['obs'], batch['action'], batch['next_obs'], batch['reward'], batch['not_done']
            action = torch.tensor(action, device=self.actor.device, dtype=torch.float32)
            reward = torch.tensor(reward, device=self.actor.device, dtype=torch.float32)
            not_done = torch.tensor(not_done, device=self.actor.device)

            weights = torch.tensor(batch['weights'], device=self.actor.device, dtype=torch.float32)

            with torch.no_grad():
                # Select action according to policy and add clipped noise
                noise = (
                    torch.randn_like(action) * self.policy_noise
                ).clamp(-self.noise_clip, self.noise_clip)

                next_action = (
                    self.actor_target.forward_train(next_state) + noise
                ).clamp(-1, 1)

                # Compute the target Q value
                target_Q1, target_Q2 = self.critic_target(next_state, next_action)
                target_Q = torch.min(target_Q1, target_Q2)
                target_Q = reward + not_done * self.discount * target_Q

            # Get current Q estimates
            current_Q1, current_Q2 = self.critic(state, action)

            # Compute critic loss
            critic_loss = self.mse(current_Q1, target_Q, weights) + self.mse(current_Q2, target_Q, weights)

            # Optimize the critic
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            errors1 = np.abs((current_Q1 - target_Q).detach().cpu().numpy())
            replay_buffer.update_priorities(batch['indexes'], errors1)

            # Delayed policy updates
            if i % self.policy_freq == 0 and i > 0:

                # Compute actor losse
                actor_loss = -self.critic.Q1(state, self.actor.forward_train(state)).mean()

                # Optimize the actor 
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # Update the frozen target models
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


    def save(self):
        filename = self.filename
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")

        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")


    def load(self):
        filename = self.filename
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)

In [None]:
def trim(state):
    if state.shape[0] > 0:
        while torch.sum(state[-1,:]) == 0:
            state = state[:-1,:state.shape[1]-14]
            if state.shape[0] == 0:
                break
        return state
    else:
        return state

In [None]:
from flow.utils.registry import make_create_env
import numpy as np
import torch

# Get the env name and a creator for the environment.
create_env, _ = make_create_env(flow_params)
# Create the environment.
env = create_env()

num_eps = 1000000
max_ep_steps = env.env_params.horizon
total_steps = 0
returns_list = []
ep_steps_list = []

state_dim = 14
action_dim = 1
max_action = 13.9/2

memory = PrioritizedReplayBuffer(2**20)
aim = TD3(
        state_dim,
        action_dim,
        max_action,
        discount=0.99,
        tau=4e-3,
        policy_noise=0.2,
        noise_clip=0.5,
        policy_freq=2,
        filename='LSTM_AIM')

def rl_actions(state):
    num = state.shape[0]
    actions = torch.randn((num,), device="cuda").clamp(-1, 1)
    return actions.detach().cpu()

for i in range(num_eps):
    returns = 0
    ep_steps = 0
    
    # state is a 2-dim tensor
    state = env.reset() # (V, F*V) where V: number of vehicles and F: number of features of each vehicle 

    for j in range(max_ep_steps):    

        # actions: (V,) ordered tensor
        if i > 0:
            actions = aim.select_action(state)
            noise = (
                torch.randn_like(actions) * 0.1).clamp(-0.5, 0.5)
            actions = (actions + noise).clamp(-1, 1)
        else:
            actions = rl_actions(state)
        
        # next_state: (V, F*V) ordered tensor
        # reward: (V,) ordered tensor
        # done: (V,) ordered tensor
        # crash: boolean
        
        next_state, reward, done, crash = env.step(actions*max_action + max_action)
        
        if state.shape[0] > 0:
            for k in range(state.shape[0]):
                memory.add(state[k,:], actions[k], reward[k], next_state[k,:], done[k])
        if total_steps % 400 == 0 and i > 0:
            aim.train(memory)

        state = next_state
        state = trim(state)
        
        returns += sum(reward.tolist())
        ep_steps += 1
        total_steps += 1
        
        if crash:
            break
        
    returns_list.append(returns)
    ep_steps_list.append(ep_steps)
    print('Episode number: {}, Episode steps: {}, Episode return: {}'.format(i, ep_steps, returns))
    np.save('results/returns.npy', returns_list)
    np.save('results/ep_steps.npy', ep_steps_list)
    
np.save('results/num_eps.npy', np.arange(num_eps))
aim.save()
env.terminate()

In [None]:
# TODO:
# capire cosa vuol dire optimizer epochs
# testare repo del vecchio paper