In [None]:
%matplotlib inline

import re

import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 100
plt.rcParams['figure.figsize'] = (10, 6)

In [None]:
def get_rewards_and_qvalues(logfile):
    with open(logfile, 'r') as f:
        contents = f.read()
        
    rewards = re.findall(r'mean_reward:(.*?),', contents)
    qvalues = re.findall(r'mean_qvalue:(.*?),', contents)
    rewards = np.array(rewards, dtype='float32')
    qvalues = np.array(qvalues, dtype='float32')
    return rewards, qvalues


def get_results(logfiles):
    all_rewards = []
    all_qvalues = []
    for logfile in logfiles:
        rewards, qvalues = get_rewards_and_qvalues(logfile)
        all_rewards.append(rewards)
        all_qvalues.append(qvalues)

    # because all simulations run for the same number of episodes but
    # potentially different number of epochs we need to truncate down to the
    # fewest number of epochs
    min_length = min(map(len, all_rewards))
    for i in range(len(all_rewards)):
        all_rewards[i] = all_rewards[i][:min_length]
    min_length = min(map(len, all_qvalues))
    for i in range(len(all_qvalues)):
        all_qvalues[i] = all_qvalues[i][:min_length]

    all_rewards = np.array(all_rewards)
    all_qvalues = np.array(all_qvalues)
    return all_rewards, all_qvalues


def plot_experiment_data(all_data, ax=None, **kwds):
    min_data = np.quantile(all_data, 0.1, axis=0)
    max_data = np.quantile(all_data, 0.9, axis=0)
    med_data = np.median(all_data, axis=0)
    episodes = np.arange(1, len(med_data) + 1)
    
    if ax is None:
        fig, ax = plt.subplots(1, 1)
        ax.hold(True)

    ax.plot(episodes, med_data, **kwds)
    kwds.pop('label')
    ax.fill_between(episodes, min_data, max_data, alpha=0.2, **kwds)
    return ax


In [None]:
DIR = '../'

dqn_logfiles = [DIR + f'logs/cart-pole-dqn-{i}.log' for i in range(6)]
ddqn_logfiles = [DIR + f'logs/cart-pole-ddqn-{i}.log' for i in range(6)]
dueling_dqn_logfiles = [DIR + f'logs/cart-pole-dueling-dqn-{i}.log' for i in range(6)]

dqn_all_rewards, dqn_all_qvalues = get_results(dqn_logfiles)
ddqn_all_rewards, ddqn_all_qvalues = get_results(ddqn_logfiles)
dueling_dqn_all_rewards, dueling_dqn_all_qvalues = get_results(dueling_dqn_logfiles)

In [None]:
fig, ax_rewards = plt.subplots(1, 1)

_ = plot_experiment_data(dqn_all_rewards, ax=ax_rewards, color='C0', label='DQN')
_ = plot_experiment_data(ddqn_all_rewards, ax=ax_rewards, color='C1', label='DDQN')
_ = plot_experiment_data(dueling_dqn_all_rewards, ax=ax_rewards, color='C2', label='Dueling DQN')


ax_rewards.set_yscale('log')
ax_rewards.set_ylim(0.97, 1.0)
ax_rewards.legend()
ax_rewards.set_title('Cart-Pole')
ax_rewards.set_xlabel('Epoch (1k Updates)')
ax_rewards.set_ylabel('Reward')
ax_rewards.set_xlim((0,100))

fig.savefig('cart-pole-rewards.png', transparent=True)

In [None]:
fig, ax_qvalues = plt.subplots(1, 1)

_ = plot_experiment_data(dqn_all_qvalues, ax=ax_qvalues, color='C0', label='DQN')
_ = plot_experiment_data(ddqn_all_qvalues, ax=ax_qvalues, color='C1', label='DDQN')
_ = plot_experiment_data(dueling_dqn_all_qvalues, ax=ax_qvalues, color='C2', label='Dueling DQN')


ax_qvalues.set_yscale('log')
ax_qvalues.legend()
ax_qvalues.set_title('Cart-Pole')
ax_qvalues.set_xlabel('Epoch (1k Updates)')
ax_qvalues.set_ylabel('Avg. Action-Value')

ax_qvalues.set_xlim((0,100))

fig.savefig('cart-pole-qvalues.png', transparent=True)