In [258]:
import gym
from gym import spaces

import random
import scipy
import numpy as np
import pandas as pd

import stable_baselines3
from stable_baselines3.sac.policies import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import DQN

In [259]:
#Function for normal distribution truncation:

from scipy.stats import truncnorm

def get_truncated_normal(mean, sd, low, upp):
    return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)

In [260]:
#Function to get date sequence based on start_date and num_of_weeks:
# Works for Multi-Agent?

def get_date_seq(start_date_arr, num_of_weeks_arr): #start_date index 2, num_of_weeks index 3
    date_seq_arr = np.empty(shape=(len(start_date_arr),), dtype='object')
    for i in range(len(date_seq_arr)):
        date_seq_arr[i] = list(range(int(start_date_arr[i]), int(start_date_arr[i]) + int(num_of_weeks_arr[i])*7, 7))
    return date_seq_arr

In [261]:
#Function to get observation (still for single env, need to modify for multi-agent env):

def full_obs(_cap_dem_chosen_req, number_of_actions):
    _obs_min_arr = np.full((288+number_of_actions-1, ), 0)
    for i in range(len(_cap_dem_chosen_req)):
        _obs_min_arr[i+int((number_of_actions-1)/2)] = min(_cap_dem_chosen_req[i])
    return _obs_min_arr

In [262]:
#Function to get the one-hot-encoded vectors for departure and arrival airports:

def one_hot_encode_airport(airport, num_airports):
    encoding = np.zeros(num_airports)
    encoding[airport] = 1
    return encoding

# Example usage
num_airports = 3
airport1 = 0
airport2 = 1
airport3 = 2

encoded_airport1 = one_hot_encode_airport(airport1, num_airports)
encoded_airport2 = one_hot_encode_airport(airport2, num_airports)
encoded_airport3 = one_hot_encode_airport(airport3, num_airports)

print(encoded_airport1)
print(encoded_airport2)
print(encoded_airport3)

[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]


In [263]:
#Generate full info for the arrival sides:

def generate_info_arv(requests):
    ts_arv = np.empty(shape=(len(requests),), dtype='object')
    start_date_arv = np.empty(shape=(len(requests),), dtype='object')
    #date_seq_arv = np.empty(shape=(len(requests),), dtype='object')
    for i in range(len(requests)):
        ts_arv[i] = requests[i][1] + requests[i][7]/5
        if ts_arv[i] > 287:
            ts_arv[i] = ts_arv[i] - 287
            start_date_arv[i] = requests[i][2] + 1
        else:
            start_date_arv[i] = requests[i][2]
    date_seq_arv = get_date_seq(start_date_arv, requests[:, 3])
    return ts_arv, start_date_arv, date_seq_arv

In [264]:
#Generate the multi-agent scenario:

#=======================================================

