In [12]:
import cityflow
import math
import pandas as pd
import os
import argparse
import json

def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--scenario', type=str)
    
    #训练参数
    parser.add_argument('--max_episode', type=int, default=5000)
    parser.add_argument('--max_step', type=int, default=200)
    parser.add_argument('--max_buffer', type=int, default=10000)
    parser.add_argument('--max_total_reward', type=float)

    args = parser.parse_args(args=[])
    return args

In [13]:
args = args_parser()
args.scenario = 'hangzhou_1x1_bc-tyc_18041607_1h'
num_step = args.max_step

In [19]:
import pandas as pd
import os

class CityFlowEnv():
    '''
    Simulator Environment with CityFlow
    '''
    def __init__(self, args):
        self.eng = cityflow.Engine(config_file='examples/config_control.json', thread_num=1)
#         self.eng.load_roadnet(config['roadnet'])
#         self.eng.load_flow(config['flow'])
#         self.config = config
        self.num_step = args.max_step
        self.lane_phase_info = self.parse_roadnet('examples/roadnet.json') # "intersection_1_1"

        self.intersection_id = list(self.lane_phase_info.keys())[0]
        self.start_lane = self.lane_phase_info[self.intersection_id]['start_lane']
        self.phase_list = self.lane_phase_info[self.intersection_id]["phase"]
        self.phase_startLane_mapping = self.lane_phase_info[self.intersection_id]["phase_startLane_mapping"]

        self.current_phase = self.phase_list[0]
        self.current_phase_time = 0
        self.yellow_time = 5

        self.phase_log = []

    def parse_roadnet(self, roadnetFile):
        roadnet = json.load(open(roadnetFile))
        lane_phase_info_dict ={}

        # many intersections exist in the roadnet and virtual intersection is controlled by signal
        for intersection in roadnet["intersections"]:
            if intersection['virtual']:
                continue
            lane_phase_info_dict[intersection['id']] = {"start_lane": [],
                                                         "end_lane": [],
                                                         "phase": [],
                                                         "phase_startLane_mapping": {},
                                                         "phase_roadLink_mapping": {}}
            road_links = intersection["roadLinks"]

            start_lane = []
            end_lane = []
            roadLink_lane_pair = {ri: [] for ri in
                                  range(len(road_links))}  # roadLink includes some lane_pair: (start_lane, end_lane)

            for ri in range(len(road_links)):
                road_link = road_links[ri]
                for lane_link in road_link["laneLinks"]:
                    sl = road_link['startRoad'] + "_" + str(lane_link["startLaneIndex"])
                    el = road_link['endRoad'] + "_" + str(lane_link["endLaneIndex"])
                    start_lane.append(sl)
                    end_lane.append(el)
                    roadLink_lane_pair[ri].append((sl, el))

            lane_phase_info_dict[intersection['id']]["start_lane"] = sorted(list(set(start_lane)))
            lane_phase_info_dict[intersection['id']]["end_lane"] = sorted(list(set(end_lane)))

            for phase_i in range(1, len(intersection["trafficLight"]["lightphases"])):
                p = intersection["trafficLight"]["lightphases"][phase_i]
                lane_pair = []
                start_lane = []
                for ri in p["availableRoadLinks"]:
                    lane_pair.extend(roadLink_lane_pair[ri])
                    if roadLink_lane_pair[ri][0][0] not in start_lane:
                        start_lane.append(roadLink_lane_pair[ri][0][0])
                lane_phase_info_dict[intersection['id']]["phase"].append(phase_i)
                lane_phase_info_dict[intersection['id']]["phase_startLane_mapping"][phase_i] = start_lane
                lane_phase_info_dict[intersection['id']]["phase_roadLink_mapping"][phase_i] = lane_pair

        return lane_phase_info_dict
    
    def reset(self):
        self.eng.reset()
        self.phase_log = []

    def step(self, next_phase):
        if self.current_phase == next_phase:
            self.current_phase_time += 1
        else:
            self.current_phase = next_phase
            self.current_phase_time = 1

        self.eng.set_tl_phase(self.intersection_id, self.current_phase)
        self.eng.next_step()
        self.phase_log.append(self.current_phase)

    def get_state(self):
        state = {}
        state['lane_vehicle_count'] = self.eng.get_lane_vehicle_count()  # {lane_id: lane_count, ...}
        state['start_lane_vehicle_count'] = {lane: self.eng.get_lane_vehicle_count()[lane] for lane in self.start_lane}
        state['lane_waiting_vehicle_count'] = self.eng.get_lane_waiting_vehicle_count()  # {lane_id: lane_waiting_count, ...}
        state['lane_vehicles'] = self.eng.get_lane_vehicles()  # {lane_id: [vehicle1_id, vehicle2_id, ...], ...}
        state['vehicle_speed'] = self.eng.get_vehicle_speed()  # {vehicle_id: vehicle_speed, ...}
        state['vehicle_distance'] = self.eng.get_vehicle_distance() # {vehicle_id: distance, ...}
        state['current_time'] = self.eng.get_current_time()
        state['current_phase'] = self.current_phase
        state['current_phase_time'] = self.current_phase_time

        return state

    def get_reward(self):
        # a sample reward function which calculates the total of waiting vehicles
        lane_waiting_vehicle_count = self.eng.get_lane_waiting_vehicle_count()
        reward = -1 * sum(list(lane_waiting_vehicle_count.values()))
        return reward

    def log(self):
        #self.eng.print_log(self.config['replay_data_path'] + "/replay_roadnet.json",
        #                   self.config['replay_data_path'] + "/replay_flow.json")
        df = pd.DataFrame({self.intersection_id: self.phase_log[:self.num_step]})
