In [None]:
import gym
from gym import spaces

import random
import scipy
import numpy as np
import pandas as pd

import stable_baselines3
from stable_baselines3.sac.policies import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import DQN

In [None]:
#Function for normal distribution truncation:

from scipy.stats import truncnorm

def get_truncated_normal(mean, sd, low, upp):
    return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)

In [None]:
#Function to get date sequence based on start_date and num_of_weeks:
# Works for Single-Agent?

def get_date_seq(start_date_arr, num_of_weeks_arr):
    # Assert that both arguments are of type list or np.ndarray
    assert isinstance(start_date_arr, (list, np.ndarray)), f"Expected start_date_arr to be list or array, got {type(start_date_arr)}"
    assert isinstance(num_of_weeks_arr, (list, np.ndarray)), f"Expected num_of_weeks_arr to be list or array, got {type(num_of_weeks_arr)}"
    
    # Assert that both arrays/lists have the same length
    assert len(start_date_arr) == len(num_of_weeks_arr), f"start_date_arr and num_of_weeks_arr lengths mismatch: {len(start_date_arr)} vs {len(num_of_weeks_arr)}"
    
    date_seq_arr = [None] * len(start_date_arr)  # Initialize as a list of None's.
    for i in range(len(date_seq_arr)):
        date_seq_arr[i] = list(range(int(start_date_arr[i]), int(start_date_arr[i]) + int(num_of_weeks_arr[i]) * 7, 7))
    return date_seq_arr



In [None]:
#Function to get observation (still for single env, need to modify for multi-agent env):

def full_obs(_cap_dem_chosen_req, number_of_actions):
    _obs_min_arr = np.full((288+number_of_actions-1, ), 0)
    for i in range(len(_cap_dem_chosen_req)):
        _obs_min_arr[i+int((number_of_actions-1)/2)] = min(_cap_dem_chosen_req[i])
    return _obs_min_arr

In [None]:
#Function to get the one-hot-encoded vectors for departure and arrival airports:

def one_hot_encode_airport(airport, num_airports):
    encoding = np.zeros(num_airports)
    encoding[airport] = 1
    return encoding

# Example usage
num_airports = 3
airport1 = 0
airport2 = 1
airport3 = 2

encoded_airport1 = one_hot_encode_airport(airport1, num_airports)
encoded_airport2 = one_hot_encode_airport(airport2, num_airports)
encoded_airport3 = one_hot_encode_airport(airport3, num_airports)

print(encoded_airport1)
print(encoded_airport2)
print(encoded_airport3)

In [None]:
#Generate full info for the arrival sides:

def generate_info_arv(requests):
    ts_arv = np.empty(shape=(len(requests),), dtype='object')
    start_date_arv = np.empty(shape=(len(requests),), dtype='object')
    #date_seq_arv = np.empty(shape=(len(requests),), dtype='object')
    for i in range(len(requests)):
        ts_arv[i] = requests[i][1] + requests[i][7]/5
        if ts_arv[i] > 287:
            ts_arv[i] = ts_arv[i] - 287
            start_date_arv[i] = requests[i][2] + 1
        else:
            start_date_arv[i] = requests[i][2]
    date_seq_arv = get_date_seq(start_date_arv, requests[:, 3])
    return ts_arv, start_date_arv, date_seq_arv

In [None]:
def get_violate_id_set(airport_req_dict, num_airports):
    violate_set = [] #(1: id, 2: airport, 3: dep, 4: arv)
    for i in range(num_airports):
        mask = ((airport_req_dict['req_{}'.format(i)][:, 8] + airport_req_dict['req_{}'.format(i)][:, 12]) >= 1)
        _id_violate_per_airport = airport_req_dict['req_{}'.format(i)][mask, :][:,0]
        violate_set.append(_id_violate_per_airport)
    violate_set = np.concatenate(violate_set, axis=0)
    violate_set = np.unique(violate_set)
    return violate_set