#Modify the distribution based on historical data later:
def generate_scenario(number_of_requests, num_airports, cap_dict):
    
    #number_of_requests = 15000
    ts_72 = get_truncated_normal(mean=72, sd=12, low=0, upp=287).rvs(int(round(number_of_requests/2)))
    ts_72 = np.round(ts_72)

    ts_216 = get_truncated_normal(mean=216, sd=12, low=0, upp=287).rvs(int(round(number_of_requests/2)))
    ts_216 = np.round(ts_216)

    ts_dep = np.concatenate((ts_72, ts_216))
    ts_dep = ts_dep.astype(int)

    #Generate start date:

    start_date_dep = np.random.randint(low = 0, high=146, size=number_of_requests) #146 because period is 182 days and we consider series which span at least 5 weeks (+35 days)

    #Generate number of weeks:

    _max_day = np.full(number_of_requests, 182 - 1)

    _remaining_days = _max_day - start_date_dep

    _max_num_of_weeks = _remaining_days // 7

    num_of_weeks = np.random.randint(5, _max_num_of_weeks + 1)

    #Generate index for requests:

    index = np.array(list(range(number_of_requests)))

    #Generate origin (0 and 1 are two considered origin airports, 2 represent other airports, encoded in one-hot vector):

    #num_airports = 3
    origin_airport = np.empty(shape=(number_of_requests,), dtype='object')
    destination_airport = np.empty(shape=(number_of_requests,), dtype='object')
    for i in range(number_of_requests):
        _org_airport = one_hot_encode_airport(random.randint(0,2), num_airports)
        _org_airport_list = _org_airport.tolist()
        origin_airport[i] = _org_airport_list
        #Generate destination (the destination will be different with the origin):
        _dest_airport = _org_airport.copy()
        while np.array_equal(_dest_airport, _org_airport):
            np.random.shuffle(_dest_airport)
        _dest_airport_list = _dest_airport.tolist()
        destination_airport[i] = _dest_airport_list

    #Generate flying time (assume between airport 0 and 1 is 2 hour, 0 to 2 and 1 to 2 is arbitrary):

    fly_time = np.empty(shape=(number_of_requests,), dtype='object')
    for i in range (number_of_requests):
        if origin_airport[i] == list([1.0, 0.0, 0.0]) and destination_airport[i] == list([0.0, 1.0, 0.0]):
            fly_time[i] = 120
        elif origin_airport[i] == list([1.0, 0.0, 0.0]) and destination_airport[i] == list([0.0, 0.0, 1.0]):
            fly_time[i] = random.choice([60, 120, 180])
        elif origin_airport[i] == list([0.0, 1.0, 0.0]) and destination_airport[i] == list([1.0, 0.0, 0.0]):
            fly_time[i] = 120
        elif origin_airport[i] == list([0.0, 1.0, 0.0]) and destination_airport[i] == list([0.0, 0.0, 1.0]):
            fly_time[i] = random.choice([60, 120, 180])
        elif origin_airport[i] == list([0.0, 0.0, 1.0]) and destination_airport[i] == list([1.0, 0.0, 0.0]):
            fly_time[i] = random.choice([60, 120, 180])
        elif origin_airport[i] == list([0.0, 0.0, 1.0]) and destination_airport[i] == list([0.0, 1.0, 0.0]):
            fly_time[i] = random.choice([60, 120, 180])

    #Get date sequence (date seq is actually a list):

    date_seq_dep = get_date_seq(start_date_dep, num_of_weeks)

    #Generate status cap:

    status_cap_dep = np.full((number_of_requests,), 0)
    status_cap_arv = np.full((number_of_requests,), 0)
    

    requests = np.stack((index, ts_dep, start_date_dep, num_of_weeks, date_seq_dep, origin_airport, destination_airport, fly_time, status_cap_dep), axis=1)

    #Generate full info for the arv side:

    ts_arv, start_date_arv, date_seq_arv = generate_info_arv(requests)

    #pseudo_belong_dep = np.full((number_of_requests,), 0)
    #pseudo_belong_arv = np.full((number_of_requests,), 0)
    
    # Define requests_full as dtype object
    # requests_full = np.stack((index, ts_dep, start_date_dep, num_of_weeks, date_seq_dep, origin_airport, destination_airport, fly_time, status_cap_dep, ts_arv, start_date_arv, date_seq_arv, status_cap_arv), axis=1)
    num_entries = len(index)  # Given that 'index' is defined using np.array(list(range(number_of_requests)))
    # Create an empty array of the desired shape with dtype=object
    requests_full = np.empty((num_entries, 13), dtype=object)
    # Fill the array
    data = [index, ts_dep, start_date_dep, num_of_weeks, date_seq_dep, origin_airport, destination_airport, fly_time, status_cap_dep, ts_arv, start_date_arv, date_seq_arv, status_cap_arv]
    for i, column_data in enumerate(data):
        requests_full[:, i] = column_data

    # airport_req_dict: A dictionary where each key corresponds to a specific airport's requirements. 
    # The key format is 'req_i', where i is the index of the airport. 
    # The value for each key is a numpy array, with each row representing a request and the columns containing different attributes of that request.
    airport_req_dict, _belong_airport_dict = get_airport_req_dict(requests_full, num_airports)

    pot_dem_dict = get_initial_pot_dem_per_airport(airport_req_dict, num_airports)

    cap_dem_dict = get_cap_dem_dict(num_airports, cap_dict, pot_dem_dict)

    return requests_full, airport_req_dict, _belong_airport_dict, pot_dem_dict, cap_dem_dict

    #Generate capacity:

    #cap_arr = np.full((288, 182), 20)

    #Create final_sched:

    #final_sched_arr = req_arr.copy()

    #Get potential demand: #Check again the function here (turn to array)
    #pot_dem_arr = get_initial_pot_dem()

    #Get remaining cap:
    #cap_dem_arr = cap_arr - pot_dem_arr

In [265]:
#  identify which requests from a set of airports violate a specific condition
# For each airport (from 0 to num_airports - 1), the function checks a condition based on the status_cap_dep and status_cap_arv 
# of the associated numpy array. If the sum of these two columns is greater than or equal to 1 (mask), it means that capacity is exceeded.
# For every airport that has violations, the function extracts the IDs of these violating requests.
# The function accumulates these IDs in the violate_set list.

def get_violate_id_set(airport_req_dict, num_airports):
    violate_set = [] #(1: id, 2: airport, 3: dep, 4: arv)
    for i in range(num_airports):
        mask = ((airport_req_dict['req_{}'.format(i)][:, 8] + airport_req_dict['req_{}'.format(i)][:, 12]) >= 1)
        _id_violate_per_airport = airport_req_dict['req_{}'.format(i)][mask, :][:,0]
        violate_set.append(_id_violate_per_airport)
    violate_set = np.concatenate(violate_set, axis=0)
    violate_set = np.unique(violate_set)
    return violate_set

In [266]:
def get_violate_id_set_req_full(requests_full):
    mask = ((requests_full[:, 8] + requests_full[:, 12]) >= 1)
    violate_set_req_full = requests_full[mask, :][:,0]
    return violate_set_req_full

In [267]:
def get_req(violate_set, requests_full):
    # if not violate_set:
    #     raise ValueError("The provided violate_set is empty!")
    _violate_index = random.choice(violate_set)
    chosen_req = requests_full[requests_full[:,0] == _violate_index]
    return chosen_req

In [268]:
def flatten_cap_dem_dict(cap_dem_dict, num_airports):
    cap_dem_dict_flat = {}
    for i in range(num_airports):
        cap_dem_dict_flat['req_{}'.format(i)] = cap_dem_dict['req_{}'.format(i)].flatten()
    return cap_dem_dict_flat

In [269]:
#Get separated req per airport and store in a dict:

