In [None]:
from stable_baselines3.common.env_util import make_vec_env
from wandb.integration.sb3 import WandbCallback
from sb3_contrib.ppo_mask import MaskablePPO
from benchmark import get_benchmarking_data
from env import MPSPEnv
import numpy as np
import torch
import wandb
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'sb3.ipynb'
os.environ["WANDB_SILENT"] = 'true'
wandb.login()

In [None]:
from torch import nn

In [None]:
test = nn.Sequential(
                nn.Flatten(),
                nn.Embedding(
                    7,
                    2
                ),
                nn.Flatten(),
)

In [None]:
test_tensor = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[4, 2, 3], [4, 5, 1]]])

In [None]:
square_tensor = torch.tensor([
    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
    [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
])
square_tensor

In [None]:
# Upper triangular matrix (without diagonal)
indeces  = torch.triu_indices(3, 3, offset=1)

In [None]:
# Extract upper triangular part of square_tensor. Apply along the batch dimension
square_tensor[:, indeces[0], indeces[1]]

In [None]:
test(test_tensor)

In [None]:
config = {
    # Environment
    'ROWS': 10,
    'COLUMNS': 4,
    'N_PORTS': 10,
    # Model
    'PI_LAYER_SIZES': [64, 128, 64],
    'VF_LAYER_SIZES': [64, 128, 64],
    # Training
    'TOTAL_TIMESTEPS': 4800000,
    'BATCH_SIZE': 128
}

In [None]:
run = wandb.init(
    project="PPO-SB3",
    entity="rl-msps",
    sync_tensorboard=True,
    name=f"N{config['N_PORTS']}_R{config['ROWS']}_C{config['COLUMNS']}",
    config=config,
    tags=["test"]
)

In [None]:
env = make_vec_env(
    lambda: MPSPEnv(
        config['ROWS'],
        config['COLUMNS'],
        config['N_PORTS']
    ),
    n_envs=8  # M2 with 8 cores
)

In [None]:
env = MPSPEnv(
        config['ROWS'],
        config['COLUMNS'],
        config['N_PORTS']
)

In [None]:
observation = env.reset()

In [None]:
import torch.nn as nn

In [None]:
conv1 = nn.Conv2d(
    in_channels=1,
    out_channels=3,
    kernel_size=3,
    stride=1,
    padding=1
)

In [None]:
input = torch.tensor(observation['bay_matrix'], dtype=torch.float32).unsqueeze(0)
input

In [None]:
input = conv1(input)
input

In [None]:
pool = nn.MaxPool2d(2, 2)
input = pool(input)
input

In [None]:
observation

In [None]:
policy_kwargs = {
    'activation_fn': torch.nn.ReLU,
    'net_arch': [{
        'pi': config['PI_LAYER_SIZES'],
        'vf': config['VF_LAYER_SIZES']
    }]
}

wandb_run_path = None

if wandb_run_path:
    model_file = wandb.restore('model.zip', run_path=wandb_run_path)
    model = MaskablePPO.load(
        model_file.name,
        env=env
    )
else:
    model = MaskablePPO(
        policy='MultiInputPolicy',
        env=env,
        batch_size=config['BATCH_SIZE'],
        verbose=0,
        tensorboard_log=f"runs/{run.id}",
        policy_kwargs=policy_kwargs
    )

In [None]:
model.learn(
    total_timesteps=config['TOTAL_TIMESTEPS'],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        model_save_freq=config['TOTAL_TIMESTEPS'] // 4,
    )
)

In [None]:
eval_data = get_benchmarking_data('rl-mpsp-benchmark/set_2')
eval_data = [
    e for e in eval_data if (
        e['R'] == config['ROWS'] and
        e['C'] == config['COLUMNS'] and
        e['N'] == config['N_PORTS']
    )
]

In [None]:
# Creating seperate env for evaluation
env = MPSPEnv(
    config['ROWS'],
    config['COLUMNS'],
    config['N_PORTS']
)

eval_rewards = []
# Negative because env returns negative reward for shifts
paper_rewards = [-e['paper_result'] for e in eval_data]
paper_seeds = [e['seed'] for e in eval_data]

for e in eval_data:
    total_reward = 0
    obs = env.reset(
        transportation_matrix=e['transportation_matrix']
    )
    done = False
    while not done:
        action, _ = model.predict(
            obs,
            action_masks=env.action_masks()
        )
        obs, reward, done, _ = env.step(action)
        total_reward += reward

    eval_rewards.append(total_reward)

eval = {
    'mean_reward': np.mean(eval_rewards),
    'mean_paper_reward': np.mean(paper_rewards),
    'rewards': eval_rewards,
    'paper_rewards': paper_rewards,
    'paper_seeds': paper_seeds
}
run.summary['evaluation_benchmark'] = eval

In [None]:
run.finish()