In [None]:
def get_violate_id_set_req_full(requests_full):
    mask = ((requests_full[:, 8] + requests_full[:, 12]) >= 1)
    violate_set_req_full = requests_full[mask, :][:,0]
    return violate_set_req_full

In [None]:
def get_req(violate_set, requests_full):
    if not violate_set:
        raise ValueError("The provided violate_set is empty!")
    _violate_index = random.choice(violate_set)
    chosen_req = requests_full[requests_full[:,0] == _violate_index]
    return chosen_req

In [None]:
def flatten_cap_dem_dict(cap_dem_dict, num_airports):
    cap_dem_dict_flat = {}
    for i in range(num_airports):
        cap_dem_dict_flat['req_{}'.format(i)] = cap_dem_dict['req_{}'.format(i)].flatten()
    return cap_dem_dict_flat

In [None]:
#Get separated req per airport and store in a dict:

def get_airport_req_dict(requests_full, num_airports):
    airport_req_dict = {}
    _belong_airport_dict = {}
    for i in range(num_airports):
        airport_req_dict['req_{}'.format(i)] = np.empty((0, 15)) #This one depends on the number of elements of a final request
        _belong_airport_dict['req_{}'.format(i)] = np.full(num_airports, 0.0, dtype=float)
        _belong_airport_dict['req_{}'.format(i)][i] = float(1.0)
        _belong_airport_dict['req_{}'.format(i)] = _belong_airport_dict['req_{}'.format(i)].tolist()
        
    for i in range(len(requests_full)):
        _found_dep = 0
        _found_arv = 0
        for k in range(num_airports):
            #_found_dep = 0
            #_found_arv = 0
            if requests_full[i][5] == _belong_airport_dict['req_{}'.format(k)]:
                _dep_req = np.append(requests_full[i], 1)
                _dep_req = np.append(_dep_req, 0)
                airport_req_dict['req_{}'.format(k)] = np.vstack((airport_req_dict['req_{}'.format(k)], _dep_req))
                _found_dep = 1
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 1)
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 0)
                #break
            if requests_full[i][6] == _belong_airport_dict['req_{}'.format(k)]:
                _arv_req = np.append(requests_full[i], 0)
                _arv_req = np.append(_arv_req, 1)
                airport_req_dict['req_{}'.format(k)] = np.vstack((airport_req_dict['req_{}'.format(k)], _arv_req))
                _found_arv = 1
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 0)
                #airport_req_dict['req_{}'.format(k)] = np.append(airport_req_dict['req_{}'.format(k)], 1)
                #break
            if _found_dep + _found_arv == 2:
                break
        if _found_dep + _found_arv != 2:
            print("Cannot found both dep and arv at req {}".format(i))
            
    return airport_req_dict, _belong_airport_dict

In [None]:
def generate_deterministic_capacity_dict(num_airports, cap_per_airport_arr): #This function is for a period of 182 days and 288 slots/ day
    cap_dict = {}
    for i in range(num_airports):
        cap_dict['req_{}'.format(i)] = np.full((288, 182), cap_per_airport_arr[i])
    return cap_dict

In [None]:
def get_initial_pot_dem_per_airport(airport_req_dict, num_airports): #Replace req_df to req_df_update to update pot_dem_df #To be replaced with final_sched
    pot_dem_dict = {}
    #TODO: increase speed
    #13 dep 14 arv, 1 dep ts, 9 arv ts
    for i in range(num_airports):
        pot_dem_dict['req_{}'.format(i)] = np.full((288, 182), 0)
        for k in range(len(airport_req_dict['req_{}'.format(i)])):
            _time_slot = int(airport_req_dict['req_{}'.format(i)][k][1]) * int(airport_req_dict['req_{}'.format(i)][k][13]) + int(airport_req_dict['req_{}'.format(i)][k][9]) * int(airport_req_dict['req_{}'.format(i)][k][14])
            _date_seq = airport_req_dict['req_{}'.format(i)][k][4] * int(airport_req_dict['req_{}'.format(i)][k][13]) + airport_req_dict['req_{}'.format(i)][k][11] * int(airport_req_dict['req_{}'.format(i)][k][14])
            pot_dem_dict['req_{}'.format(i)][_time_slot, _date_seq] += 1
    return pot_dem_dict