def get_airport_req_dict(requests_full, num_airports):
    airport_req_dict = {}
    _belong_airport_dict = {}
    for i in range(num_airports):
        airport_req_dict['req_{}'.format(i)] = np.empty((0, 15)) #This one depends on the number of elements of a final request
        _belong_airport_dict['req_{}'.format(i)] = np.full(num_airports, 0.0, dtype=float)
        _belong_airport_dict['req_{}'.format(i)][i] = float(1.0)
        _belong_airport_dict['req_{}'.format(i)] = _belong_airport_dict['req_{}'.format(i)].tolist()
        
    for i in range(len(requests_full)):
        _found_dep = 0
        _found_arv = 0
        for k in range(num_airports):
            #_found_dep = 0
            #_found_arv = 0
            if requests_full[i][5] == _belong_airport_dict['req_{}'.format(k)]:
                _dep_req = np.append(requests_full[i], 1)
                _dep_req = np.append(_dep_req, 0)
                airport_req_dict['req_{}'.format(k)] = np.vstack((airport_req_dict['req_{}'.format(k)], _dep_req))
                _found_dep = 1
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 1)
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 0)
                #break
            if requests_full[i][6] == _belong_airport_dict['req_{}'.format(k)]:
                _arv_req = np.append(requests_full[i], 0)
                _arv_req = np.append(_arv_req, 1)
                airport_req_dict['req_{}'.format(k)] = np.vstack((airport_req_dict['req_{}'.format(k)], _arv_req))
                _found_arv = 1
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 0)
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 1)
                #break
            if _found_dep + _found_arv == 2:
                break
        if _found_dep + _found_arv != 2:
            print("Cannot found both dep and arv at req {}".format(i))
            
    return airport_req_dict, _belong_airport_dict

In [206]:
def generate_deterministic_capacity_dict(num_airports, cap_per_airport_arr): #This function is for a period of 182 days and 288 slots/ day
    cap_dict = {}
    for i in range(num_airports):
        cap_dict['req_{}'.format(i)] = np.full((288, 182), cap_per_airport_arr[i])
    return cap_dict

In [270]:
def get_initial_pot_dem_per_airport(airport_req_dict, num_airports): #Replace req_df to req_df_update to update pot_dem_df #To be replaced with final_sched
    pot_dem_dict = {}
    #TODO: increase speed
    #13 dep 14 arv, 1 dep ts, 9 arv ts
    for i in range(num_airports):
        pot_dem_dict['req_{}'.format(i)] = np.full((288, 182), 0)
        for k in range(len(airport_req_dict['req_{}'.format(i)])):
            _time_slot = int(airport_req_dict['req_{}'.format(i)][k][1]) * int(airport_req_dict['req_{}'.format(i)][k][13]) + int(airport_req_dict['req_{}'.format(i)][k][9]) * int(airport_req_dict['req_{}'.format(i)][k][14])
            _date_seq = airport_req_dict['req_{}'.format(i)][k][4] * int(airport_req_dict['req_{}'.format(i)][k][13]) + airport_req_dict['req_{}'.format(i)][k][11] * int(airport_req_dict['req_{}'.format(i)][k][14])
            pot_dem_dict['req_{}'.format(i)][_time_slot, _date_seq] += 1
    return pot_dem_dict

In [271]:
def get_cap_dem_dict(num_airports, cap_dict, pot_dem_dict):
    cap_dem_dict = {}
    for i in range(num_airports):
        cap_dem_dict['req_{}'.format(i)] = cap_dict['req_{}'.format(i)] - pot_dem_dict['req_{}'.format(i)]
    return cap_dem_dict

In [272]:
def update_status_capacity(airport_req_dict, num_airports, cap_dem_dict, requests_full):
    #Them cot cap_status o init:
    for i in range(num_airports):
        for k in range(len(airport_req_dict['req_{}'.format(i)])):
            _time_slot = int(airport_req_dict['req_{}'.format(i)][k][1]) * int(airport_req_dict['req_{}'.format(i)][k][13]) + int(airport_req_dict['req_{}'.format(i)][k][9]) * int(airport_req_dict['req_{}'.format(i)][k][14])
            _date_seq = airport_req_dict['req_{}'.format(i)][k][4] * int(airport_req_dict['req_{}'.format(i)][k][13]) + airport_req_dict['req_{}'.format(i)][k][11] * int(airport_req_dict['req_{}'.format(i)][k][14])
            if all(x >= 0 for x in cap_dem_dict['req_{}'.format(i)][_time_slot, _date_seq]):
                #print(self.cap_dem_arr[_time_slot, _date_seq])
                airport_req_dict['req_{}'.format(i)][k][8] = 0
                airport_req_dict['req_{}'.format(i)][k][12] = 0
            else:
                if airport_req_dict['req_{}'.format(i)][k][13] == 1:
                    airport_req_dict['req_{}'.format(i)][k][8] = 1
                    _indices = np.where(requests_full[:, 0] == airport_req_dict['req_{}'.format(i)][k][0])
                    requests_full[_indices, 8] = 1
                else:
                    airport_req_dict['req_{}'.format(i)][k][12] = 1
                    _indices = np.where(requests_full[:, 0] == airport_req_dict['req_{}'.format(i)][k][0])
                    requests_full[_indices, 12] = 1

