In [2]:
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
from copy import deepcopy

from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy

### Prepare the dataset

In [3]:
tasks = pd.read_csv(r"complexity_by_hour.csv")
tasks['hourly_time'] = pd.to_datetime(tasks['hourly_time'])
tasks['time'] = (tasks['hourly_time'] - tasks['hourly_time'].iloc[0]).dt.total_seconds() / 3600 
tasks['complexity'] = tasks['task_complexity'] / 60
tasks = tasks[['time', 'complexity']]

tasks = tasks.to_dict(orient='records')
tasks[:5]

[{'time': 0.0, 'complexity': 0.05},
 {'time': 1.0, 'complexity': 0.016666666666666666},
 {'time': 9.0, 'complexity': 22.283333333333335},
 {'time': 10.0, 'complexity': 19.95},
 {'time': 11.0, 'complexity': 2.033333333333333}]

In [4]:
TRAIN_SIZE = int(len(tasks) * 0.9)

### Define the Environment

In [None]:
class TaskSchedulingEnv(gym.Env):
    def __init__(self, N_total_nodes, total_tasks):
        super(TaskSchedulingEnv, self).__init__()

        self.waiting_capacity = 10000  # Maximum waiting tasks
        self.N_total_nodes = N_total_nodes
        self.total_tasks = deepcopy(total_tasks)

        self.upcoming_tasks = self.total_tasks.copy()
        self.waiting_tasks = []
        self.executing_tasks = []
        self.available_nodes = self.N_total_nodes
        self.current_time = self.total_tasks[0]['time']

        self.average_waiting_time = 0
        self.num_executed_tasks = 0

        self.observation_space = spaces.Box(
                low=0, high=np.inf, shape=(1 + self.waiting_capacity,), dtype=np.int32
            )
        self.action_space = spaces.Discrete(self.N_total_nodes + 1)
        self.state = np.array([self.available_nodes] + [0] * self.waiting_capacity, dtype=np.int32)

    def reset(self, seed=42, options=None):
        super().reset(seed=seed)

        self.upcoming_tasks = self.total_tasks.copy()
        self.waiting_tasks = []
        self.executing_tasks = []
        self.available_nodes = self.N_total_nodes
        self.current_time = self.total_tasks[0]['time']

        padded_waiting_times = [0] * self.waiting_capacity
        state = np.array([self.available_nodes] + padded_waiting_times, dtype=np.int32)

        return state, {}

    def step(self, action):
        # Validate the action
        if action < 0 or action > self.available_nodes:
            return self.state, -1000, False, False, {}

        # Add new task to waiting list
        if self.upcoming_tasks and self.current_time == self.upcoming_tasks[0]['time']:
            self.upcoming_tasks[0]['waiting_time'] = 0
            self.waiting_tasks.append(self.upcoming_tasks.pop(0))

        # Update executing tasks
        for task in self.executing_tasks:
            task['complexity'] -= task['nodes_alloc']
            if task['complexity'] <= 0:
                self.available_nodes += task['nodes_alloc']
                self.average_waiting_time = (self.average_waiting_time*self.num_executed_tasks + task['waiting_time'])/(self.num_executed_tasks+1)
                self.num_executed_tasks += 1

        self.executing_tasks = [task for task in self.executing_tasks if task['complexity'] > 0]

        # Allocate nodes to the first waiting task
        if self.waiting_tasks and action > 0:
            self.waiting_tasks[0]['nodes_alloc'] = int(action)
            self.executing_tasks.append(self.waiting_tasks.pop(0))
            self.available_nodes -= action

        # Update waiting times
        for task in self.waiting_tasks:
            task['waiting_time'] += 1

        # Compute next state
        waiting_times = [task['waiting_time'] for task in self.waiting_tasks]
        padded_waiting_times = waiting_times + [0] * (self.waiting_capacity - len(waiting_times))

        state = np.array([self.available_nodes] + padded_waiting_times, dtype=np.int32)
        reward = -np.sum(waiting_times)
        done = len(self.upcoming_tasks) == 0 and len(self.waiting_tasks) == 0

        # print(f"Time: {self.current_time}. Allocated {action} nodes. Waiting tasks: {self.waiting_tasks}")

        self.current_time += 1
        return state, reward, done, False, {}

    def get_action_mask(self):
        # Generate a binary mask for valid actions
        mask = np.zeros(self.N_total_nodes + 1, dtype=np.int32)
        if self.available_nodes == 0:
            mask[0] = 1
        else:
          for i in range(1, self.available_nodes + 1):
              mask[i] = 1
        return mask
    
    def get_average_waiting_times(self):
        remaining_waiting_time = np.sum([np.ceil(task['complexity']/task['nodes_alloc']) for task in self.executing_tasks])
        return (self.average_waiting_time*self.num_executed_tasks + remaining_waiting_time)/(self.num_executed_tasks + len(self.executing_tasks))
      

### Train the Model

In [None]:
env = TaskSchedulingEnv(N_total_nodes=5, total_tasks=tasks[:10])
def action_mask(env: TaskSchedulingEnv):
    return env.get_action_mask()

wrapped_env = ActionMasker(env, action_mask)

model = MaskablePPO(MaskableActorCriticPolicy, wrapped_env, verbose=1)
model.learn(total_timesteps=100000)

### Inference

In [162]:
obs, info = wrapped_env.reset()

# Run inference loop
while True:
    # Get the current action mask
    action_masks = get_action_masks(wrapped_env)
    print(f"Time: {wrapped_env.env.current_time}. Available nodes: {wrapped_env.env.available_nodes}.")
    
    action, _ = model.predict(obs, action_masks=action_masks, deterministic=True)
    obs, reward, done, truncated, info = wrapped_env.step(action)

    print(f"Action: {action}. Reward: {reward}. Executing: {wrapped_env.env.executing_tasks}. Waiting: {wrapped_env.env.waiting_tasks}. Upcoming: {wrapped_env.env.upcoming_tasks}")
    
    # Exit the loop if the episode is finished
    if done or truncated:
        print(f"Average waiting time: {wrapped_env.env.get_average_waiting_times()}")
        break

Time: 0.0. Available nodes: 5.
Action: 1. Reward: -0.0. Executing: [{'time': 0.0, 'complexity': -7.95, 'waiting_time': 0, 'nodes_alloc': 1}]. Waiting: []. Upcoming: [{'time': 1.0, 'complexity': -3.9833333333333334, 'waiting_time': 0, 'nodes_alloc': 1}, {'time': 9.0, 'complexity': -6.716666666666665, 'waiting_time': 0, 'nodes_alloc': 4}, {'time': 10.0, 'complexity': -4.050000000000001, 'waiting_time': 0, 'nodes_alloc': 1}, {'time': 11.0, 'complexity': -5.966666666666667, 'waiting_time': 0, 'nodes_alloc': 4}, {'time': 12.0, 'complexity': -3.9833333333333334, 'waiting_time': 0, 'nodes_alloc': 1}, {'time': 13.0, 'complexity': 5019.266666666666, 'waiting_time': 0, 'nodes_alloc': 2}, {'time': 14.0, 'complexity': 465956.55, 'waiting_time': 0, 'nodes_alloc': 3}, {'time': 15.0, 'complexity': -1.849999999999909, 'waiting_time': 935, 'nodes_alloc': 2}, {'time': 16.0, 'complexity': 6754.55, 'waiting_time': 934, 'nodes_alloc': 1}]
Time: 1.0. Available nodes: 4.
Action: 1. Reward: -0.0. Executing: [

list