In [1]:
from ReplayBuffer import ReplayBuffer
from DQN_Solver import DQN_Solver
from env import MPSPEnv
from DQN import DQN
import wandb
import numpy as np
import torch
import gym
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.ipynb'
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mhojmax[0m ([33mrl-msps[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
config = {
    # Env
    'ROWS': 3,
    'COLUMNS': 3,
    'N_PORTS': 5,
    # Training
    'EPISODES': 400,
    'LEARNING_RATE': 0.001,
    'ADAM_EPSILON': 0.01,
    'MEM_SIZE': 10000,
    'BATCH_SIZE': 100,
    'GAMMA': 0.95,
    'EXPLORATION_MAX': 1.0,
    'EXPLORATION_DECAY': 0.999,
    'EXPLORATION_MIN': 0.005,
    'EVAL_EPISODES': 50,
    'MAX_EPISODE_STEPS': 200,
    # Model
    'HIDDEN_SIZE': 64,
    'N_LAYERS': 2,
}


In [3]:
env = gym.make('CartPole-v1')
# MPSPEnv(
#     config['ROWS'],
#     config['COLUMNS'],
#     config['N_PORTS']
# )
# We flatten the observation space
config['OBSERVATION_SPACE'] = env.observation_space.shape[0]
#  (
#     np.prod(env.observation_space[0].shape) +
#     np.prod(env.observation_space[1].shape)
# )
config['ACTION_SPACE'] = env.action_space.n


In [4]:
wandb.init(
    project="Q-learning",
    entity="rl-msps",
    name="cartpole",
    # name=f"{config['ROWS']}x{config['COLUMNS']}_{config['N_PORTS']}-ports",
    config=config,
    tags=["test"]
)

In [5]:
ReplayBuffer = ReplayBuffer(
    mem_size=config['MEM_SIZE'],
    observation_space=config['OBSERVATION_SPACE'],
    batch_size=config['BATCH_SIZE']
)
DQN = DQN(
    input_size=config['OBSERVATION_SPACE'],
    output_size=config['ACTION_SPACE'],
    hidden_size=config['HIDDEN_SIZE'],
    n_layers=config['N_LAYERS'],
    learning_rate=config['LEARNING_RATE'],
    adam_epsilon=config['ADAM_EPSILON']
)
agent = DQN_Solver(
    ReplayBuffer=ReplayBuffer,
    DQN=DQN,
    batch_size=config['BATCH_SIZE'],
    exploration_max=config['EXPLORATION_MAX'],
    gamma=config['GAMMA'],
    exploration_decay=config['EXPLORATION_DECAY'],
    exploration_min=config['EXPLORATION_MIN']
)


In [6]:
agent.train()

for i in range(1, config['EPISODES']):
    state, info = env.reset()
    # state = np.concatenate((state[0].flatten(), state[1].flatten()))
    sum_reward = 0
    sum_loss = 0
    iter = 0

    while iter < config['MAX_EPISODE_STEPS']:
        action, _ = agent.choose_action(state, np.ones(2, dtype=np.int8) , env) #info['mask']
        state_, reward, done, _, info = env.step(action)
        # state_ = np.concatenate((state_[0].flatten(), state_[1].flatten()))
        agent.memory.add(state, action, reward, state_, done)
        sum_loss += agent.learn()
        state = state_
        sum_reward += reward
        iter += 1

        if done:
            break

    wandb.log({
        "Sum Episode Reward": sum_reward,
        "Avg. Episode Loss": sum_loss / iter,
        "Exploration Rate": agent.exploration_rate
    })

states tensor([[ 1.3962e-02, -3.5283e-03,  2.4769e-02,  1.1606e-04],
        [ 2.2651e-02,  1.8366e-01,  3.7733e-02, -2.9899e-01],
        [ 1.3638e-02, -3.9502e-01,  2.5681e-02,  6.1312e-01],
        [ 1.0382e-02,  3.5199e-02,  3.7234e-02,  3.4426e-02],
        [ 8.6221e-02, -2.3438e-01,  1.3861e-01,  1.0898e+00],
        [ 2.8617e-02, -1.6681e-01, -3.0768e-02,  3.0039e-01],
        [ 7.8768e-03,  3.4125e-02,  4.4695e-02,  5.8133e-02],
        [ 1.3857e-02, -1.5529e-01,  3.6611e-02,  2.7101e-01],
        [ 3.1241e-02,  1.7092e-01,  8.6152e-02,  1.6402e-01],
        [ 3.7287e-02,  1.8218e-01,  1.5061e-02, -2.6609e-01],
        [ 1.1322e-01,  5.7383e-01, -9.6943e-02, -8.8600e-01],
        [ 7.8618e-02,  9.6245e-01, -4.5193e-02, -1.4329e+00],
        [-1.0826e-02, -5.5631e-01,  3.2626e-02,  8.6955e-01],
        [ 8.5593e-03, -1.6161e-01,  4.5858e-02,  3.6458e-01],
        [ 8.6977e-02, -3.7791e-02,  1.2339e-01,  7.6102e-01],
        [ 1.3962e-02, -3.5283e-03,  2.4769e-02,  1.1606e-04],
 

  nonzero_finite_vals = torch.masked_select(


tensor([[ 1.3891e-02,  1.9123e-01,  2.4772e-02, -2.8465e-01],
        [ 2.6324e-02,  3.7822e-01,  3.1753e-02, -5.7954e-01],
        [ 5.7381e-03, -5.9049e-01,  3.7944e-02,  9.1378e-01],
        [ 1.1085e-02, -1.6044e-01,  3.7922e-02,  3.3862e-01],
        [ 8.1534e-02, -4.1327e-02,  1.6041e-01,  8.4367e-01],
        [ 2.5281e-02, -3.6148e-01, -2.4761e-02,  5.8321e-01],
        [ 8.5593e-03, -1.6161e-01,  4.5858e-02,  3.6458e-01],
        [ 1.0751e-02, -3.5092e-01,  4.2032e-02,  5.7501e-01],
        [ 3.4660e-02,  3.6471e-01,  8.9433e-02, -1.0029e-01],
        [ 4.0930e-02, -1.3157e-02,  9.7392e-03,  3.1309e-02],
        [ 1.2470e-01,  7.7013e-01, -1.1466e-01, -1.2075e+00],
        [ 9.7867e-02,  7.6792e-01, -7.3850e-02, -1.1546e+00],
        [-2.1952e-02, -7.5186e-01,  5.0017e-02,  1.1723e+00],
        [ 5.3271e-03,  3.2833e-02,  5.3149e-02,  8.6698e-02],
        [ 8.6221e-02, -2.3438e-01,  1.3861e-01,  1.0898e+00],
        [ 1.3891e-02,  1.9123e-01,  2.4772e-02, -2.8465e-01],
        

In [None]:
torch.save(DQN.state_dict(), os.path.join(wandb.run.dir, "dqn.pt"))
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.042 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.023933…

0,1
Avg. Episode Loss,▁████▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
Exploration Rate,█▇▆▆▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Sum Episode Reward,▂▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▂▁█▁▁▂▂▂

0,1
Avg. Episode Loss,0.69993
Exploration Rate,0.005
Sum Episode Reward,24.0
