In [None]:
pip install Rtree

In [None]:
import os
import sys

if 'SUMO_HOME' in os.environ:
    tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
    sys.path.append(tools)
else:
    sys.exit("Please declare environment variable 'SUMO_HOME'")

import traci
import sumolib

print("traci and sumolib modules are installed correctly.")

In [None]:
import os
import numpy as np
import traci
import sumolib
import math
import csv
import signal
import sys
from rtree import index
import heapq

# Global variable to store data
data = []
switch_time_queue = []

def stateDifferent(s1, s2):
    return (( s1 in ['r', 'R'] and s2 in ['g', 'y', 'G', 'Y']) or (s1 in ['g', 'y', 'G', 'Y'] and s2 in ['r', 'R']) )

def distance(x1, y1, x2, y2):
    return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)

def get_lane_direction(lane_id):
    shape = traci.lane.getShape(lane_id)
    p1, p2 = shape[0], shape[-1]
    direction = np.arctan2(p2[1] - p1[1], p2[0] - p1[0]) * 180 / np.pi
    return direction #degree

def get_direction(angle):
    if 0 <= angle < 45 or 315 <= angle <= 360:
        return 'N'
    elif 45 <= angle < 135:
        return 'E'
    elif 135 <= angle < 225:
        return 'S'
    elif 225 <= angle < 315:
        return 'W'
    else:
        return 'Unknown'

def is_right_turn(lane_id):
    try:
        current_direction = get_lane_direction(lane_id)
        links = traci.lane.getLinks(lane_id)
        for link in links:
            next_lane_id = link[0]
            next_direction = get_lane_direction(next_lane_id)
            turn_angle = ((next_direction - current_direction)%360 + 360) % 360
            # print(f"Right Turn Check: {lane_id} -> {next_lane_id}: Current {current_direction}, Next {next_direction}, Angle {turn_angle}")
            if (turn_angle > 180 and turn_angle < 315):
                return True
        return False
    except Exception as e:
        print(f"Error determining if lane {lane_id} is a right turn lane: {e}")
        return False

def is_left_turn(lane_id):
    try:
        current_direction = get_lane_direction(lane_id)
        links = traci.lane.getLinks(lane_id)
        for link in links:
            next_lane_id = link[0]
            next_direction = get_lane_direction(next_lane_id)
            turn_angle = ((next_direction - current_direction)%360 + 360) % 360
            # print(f"Left Turn Check: {lane_id} -> {next_lane_id}: Current {current_direction}, Next {next_direction}, Angle {turn_angle}")
            if 180 <= turn_angle <= 315:  # Check if the turn angle indicates a left turn
                return True
        return False
    except Exception as e:
        print(f"Error determining if lane {lane_id} is a left turn lane: {e}")
        return False

def calculate_duration(tls_id):
    current_phase_index = traci.trafficlight.getPhase(tls_id)
    phase_definitions = traci.trafficlight.getCompleteRedYellowGreenDefinition(tls_id)
    current_program_id = traci.trafficlight.getProgram(tls_id)
    current_program = next(prog for prog in phase_definitions if prog.programID == current_program_id)
    current_phase = current_program.phases[current_phase_index]

    controlled_lanes = traci.trafficlight.getControlledLanes(tls_id)
    num_lanes = len(controlled_lanes)

    lights_state = [current_phase.state[lane_index] for lane_index in range(num_lanes)]
    accu_min_dur = [0] * num_lanes
    accu_max_dur = [0] * num_lanes
    minDur_list = [-1] * num_lanes
    maxDur_list = [-1] * num_lanes
    max_iterations = len(current_program.phases) * 2 

    iterations = 0

    while iterations < max_iterations:
        next_phase_index = (current_phase_index + 1) % len(current_program.phases)
        next_phase = current_program.phases[next_phase_index]

        next_lights_state = ['NA'] * num_lanes

        for lane_index in range(num_lanes):
            signal_state = next_phase.state[lane_index]
            next_lights_state[lane_index] = signal_state
        
        for d in range(num_lanes):
            if minDur_list[d] != -1:
                continue

            accu_min_dur[d] += current_phase.minDur
            accu_max_dur[d] += current_phase.maxDur

            # state changes
            if stateDifferent(lights_state[d], next_lights_state[d]):
                minDur_list[d] = accu_min_dur[d]
                maxDur_list[d] = accu_max_dur[d]

        current_phase_index = next_phase_index
        current_phase = next_phase
        lights_state = next_lights_state
        iterations += 1

    return minDur_list, maxDur_list


