In [1]:
%load_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline

In [2]:
import sys
import os

ROOT = os.path.dirname(os.getcwd())
sys.path.append(ROOT)

In [3]:
from env.env import Env
from agents.random_agent import RandomAgent
from utils.tools import set_random_seeds

from multiprocessing.dummy import Pool as ThreadPool
from copy import deepcopy

In [4]:
SEED = 42
set_random_seeds(SEED)

In [5]:
agent = RandomAgent()

In [6]:
N_GAMES = 50
T = 30

In [7]:
%%timeit -n15 -r20
env = Env()
env.reset()
envs  = [deepcopy(env) for _ in range(N_GAMES)]
data = [data_ for env_ in envs for data_ in agent.collect_data(env_, T=T)]

311 ms ± 10.1 ms per loop (mean ± std. dev. of 20 runs, 15 loops each)


In [8]:
%%timeit -n15 -r20
env = Env()
data = []
for _ in range(N_GAMES):
    env.reset()
    data += agent.collect_data(env, T=T)

295 ms ± 3.13 ms per loop (mean ± std. dev. of 20 runs, 15 loops each)


In [9]:
%%timeit -n15 -r20
env = Env()
env.reset()
pool = ThreadPool(2)
envs  = [deepcopy(env) for _ in range(N_GAMES)]
data = pool.map(lambda x: agent.collect_data(x, T=T), envs)

336 ms ± 2.64 ms per loop (mean ± std. dev. of 20 runs, 15 loops each)


In [11]:
import numpy as np

In [44]:
t = 30
discount = 1.0
rewards = np.random.uniform(size=(t)).tolist()
actions = np.random.uniform(size=(t)).tolist()
action_masks = np.random.uniform(size=(t, 132)).tolist()
bootstrapped_state_values = np.random.uniform(size=(t)).tolist()

In [45]:
assert (t == len(rewards) == len(actions) == len(bootstrapped_state_values))
reversed_value_targets = []
reversed_advantages = []
value = 0.
for s in range(t):
    value = rewards[t - s - 1] + discount * value
    reversed_value_targets.append(value)
    reversed_advantages.append(value - bootstrapped_state_values[t - s - 1])

In [46]:
assert (t == len(rewards) == len(actions) == len(bootstrapped_state_values))

value = 0.
all_rewards = []
values = []
advantages = []
all_actions = []
all_action_masks = []
for s in range(t):
    all_rewards.append(rewards[t - s - 1])
    value = rewards[t - s - 1] + discount * value
    values.append(value)
    advantage = value - bootstrapped_state_values[t - s - 1]
    advantages.append(advantage)
    all_actions.append(actions[t - s - 1])
    all_action_masks.append(action_masks[t - s - 1])

In [53]:
assert values == reversed_value_targets
assert advantages == reversed_advantages
assert all_rewards[::-1] == rewards

In [56]:
type(np.array([1, 2, 3])) == np.ndarray

True

In [57]:
type(np.array([1, 2, 3])) == np.array

False

In [41]:
rewards

[0.635853817507843,
 0.44521484561352576,
 0.23877797788909838,
 0.8520482631995127,
 0.8552880424022111,
 0.9677533035051455,
 0.9499621470491301,
 0.7465185482802892,
 0.02106913440851399,
 0.9777099483299123,
 0.529360432710409,
 0.10810423136647063,
 0.24177793265530423,
 0.7557356770267889,
 0.07837662123565281,
 0.24445739451823378,
 0.06082527544226324,
 0.7821811114146323,
 0.9380938804148642,
 0.5589552195221329,
 0.08732523503985157,
 0.8268912080411411,
 0.6757952098737677,
 0.2928408621380427,
 0.7204799201718771,
 0.6588402458716641,
 0.29180737051434713,
 0.055586198316447955,
 0.2838405777148374,
 0.15728747744118965]

In [43]:
all_rewards[::-1]