In [None]:
def get_cap_dem_dict(num_airports, cap_dict, pot_dem_dict):
    cap_dem_dict = {}
    for i in range(num_airports):
        cap_dem_dict['req_{}'.format(i)] = cap_dict['req_{}'.format(i)] - pot_dem_dict['req_{}'.format(i)]
    return cap_dem_dict

In [None]:
def update_status_capacity(airport_req_dict, num_airports, cap_dem_dict, requests_full):
    #Them cot cap_status o init:
    for i in range(num_airports):
        for k in range(len(airport_req_dict['req_{}'.format(i)])):
            _time_slot = int(airport_req_dict['req_{}'.format(i)][k][1]) * int(airport_req_dict['req_{}'.format(i)][k][13]) + int(airport_req_dict['req_{}'.format(i)][k][9]) * int(airport_req_dict['req_{}'.format(i)][k][14])
            _date_seq = airport_req_dict['req_{}'.format(i)][k][4] * int(airport_req_dict['req_{}'.format(i)][k][13]) + airport_req_dict['req_{}'.format(i)][k][11] * int(airport_req_dict['req_{}'.format(i)][k][14])
            if all(x >= 0 for x in cap_dem_dict['req_{}'.format(i)][_time_slot, _date_seq]):
                #print(self.cap_dem_arr[_time_slot, _date_seq])
                airport_req_dict['req_{}'.format(i)][k][8] = 0
                airport_req_dict['req_{}'.format(i)][k][12] = 0
            else:
                if airport_req_dict['req_{}'.format(i)][k][13] == 1:
                    airport_req_dict['req_{}'.format(i)][k][8] = 1
                    _indices = np.where(requests_full[:, 0] == airport_req_dict['req_{}'.format(i)][k][0])
                    requests_full[_indices, 8] = 1
                else:
                    airport_req_dict['req_{}'.format(i)][k][12] = 1
                    _indices = np.where(requests_full[:, 0] == airport_req_dict['req_{}'.format(i)][k][0])
                    requests_full[_indices, 12] = 1

In [None]:
#Previous single-agent simulator:

class SchedEnv(gym.Env):
    def __init__(self, number_of_actions, number_of_request, number_of_days):
        
        super(SchedEnv, self).__init__()
        self.number_of_actions = int(number_of_actions) #Number of action should be an odd number
        self.do_nothing_action = int((self.number_of_actions - 1)/2)
        self.number_of_request = number_of_request
        self.number_of_days = number_of_days
        self.generate_scenario()
        _cap_dem_flat = self.cap_dem_arr.flatten()
        #Test lai so lan chay while o day, neu ko thi dung pattern (distribution) de chac chan over:
        while min(_cap_dem_flat) >= 0:
            self.generate_scenario()
            _cap_dem_flat = self.cap_dem_arr.flatten()
        self.update_status_capacity
        #Check number of training step:
        self.num_step = 0
        
        
#         #Get remaining cap:
#         self.cap_dem_df = self.cap_df - self.pot_dem_df
        
        #Define action and observation space:
        self.action_space = spaces.Discrete(int(number_of_actions)) #TODO can change later if increase the number of shifting slot
        #self.observation_space = spaces.Box(low= -np.inf, high= np.inf, shape=(5, 11), dtype=np.float)
        self.observation_space = spaces.Box(low= -np.inf, high= np.inf, shape=(self.number_of_actions + 1,), dtype=float)
        

    def generate_scenario(self):
        #Generate requests:
#         number_of_request = 80
#         number_of_days = 14

#         time_slot = list(np.random.randint(low = 0, high=12, size=number_of_request))