def get_traffic_light_info(vehicle_id):
    try:
        tls_list = traci.vehicle.getNextTLS(vehicle_id)
        for tls in tls_list:
            tls_id = tls[0]
            if tls_id:
                tls_position = traci.junction.getPosition(tls_id)
                tls_state = traci.trafficlight.getRedYellowGreenState(tls_id)
                tls_controlled_lanes = traci.trafficlight.getControlledLanes(tls_id)
                distance_to_tls = tls[2]
                lane_id = traci.vehicle.getLaneID(vehicle_id)
                time_to_next_switch = traci.trafficlight.getNextSwitch(tls_id)

                try:
                    lane_index = tls_controlled_lanes.index(lane_id)
                except ValueError:
                    lane_index = None
                if lane_index is None:
                    return tls_position, ['NA', 'NA', 'NA'], distance_to_tls, [-1,-1,-1], [-1,-1,-1], time_to_next_switch, tls_id, lane_id

                signal_vector = ['NA', 'NA', 'NA']
                minDur = [-1, -1, -1]
                maxDur = [-1, -1, -1]

                current_lane_signal = tls_state[lane_index]
                minDur_list, maxDur_list = calculate_duration(tls_id)

                is_left, is_right = is_left_turn(lane_id), is_right_turn(lane_id)
                if is_left:
                    signal_vector[0] = current_lane_signal
                    maxDur[0] = maxDur_list[lane_index]
                    minDur[0] = minDur_list[lane_index]
                elif is_right:
                    signal_vector[2] = current_lane_signal
                    maxDur[2] = maxDur_list[lane_index]
                    minDur[2] = minDur_list[lane_index]
                else:
                    signal_vector[1] = current_lane_signal
                    maxDur[1] = maxDur_list[lane_index]
                    minDur[1] = minDur_list[lane_index]

                if is_left:
                    for offset in range(1, len(tls_controlled_lanes)):
                        right_idx = lane_index + offset
                        if right_idx < len(tls_controlled_lanes):
                            current_lane_id = tls_controlled_lanes[right_idx]
                            if signal_vector[2] == 'NA' and is_right_turn(current_lane_id):
                                signal_vector[2] = tls_state[right_idx]
                                maxDur[2] = maxDur_list[right_idx]
                                minDur[2] = minDur_list[right_idx]
                            elif signal_vector[1] == 'NA' and not is_right_turn(current_lane_id) and not is_left_turn(current_lane_id):
                                signal_vector[1] = tls_state[right_idx]
                                maxDur[1] = maxDur_list[right_idx]
                                minDur[1] = minDur_list[right_idx]
                        if all(s != 'NA' for s in signal_vector):
                            break
                elif is_right:
                    for offset in range(1, len(tls_controlled_lanes)):
                        left_idx = lane_index - offset
                        if left_idx >= 0:
                            current_lane_id = tls_controlled_lanes[left_idx]
                            if signal_vector[0] == 'NA' and is_left_turn(current_lane_id):
                                signal_vector[0] = tls_state[left_idx]
                                maxDur[0] = maxDur_list[left_idx]
                                minDur[0] = minDur_list[left_idx]
                            elif signal_vector[1] == 'NA' and not is_right_turn(current_lane_id) and not is_left_turn(current_lane_id):
                                signal_vector[1] = tls_state[left_idx]
                                maxDur[1] = maxDur_list[left_idx]
                                minDur[1] = minDur_list[left_idx]
                        if all(s != 'NA' for s in signal_vector):
                            break
                else:
                    for offset in range(1, len(tls_controlled_lanes)):
                        left_idx = lane_index - offset
                        right_idx = lane_index + offset
                        if left_idx >= 0:
                            current_lane_id = tls_controlled_lanes[left_idx]
                            if signal_vector[0] == 'NA' and is_left_turn(current_lane_id):
                                signal_vector[0] = tls_state[left_idx]
                                maxDur[0] = maxDur_list[left_idx]
                                minDur[0] = minDur_list[left_idx]
                        if right_idx < len(tls_controlled_lanes):
                            current_lane_id = tls_controlled_lanes[right_idx]
                            if signal_vector[2] == 'NA' and is_right_turn(current_lane_id):
                                signal_vector[2] = tls_state[right_idx]
                                maxDur[2] = maxDur_list[right_idx]
                                minDur[2] = minDur_list[right_idx]
                        if all(s != 'NA' for s in signal_vector):
                            break

                return tls_position, signal_vector, distance_to_tls, minDur, maxDur, time_to_next_switch, tls_id, lane_id
            
    except Exception as e:
        print(f"Error getting traffic light info for vehicle {vehicle_id}: {e}")
    return None, ['NA', 'NA', 'NA'], None, [-1,-1,-1], [-1,-1,-1], None, None, None

