In [None]:
from ReplayBuffer import ReplayBuffer
from DQN_Solver import DQN_Solver
from env import MPSPEnv
from DQN import DQN
import numpy as np
import wandb
import torch
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.ipynb'
wandb.login()

In [None]:
config = {
    # Env
    'ROWS': 10,
    'COLUMNS': 10,
    'N_PORTS': 8,
    # Training
    'EPISODES': 1000,
    'LEARNING_RATE': 0.0001,
    'MEM_SIZE': 10000,
    'BATCH_SIZE': 64,
    'GAMMA': 0.95,
    'EXPLORATION_MAX': 1.0,
    'EXPLORATION_DECAY': 0.999,
    'EXPLORATION_MIN': 0.001,
    # Model
    'HIDDEN_SIZE_1': 1024,
    'HIDDEN_SIZE_2': 512,
}

In [None]:
env = MPSPEnv(
    config['ROWS'],
    config['COLUMNS'],
    config['N_PORTS']
)
# We flatten the observation space
config['OBSERVATION_SPACE'] = (
    np.prod(env.observation_space[0].shape) +
    np.prod(env.observation_space[1].shape)
)
config['ACTION_SPACE'] = env.action_space.n

In [None]:
wandb.init(
    project="Q-learning",
    entity="rl-msps",
    name=f"{config['ROWS']}x{config['COLUMNS']}_{config['N_PORTS']}-ports",
    config=config,
    tags=["test"]
)

In [None]:
ReplayBuffer = ReplayBuffer(
    mem_size=config['MEM_SIZE'],
    observation_space=config['OBSERVATION_SPACE'],
    batch_size=config['BATCH_SIZE']
)
DQN = DQN(
    input_size=config['OBSERVATION_SPACE'],
    output_size=config['ACTION_SPACE'],
    hidden_size_1=config['HIDDEN_SIZE_1'],
    hidden_size_2=config['HIDDEN_SIZE_2'],
    learning_rate=config['LEARNING_RATE']
)
agent = DQN_Solver(
    ReplayBuffer=ReplayBuffer,
    DQN=DQN,
    batch_size=config['BATCH_SIZE'],
    exploration_max=config['EXPLORATION_MAX'],
    gamma=config['GAMMA'],
    exploration_decay=config['EXPLORATION_DECAY'],
    exploration_min=config['EXPLORATION_MIN']
)


In [None]:
for i in range(1, config['EPISODES']):
    state, info = env.reset()
    state = np.concatenate((state[0].flatten(), state[1].flatten()))
    score = 0

    while True:
        action = agent.choose_action(state, info['mask'], env)
        state_, reward, done, info = env.step(action)
        state_ = np.concatenate((state_[0].flatten(), state_[1].flatten()))
        agent.memory.add(state, action, reward, state_, done)
        agent.learn()
        state = state_
        score += reward

        if done:
            wandb.log({
                "Episode Reward": score
            })
            break


In [None]:
torch.save(DQN.state_dict(), os.path.join(wandb.run.dir, "dqn.pt"))
wandb.finish()