In [None]:
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.monitor import Monitor
from CustomEncoder import CustomCombinedExtractor
from wandb.integration.sb3 import WandbCallback
from sb3_contrib.ppo_mask import MaskablePPO
from env import MPSPEnv
from benchmark import get_benchmarking_data
import wandb
import torch
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'sb3.ipynb'
os.environ["WANDB_SILENT"] = 'true'
wandb.login()

In [None]:
config = {
    # Environment
    'ROWS': 5,
    'COLUMNS': 5,
    'N_PORTS': 7,
    # Model
    'EMBEDDING_SIZE': 10,
    'PI_LAYER_SIZES': [256, 256, 256],
    'VF_LAYER_SIZES': [256, 256, 256],
    # Training
    'TOTAL_TIMESTEPS': 400000,
    'START_LEARNING_RATE': 0.00007,
    'END_LEARNING_RATE': 0.000004,
    'BATCH_SIZE': 128
}

In [None]:
run = wandb.init(
    project="PPO-SB3",
    entity="rl-msps",
    sync_tensorboard=True,
    name=f"{config['ROWS']}x{config['COLUMNS']}_{config['N_PORTS']}-ports",
    config=config,
    tags=["test"]
)

In [None]:
env = MPSPEnv(
    config['ROWS'],
    config['COLUMNS'],
    config['N_PORTS']
)
env = Monitor(env)

In [None]:
def linear_schedule(start, end):
    """
    Linear learning rate schedule.

    :param initial_value: Initial learning rate.
    :return: schedule that computes
      current learning rate depending on remaining progress
    """

    def func(progress_remaining: float) -> float:
        """
        Progress will decrease from 1 (beginning) to 0.

        :param progress_remaining:
        :return: current learning rate
        """
        return start + progress_remaining * (end - start)

    return func

In [None]:
policy_kwargs = {
    'activation_fn': torch.nn.ReLU,
    'net_arch': [{
        'pi': config['PI_LAYER_SIZES'],
        'vf': config['VF_LAYER_SIZES']
    }],
    'features_extractor_class': CustomCombinedExtractor,
    'features_extractor_kwargs': {
        'n_ports': config['N_PORTS'],
        'embedding_size': config['EMBEDDING_SIZE']
    }
}

model = MaskablePPO(
    policy='MultiInputPolicy',
    env=env,
    batch_size=config['BATCH_SIZE'],
    verbose=1,
    tensorboard_log=f"runs/{run.id}",
    policy_kwargs=policy_kwargs,
    learning_rate=linear_schedule(
        start=config['START_LEARNING_RATE'],
        end=config['END_LEARNING_RATE']
    )
)

In [None]:
model.learn(
    total_timesteps=config['TOTAL_TIMESTEPS'],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
    )
)

In [None]:
run.finish()