def get_tl_state(tls_id, lane_id):
    try:
        tls_state = traci.trafficlight.getRedYellowGreenState(tls_id)
        tls_controlled_lanes = traci.trafficlight.getControlledLanes(tls_id)

        try:
            lane_index = tls_controlled_lanes.index(lane_id)
        except ValueError:
            lane_index = None
        if lane_index is None:
            return ['NA', 'NA', 'NA']

        signal_vector = ['NA', 'NA', 'NA']

        current_lane_signal = tls_state[lane_index]

        if is_left_turn(lane_id):
            signal_vector[0] = current_lane_signal
        elif is_right_turn(lane_id):
            signal_vector[2] = current_lane_signal
        else:
            signal_vector[1] = current_lane_signal

        if is_left_turn(lane_id):
            for offset in range(1, len(tls_controlled_lanes)):
                right_idx = lane_index + offset
                if right_idx < len(tls_controlled_lanes):
                    current_lane_id = tls_controlled_lanes[right_idx]
                    if signal_vector[1] == 'NA' and not is_right_turn(current_lane_id) and not is_left_turn(current_lane_id):
                        signal_vector[1] = tls_state[right_idx]
                    if signal_vector[2] == 'NA' and is_right_turn(current_lane_id):
                        signal_vector[2] = tls_state[right_idx]
                if all(s != 'NA' for s in signal_vector):
                    break
        elif is_right_turn(lane_id):
            for offset in range(1, len(tls_controlled_lanes)):
                left_idx = lane_index - offset
                if left_idx >= 0:
                    current_lane_id = tls_controlled_lanes[left_idx]
                    if signal_vector[1] == 'NA' and not is_right_turn(current_lane_id) and not is_left_turn(current_lane_id):
                        signal_vector[1] = tls_state[left_idx]
                    if signal_vector[0] == 'NA' and is_left_turn(current_lane_id):
                        signal_vector[0] = tls_state[left_idx]
                if all(s != 'NA' for s in signal_vector):
                    break
        else:
            for offset in range(1, len(tls_controlled_lanes)):
                left_idx = lane_index - offset
                right_idx = lane_index + offset
                if left_idx >= 0:
                    current_lane_id = tls_controlled_lanes[left_idx]
                    if signal_vector[0] == 'NA' and is_left_turn(current_lane_id):
                        signal_vector[0] = tls_state[left_idx]
                if right_idx < len(tls_controlled_lanes):
                    current_lane_id = tls_controlled_lanes[right_idx]
                    if signal_vector[2] == 'NA' and is_right_turn(current_lane_id):
                        signal_vector[2] = tls_state[right_idx]
                if all(s != 'NA' for s in signal_vector):
                    break

        return signal_vector
            
    except Exception as e:
        print(f"Error getting traffic light info {e}")
    return None, ['NA', 'NA', 'NA'], None, [-1,-1,-1], [-1,-1,-1], None, None, None


def collect_data(vehicle_id, current_time, detection_distance, recording_distance, idx):
    try:
        x, y = traci.vehicle.getPosition(vehicle_id)
        tls_position, tls_state, distance_to_tls, phase_min_durations, phase_max_durations, time_to_next_switch, tls_id, lane_id = get_traffic_light_info(vehicle_id)
        if distance_to_tls is not None and distance_to_tls <= recording_distance:
            tls_x, tls_y = tls_position if tls_position else (None, None)
            
            nearest_vehicles = list(idx.intersection((x - detection_distance, y - detection_distance, x + detection_distance, y + detection_distance)))
            
            for other_veh_id in nearest_vehicles:
                other_veh_id = str(other_veh_id)
                other_tls_list = traci.vehicle.getNextTLS(other_veh_id)
                if other_tls_list and other_tls_list[0][0] == traci.vehicle.getNextTLS(vehicle_id)[0][0]:
                    # only collect the data from vehicle who facing the same intersection
                    try:
                        if other_veh_id not in traci.vehicle.getIDList():
                            print(f"Other vehicle {other_veh_id} not found in simulation")
                            continue

                        other_x, other_y = traci.vehicle.getPosition(other_veh_id)
                        # if distance(x, y, other_x, other_y) < detection_distance:
                        other_speed = traci.vehicle.getSpeed(other_veh_id)
                        other_acceleration = traci.vehicle.getAcceleration(other_veh_id)
                        other_direction = get_direction(traci.vehicle.getAngle(other_veh_id))

                        switch_time = [-1, -1, -1]
                        cur_data = [current_time, vehicle_id, other_veh_id, other_speed, other_acceleration, other_direction, other_x, other_y, tls_x, tls_y, tls_state, phase_min_durations, phase_max_durations, switch_time, tls_id, lane_id, time_to_next_switch]
                        heapq.heappush(switch_time_queue, (time_to_next_switch, cur_data))
                    except Exception as e:
                        print(f"Error collecting data for vehicle {other_veh_id}: {e}")
    except Exception as e:
        print(f"Error collecting data for vehicle {vehicle_id}: {e}")
    return #return nothing