In [273]:
class TwoAirportSchedEnv(gym.Env):
    def __init__(self, number_of_actions, number_of_requests, num_airports, cap_per_airport_arr):
        
        super(TwoAirportSchedEnv, self).__init__()
        self.number_of_actions = int(number_of_actions)
        self.do_nothing_action = int((self.number_of_actions - 1)/2)
        self.number_of_requests = number_of_requests
        self.num_airports = num_airports
        #self.number_of_days = number_of_days
        self.cap_per_airport_arr = cap_per_airport_arr
        self.cap_dict = generate_deterministic_capacity_dict(self.num_airports, self.cap_per_airport_arr)
        #self.generate_scenario()
        #update the code to add generate scenario 
        self.requests_full, self.airport_req_dict, self._belong_airport_dict, self.pot_dem_dict, self.cap_dem_dict = generate_scenario(number_of_requests, num_airports, cap_dict = self.cap_dict)
        #to generate action:
        self.action_space = spaces.Discrete(self.number_of_actions)

        #cap_dem_dict_flat = flatten_cap_dem_dict(cap_dem_dict, num_airports)
        #_generate = True
        update_status_capacity(self.airport_req_dict, self.num_airports, self.cap_dem_dict, self.requests_full)
        #for i in range(len(self.requests_full)):
        #  #print(self.requests_full[i][8], self.requests_full[i][12])
        '''
        _exceed_cap = 0         
        while _exceed_cap == 0:
            self.requests_full, self.airport_req_dict, self._belong_airport_dict, self.pot_dem_dict, self.cap_dem_dict = generate_scenario(number_of_requests = self.number_of_requests, num_airports = self.num_airports, cap_dict = self.cap_dict)
            self.cap_dem_dict_flat = flatten_cap_dem_dict(self.cap_dem_dict, self.num_airports)
            for i in range(num_airports):
                _exceed_cap = _exceed_cap + (min(self.cap_dem_dict_flat['req_{}'.format(i)]))
        '''

        update_status_capacity(self.airport_req_dict, self.num_airports, self.cap_dem_dict, self.requests_full)
        self.num_step = 0
        
        #self.chosen_req = random.choice(self.requests_full)
        self.get_req()

        self.dep_time_slot = self.chosen_req[1]
    
    #added get request to know which is the request we are going to move at this step",
    def get_req(self):
        violate_set=get_violate_id_set(self.airport_req_dict, num_airports)
        _violate_index = random.choice(violate_set)
        self.chosen_req = self.requests_full[self.requests_full[:,0] == _violate_index][0]
    
    #need to check arrival time or change the variable to check_outbound(arv) and chec_outbound(dep)
    def check_outbound(self, action):
        dep_time_slot = self.chosen_req[1]
        arv_time_slot = self.chosen_req[9]
        change_slot = action - self.do_nothing_action

        outbound = False

        new_dep_time_slot = dep_time_slot + change_slot
        new_arv_time_slot = arv_time_slot + change_slot

        if (new_dep_time_slot < 0) or (new_dep_time_slot > 287) or (new_arv_time_slot < 0) or (new_arv_time_slot > 287):
            outbound = True

        return outbound
    
    def dep_or_arv(self):
        pass

    def update_dem(self, dep_airport, arv_airport, time_slot_dep, time_slot_arv, new_time_slot_dep, new_time_slot_arv, date_seq_dep, date_seq_arv):
        # Increment demand for the new time slot for both departure and arrival airports
        self.pot_dem_dict['req_{}'.format(dep_airport)][new_time_slot_dep, date_seq_dep] += 1
        self.pot_dem_dict['req_{}'.format(arv_airport)][new_time_slot_arv, date_seq_arv] += 1

        # Decrease demand for the initial time slot for both departure and arrival airports
        self.pot_dem_dict['req_{}'.format(dep_airport)][time_slot_dep, date_seq_dep] -= 1
        self.pot_dem_dict['req_{}'.format(arv_airport)][time_slot_arv, date_seq_arv] -= 1

    def update_cap_dem(self):
        for airport in range(self.num_airports):
            cap_key = 'req_{}'.format(airport)
            dem_key = 'req_{}'.format(airport)
            if cap_key not in self.cap_dict:
                raise KeyError(f"'{cap_key}' not found in cap_dict. Available keys: {list(self.cap_dict.keys())}")
            self.cap_dem_dict[cap_key] = self.cap_dict[cap_key] - self.pot_dem_dict[dem_key]

    def update_violate_set(curr_violate, not_violate_update, violate_update):
        curr_violate = set(curr_violate)
        for req_index in not_violate_update:
            curr_violate.remove(req_index)
        for req_index in violate_update:
            curr_violate.add(req_index) 
        curr_violate = list(curr_violate)
        return curr_violate 
    
    def step(self, action):
        
        _index = self.chosen_req[0]
        time_slot_dep = self.chosen_req[1]
        time_slot_arv = self.chosen_req[9]
        time_slot_arv = int(time_slot_arv)
        change_slot = action - self.do_nothing_action
        _dep_airport = self.chosen_req[5].index(1.0)
        _arv_airport = self.chosen_req[6].index(1.0)
        _date_seq_dep = self.chosen_req[4]
        _date_seq_arv = self.chosen_req[11]
        _start_date_dep = self.chosen_req[2]
        _start_date_arv = self.chosen_req[10]
        _num_weeks = self.chosen_req[3]
        new_time_slot_dep = time_slot_dep + change_slot
        new_time_slot_arv = time_slot_arv + change_slot
        new_time_slot_arv = int(new_time_slot_arv)
        new_date_seq_dep = _date_seq_dep
        new_date_seq_arv = _date_seq_arv

        outbound = self.check_outbound(action)

        if not outbound:
            pass
            
        elif outbound:
            # Adjust the departure time slot as needed
            if new_time_slot_dep < 0:
                new_time_slot_dep = 287  # Move to the last time slot of the previous day
                new_start_date_dep = _start_date_dep - 1
                new_date_seq_dep = get_date_seq(new_start_date_dep, _num_weeks)
            elif new_time_slot_dep > 287:
                new_time_slot_dep = 0  # Move to the first time slot of the next day
                new_start_date_dep = _start_date_dep + 1
                new_date_seq_dep = get_date_seq(new_start_date_dep, _num_weeks)
            # Adjust the departure time slot as needed
            if new_time_slot_arv < 0:
                # Assume that we only move by max one slot for a timeslot change
                new_time_slot_arv = 287  # Move to the last time slot of the previous day
                new_time_slot_arv = int(new_time_slot_arv)
                new_start_date_arv = _start_date_arv - 1
                new_date_seq_arv = get_date_seq(new_start_date_arv, _num_weeks)
            elif new_time_slot_dep > 287:
                # Assume that we only move by max one slot for a timeslot change
                new_time_slot_dep = 0  # Move to the first time slot of the next day
                new_time_slot_arv = int(new_time_slot_arv)
                new_start_date_arv = _start_date_arv + 1
                new_date_seq_arv = get_date_seq(new_start_date_arv, _num_weeks)

        else:
            print('Problem with check outbound!')
            #update airport_req_dict
            #update dem_dict all
            #update cap_dem all
            #update status all
            
        # Update the request for both departure and arrival time slots and dates
        self.requests_full[self.requests_full[:, 0] == _index][:, 1] = new_time_slot_dep
        self.requests_full[self.requests_full[:, 0] == _index][:, 9] = new_time_slot_arv
        # self.requests_full[self.requests_full[:, 0] == _index][:, 4] = new_date_seq_dep
        # self.requests_full[self.requests_full[:, 0] == _index][:, 11] = new_date_seq_arv
        _index_matching = np.where(self.requests_full[:, 0] == _index)[0][0]
        self.requests_full[_index_matching, 4] = new_date_seq_dep
        self.requests_full[_index_matching, 11] = new_date_seq_arv


        # Update the airport request dict
        self.airport_req_dict['req_{}'.format(_dep_airport)][self.airport_req_dict['req_{}'.format(_dep_airport)][:, 0] == _index][:, 1] = new_time_slot_dep
        self.airport_req_dict['req_{}'.format(_dep_airport)][self.airport_req_dict['req_{}'.format(_dep_airport)][:, 0] == _index][:, 9] = new_time_slot_dep
        self.airport_req_dict['req_{}'.format(_arv_airport)][self.airport_req_dict['req_{}'.format(_arv_airport)][:, 0] == _index][:, 1] = new_time_slot_arv
        self.airport_req_dict['req_{}'.format(_arv_airport)][self.airport_req_dict['req_{}'.format(_arv_airport)][:, 0] == _index][:, 9] = new_time_slot_arv

        # Update demand for the new time slots
        self.update_dem(_dep_airport, _arv_airport, time_slot_dep, time_slot_arv, new_time_slot_dep, new_time_slot_arv, new_date_seq_dep, new_date_seq_arv)
        
        # Update cap_dem 
        self.update_cap_dem()

        # Update status capacity after the cap_dem table is updated
        update_status_capacity(self.airport_req_dict, self.num_airports, self.cap_dem_dict, self.requests_full)

        # Reward part:
        local_reward = 0
        if outbound:
            local_reward = -1
        else:
            local_reward = 0.1*(-abs(change_slot)*0.5*_num_weeks) #TODO change if increase number of actions
        
        self.num_step += 1
        done = False
        if self.chosen_req[8] == 0 and self.chosen_req[12] == 0:
            done = True
            obs = np.zeros((self.number_of_actions + 1,))
            global_reward = 100

        elif self.num_step == self.number_of_requests*5:
            done = True
            obs = np.zeros((self.number_of_actions + 1,))
            negative_sum = 500
            for value in self.cap_dem_dict.values():
                # Assuming each value is a numeric value or a numpy array
                # If it's a numpy array, you can sum all negative values directly using numpy
                if isinstance(value, np.ndarray):
                    negative_sum += np.sum(value[value < 0])
                else:
                    # If it's a single numeric value, just check if it's negative
                    if value < 0:
                        negative_sum += value
            global_reward = negative_sum*10

        else:
            global_reward = 0
            #obs = self._next_observation()
            obs = np.zeros((self.number_of_actions + 1,))

        reward_time_step = -0.5
            
        total_reward = float(local_reward + global_reward + reward_time_step)
    
        return obs, total_reward, done, {}

    def take_action(self):

        # Use it like this:
        # env = TwoAirportSchedEnv(...)
        # env.take_action()
        # Step 1: Randomly choose actions
        action_dep = self.action_space.sample()  # For the departure airport
        action_arv = self.action_space.sample()  # For the arrival airport

        #Step 2: Check if actions match
        # matched_action = None
        while action_dep != action_arv:
          action_dep = self.action_space.sample() 
          action_arv = self.action_space.sample()

        return action_dep
    
    # def check_request(selfrequests_full, airport_req_dict, _belong_airport_dict, num_airports, pot_dem_dict):
    #     check_if_same_org_dest(self.requests_full)
    #     check_correspond_time_slot(self.requests_full)
    #     check_correspond_date(self.requests_full)
    #     check_date_seq(self.requests_full)
    #     check_exceed_period(self.requests_full)
        
        # check_duplication_each_req_airport(airport_req_dict, num_airports)
        # check_correct_dep_arv_binary_values(airport_req_dict, _belong_airport_dict, num_airports)
        # check_pot_dem_dict(airport_req_dict, num_airports, pot_dem_dict)

    # Implement next_observation

    # Implement reset