#         if not os.path.exists(''):
#             os.makedirs('')
        df.to_csv(os.path.join('examples', 'signal_plan_template.txt'), index=None)

In [20]:
class SOTLAgent():
    ''' Agent using Fixed-time algorithm to control traffic signal
        '''

    def __init__(self, args):
        self.args = args
        self.lane_phase_info = self.parse_roadnet('examples/roadnet.json') # "intersection_1_1"

        self.intersection_id = list(self.lane_phase_info.keys())[0]
        self.start_lane = self.lane_phase_info[self.intersection_id]['start_lane']
        self.phase_list = self.lane_phase_info[self.intersection_id]["phase"]
        self.phase_startLane_mapping = self.lane_phase_info[self.intersection_id]["phase_startLane_mapping"]
        
        self.phi = 20
        self.min_green_vehicle = 20
        self.max_red_vehicle = 30

        self.action = self.phase_list[0]

    def choose_action(self, state):
        cur_phase = state["current_phase"]
        if state["current_phase_time"] >= self.phi:
            num_green_vehicle = sum([state["lane_waiting_vehicle_count"][i] for i in self.phase_startLane_mapping[cur_phase]])
            num_red_vehicle = sum([state["lane_waiting_vehicle_count"][i] for i in self.lane_phase_info[self.intersection_id]["start_lane"]]) - num_green_vehicle
            if num_green_vehicle <= self.min_green_vehicle and num_red_vehicle > self.max_red_vehicle:
                self.action = cur_phase % len(self.phase_list) + 1
        return self.action
    
    def parse_roadnet(self, roadnetFile):
        roadnet = json.load(open(roadnetFile))
        lane_phase_info_dict ={}

        # many intersections exist in the roadnet and virtual intersection is controlled by signal
        for intersection in roadnet["intersections"]:
            if intersection['virtual']:
                continue
            lane_phase_info_dict[intersection['id']] = {"start_lane": [],
                                                         "end_lane": [],
                                                         "phase": [],
                                                         "phase_startLane_mapping": {},
                                                         "phase_roadLink_mapping": {}}
            road_links = intersection["roadLinks"]

            start_lane = []
            end_lane = []
            roadLink_lane_pair = {ri: [] for ri in
                                  range(len(road_links))}  # roadLink includes some lane_pair: (start_lane, end_lane)

            for ri in range(len(road_links)):
                road_link = road_links[ri]
                for lane_link in road_link["laneLinks"]:
                    sl = road_link['startRoad'] + "_" + str(lane_link["startLaneIndex"])
                    el = road_link['endRoad'] + "_" + str(lane_link["endLaneIndex"])
                    start_lane.append(sl)
                    end_lane.append(el)
                    roadLink_lane_pair[ri].append((sl, el))

            lane_phase_info_dict[intersection['id']]["start_lane"] = sorted(list(set(start_lane)))
            lane_phase_info_dict[intersection['id']]["end_lane"] = sorted(list(set(end_lane)))

            for phase_i in range(1, len(intersection["trafficLight"]["lightphases"])):
                p = intersection["trafficLight"]["lightphases"][phase_i]
                lane_pair = []
                start_lane = []
                for ri in p["availableRoadLinks"]:
                    lane_pair.extend(roadLink_lane_pair[ri])
                    if roadLink_lane_pair[ri][0][0] not in start_lane:
                        start_lane.append(roadLink_lane_pair[ri][0][0])
                lane_phase_info_dict[intersection['id']]["phase"].append(phase_i)
                lane_phase_info_dict[intersection['id']]["phase_startLane_mapping"][phase_i] = start_lane
                lane_phase_info_dict[intersection['id']]["phase_roadLink_mapping"][phase_i] = lane_pair

        return lane_phase_info_dict

