In [None]:
from ReplayBuffer import ReplayBuffer
from DQN_Solver import DQN_Solver
from env import MPSPEnv
from DQN import DQN
import wandb
import numpy as np
import torch
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.ipynb'
torch.set_printoptions(sci_mode=False)
wandb.login()

In [None]:
config = {
    # Env
    'ROWS': 3,
    'COLUMNS': 3,
    'N_PORTS': 5,
    # Training
    'EPISODES': 2000,
    'LEARNING_RATE': 0.001,
    'ADAM_EPSILON': 0.01,
    'MEM_SIZE': 10000,
    'BATCH_SIZE': 100,
    'GAMMA': 0.95,
    'EXPLORATION_MAX': 1.0,
    'EXPLORATION_DECAY': 0.999,
    'EXPLORATION_MIN': 0.005,
    'EVAL_EPISODES': 50,
    'MAX_EPISODE_STEPS': 200,
    'TARGET_UPDATE_FREQ': 500,
    'GRADIENT_CLIP': 5,
    # Model
    'HIDDEN_SIZE': 256,
    'N_LAYERS': 4,
}


In [None]:
env = MPSPEnv(
    config['ROWS'],
    config['COLUMNS'],
    config['N_PORTS']
)
# We flatten the observation space
config['OBSERVATION_SPACE'] = (
    np.prod(env.observation_space[0].shape) +
    np.prod(env.observation_space[1].shape)
)
config['ACTION_SPACE'] = env.action_space.n


In [None]:
wandb.init(
    project="Q-learning",
    entity="rl-msps",
    name=f"{config['ROWS']}x{config['COLUMNS']}_{config['N_PORTS']}-ports",
    config=config,
    tags=["test"]
)

In [None]:
ReplayBuffer = ReplayBuffer(
    mem_size=config['MEM_SIZE'],
    observation_space=config['OBSERVATION_SPACE'],
    batch_size=config['BATCH_SIZE']
)
current_DQN = DQN(
    input_size=config['OBSERVATION_SPACE'],
    output_size=config['ACTION_SPACE'],
    hidden_size=config['HIDDEN_SIZE'],
    n_layers=config['N_LAYERS'],
    learning_rate=config['LEARNING_RATE'],
    adam_epsilon=config['ADAM_EPSILON']
)
target_DQN = DQN(
    input_size=config['OBSERVATION_SPACE'],
    output_size=config['ACTION_SPACE'],
    hidden_size=config['HIDDEN_SIZE'],
    n_layers=config['N_LAYERS'],
    learning_rate=config['LEARNING_RATE'],
    adam_epsilon=config['ADAM_EPSILON']
)
agent = DQN_Solver(
    ReplayBuffer=ReplayBuffer,
    DQN=current_DQN,
    target_DQN=target_DQN,
    batch_size=config['BATCH_SIZE'],
    exploration_max=config['EXPLORATION_MAX'],
    gamma=config['GAMMA'],
    exploration_decay=config['EXPLORATION_DECAY'],
    exploration_min=config['EXPLORATION_MIN'],
    target_update_freq=config['TARGET_UPDATE_FREQ'],
    gradient_clip=config['GRADIENT_CLIP']
)

In [None]:
agent.train()

for i in range(1, config['EPISODES']):
    state, info = env.reset()
    state = np.concatenate((state[0].flatten(), state[1].flatten()))
    sum_reward = 0
    sum_loss = 0
    iter = 0

    while iter < config['MAX_EPISODE_STEPS']:
        action, _ = agent.choose_action(state, info['mask'], env)
        state_, reward, done, info = env.step(action)
        state_ = np.concatenate((state_[0].flatten(), state_[1].flatten()))
        agent.memory.add(state, action, reward, state_, done)
        sum_loss += agent.learn()
        state = state_
        sum_reward += reward
        iter += 1

        if done:
            break

    wandb.log({
        "Sum Episode Reward": sum_reward,
        "Avg. Episode Loss": sum_loss / iter,
        "Exploration Rate": agent.exploration_rate
    })

In [None]:
torch.save(current_DQN.state_dict(), os.path.join(wandb.run.dir, "dqn.pt"))
wandb.finish()