In [212]:
#Not done
def next_observation(self):
    _chosen_req_info = self.chosen_req
    
    _dep_time_slot = self.chosen_req[:,1]
    _num_of_weeks = self.chosen_req[:,3]
    _dep_date_seq = self.chosen_req[:,4]
    _dep_belong = self.chosen_req[:,5][0]
    
    _arv_time_slot = self.chosen_req[:,9]
    _arv_date_seq = self.chosen_req[:,11]
    _arv_belong = self.chosen_req[:,6][0]

    _cap_dem_dep = self.cap_dem_arr[:, _date_seq].copy()
    _cap_dem_arv = self.cap_dem
    #print(self.number_of_actions, type(self.number_of_actions))
    _obs_min_arr = full_obs(_cap_dem_chosen_req, self.number_of_actions)
    _obs_time_slot_related = list(range(_time_slot, _time_slot + self.number_of_actions, 1))
    _cap_dem_obs = _obs_min_arr[_obs_time_slot_related]    

In [274]:
TwoAirportSchedEnv = TwoAirportSchedEnv(number_of_actions=3, number_of_requests=15000, num_airports=3, cap_per_airport_arr= [9,9,9])

In [275]:
requests_full = TwoAirportSchedEnv.requests_full
airport_req_dict = TwoAirportSchedEnv.airport_req_dict
_belong_airport_dict = TwoAirportSchedEnv._belong_airport_dict
pot_dem_dict = TwoAirportSchedEnv.pot_dem_dict