#         start_date = list(np.random.randint(low = 0, high=7, size=number_of_request))

#         num_of_weeks = list(np.random.randint(low = 1, high=3, size=number_of_request))
        
    
        #Change parameters here:
#         self.number_of_request = number_of_request
#         self.number_of_days = number_of_days

        #Generate time slot, 2 peak time slots are 72 and 216 (TODO to be changed accordingly later):
        ts_72 = get_truncated_normal(mean=72, sd=12, low=0, upp=287).rvs(int(round(self.number_of_request/2)))
        ts_72 = np.round(ts_72)

        ts_216 = get_truncated_normal(mean=216, sd=12, low=0, upp=287).rvs(int(round(self.number_of_request/2)))
        ts_216 = np.round(ts_216)

        ts_arr = np.concatenate((ts_72, ts_216))
        ts_arr = ts_arr.astype(int)

        #Generate start date:
        start_date_arr = np.random.randint(low = 0, high=6, size=self.number_of_request) #This is only for more than 150 days only
        start_date_arr = start_date_arr*30
        #start_date_arr = np.random.randint(low = 0, high=self.number_of_days, size=self.number_of_request)

        #Generate number of weeks:
        _max_day = np.full(self.number_of_request, self.number_of_days - 1)

        _remaining_days_arr = _max_day - start_date_arr

        _max_num_of_weeks_arr = _remaining_days_arr // 7

        # num_of_weeks_arr = np.random.randint(1, _max_num_of_weeks_arr + 2)

        if (_max_num_of_weeks_arr + 2 > 1).all():
            num_of_weeks_arr = np.random.randint(1, _max_num_of_weeks_arr + 2, size=_max_num_of_weeks_arr.shape)
        else:
            # Handle the condition where you cannot generate a random number in the given range.
            # For instance, set a default value or adjust _max_num_of_weeks_arr appropriately.
            num_of_weeks_arr = np.ones_like(_max_num_of_weeks_arr)


        
        #Generate index for requests:
        
        index_arr = np.array(list(range(self.number_of_request)))
        
        #Get date sequence:

        date_seq_arr = get_date_seq(start_date_arr, num_of_weeks_arr)
        
        #Generate status cap:
        
        status_cap_arr = np.full((self.number_of_request,), 0)

        # self.req_arr = np.stack((index_arr, ts_arr, start_date_arr, num_of_weeks_arr, date_seq_arr, status_cap_arr), axis=1)
        # Step 1: Define dtypes
        dtypes = [
            ('index', int),
            ('ts', float),
            ('start_date', int),
            ('num_of_weeks', int),
            ('date_seq', 'O'),  # Object, can hold list or array
            ('status_cap', int)
        ]

        # Step 2: Initialize structured array
        combined_array = np.zeros(len(index_arr), dtype=dtypes)

        # Step 3: Assign values
        combined_array['index'] = index_arr
        combined_array['ts'] = ts_arr
        combined_array['start_date'] = start_date_arr
        combined_array['num_of_weeks'] = num_of_weeks_arr
        combined_array['date_seq'] = date_seq_arr
        combined_array['status_cap'] = status_cap_arr

        self.req_arr = combined_array
        
        #Generate capacity:
        self.cap_arr = np.full((288, self.number_of_days), 20)
        
        
        #Create final_sched:
        self.final_sched_arr = self.req_arr.copy()
        
        #Get potential demand: #Check again the function here (turn to array)
        self.pot_dem_arr = self.get_initial_pot_dem()
        
        #Get remaining cap:
        self.cap_dem_arr = self.cap_arr - self.pot_dem_arr
        
    def get_date_seq(row):
        return(list(range(row['start_date'], row['start_date'] + row['num_of_weeks']*7, 7)))
    
    def check_in_date_seq(row, value):
        if value in row['date_seq']:
            return True
        else:
            return False
    
    def get_initial_pot_dem(self): #Replace req_df to req_df_update to update pot_dem_df #To be replaced with final_sched
        pot_dem_arr = np.full((288, self.number_of_days), 0)
        #TODO: increase speed
        for i in range(len(self.final_sched_arr)):
            _time_slot = int(self.final_sched_arr['ts'][i])
            _date_seq = self.final_sched_arr['date_seq'][i]
            pot_dem_arr[_time_slot, _date_seq] += 1
        return pot_dem_arr    


    #To be updated:            
    def update_pot_dem(self, action):
        if not self.check_outbound(action):
            if action != self.do_nothing_action:
                _time_slot = int(self.final_sched_arr['ts'][i])
                _date_seq = self.final_sched_arr['date_seq'][i]
                self.pot_dem_arr[_time_slot, _date_seq] -= 1
                self.pot_dem_arr[_time_slot + action - self.do_nothing_action, _date_seq] += 1       
            
    def update_status_capacity(self):
        #Them cot cap_status o init:
        for i in range(len(self.final_sched_arr)):
            _time_slot = int(self.final_sched_arr[i,1])
            _date_seq = self.final_sched_arr[i,4]
            if all(x >= 0 for x in self.cap_dem_arr[_time_slot, _date_seq]):
                #print(self.cap_dem_arr[_time_slot, _date_seq])
                self.final_sched_arr[i,5] = 0
            else:
                self.final_sched_arr[i,5] = 1
        
    def get_req(self):
            
        #Them cot status capacity of req trong req_df and update moi khi take action, chu y la con lien quan toi req khac
        _chosen_req = self.final_sched_arr[self.final_sched_arr[:,5] == 1]
        #print(_chosen_req)
        #print(self.final_sched_arr)
        #print(np.unique(env.cap_dem_arr))
        self.chosen_req_arr = random.choice(_chosen_req)
        #return self.chosen_req_arr
           
    def check_outbound(self):
        pass
    
    
    #TODO: change to 288 when changing time slot
    def _take_action(self, action):
        if not self.check_outbound(action):
            self.final_sched_arr[self.chosen_req_arr[0],1] += action - self.do_nothing_action
            
    #Additional func:
    