[0.635853817507843,
 0.44521484561352576,
 0.23877797788909838,
 0.8520482631995127,
 0.8552880424022111,
 0.9677533035051455,
 0.9499621470491301,
 0.7465185482802892,
 0.02106913440851399,
 0.9777099483299123,
 0.529360432710409,
 0.10810423136647063,
 0.24177793265530423,
 0.7557356770267889,
 0.07837662123565281,
 0.24445739451823378,
 0.06082527544226324,
 0.7821811114146323,
 0.9380938804148642,
 0.5589552195221329,
 0.08732523503985157,
 0.8268912080411411,
 0.6757952098737677,
 0.2928408621380427,
 0.7204799201718771,
 0.6588402458716641,
 0.29180737051434713,
 0.055586198316447955,
 0.2838405777148374,
 0.15728747744118965]

In [47]:
values

[0.7942045582052102,
 0.9419061569109328,
 1.682318779006784,
 2.3606379366838346,
 3.3276174546425463,
 3.354543021458425,
 3.6243849432687423,
 4.3048690500014875,
 4.977310659138583,
 5.065385982866573,
 5.973637472714793,
 6.376095173995536,
 7.299867215336802,
 7.635378623955518,
 7.756462586476523,
 8.24479770426585,
 9.158132117029377,
 9.857428428915778,
 10.301420277930266,
 10.787209731651204,
 11.560996333757279,
 12.353970826722254,
 12.906607377100158,
 13.374277848570008,
 13.915457199469612,
 13.925758516334371,
 14.093474467180169,
 14.941078313238835,
 15.224136909432545,
 15.829276833961378]

In [48]:
reversed_advantages

[0.3477539930040401,
 0.7191167240703181,
 0.7244314023958224,
 1.8353445711003273,
 3.008307041848842,
 2.842033593565865,
 2.7014643349643594,
 4.276612142582398,
 4.414043471452351,
 4.10076963136792,
 5.904000612001189,
 5.839087600920779,
 6.769270410671908,
 6.695187148653584,
 6.788804271157968,
 7.619855924122331,
 8.537763465955718,
 9.374559888919816,
 10.166714281820065,
 9.832368275894048,
 11.526938813150199,
 11.872516440691946,
 12.539229015637567,
 12.460587342202448,
 13.22642663383443,
 13.901897443432361,
 13.88826828706658,
 14.545548835717996,
 14.816311487567274,
 15.576701825970288]

In [49]:
advantages

[0.3477539930040401,
 0.7191167240703181,
 0.7244314023958224,
 1.8353445711003273,
 3.008307041848842,
 2.842033593565865,
 2.7014643349643594,
 4.276612142582398,
 4.414043471452351,
 4.10076963136792,
 5.904000612001189,
 5.839087600920779,
 6.769270410671908,
 6.695187148653584,
 6.788804271157968,
 7.619855924122331,
 8.537763465955718,
 9.374559888919816,
 10.166714281820065,
 9.832368275894048,
 11.526938813150199,
 11.872516440691946,
 12.539229015637567,
 12.460587342202448,
 13.22642663383443,
 13.901897443432361,
 13.88826828706658,
 14.545548835717996,
 14.816311487567274,
 15.576701825970288]

In [50]:
reversed_value_targets

[0.7942045582052102,
 0.9419061569109328,
 1.682318779006784,
 2.3606379366838346,
 3.3276174546425463,
 3.354543021458425,
 3.6243849432687423,
 4.3048690500014875,
 4.977310659138583,
 5.065385982866573,
 5.973637472714793,
 6.376095173995536,
 7.299867215336802,
 7.635378623955518,
 7.756462586476523,
 8.24479770426585,
 9.158132117029377,
 9.857428428915778,
 10.301420277930266,
 10.787209731651204,
 11.560996333757279,
 12.353970826722254,
 12.906607377100158,
 13.374277848570008,
 13.915457199469612,
 13.925758516334371,
 14.093474467180169,
 14.941078313238835,
 15.224136909432545,
 15.829276833961378]