In [276]:
cap_dict = TwoAirportSchedEnv.cap_dict

In [277]:
cap_dem_dict = get_cap_dem_dict(3, cap_dict, pot_dem_dict)
cap_dem_dict

{'req_0': array([[9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        ...,
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9]]),
 'req_1': array([[9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        ...,
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9]]),
 'req_2': array([[9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        ...,
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9],
        [9, 9, 9, ..., 9, 9, 9]])}

In [278]:
_belong_airport_dict

{'req_0': [1.0, 0.0, 0.0], 'req_1': [0.0, 1.0, 0.0], 'req_2': [0.0, 0.0, 1.0]}

In [279]:
# _dep_belong = chosen_req[:,5][0]
# _dep_belong

# Accessing chosen_req from the environment
chosen_req = TwoAirportSchedEnv.chosen_req

# Assuming chosen_req is a 1D array or list, accessing the element at index 5
_dep_belong = chosen_req[5]
print(_dep_belong)

[0.0, 1.0, 0.0]


In [280]:
def getKey(dct,value):
     return [key for key in dct if (dct[key] == value)]

In [281]:
dep_key = getKey(_belong_airport_dict, _dep_belong)
dep_key

['req_1']

In [282]:
# _obs_time_slot_related = list(range(_dep_time_slot[0] - 2, _dep_time_slot[0] + 2, 1)) #5 is number of actions
# _obs_time_slot_related

# Assuming you have your environment defined as 'env'
# Fetching _dep_time_slot from the environment
_dep_time_slot = TwoAirportSchedEnv.dep_time_slot

# Using _dep_time_slot to get _obs_time_slot_related
_obs_time_slot_related = list(range(_dep_time_slot - 2, _dep_time_slot + 2, 1))
print(_obs_time_slot_related)

[214, 215, 216, 217]


In [283]:
_dep_date_seq = chosen_req[4]
_dep_date_seq

[59, 66, 73, 80, 87, 94, 101]

In [284]:
#Havent modified
def full_obs(_cap_dem_chosen_req, number_of_actions):
    _obs_min_arr = np.full((288+number_of_actions-1, ), 0)
    for i in range(len(_cap_dem_chosen_req)):
        _obs_min_arr[i+int((number_of_actions-1)/2)] = min(_cap_dem_chosen_req[i])
    return _obs_min_arr

In [285]:
violate_id = get_violate_id_set(airport_req_dict, 3)

In [286]:
violate_req = get_violate_id_set_req_full(requests_full)

In [287]:
len(violate_id)

8715

In [288]:
len(violate_req)

8715

In [289]:
violate_id

array([0, 1, 3, ..., 14994, 14997, 14998], dtype=object)

In [290]:
violate_req

array([0, 1, 3, ..., 14994, 14997, 14998], dtype=object)

In [291]:
#Check if airport req dict generated correctly:
a = TwoAirportSchedEnv.airport_req_dict['req_0'][:,0]
b = TwoAirportSchedEnv.airport_req_dict['req_1'][:,0]
c = TwoAirportSchedEnv.airport_req_dict['req_2'][:,0]
ab = np.intersect1d(a,b)
bc = np.intersect1d(b,c)
ca = np.intersect1d(c,a)
len(a) + len(b) + len(c) - len(ab) - len(bc) - len(ca)

15000

In [292]:
def get_req(violate_set, requests_full): 
    _violate_index = random.choice(violate_set)
    chosen_req = requests_full[requests_full[:,0] == _violate_index]
    return chosen_req

In [293]:
def check_random_n_violate_req(n, violate_set, requests_full): #violate_set <- violate_id or violate_req
    for i in range(n):
        _chosen_req = get_req(violate_set, requests_full)
        if _chosen_req[:,8] + _chosen_req[:,12] == 0:
            print('There is error at req {}'.format(_chosen_req))
    print('Check {} violated reqs done!'.format(n))

In [294]:
violate_set = get_violate_id_set(TwoAirportSchedEnv.airport_req_dict, TwoAirportSchedEnv.num_airports)

In [295]:
check_random_n_violate_req(100, violate_set, requests_full)

Check 100 violated reqs done!


In [None]:
def check_violate_set(violate_set, airport_req_dict, ): #TODO
    pass

In [None]:
check_request(TwoAirportSchedEnv.requests_full, TwoAirportSchedEnv.airport_req_dict, TwoAirportSchedEnv._belong_airport_dict, TwoAirportSchedEnv.num_airports, TwoAirportSchedEnv.pot_dem_dict)

In [None]:
def check_union_req_per_airport(airport_req_dict, num_airports, num_requests):
    pass
#Written above, hasnt writtern in terms of function 