In [21]:
env = CityFlowEnv(args)
agent = SOTLAgent(args)

In [22]:
# reset initially
t = 0
env.reset()
last_action = agent.choose_action(env.get_state())

while t < args.max_step:
    state = env.get_state()
    action = agent.choose_action(state)
    if action == last_action:
        env.step(action)
    else:
        for _ in range(env.yellow_time):
            env.step(0)  # required yellow time
            t += 1
            flag = (t >= args.max_step)
            if flag:
                break
        if flag:
            break
        env.step(action)
    last_action = action
    t += 1
    print("Time: {}, Phase: {}, lane_vehicle_count: {}".format(state['current_time'], state['current_phase'],
                                                                       state['lane_vehicle_count']))

# log environment files
env.log()

Time: 0.0, Phase: 1, lane_vehicle_count: {'road_0_1_0_0': 0, 'road_0_1_0_1': 0, 'road_0_1_0_2': 0, 'road_0_1_0_3': 0, 'road_0_1_0_4': 0, 'road_0_1_0_5': 0, 'road_0_1_0_6': 0, 'road_1_0_1_0': 0, 'road_1_0_1_1': 0, 'road_1_0_1_2': 0, 'road_1_0_1_3': 0, 'road_1_0_1_4': 0, 'road_1_0_1_5': 0, 'road_1_0_1_6': 0, 'road_1_1_0_0': 0, 'road_1_1_0_1': 0, 'road_1_1_0_2': 0, 'road_1_1_0_3': 0, 'road_1_1_0_4': 0, 'road_1_1_0_5': 0, 'road_1_1_0_6': 0, 'road_1_1_1_0': 0, 'road_1_1_1_1': 0, 'road_1_1_1_2': 0, 'road_1_1_1_3': 0, 'road_1_1_1_4': 0, 'road_1_1_1_5': 0, 'road_1_1_1_6': 0, 'road_1_1_2_0': 0, 'road_1_1_2_1': 0, 'road_1_1_2_2': 0, 'road_1_1_2_3': 0, 'road_1_1_2_4': 0, 'road_1_1_2_5': 0, 'road_1_1_2_6': 0, 'road_1_1_3_0': 0, 'road_1_1_3_1': 0, 'road_1_1_3_2': 0, 'road_1_1_3_3': 0, 'road_1_1_3_4': 0, 'road_1_1_3_5': 0, 'road_1_1_3_6': 0, 'road_1_2_3_0': 0, 'road_1_2_3_1': 0, 'road_1_2_3_2': 0, 'road_1_2_3_3': 0, 'road_1_2_3_4': 0, 'road_1_2_3_5': 0, 'road_1_2_3_6': 0, 'road_2_1_2_0': 0, 'road_2_