#     def full_obs(_cap_dem_chosen_req):
#         _obs_min_arr = np.full((288+5*2, ), 0)
#         for i in range(len(_cap_dem_chosen_req)):
#             _obs_min_arr[i+5] = min(_cap_dem_chosen_req[i])
#         return _obs_min_arr
            
    
    def _next_observation(self):
        self.get_req()
        _time_slot = self.chosen_req_arr[1]
        _num_of_weeks = self.chosen_req_arr[3]
        _date_seq = self.chosen_req_arr[4]
        _cap_dem_chosen_req = self.cap_dem_arr[:, _date_seq].copy()
        #print(self.number_of_actions, type(self.number_of_actions))
        _obs_min_arr = full_obs(_cap_dem_chosen_req, self.number_of_actions)
        _obs_time_slot_related = list(range(_time_slot, _time_slot + self.number_of_actions, 1))
        _cap_dem_obs = _obs_min_arr[_obs_time_slot_related]

        self.obs = np.append(_cap_dem_obs, _num_of_weeks)
        
        return self.obs
    
    def step(self, action):
        self._take_action(action)
        self.update_pot_dem(action)
        #Update cap_dem_df:
        self.cap_dem_arr = self.cap_arr - self.pot_dem_arr
        self.update_status_capacity()
        
        
        #Reward part: (TODO change to 288), must add reward for violating capacity
        local_reward = 0
        if self.check_outbound(action):
            local_reward = -1
        else:
            local_reward = 0.1*(-abs(action-self.do_nothing_action)*0.5*self.chosen_req_arr[3]) #TODO change if increase number of actions
