In [2]:
import gym
import pybullet
import pybullet_envs
from gym import wrappers
from datetime import datetime

# DR TRPO related files
from train_helper import *
from value import NNValueFunction
from utils import Logger
from dr_policy import DRPolicyKL, DRPolicyWass

# Discrete State Space - KL DR TRPO Policy
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [None]:
env_name = 'Taxi-v3'
pybullet.connect(pybullet.DIRECT)
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyKL(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
lam = 1
total_eps = 5000
batch_eps = 60
logger = Logger(logname=env_name + '_DR-KL_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        trajectories = run_policy(env, policy, batch_eps, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, lam)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        policy.update(observes, actions, advantages, disc_freqs)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True) 
logger.close()

In [None]:
for i in range(10):
    run_episode(env, policy, True)
    print('------------------------')

# Discrete State Space - Wasserstein DR TRPO Policy
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [3]:
env_name = 'Taxi-v3'
pybullet.connect(pybullet.DIRECT)
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyWass(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
lam = 1
total_eps = 5000
batch_eps = 60
logger = Logger(logname=env_name + '_DR-Wass_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        trajectories = run_policy(env, policy, batch_eps, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, lam)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        policy.update(observes, actions, advantages, disc_freqs, env_name)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True) 
logger.close()

Value Params -- h1: 10, h2: 7, h3: 5, lr: 0.00378
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.
200
200
200
200
192
200
200
200
200
200
168
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
84
200
200
190
200
200
Instructions for updating:
Use tf.cast instead.
***** Episode 60, Mean Return = -761.6, Mean Discounted Return = -40.6 *****
ExplainedVarNew: -7.52e-10
ExplainedVarOld: -1.29e-07
ValFuncLoss: 715


200
184
183
200
200
200
200
200
200
200
200
200
200
79
200
200
200
200
200
200
200
200
200
200
200
200
200
200
78
200
200
200
200
200
200
200
200
200
176
200
200
200
200
200
200
200
200
200
137
200
200
200
143
200
200
200
200
200
200
200
***** Episode 120, Mean Return = -726.5, Mean Discounted Return = -38.4 *****
ExplainedVarNew: -5.92e-11
ExplainedVarOld: -1.05e-09
ValFuncLoss: 236


116


200
10
200
200
23
200
200
200
34
200
200
200
200
11
14
46
29
200
200
200
200
200
200
200
***** Episode 1260, Mean Return = -239.4, Mean Discounted Return = -18.6 *****
ExplainedVarNew: -0.0188
ExplainedVarOld: -0.013
ValFuncLoss: 153


55
72
35
91
16
200
200
29
69
9
200
64
27
10
28
28
12
18
200
89
200
12
200
12
12
21
8
200
200
200
200
34
17
200
200
200
13
200
200
200
140
200
15
200
200
79
200
200
57
200
28
57
22
11
200
90
200
27
200
200
***** Episode 1320, Mean Return = -224.1, Mean Discounted Return = -18.1 *****
ExplainedVarNew: -0.0128
ExplainedVarOld: -0.0161
ValFuncLoss: 160


200
71
13
20
200
7
200
42
107
200
9
21
54
200
200
200
200
200
200
200
17
200
200
53
14
200
19
59
200
27
200
10
14
200
17
29
11
16
18
200
15
200
11
13
16
22
25
91
138
29
19
28
59
10
31
22
35
200
200
200
***** Episode 1380, Mean Return = -180.3, Mean Discounted Return = -16.7 *****
ExplainedVarNew: -0.0346
ExplainedVarOld: -0.0335
ValFuncLoss: 160


50
200
200
45
20
200
200
16
17
30
12
61
14
200
200
12
176
200

***** Episode 2640, Mean Return = -107.8, Mean Discounted Return = -8.9 *****
ExplainedVarNew: -0.0295
ExplainedVarOld: -0.0259
ValFuncLoss: 37


12
34
21
24
11
11
200
200
21
13
200
123
200
24
15
26
21
200
11
51
200
22
17
11
12
13
16
32
9
29
19
13
200
15
200
200
56
19
9
200
100
11
200
38
14
200
16
10
16
200
200
14
10
21
200
16
47
14
200
200
***** Episode 2700, Mean Return = -66.2, Mean Discounted Return = -7.1 *****
ExplainedVarNew: -0.0515
ExplainedVarOld: -0.0486
ValFuncLoss: 52.2


200
13
200
200
200
16
19
13
11
200
88
13
45
10
134
14
200
200
73
12
20
13
200
15
8
200
146
12
14
16
200
200
14
200
200
14
13
200
21
200
200
13
103
112
200
200
200
200
14
200
61
200
26
20
200
23
22
23
35
200
***** Episode 2760, Mean Return = -94.5, Mean Discounted Return = -8.2 *****
ExplainedVarNew: -0.0343
ExplainedVarOld: -0.0268
ValFuncLoss: 33.5


200
200
13
36
32
17
14
200
8
145
12
17
10
13
200
12
16
114
131
200
18
16
27
200
34
12
200
34
120
200
15
22
12
11
9
200
15
84
200
7
200
10
37
10
200
52
26
14

12
25
200
17
200
19
14
200
18
18
68
12
13
23
10
200
15
200
15
16
13
15
135
***** Episode 4080, Mean Return = -72.0, Mean Discounted Return = -5.6 *****
ExplainedVarNew: -0.0477
ExplainedVarOld: -0.0416
ValFuncLoss: 33.3


200
25
16
200
200
18
200
200
13
10
34
8
15
200
9
8
200
16
12
27
78
11
11
13
200
200
200
69
19
12
18
200
200
14
23
40
200
23
10
200
15
200
17
26
200
21
175
109
7
30
16
13
200
48
12
33
9
15
200
200
***** Episode 4140, Mean Return = -67.4, Mean Discounted Return = -5.9 *****
ExplainedVarNew: -0.0577
ExplainedVarOld: -0.0594
ValFuncLoss: 35.4


15
88
45
16
8
11
200
17
12
16
200
15
200
200
200
59
32
12
200
200
102
16
200
200
12
200
15
17
200
111
11
10
14
17
200
67
17
10
31
12
18
200
13
17
200
21
200
200
16
12
21
13
200
14
200
18
13
40
13
11
***** Episode 4200, Mean Return = -64.6, Mean Discounted Return = -6.2 *****
ExplainedVarNew: -0.0505
ExplainedVarOld: -0.0486
ValFuncLoss: 41


200
24
200
11
24
200
200
74
14
151
15
26
200
32
9
19
25
200
200
17
32
12
11
200
12
22
200
2

In [9]:
for i in range(10):
    run_episode(env, policy, True)
    print('------------------------')

+---------+
|[34;1mR[0m: | : :G|
| : | : : |
| : : : : |
| | : |[43m [0m: |
|[35mY[0m| : |B: |
+---------+

+---------+
|[34;1mR[0m: | : :G|
| : | : : |
| : : :[43m [0m: |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (North)
+---------+
|[34;1mR[0m: | : :G|
| : | :[43m [0m: |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (North)
+---------+
|[34;1mR[0m: | : :G|
| : |[43m [0m: : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|[34;1mR[0m: | : :G|
| : |[43m [0m: : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|[34;1mR[0m: | : :G|
| : | : : |
| : :[43m [0m: : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (South)
+---------+
|[34;1mR[0m: | : :G|
| : | : : |
| :[43m [0m: : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|[34;1mR[0m: | : :G|
| :[43m [0m| : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (North)
+---------+
|[34;1mR[0m:[