In [1]:
from ReplayBuffer import ReplayBuffer
from DQN_Solver import DQN_Solver
from env import MPSPEnv
from DQN import DQN
import numpy as np
import wandb
import torch
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.ipynb'
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mhojmax[0m ([33mrl-msps[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
config = {
    # Env
    'ROWS': 3,
    'COLUMNS': 3,
    'N_PORTS': 5,
    # Training
    'EPISODES': 1000,
    'LEARNING_RATE': 0.00001,
    'MEM_SIZE': 10000,
    'BATCH_SIZE': 64,
    'GAMMA': 0.95,
    'EXPLORATION_MAX': 1.0,
    'EXPLORATION_DECAY': 0.999999,
    'EXPLORATION_MIN': 0.001,
    'EVAL_EPISODES': 50,
    'MAX_EPISODE_STEPS': 200,
    # Model
    'HIDDEN_SIZE': 128,
    'N_LAYERS': 6,
}

In [3]:
env = MPSPEnv(
    config['ROWS'],
    config['COLUMNS'],
    config['N_PORTS']
)
# We flatten the observation space
config['OBSERVATION_SPACE'] = (
    np.prod(env.observation_space[0].shape) +
    np.prod(env.observation_space[1].shape)
)
config['ACTION_SPACE'] = env.action_space.n

In [4]:
wandb.init(
    project="Q-learning",
    entity="rl-msps",
    name=f"{config['ROWS']}x{config['COLUMNS']}_{config['N_PORTS']}-ports",
    config=config,
    tags=["test"]
)

In [5]:
ReplayBuffer = ReplayBuffer(
    mem_size=config['MEM_SIZE'],
    observation_space=config['OBSERVATION_SPACE'],
    batch_size=config['BATCH_SIZE']
)
DQN = DQN(
    input_size=config['OBSERVATION_SPACE'],
    output_size=config['ACTION_SPACE'],
    hidden_size=config['HIDDEN_SIZE'],
    n_layers=config['N_LAYERS'],
    learning_rate=config['LEARNING_RATE']
)
agent = DQN_Solver(
    ReplayBuffer=ReplayBuffer,
    DQN=DQN,
    batch_size=config['BATCH_SIZE'],
    exploration_max=config['EXPLORATION_MAX'],
    gamma=config['GAMMA'],
    exploration_decay=config['EXPLORATION_DECAY'],
    exploration_min=config['EXPLORATION_MIN']
)


In [6]:
agent.train()

for i in range(1, config['EPISODES']):
    state, info = env.reset()
    state = np.concatenate((state[0].flatten(), state[1].flatten()))
    sum_reward = 0
    sum_loss = 0
    iter = 0

    while iter < config['MAX_EPISODE_STEPS']:
        action, _ = agent.choose_action(state, info['mask'], env)
        state_, reward, done, info = env.step(action)
        state_ = np.concatenate((state_[0].flatten(), state_[1].flatten()))
        agent.memory.add(state, action, reward, state_, done)
        sum_loss += agent.learn(iter == 0)
        state = state_
        sum_reward += reward
        iter += 1

        if done:
            break

    wandb.log({
        "Sum Episode Reward": sum_reward,
        "Avg. Episode Loss": sum_loss / iter,
        "Exploration Rate": agent.exploration_rate
    })


Rewards:
tensor([ 0.,  0., -1.,  0.,  0.,  0.,  0., -1.,  0., -1., -1., -1.,  0.,  0.,
         0., -1.,  0., -1.,  0., -1.,  0.,  0.,  0.,  0., -1.,  0., -1., -1.,
         0., -1., -1.,  0., -1.,  0.,  0.,  0., -1.,  0.,  0., -1.,  0., -1.,
         0.,  0., -1., -1.,  0.,  0., -1.,  0., -1., -1.,  0., -1.,  0., -1.,
         0.,  0., -1., -1., -1., -1.,  0.,  0.], device='mps:0')
Target:
tensor([ 0.,  0., -1.,  0.,  0.,  0.,  0., -1.,  0., -1., -1., -1.,  0.,  0.,
         0., -1.,  0., -1.,  0., -1.,  0.,  0.,  0.,  0., -1.,  0., -1., -1.,
         0., -1., -1.,  0., -1.,  0.,  0.,  0., -1.,  0.,  0., -1.,  0., -1.,
         0.,  0., -1., -1.,  0.,  0., -1.,  0., -1., -1.,  0., -1.,  0., -1.,
         0.,  0., -1., -1., -1., -1.,  0.,  0.], device='mps:0')
Predicted:
tensor([-1.1506e-03, -2.7169e-03, -9.8151e-01,  2.0850e-02, -1.3024e-03,
        -2.4603e-03, -8.3304e-03, -8.8881e-01, -6.0909e-03, -9.7924e-01,
        -1.0394e+00, -1.0826e+00, -6.1674e-03, -1.1971e-02, -3.6280e-03,

In [None]:
agent.eval()
sum_reward = 0

for i in range(1, config['EVAL_EPISODES']):
    state, info = env.reset()
    state = np.concatenate((state[0].flatten(), state[1].flatten()))

    while True:
        action, _ = agent.choose_action(state, info['mask'], env)
        state_, reward, done, info = env.step(action)
        state_ = np.concatenate((state_[0].flatten(), state_[1].flatten()))
        state = state_
        sum_reward += reward

        if done:
            break

wandb.summary['Avg. Eval Episode Reward'] = (
    sum_reward / config['EVAL_EPISODES']
)

In [None]:
torch.save(DQN.state_dict(), os.path.join(wandb.run.dir, "dqn.pt"))
wandb.finish()