#         print('Local reward is: \n')
#         print(local_reward)
        
        reward_action = self.reward_action(action) #Must put before next obs
        
        self.num_step += 1
        done = False
        
        _cap_dem_violate = self.cap_dem_arr[self.cap_dem_arr < 0]
        if _cap_dem_violate.size == 0:
            done = True
            obs = np.zeros((self.number_of_actions + 1,)) #Change to a positive number
            global_reward = 100 #TODO change according to the number of req
            
        elif self.num_step == len(self.req_arr)*5: #TODO parameter
            done = True
            obs = np.zeros((self.number_of_actions + 1,))
            #_cap_dem_check = _cap_dem_check.values.flatten()
            _violate_num = self.cap_dem_arr[self.cap_dem_arr < 0]
            global_reward = -sum(abs(_violate_num))*10
            
        else:
            global_reward = 0
            obs = self._next_observation()
        
        reward_time_step = -0.5
            
        reward = local_reward + global_reward + reward_action + reward_time_step
        reward = float(reward)
    
        return obs, reward, done, {}
        
    def reward_action(self, action):
        _time_slot = self.chosen_req_arr[1]
        
        _num_of_weeks = self.chosen_req_arr[3]

        _date_seq = self.chosen_req_arr[4]
        
        _cap_dem_chosen_req = self.cap_dem_arr[:,_date_seq].copy()
        
        
        if self.check_outbound(action):
            reward_action = 0
        else:
            if action == self.do_nothing_action:
                _cap_dem_inital = _cap_dem_chosen_req[_time_slot]
            else:
                _cap_dem_inital = _cap_dem_chosen_req[_time_slot] - 1 #Since we have update demand
            #print("Cap dem initial:", _cap_dem_inital)
            _cap_dem_new_slot = _cap_dem_chosen_req[_time_slot + action - self.do_nothing_action]
            #print("Cap dem new:", _cap_dem_new_slot)
            if len(_cap_dem_inital[_cap_dem_inital < 0]) == 0:
                initial_under = True
            else:
                initial_under = False

            if len(_cap_dem_new_slot[_cap_dem_new_slot < 0]) == 0:
                new_slot_under = True
            else:
                new_slot_under = False

            if initial_under and new_slot_under:
                reward_action = 0

            elif initial_under and not new_slot_under:  #TODO update to -5*num_of_weeks
                #reward_action = -5
                reward_action = -(self.number_of_actions*_num_of_weeks*0.1)/4

            elif not initial_under and new_slot_under:
                #reward_action = 5
                reward_action = (self.number_of_actions*_num_of_weeks*0.1)/4

            else:
                reward_action = 0
        
        #print('Reward action:\n', reward_action)
        return reward_action
    
    def check_outbound(self, action):
        _time_slot = self.chosen_req_arr[1]
        if ((_time_slot + (action - self.do_nothing_action)) < 0) or ((_time_slot + (action - self.do_nothing_action)) > 287):
            outbound = True
        else:
            outbound = False
        #print('Time slot:\n', _id)
        #print('Action:\n', action)
        #print('Outbound:\n', outbound)
        return outbound
        
    
    #TODO change to 288
    def check(self):
        #TODO change according to capacity
        if sum(sum(self.pot_dem_arr)) == sum(self.req_arr[:,3]):
            print('Pot_dem_arr passed!')
        else:
            print('+++++++++++++++++++++++++')
            raise ValueError('Check Pot_dem_arr!')

        if sum(sum(self.cap_dem_arr)) == 14*288*2 - sum(self.req_arr[:,3]):
            print('Pot_dem_df passed!')
        else:
            print('+++++++++++++++++++++++++')
            raise ValueError('Check pot_dem_df!')
            
    #TODO: Implement for valuation:
    def check_num_violate(self):
        pass
    
    def reset(self):
        #done = False
        
        self.generate_scenario()
        _cap_dem_flat = self.cap_dem_arr.flatten()
        #Test lai so lan chay while o day, neu ko thi dung pattern (distribution) de chac chan over:
        while min(_cap_dem_flat) >= 0:
            self.generate_scenario()
            _cap_dem_flat = self.cap_dem_arr.flatten()
        self.update_status_capacity()
        print('Number of violation: ', len(_cap_dem_flat[_cap_dem_flat < 0]))
        
        #print('Min cap dem flat:\n', min(_cap_dem_flat))

        #Check number of training step:
        self.num_step = 0
        #print('Initial sched:\n', env.req_df)
        #print('Initial cap dem:\n', env.cap_dem_df)
        
        
        return self._next_observation()

    #Add them eval schedule delay among requests (min, max, mean schedule delay)
    def eval(self):
        _initial_pot_dem_arr = np.full((288, self.number_of_days), 0)
        for i in range(len(self.req_arr)):
            _time_slot = int(self.req_arr[i,1])
            _date_seq = self.req_arr[i,4]
            _initial_pot_dem_arr[_time_slot, _date_seq] += 1
        _initial_cap_dem = self.cap_arr - _initial_pot_dem_arr
        
        _initial_cap_dem_flat = _initial_cap_dem.flatten()
        _initial_violate = len(_initial_cap_dem_flat[_initial_cap_dem_flat < 0])
        print('Initial violation is: ', _initial_violate)
        _final_cap_dem_flat = self.cap_dem_arr.flatten()
        _final_violate = len(_final_cap_dem_flat[_final_cap_dem_flat < 0])
        print('Final violation is: ', _final_violate)
        _total_sched_delay = sum(abs(self.req_arr[:,1] - self.final_sched_arr[:,1])* self.req_arr[:,3])
        print('Total schedule delay is: ', _total_sched_delay)
        _max_shift = max(abs(self.req_arr[:,1] - self.final_sched_arr[:,1]))
        print('Max shift: ', _max_shift)
        return _initial_violate, _final_violate, _total_sched_delay, _max_shift
        
        
    def get_scenario(self):
        return self.req_arr, self.cap_arr
        

