In [1]:
!pip3 install mushroom_rl

Collecting mushroom_rl
[?25l  Downloading https://files.pythonhosted.org/packages/00/c6/8752704c263e8e20c471cf824a807370f67d70cc897ccd3ea8ec175cb2f4/mushroom_rl-1.4.0-py3-none-any.whl (171kB)
[K     |████████████████████████████████| 174kB 3.4MB/s 
Collecting pygame
[?25l  Downloading https://files.pythonhosted.org/packages/8e/24/ede6428359f913ed9cd1643dd5533aefeb5a2699cc95bea089de50ead586/pygame-1.9.6-cp36-cp36m-manylinux1_x86_64.whl (11.4MB)
[K     |████████████████████████████████| 11.4MB 9.5MB/s 
Installing collected packages: pygame, mushroom-rl
Successfully installed mushroom-rl-1.4.0 pygame-1.9.6


In [2]:
import matplotlib
matplotlib.use('Agg')

import numpy as np
from matplotlib import pyplot as plt
from joblib import Parallel, delayed

from mushroom_rl.algorithms.value import QLearning, DoubleQLearning,\
    WeightedQLearning, SpeedyQLearning, SARSA
from mushroom_rl.core import Core
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.callbacks import CollectDataset, CollectMaxQ
from mushroom_rl.utils.dataset import parse_dataset
from mushroom_rl.utils.parameters import ExponentialParameter

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


In [0]:
def experiment(algorithm_class, exp):
    np.random.seed()

    # MDP
    mdp = GridWorldVanHasselt()

    # Policy
    epsilon = ExponentialParameter(value=1, exp=.5, size=mdp.info.observation_space.size)
    pi = EpsGreedy(epsilon=epsilon)

    # Agent
    learning_rate = ExponentialParameter(value=1, exp=exp, size=mdp.info.size)
    algorithm_params = dict(learning_rate=learning_rate)
    agent = algorithm_class(mdp.info, pi, **algorithm_params)

    # Algorithm
    start = mdp.convert_to_int(mdp._start, mdp._width)
    collect_max_Q = CollectMaxQ(agent.approximator, start)
    collect_dataset = CollectDataset()
    callbacks = [collect_dataset, collect_max_Q]
    core = Core(agent, mdp, callbacks)

    # Train
    core.learn(n_steps=10000, n_steps_per_fit=1, quiet=True)

    _, _, reward, _, _, _ = parse_dataset(collect_dataset.get())
    max_Qs = collect_max_Q.get()

    return reward, max_Qs

In [6]:
n_experiment = 10

names = {1: '1', .8: '08', QLearning: 'Q', DoubleQLearning: 'DQ',
          WeightedQLearning: 'WQ', SpeedyQLearning: 'SPQ', SARSA: 'SARSA'}

#dir_name = 'TestResults/'
for e in [1, .8]:
    print('Exp: ', e)
    fig = plt.figure()
    plt.suptitle(names[e])
    legend_labels = []
    for a in [QLearning, DoubleQLearning, WeightedQLearning,
              SpeedyQLearning, SARSA]:
        print('Alg: ', names[a])
        out = Parallel(n_jobs=-1)(
            delayed(experiment)(a, e) for _ in range(n_experiment))
        r = np.array([o[0] for o in out])
        max_Qs = np.array([o[1] for o in out])

        r = np.convolve(np.mean(r, 0), np.ones(100) / 100., 'valid')
        max_Qs = np.mean(max_Qs, 0)

        print(max_Qs)
        print(r)
        #np.save(dir_name + names[a] + '_' + names[e] + '_r.npy', r)
        #np.save(dir_name + names[a] + '_' + names[e] + '_maxQ.npy', max_Qs)

        plt.subplot(2, 1, 1)
        plt.plot(r)
        plt.subplot(2, 1, 2)
        plt.plot(max_Qs)
        legend_labels.append(names[a])
    plt.legend(legend_labels)
    #fig.savefig(dir_name + 'test_' + names[e] + '.png')

Exp:  1
Alg:  Q
[ 7.         10.325      10.67625    ... 14.18160214 14.18167181
 14.18151943]
[-1.2   -1.222 -1.266 ... -1.181 -1.159 -1.093]
Alg:  DQ
[ 3.          3.75        3.875      ... -3.45778095 -3.45690245
 -3.45622817]
[-1.555 -1.533 -1.511 ... -0.515 -0.515 -0.464]
Alg:  WQ




[ 6.          6.849555    7.38471629 ... -1.5526372  -1.5503396
 -1.55245758]
[-0.135 -0.201 -0.245 ... -1.286 -1.242 -1.281]
Alg:  SPQ




[ 6.          6.475       5.95       ... 11.81982151 11.8205875
 11.8205875 ]
[-0.317 -0.339 -0.273 ... -0.719 -0.653 -0.587]
Alg:  SARSA




[7.         8.75       8.70208333 ... 5.69288794 5.69288794 5.72722249]
[-0.77  -0.858 -0.836 ... -0.528 -0.528 -0.479]
Exp:  0.8
Alg:  Q




[ 8.          9.55919991 11.42025855 ...  5.27350515  5.27230061
  5.27230061]
[-1.037 -1.103 -1.191 ... -0.717 -0.739 -0.751]
Alg:  DQ
[ 1.5         1.5         2.         ... -1.34103163 -1.34103163
 -1.33487594]
[-0.418 -0.357 -0.291 ... -0.171 -0.181 -0.11 ]
Alg:  WQ




[ 4.          5.          5.71158    ... -0.80510907 -0.80510907
 -0.80461336]
[-0.781 -0.83  -0.786 ... -0.012  0.017 -0.005]
Alg:  SPQ




[7.         4.42240027 5.23160018 ... 7.38108302 7.38108302 7.38237666]
[-0.496 -0.54  -0.496 ... -0.582 -0.692 -0.719]
Alg:  SARSA




[ 8.00000000e+00  8.00000000e+00  9.00000000e+00 ...  1.27498745e-02
  4.91865586e-03 -4.98180083e-04]
[-0.445 -0.494 -0.494 ...  0.065 -0.001 -0.018]


