In [12]:
from ReplayBuffer import ReplayBuffer
from DQN_Solver import DQN_Solver
from env import MPSPEnv
from DQN import DQN
import numpy as np
import torch
import wandb
import os

os.environ['WANDB_NOTEBOOK_NAME'] = 'reproduce.ipynb'
wandb.login()
api = wandb.Api()

In [2]:
run_path = "rl-msps/Q-learning/3psmd7fm"
files_to_restore = [
    'dqn.pt'
]
run = api.run(run_path)
files = run.files()

for file in files:
    if file.name in files_to_restore:
        file.download(replace=True, root='./saved_models')

config = run.config

In [3]:
env = MPSPEnv(
    config['ROWS'],
    config['COLUMNS'],
    config['N_PORTS']
)

In [4]:
DQN = DQN(
    input_size=config['OBSERVATION_SPACE'],
    output_size=config['ACTION_SPACE'],
    hidden_size=config['HIDDEN_SIZE'],
    n_layers=config['N_LAYERS'],
    learning_rate=config['LEARNING_RATE']
)
DQN.load('./saved_models/dqn.pt')
agent = DQN_Solver(
    ReplayBuffer=None,
    DQN=DQN,
    batch_size=None,
    exploration_max=None,
    gamma=None,
    exploration_decay=None,
    exploration_min=None
)

In [33]:
def get_writer(filename):
    # Clears file
    with open(filename, 'w') as f:
        pass
    
    def write_to_file(text):
        with open(filename, 'a') as f:
            f.write(text)
    return write_to_file


writer = get_writer('run.txt')

In [34]:
agent.eval()
sum_reward = 0

for i in range(1, config['EVAL_EPISODES']):
    writer(f"--- Episode {i}:\n")
    state, info = env.reset()
    writer(f"- Port {env.port}\n")
    writer(f"Bay:\n{state[0]}\n")
    writer(f"Transportation:\n{state[1]}\n\n\n")
    state = np.concatenate((state[0].flatten(), state[1].flatten()))

    while True:
        action, q_values = agent.choose_action(state, info['mask'], env)
        q_values = np.round(q_values.cpu().numpy(), 2)[0]
        state_, reward, done, info = env.step(action)
        writer(f"- Port {env.port}\n")
        writer(f"Bay:\n{state_[0]}\n")
        writer(f"Transportation:\n{state_[1]}\n")
        writer(f"Action: {action}\n")
        writer(f"Q-values: {q_values}\n")
        writer(f"Reward: {reward}\n\n")


        state_ = np.concatenate((state_[0].flatten(), state_[1].flatten()))
        state = state_
        sum_reward += reward

        if done:
            break

print(f'Average reward: {sum_reward / config["EVAL_EPISODES"]}')


Average reward: -0.86