JY's code to test and visualize SARL

In [None]:
# Training the agent: Q-learning algorithm

import numpy as np

class QLearningAgent:
    def __init__(self, alpha, gamma, n_actions):
        self.alpha = alpha
        self.gamma = gamma
        self.n_actions = n_actions
        self.q_table = np.zeros(n_actions)

    def choose_action(self, observation, epsilon):
        if np.random.random() < epsilon:
            return np.random.choice(self.n_actions)
        return np.argmax(self.q_table)

    def learn(self, old_observation, reward, new_observation, action):
        predict = self.q_table[action]
        target = reward + self.gamma * np.max(self.q_table)
        self.q_table[action] += self.alpha * (target - predict)


In [None]:
# Training loop
def train_agent(env, agent, n_epochs, epsilon=1.0, epsilon_decay=0.995, min_epsilon=0.01):
    rewards = []
    for i in range(n_epochs):
        total_reward = 0
        obs = env.reset()
        done = False
        while not done:
            action = agent.choose_action(obs, epsilon)
            new_obs, reward, done, info = env.step(action)
            agent.learn(obs, reward, new_obs, action)
            total_reward += reward
            obs = new_obs
        rewards.append(total_reward)
        epsilon = max(min_epsilon, epsilon * epsilon_decay)
        print(f"Epoch {i + 1}/{n_epochs} completed. Accumulated reward: {total_reward:.2f}")
    return rewards


In [None]:
# Visualization
import matplotlib.pyplot as plt

def plot_rewards(rewards):
    plt.plot(rewards, label='Reward per epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Reward')
    plt.title('Training Convergence')
    plt.legend()
    plt.show()


In [None]:
# Execution
if __name__ == "__main__":
    # Initialize environment and agent
    env = SchedEnv(number_of_actions=5, number_of_request=10, number_of_days=200)
    agent = QLearningAgent(alpha=0.1, gamma=0.95, n_actions=5)
    
    # Train the agent
    rewards = train_agent(env, agent, n_epochs=1000)
    
    # Visualize the results
    plot_rewards(rewards)