In [296]:
def check_pot_dem_dict(airport_req_dict, num_airports, pot_dem_dict):
    for i in range(num_airports):
        if sum(airport_req_dict['req_{}'.format(i)][:,3]) == sum(sum(pot_dem_dict['req_{}'.format(i)])):
            print('req {} checked. No issue!'.format(i))
        else:
            print('req {} has problem!'. format(i))

In [300]:
#Functions to check the generated scenarios:

def check_exceed_period(requests_full):
    for i in range(len(requests_full)):
        if max(requests_full[i][4]) >= 182:
            print('Exceed period departure at request {}'.format(i))
        if max(requests_full[i][11]) >= 182:
            print('Exceed period arrival at request {}'.format(i))
    print('Check exceed period done!')

    
def check_if_same_org_dest(requests):
    for i in range(len(requests)):
        if requests[i][5] == requests[i][6]:
            print('Problem at {}'.format(i))
#         else:
#             print(i, 'Checked')
    print('Check if any same origin and destination done!')
    
    
def check_correspond_time_slot(requests):
    for i in range(len(requests)):
        if requests[i][1] + requests[i][7]/5 > 287:
            #print(i, '1')
            _ts_arv = (requests[i][1] + requests[i][7]/5) - 287
            if requests[i][9] != _ts_arv:
                print("Not corresponding time slot at request {}".format(i))
        else:
            #print(i, '2')
            if requests[i][9] != requests[i][1] + requests[i][7]/5:
                print("Not corresponding time slot at request {}".format(i))
    print('Check corresponding time slot done!')
    
    
def check_correspond_date(requests):
    for i in range(len(requests)):
        if requests[i][1] + requests[i][7]/5 > 287:
            #print(i, '1')
            _start_date_arv = requests[i][2] + 1
            if requests[i][10] != _start_date_arv:
                print("Not corresponding start date at request {}".format(i))
        else:
            #print(i, '2')
            if requests[i][10] != requests[i][2]:
                print("Not corresponding start date at request {}".format(i))
    print("Check start date done!")
    
    
def check_date_seq(requests):
    for i in range(len(requests)):
        if requests[i][1] + requests[i][7]/5 > 287:
            #print(i, '1')
            _date_seq_arv = [x+1 for x in requests[i][4]]
            if requests[i][11] != _date_seq_arv:
                print('Not corresponding date seq at request {}'.format(i))
        else:
            #print(i, '2')
            if requests[i][11] != requests[i][4]:
                print('Not corresponding date seq at request {}'.format(i)) 
    print('Check date seq done!')
    
    
def check_duplication_each_req_airport(airport_req_dict, num_airports):
    for i in range(num_airports):
        if len(airport_req_dict['req_{}'.format(i)]) != len(list(set(airport_req_dict['req_{}'.format(i)][:,0]))):
            print('Duplication at airport {}'.format(i))
    print('Check duplication per airport req done!')
    
    
def check_correct_dep_arv_binary_values(airport_req_dict, _belong_airport_dict, num_airports):
    for i in range(num_airports):
        for k in range(len(airport_req_dict['req_{}'.format(i)])):
            if airport_req_dict['req_{}'.format(i)][k][5] == _belong_airport_dict['req_{}'.format(i)] and airport_req_dict['req_{}'.format(i)][k][13] != 1:
                print('Not correct dep at airport {} and req {}!'.format(i, k))
            if airport_req_dict['req_{}'.format(i)][k][6] == _belong_airport_dict['req_{}'.format(i)] and airport_req_dict['req_{}'.format(i)][k][14] != 1:
                print('Not correct arv at airport {} and req {}!'.format(i, k))
            if airport_req_dict['req_{}'.format(i)][k][5] != _belong_airport_dict['req_{}'.format(i)] and airport_req_dict['req_{}'.format(i)][k][6] != _belong_airport_dict['req_{}'.format(i)]:
                print('Req {} not belong to airport {}'.format(k, i))
                print('====')
                print(airport_req_dict['req_{}'.format(i)][k][5], airport_req_dict['req_{}'.format(i)][k][6], _belong_airport_dict['req_{}'.format(i)])
            if airport_req_dict['req_{}'.format(i)][k][13] + airport_req_dict['req_{}'.format(i)][k][14] != 1:
                print('Not unique in dep and arv status at airport {} and req {}!'.format(i, k))
    print('Check dep and arv status done!')
    

def check_pot_dem_dict(airport_req_dict, num_airports, pot_dem_dict):
    for i in range(num_airports):
        if sum(airport_req_dict['req_{}'.format(i)][:,3]) == sum(sum(pot_dem_dict['req_{}'.format(i)])):
            print('req {} checked. No issue!'.format(i))
        else:
            print('req {} has problem!'. format(i))