def save_data():
    global data
    if data:
        file_path = './numo raw dataset/NUMO_raw_v2.csv'
        file_exists = os.path.isfile(file_path)
        with open(file_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            # Only write the header if the file does not exist
            if not file_exists:
                writer.writerow(['Step', 'VehicleID', 'Obs_VehicleID', 'Obs_Speed', 'Obs_Acceleration', 'Obs_Direction', 'Obs_Location_X', 'Obs_Location_Y', 'TLS_X', 'TLS_Y', 'TLS_State', 'minDur', 'maxDur', 'timeSwitch'])
            writer.writerows(data)
        print(f"Data written to file: {len(data)} entries")
        data.clear()  # clean up memory
    else:
        print("No data collected to write to file")

def signal_handler(sig, frame):
    print("Terminating simulation and saving data...")
    traci.close()
    save_data()
    sys.exit(0)

def updateSwitchTime(prev_data, curTime):
    # prev_data: [current_time, vehicle_id, other_veh_id, other_speed, other_acceleration, other_direction, 
    # other_x, other_y, tls_x, tls_y, tls_state, phase_min_durations, phase_max_durations, switch_time] + [tls_id, lane_id, time_to_next_switch]
    preTime = prev_data[0]
    vehicle_id = prev_data[1]
    tls_state = prev_data[10] # prev_state
    switch_time = prev_data[-4]
    tls_id = prev_data[-3]
    lane_id = prev_data[-2]
    isDone = False
    try:
        # switch time must be updated
        next_switch_time = traci.trafficlight.getNextSwitch(tls_id)
        prev_data[-1] = next_switch_time
       
        cur_tls_state = get_tl_state(tls_id, lane_id)

        for i in range(3):
            if stateDifferent(tls_state[i], cur_tls_state[i]) and switch_time[i]==-1:
                prev_data[-4][i] = curTime - preTime #switch_time
                isDone = True if all(switch_time[i] != -1 or tls_state[i]=="NA" for i in range(3)) else False
        return prev_data, isDone
    
    except Exception as e:
        print(f"Error updating switch time for vehicle {vehicle_id}: {e}")
    return prev_data, False

def main():
    sumoBinary = sumolib.checkBinary('sumo-gui')
    sumoCmd = [sumoBinary, "-c", "/Users/hy.c/Desktop/Numo_env/numo/nagoya.sumocfg", "--start"]
    traci.start(sumoCmd)

    detection_distance = 100 # the distance for 2 cars to communicate
    time_interval = 1
    recording_distance = 150 # be recorded within 150m from the traffic light
    global data

    signal.signal(signal.SIGINT, signal_handler)

    step = 0
    # with Pool(processes=8) as pool:
    while traci.simulation.getMinExpectedNumber() > 0:
        traci.simulationStep()

        idx = index.Index()
        # Build R tree
        vehicle_ids = traci.vehicle.getIDList()
        try:
            for other_veh_id in vehicle_ids:
                other_x, other_y = traci.vehicle.getPosition(other_veh_id)
                idx.insert(int(other_veh_id), (other_x, other_y, other_x, other_y))
        except Exception as e:
            print(f"Error building R-tree index: {e}")

        current_time = traci.simulation.getTime()
        for veh_id in vehicle_ids:
            try:
                collect_data(veh_id, current_time, detection_distance, recording_distance, idx)
                while switch_time_queue and switch_time_queue[0][0] < current_time:
                    next_switch_time, prev_data = heapq.heappop(switch_time_queue)
                    # update_data, isDone = updateSwitchTime(prev_data, current_time-1) #actual state change time is current time -1
                    update_data, isDone = updateSwitchTime(prev_data, current_time)
                    nextSwitchTime = update_data[-1]
                    if isDone:
                        update_data[10] = ['r' if s == 'NA' else s for s in update_data[10]] #update tls_state
                        data.append(update_data[:-3]) # tls_id, tls_lane, next_switch_time
                    else:
                        heapq.heappush(switch_time_queue, (nextSwitchTime, update_data))
            except Exception as e:
                print(f"Error processing data for vehicle {veh_id}: {e}")

        # Debug print to check data collection
        print(f"Step {step}, Collected data: {len(data)} entries")

        # Save data periodically
        if step % (time_interval * 50) == 0:
            save_data()

        step += 1

    traci.close()
    save_data()

if __name__ == "__main__":
    main()