# Test update_dem
def check_update_dem():
    # Mock the pot_dem_dict to control our test case
    TwoAirportSchedEnv.pot_dem_dict = {
        'req_0': np.zeros((288, 7)),  # assuming 288 time slots and 7 dates
        'req_1': np.zeros((288, 7)),
        'req_2': np.zeros((288, 7))
    }
    dep_airport = 1  # choosing different values this time to test the functionality
    arv_airport = 2
    time_slot_dep = 60  # arbitrary time slot for the test
    time_slot_arv = 120
    new_time_slot_dep = 65
    new_time_slot_arv = 125
    date_seq_dep = 2
    date_seq_arv = 5
    # Increase demand for initial time slots for departure and arrival to mock a pre-existing request
    TwoAirportSchedEnv.pot_dem_dict['req_{}'.format(dep_airport)][time_slot_dep, date_seq_dep] = 1
    TwoAirportSchedEnv.pot_dem_dict['req_{}'.format(arv_airport)][time_slot_arv, date_seq_arv] = 1

    # Print pot_dem_dict before update
    # print("Before update:")
    # for k, v in TwoAirportSchedEnv.pot_dem_dict.items():
    #     print(f"{k}: {v}")
        
    # Call the function under test
    TwoAirportSchedEnv.update_dem(dep_airport, arv_airport, time_slot_dep, time_slot_arv, new_time_slot_dep, new_time_slot_arv, date_seq_dep, date_seq_arv)

    # Print pot_dem_dict after update
    # print("\nAfter update:")
    # for k, v in TwoAirportSchedEnv.pot_dem_dict.items():
    #     print(f"{k}: {v}")

    # Assertions
    # Check if demand for the initial time slots is decreased
    assert TwoAirportSchedEnv.pot_dem_dict['req_{}'.format(dep_airport)][time_slot_dep, date_seq_dep] == 0
    assert TwoAirportSchedEnv.pot_dem_dict['req_{}'.format(arv_airport)][time_slot_arv, date_seq_arv] == 0
    # Check if demand for the new time slots is increased
    assert TwoAirportSchedEnv.pot_dem_dict['req_{}'.format(dep_airport)][new_time_slot_dep, date_seq_dep] == 1
    assert TwoAirportSchedEnv.pot_dem_dict['req_{}'.format(arv_airport)][new_time_slot_arv, date_seq_arv] == 1
    print("Update dem Test passed!")

def check_update_cap_dem():
    # Manually set the `chosen_req`
    TwoAirportSchedEnv.chosen_req[1] = 1
    # Call check_outbound and verify
    # Test case where the change_slot does not make it go outbound
    action = TwoAirportSchedEnv.do_nothing_action
    assert TwoAirportSchedEnv.check_outbound(action) == False, "Expected no outbound for do_nothing_action"
    # Test case where the change_slot makes it go outbound on the lower side
    action = -2
    assert TwoAirportSchedEnv.check_outbound(action) == True, "Expected outbound for a negative shift from the 5th time slot"
    # Test case where the change_slot makes it go outbound on the upper side
    TwoAirportSchedEnv.chosen_req[1] = 286
    action = 3
    assert TwoAirportSchedEnv.check_outbound(action) == True, "Expected outbound for a positive shift from the 286th time slot"
    # Add more cases as needed...
    print("All check_outbound tests passed!")

#==================================================================================
    
    
def check_request(requests_full, airport_req_dict, _belong_airport_dict, num_airports, pot_dem_dict):
    check_if_same_org_dest(requests_full)
    check_correspond_time_slot(requests_full)
    check_correspond_date(requests_full)
    check_date_seq(requests_full)
    check_exceed_period(requests_full)

    #add the unit test
    
    check_duplication_each_req_airport(airport_req_dict, num_airports)
    check_correct_dep_arv_binary_values(airport_req_dict, _belong_airport_dict, num_airports)
    check_pot_dem_dict(airport_req_dict, num_airports, pot_dem_dict)
    check_update_dem()
    check_update_cap_dem()

In [301]:
for i in range(100):
    num_airports = 3
    requests_full, airport_req_dict, _belong_airport_dict, pot_dem_dict, cap_dem_dict = generate_scenario(number_of_requests = 15000, num_airports = 3, cap_dict = cap_dict)
    check_request(requests_full, airport_req_dict, _belong_airport_dict, num_airports, pot_dem_dict)
    print('{}:======================================='.format(i))

Check if any same origin and destination done!
Check corresponding time slot done!
Check start date done!
Check date seq done!
Check exceed period done!
Check duplication per airport req done!
Check dep and arv status done!
req 0 checked. No issue!
req 1 checked. No issue!
req 2 checked. No issue!
Update dem Test passed!
All check_outbound tests passed!
Check if any same origin and destination done!
Check corresponding time slot done!
Check start date done!
Check date seq done!
Check exceed period done!
Check duplication per airport req done!
Check dep and arv status done!
req 0 checked. No issue!
req 1 checked. No issue!
req 2 checked. No issue!
Update dem Test passed!
All check_outbound tests passed!
Check if any same origin and destination done!
Check corresponding time slot done!
Check start date done!
Check date seq done!
Check exceed period done!
Check duplication per airport req done!
Check dep and arv status done!
req 0 checked. No issue!
req 1 checked. No issue!
req 2 checked.

KeyboardInterrupt: 

Tests for new functions

All check_outbound tests passed!


DON'T RUN CELL BELOW

In [None]:
# import matplotlib.pyplot as plt

# def test_simulator():
#     # 1. Initialization
#     num_actions = 10
#     num_requests = 15000
#     num_airports = 3
#     cap_per_airport = [10, 12, 14]  # Example capacities
#     env = TwoAirportSchedEnv(num_actions, num_requests, num_airports, cap_per_airport)

#     # Number of steps you want to test the simulator for
#     num_steps = 100
#     rewards = []

#     for _ in range(num_steps):
#         # 2. Reset the environment
#         # obs = env.reset()
#         # 3. Take steps
#         action = env.take_action()
#         next_obs, reward, done, info = env.step(action)
#         check_request(env.requests_full, env.airport_req_dict, env._belong_airport_dict, num_airports, env.pot_dem_dict)
#         rewards.append(reward)
        
#         if done:
#             break

#     # 4. Visualize Results
#     plt.plot(rewards)
#     plt.xlabel("Steps")
#     plt.ylabel("Reward")
#     plt.title("Reward over Time")
#     plt.show()

# # Call the test function
# test_simulator()