In this assignment you will learn how to apply the REINFORCE algorithm within the OpenAI Gym environment. Make sure OpenAI gym is installed on your machine. Now let's import some relevant packages.

In [1]:
import gym
from gym import wrappers, logger
import matplotlib.pyplot as plt
import tqdm
import numpy as np
from chainer import Chain
import chainer.links as L
import chainer.functions as F
from chainer.optimizers import Adam
from chainer import Variable

%matplotlib inline

We will make use of the classic CartPole environment provided by OpenAI Gym. Figure out what the details of this environment are.

https://github.com/openai/gym/wiki/CartPole-v0

**Observation**

Type: Box(4)
Num 	Observation 	Min 	Max

0 	Cart Position 	-2.4 	2.4

1 	Cart Velocity 	-Inf 	Inf

2 	Pole Angle 	~ -41.8° 	~ 41.8°

3 	Pole Velocity At Tip 	-Inf 	Inf

[Cart Position, Cart Velocity, Pole Angle, Pole Velocity]

**Actions**

Type: Discrete(2)

Num 	Action

0 	Push cart to the left

1 	Push cart to the right

Note: The amount the velocity is reduced or increased is not fixed as it depends on the angle the pole is pointing. This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it.

**Reward**

Reward is 1 for every step taken, including the termination step

**Starting State**

All observations are assigned a uniform random value between ±0.05
Episode Termination

1. Pole Angle is more than ±12°
2. Cart Position is more than ±2.4 (center of the cart reaches the edge of the display)
3. Episode length is greater than 200

**Solved Requirements**

Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials.

In [2]:
env_id = 'CartPole-v0'

# You can set the level to logger.DEBUG or logger.WARN if you want to change the amount of output. INFO
logger.set_level(logger.DEBUG)

In [3]:
### TEST
env = gym.make('CartPole-v0')
for i_episode in range(2):
    print('\n>> episode', i_episode)
    observation = env.reset()
    
    # actions inside each episode
    for t in range(20):
#         env.render()
        print('observation',observation, end=' ')
        action = env.action_space.sample()
        print('action', action, end=' ')
        observation, reward, done, info = env.step(action)
        print('reward', reward)
        if done:
            print("\nEpisode finished after {} timesteps".format(t+1))
            break

INFO: Making new env: CartPole-v0

>> episode 0
observation [-0.0473875   0.0249805  -0.02103165  0.00051361] action 0 reward 1.0
observation [-0.04688789 -0.16983361 -0.02102137  0.28648731] action 1 reward 1.0
observation [-0.05028456  0.02558173 -0.01529163 -0.01275077] action 1 reward 1.0
observation [-0.04977293  0.22091961 -0.01554664 -0.31021889] action 0 reward 1.0
observation [-0.04535453  0.02602257 -0.02175102 -0.02247919] action 1 reward 1.0
observation [-0.04483408  0.22144959 -0.0222006  -0.32194467] action 1 reward 1.0
observation [-0.04040509  0.41688053 -0.0286395  -0.62154529] action 1 reward 1.0
observation [-0.03206748  0.61239047 -0.0410704  -0.9231088 ] action 1 reward 1.0
observation [-0.01981967  0.8080425  -0.05953258 -1.22841067] action 1 reward 1.0
observation [-0.00365882  1.0038779  -0.08410079 -1.53913562] action 1 reward 1.0
observation [ 0.01641874  1.19990495 -0.11488351 -1.85683343] action 1 reward 1.0
observation [ 0.04041684  1.39608583 -0.15202017 -

Let's define a baseline agent which just emits random actions.

In [4]:
class RandomAgent(object):
    """The world's simplest agent!"""

    def __init__(self, action_space):
        self.action_space = action_space

    def act(self, observation, reward, done):
        return self.action_space.sample()


Let's run the agent on the environment.

In [5]:
env = gym.make(env_id)
env.seed(0)
agent = RandomAgent(env.action_space)

episode_count = 1000
done = False
reward = 0
    
R0 = np.zeros(episode_count)

random_history = {}

for i in tqdm.trange(episode_count):

    ob = env.reset()
    temp_history = {'obs':[], 'act':[], 'reward':0}
    temp_history['obs'].append(ob)
    while True:

        action = agent.act(ob, reward, done)
        temp_history['act'].append(action)
        ob, reward, done, _ = env.step(action)
        temp_history['obs'].append(ob)
        R0[i] += reward
        
        if done:
            temp_history['reward'] = R0[i]
            random_history[i] = temp_history
            break

# Close the env and write monitor result info to disk
env.close()
print(max(R0))


 36%|███▌      | 361/1000 [00:00<00:00, 3605.44it/s]

INFO: Making new env: CartPole-v0


100%|██████████| 1000/1000 [00:00<00:00, 3822.26it/s]

114.0





In [6]:
for eps in random_history:
    print(eps, random_history[eps])
    print('---')

0 {'obs': [array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]), array([-0.04363321, -0.14877061,  0.01284913,  0.2758415 ]), array([-0.04660862,  0.04616568,  0.01836596, -0.01276126]), array([-0.04568531,  0.24101949,  0.01811073, -0.2995934 ]), array([-0.04086492,  0.04564414,  0.01211887, -0.00125416]), array([-0.03995204,  0.24059021,  0.01209378, -0.29008894]), array([-0.03514023,  0.43553764,  0.006292  , -0.57893321]), array([-0.02642948,  0.24032808, -0.00528666, -0.28427483]), array([-0.02162292,  0.04528193, -0.01097216,  0.00673604]), array([-0.02071728,  0.2405595 , -0.01083744, -0.28938844]), array([-0.01590609,  0.04559375, -0.0166252 , -0.00014314]), array([-0.01499422,  0.24095014, -0.01662807, -0.2980248 ]), array([-0.01017521,  0.43630513, -0.02258856, -0.5959052 ]), array([-0.00144911,  0.6317358 , -0.03450667, -0.8956169 ]), array([ 0.01118561,  0.8273082 , -0.05241901, -1.19894382]), array([ 0.02773177,  1.02306774, -0.07639788, -1.50758392]), array([ 0.048

68 {'obs': [array([-0.02513393, -0.02394553,  0.04248389,  0.02388029]), array([-0.02561284,  0.17054223,  0.0429615 , -0.25510161]), array([-0.02220199,  0.36502529,  0.03785946, -0.53393022]), array([-0.01490149,  0.16939193,  0.02718086, -0.22956244]), array([-0.01151365,  0.36411513,  0.02258961, -0.51354917]), array([-0.00423135,  0.55891177,  0.01231863, -0.79902879]), array([ 0.00694689,  0.36362301, -0.00366195, -0.50249629]), array([ 0.01421935,  0.16855286, -0.01371187, -0.21096965]), array([ 0.0175904 ,  0.36386815, -0.01793127, -0.50794621]), array([ 0.02486777,  0.16900338, -0.02809019, -0.22096755]), array([ 0.02824784,  0.36451535, -0.03250954, -0.52237724]), array([ 0.03553814,  0.56007944, -0.04295709, -0.82512449]), array([ 0.04673973,  0.36557052, -0.05945958, -0.54625576]), array([ 0.05405114,  0.17133218, -0.07038469, -0.27288397]), array([ 0.05747779,  0.36738417, -0.07584237, -0.58690983]), array([ 0.06482547,  0.17340176, -0.08758057, -0.31904896]), array([ 0.06

84 {'obs': [array([-0.0273527 ,  0.04917296, -0.0241096 ,  0.03781551]), array([-0.02636925,  0.2446322 , -0.02335329, -0.26237568]), array([-0.0214766 ,  0.04985125, -0.02860081,  0.02285084]), array([-0.02047958, -0.14484911, -0.02814379,  0.30637445]), array([-0.02337656,  0.05066234, -0.0220163 ,  0.00495022]), array([-0.02236331, -0.14413706, -0.0219173 ,  0.29060623]), array([-0.02524605, -0.33893974, -0.01610517,  0.57629691]), array([-0.03202485, -0.14359578, -0.00457923,  0.27858426]), array([-0.03489676,  0.05159119,  0.00099245, -0.01553943]), array([-0.03386494, -0.14354498,  0.00068166,  0.27745646]), array([-0.03673584,  0.05156724,  0.00623079, -0.01501139]), array([-0.03570449, -0.14364351,  0.00593056,  0.2796309 ]), array([-0.03857736,  0.05139334,  0.01152318, -0.01117566]), array([-0.0375495 ,  0.24634816,  0.01129967, -0.30020072]), array([-0.03262253,  0.44130724,  0.00529565, -0.58929865]), array([-0.02379639,  0.24611154, -0.00649032, -0.29495229]), array([-0.01

139 {'obs': [array([0.04864489, 0.00245833, 0.02213174, 0.00834026]), array([ 0.04869406, -0.19297391,  0.02229854,  0.30792303]), array([ 0.04483458, -0.38840638,  0.028457  ,  0.60755408]), array([ 0.03706645, -0.5839144 ,  0.04060808,  0.90906261]), array([ 0.02538817, -0.38936492,  0.05878934,  0.62941446]), array([ 0.01760087, -0.58525586,  0.07137762,  0.94001735]), array([ 0.00589575, -0.39116485,  0.09017797,  0.67058936]), array([-0.00192755, -0.58741711,  0.10358976,  0.99024802]), array([-0.01367589, -0.78376155,  0.12339472,  1.31358683]), array([-0.02935112, -0.59039849,  0.14966646,  1.06193364]), array([-0.04115909, -0.39754076,  0.17090513,  0.81971887]), array([-0.04910991, -0.59453786,  0.18729951,  1.16091137]), array([-0.06100066, -0.4022826 ,  0.21051773,  0.93231677])], 'act': [0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1], 'reward': 12.0}
---
140 {'obs': [array([-0.01115444, -0.04525447,  0.02523291, -0.04214219]), array([-0.01205953, -0.24072899,  0.02439006,  0.25839395]

165 {'obs': [array([-0.02563561,  0.01538109,  0.01392455,  0.03274226]), array([-0.02532799,  0.21030062,  0.01457939, -0.255515  ]), array([-0.02112198,  0.4052114 ,  0.00946909, -0.54356392]), array([-0.01301775,  0.20995767, -0.00140219, -0.24791258]), array([-0.0088186 ,  0.01485577, -0.00636044,  0.04432773]), array([-0.00852148, -0.1801744 , -0.00547388,  0.33499712]), array([-0.01212497,  0.01502503,  0.00122606,  0.04059305]), array([-0.01182447,  0.21012938,  0.00203792, -0.25170279]), array([-0.00762188,  0.01497838, -0.00299614,  0.04162224]), array([-0.00732232, -0.18010048, -0.00216369,  0.33335835]), array([-0.01092433,  0.0150522 ,  0.00450348,  0.03999389]), array([-0.01062328,  0.21010929,  0.00530335, -0.25126475]), array([-0.0064211 ,  0.01491201,  0.00027806,  0.04308623]), array([-0.00612286,  0.21002997,  0.00113978, -0.24950895]), array([-0.00192226,  0.01489176, -0.0038504 ,  0.04353327]), array([-0.00162442, -0.18017477, -0.00297973,  0.33499888]), array([-0.0

217 {'obs': [array([-0.02284648, -0.03935665,  0.01580209,  0.01798373]), array([-0.02363361, -0.23470161,  0.01616176,  0.31561022]), array([-0.02832764, -0.43005   ,  0.02247397,  0.61334589]), array([-0.03692864, -0.62547868,  0.03474089,  0.91302167]), array([-0.04943822, -0.82105296,  0.05300132,  1.21641783]), array([-0.06585928, -0.62665314,  0.07732968,  0.94080268]), array([-0.07839234, -0.82272737,  0.09614573,  1.25674736]), array([-0.09484689, -1.01893948,  0.12128068,  1.57792955]), array([-0.11522568, -0.8254532 ,  0.15283927,  1.32540118]), array([-0.13173474, -1.02213862,  0.17934729,  1.6617501 ]), array([-0.15217751, -1.21883928,  0.21258229,  2.00451446])], 'act': [0, 0, 0, 0, 1, 0, 0, 1, 0, 0], 'reward': 10.0}
---
218 {'obs': [array([ 0.01728698, -0.02090673,  0.01970131, -0.04146262]), array([ 0.01686884,  0.17392725,  0.01887206, -0.32786507]), array([ 0.02034739, -0.02145822,  0.01231476, -0.02929099]), array([ 0.01991822, -0.21675459,  0.01172894,  0.26725179]),

246 {'obs': [array([-0.00875235, -0.03831792, -0.0221506 ,  0.01222768]), array([-0.00951871,  0.15711458, -0.02190604, -0.2873609 ]), array([-0.00637642,  0.35254197, -0.02765326, -0.58687151]), array([ 6.74421835e-04,  5.48040079e-01, -3.93906909e-02, -8.88135585e-01]), array([ 0.01163522,  0.35347424, -0.0571534 , -0.60809108]), array([ 0.01870471,  0.15919594, -0.06931522, -0.33394342]), array([ 0.02188863, -0.03487453, -0.07599409, -0.06389913]), array([ 0.02119114,  0.16125008, -0.07727208, -0.37955773]), array([ 0.02441614,  0.35737945, -0.08486323, -0.6955693 ]), array([ 0.03156373,  0.55356949, -0.09877462, -1.01371516]), array([ 0.04263512,  0.74986028, -0.11904892, -1.33570872]), array([ 0.05763232,  0.5564225 , -0.14576309, -1.08252354]), array([ 0.06876077,  0.36349328, -0.16741356, -0.83890139]), array([ 0.07603064,  0.17100434, -0.18419159, -0.60319474]), array([ 0.07945072, -0.0211288 , -0.19625549, -0.37371668]), array([ 0.07902815,  0.17616055, -0.20372982, -0.7212992

302 {'obs': [array([-0.00874277,  0.01196509, -0.00342707,  0.03533495]), array([-0.00850347, -0.18310755, -0.00272037,  0.32693463]), array([-0.01216562, -0.37819067,  0.00381832,  0.61875843]), array([-0.01972943, -0.18312226,  0.01619349,  0.32728053]), array([-0.02339187,  0.01176545,  0.0227391 ,  0.03974798]), array([-0.02315657, -0.18367507,  0.02353406,  0.33951769]), array([-0.02683007,  0.01110424,  0.03032441,  0.054348  ]), array([-0.02660798, -0.1844391 ,  0.03141137,  0.35644217]), array([-0.03029676, -0.37999323,  0.03854022,  0.6588621 ]), array([-0.03789663, -0.57562979,  0.05171746,  0.96342704]), array([-0.04940922, -0.77140703,  0.070986  ,  1.27189841]), array([-0.06483737, -0.57725923,  0.09642397,  1.00226249]), array([-0.07638255, -0.77352802,  0.11646922,  1.32360286]), array([-0.09185311, -0.96991258,  0.14294127,  1.65034857]), array([-0.11125136, -0.77672114,  0.17594825,  1.40539864]), array([-0.12678578, -0.97353615,  0.20405622,  1.7475255 ]), array([-0.1

319 {'obs': [array([-0.02645633,  0.03282085, -0.03823991,  0.03500493]), array([-0.02579991,  0.22846971, -0.03753981, -0.26949361]), array([-0.02123052,  0.42410671, -0.04292968, -0.57377647]), array([-0.01274838,  0.2296121 , -0.05440521, -0.2949208 ]), array([-0.00815614,  0.42546578, -0.06030363, -0.60425388]), array([ 0.00035317,  0.23123672, -0.0723887 , -0.3311584 ]), array([ 0.00497791,  0.03721583, -0.07901187, -0.06215295]), array([ 0.00572223, -0.15668959, -0.08025493,  0.20459243]), array([ 0.00258843, -0.3505776 , -0.07616308,  0.47091871]), array([-0.00442312, -0.54454577, -0.06674471,  0.7386572 ]), array([-0.01531403, -0.34856873, -0.05197156,  0.4257378 ]), array([-0.02228541, -0.54291748, -0.04345681,  0.70159455]), array([-0.03314376, -0.34722097, -0.02942492,  0.39555464]), array([-0.04008818, -0.54191333, -0.02151382,  0.67881708]), array([-0.05092644, -0.73672991, -0.00793748,  0.96464975]), array([-0.06566104, -0.93174434,  0.01135551,  1.25482858]), array([-0.0

---
382 {'obs': [array([ 0.0043969 ,  0.01449824,  0.01309462, -0.03760022]), array([ 0.00468686,  0.20942999,  0.01234262, -0.32612307]), array([ 0.00887546,  0.40437406,  0.00582016, -0.61488826]), array([ 0.01696294,  0.5994142 , -0.00647761, -0.90573241]), array([ 0.02895123,  0.79462327, -0.02459226, -1.20044426]), array([ 0.04484369,  0.59982791, -0.04860114, -0.91556897]), array([ 0.05684025,  0.40539572, -0.06691252, -0.63854818]), array([ 0.06494817,  0.21126743, -0.07968348, -0.36766499]), array([ 0.06917352,  0.40742589, -0.08703678, -0.6843712 ]), array([ 0.07732203,  0.60364158, -0.10072421, -1.00313714]), array([ 0.08939487,  0.79995435, -0.12078695, -1.32567669]), array([ 0.10539395,  0.60654674, -0.14730048, -1.07310406]), array([ 0.11752489,  0.41364602, -0.16876257, -0.83003631]), array([ 0.12579781,  0.22118307, -0.18536329, -0.59482427]), array([ 0.13022147,  0.02907289, -0.19725978, -0.36578071]), array([ 0.13080293, -0.16277928, -0.20457539, -0.14120226]), array([

398 {'obs': [array([ 0.02893066, -0.01340268,  0.02558562,  0.00354373]), array([ 0.02866261,  0.18134317,  0.0256565 , -0.28095809]), array([ 0.03228947,  0.37608993,  0.02003733, -0.5654399 ]), array([ 0.03981127,  0.18069268,  0.00872854, -0.26651228]), array([ 0.04342513, -0.01455275,  0.00339829,  0.02891087]), array([ 0.04313407,  0.1805203 ,  0.00397651, -0.26269793]), array([ 0.04674448, -0.01465818, -0.00127745,  0.03123658]), array([ 0.04645131, -0.20976179, -0.00065272,  0.32351618]), array([ 0.04225608, -0.01463055,  0.0058176 ,  0.03062748]), array([ 0.04196347,  0.18040749,  0.00643015, -0.26021426]), array([ 0.04557162,  0.37543706,  0.00122587, -0.55086212]), array([ 0.05308036,  0.18029791, -0.00979137, -0.25779321]), array([ 0.05668632, -0.01468289, -0.01494724,  0.03178536]), array([ 0.05639266, -0.20958734, -0.01431153,  0.31971505]), array([ 0.05220091, -0.01426452, -0.00791723,  0.02255338]), array([ 0.05191562,  0.18097007, -0.00746616, -0.27261697]), array([ 0.0

448 {'obs': [array([-0.02688643, -0.01764518,  0.03856924,  0.01832714]), array([-0.02723933, -0.21329842,  0.03893579,  0.32292529]), array([-0.0315053 , -0.40895256,  0.04539429,  0.62762828]), array([-0.03968435, -0.6046777 ,  0.05794686,  0.93425491]), array([-0.05177791, -0.80053145,  0.07663196,  1.24456928]), array([-0.06778854, -0.60647175,  0.10152334,  0.97684177]), array([-0.07991797, -0.80279781,  0.12106018,  1.29961112]), array([-0.09597393, -0.60940251,  0.1470524 ,  1.04714613]), array([-0.10816198, -0.80613733,  0.16799532,  1.38214117]), array([-0.12428473, -1.00290935,  0.19563815,  1.72230257]), array([-0.14434291, -0.81049156,  0.2300842 ,  1.49633319])], 'act': [0, 0, 0, 0, 1, 0, 1, 0, 0, 1], 'reward': 10.0}
---
449 {'obs': [array([-0.01940813, -0.02459215, -0.04635088, -0.03694244]), array([-0.01989997, -0.21901983, -0.04708972,  0.24076358]), array([-0.02428037, -0.41343859, -0.04227445,  0.51822918]), array([-0.03254914, -0.60794064, -0.03190987,  0.7972966 ]),

462 {'obs': [array([-0.03129897, -0.04005572, -0.04735088,  0.01960904]), array([-0.03210008,  0.15571219, -0.0469587 , -0.28762964]), array([-0.02898584,  0.35147125, -0.05271129, -0.59474533]), array([-0.02195641,  0.15712515, -0.0646062 , -0.31912101]), array([-0.01881391,  0.35310489, -0.07098862, -0.63145834]), array([-0.01175181,  0.54914175, -0.08361778, -0.94562613]), array([-0.00076897,  0.35523952, -0.10253031, -0.68034401]), array([ 0.00633582,  0.55162488, -0.11613719, -1.00346614]), array([ 0.01736831,  0.74809061, -0.13620651, -1.33024718]), array([ 0.03233013,  0.9446423 , -0.16281145, -1.6622651 ]), array([ 0.05122297,  0.75174814, -0.19605675, -1.42440564]), array([ 0.06625793,  0.94867791, -0.22454487, -1.77141336])], 'act': [1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1], 'reward': 11.0}
---
463 {'obs': [array([-0.03483219, -0.03053419,  0.01607204, -0.00109612]), array([-0.03544287, -0.2258829 ,  0.01605012,  0.29661408]), array([-0.03996053, -0.03099339,  0.0219824 ,  0.009036  

532 {'obs': [array([ 0.04611483, -0.01693597, -0.00257142, -0.04738308]), array([ 0.04577611, -0.21202095, -0.00351908,  0.24448743]), array([ 0.04153569, -0.40709247,  0.00137066,  0.53605828]), array([ 0.03339384, -0.60223366,  0.01209183,  0.82917277]), array([ 0.02134917, -0.40727908,  0.02867529,  0.54031719]), array([ 0.01320359, -0.60279213,  0.03948163,  0.84189558]), array([ 0.00114774, -0.4082307 ,  0.05631954,  0.56188545]), array([-0.00701687, -0.21394245,  0.06755725,  0.28746434]), array([-0.01129572, -0.0198457 ,  0.07330654,  0.01683008]), array([-0.01169263,  0.17415256,  0.07364314, -0.25185295]), array([-0.00820958,  0.36814989,  0.06860608, -0.5204287 ]), array([-0.00084658,  0.17213262,  0.05819751, -0.20694012]), array([ 0.00259607, -0.02377116,  0.0540587 ,  0.10351846]), array([ 0.00212064, -0.21962448,  0.05612907,  0.4127548 ]), array([-0.00227184, -0.02534122,  0.06438417,  0.13828215]), array([-0.00277867, -0.22132343,  0.06714981,  0.45056222]), array([-0.0

558 {'obs': [array([-0.03505166,  0.02001774, -0.01020205, -0.03004183]), array([-0.0346513 ,  0.2152845 , -0.01080289, -0.32592607]), array([-0.03034561,  0.41055858, -0.01732141, -0.62199609]), array([-0.02213444,  0.21568274, -0.02976133, -0.33481841]), array([-0.01782079,  0.41121533, -0.0364577 , -0.63673597]), array([-0.00959648,  0.60682624, -0.04919242, -0.94067355]), array([ 0.00254005,  0.4124006 , -0.06800589, -0.66384461]), array([ 0.01078806,  0.21828731, -0.08128278, -0.39332605]), array([ 0.0151538 ,  0.02440722, -0.08914931, -0.12733692]), array([ 0.01564195, -0.16933201, -0.09169604,  0.13594267]), array([ 0.01225531, -0.36302908, -0.08897719,  0.39834678]), array([ 0.00499473, -0.5567835 , -0.08101025,  0.66170393]), array([-0.00614094, -0.36063338, -0.06777618,  0.34465269]), array([-0.01335361, -0.554729  , -0.06088312,  0.61521674]), array([-0.02444819, -0.74894981, -0.04857879,  0.88811959]), array([-0.03942719, -0.94337999, -0.0308164 ,  1.16514425]), array([-0.0

593 {'obs': [array([ 0.02337172, -0.03330959, -0.00628882,  0.01430278]), array([ 0.02270553, -0.22834079, -0.00600277,  0.30499489]), array([ 1.81387158e-02, -4.23376684e-01,  9.71326051e-05,  5.95778662e-01]), array([ 0.00967118, -0.61849999,  0.01201271,  0.88849218]), array([-0.00269882, -0.8137829 ,  0.02978255,  1.18492705]), array([-0.01897448, -0.61905966,  0.05348109,  0.9017265 ]), array([-0.03135567, -0.8148638 ,  0.07151562,  1.21072873]), array([-0.04765294, -1.0108325 ,  0.0957302 ,  1.52493806]), array([-0.06786959, -1.20697101,  0.12622896,  1.84590047]), array([-0.09200902, -1.01344613,  0.16314697,  1.5949356 ]), array([-0.11227794, -0.82059183,  0.19504568,  1.35724821]), array([-0.12868977, -0.6283767 ,  0.22219064,  1.13137299])], 'act': [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1], 'reward': 11.0}
---
594 {'obs': [array([-0.01558697,  0.00986731,  0.01459264,  0.01806602]), array([-0.01538963, -0.18546084,  0.01495396,  0.31531714]), array([-0.01909884,  0.00944494,  0.02126

617 {'obs': [array([-0.03604001, -0.03415266, -0.04183784, -0.01903662]), array([-0.03672306,  0.16154354, -0.04221857, -0.32462078]), array([-0.03349219,  0.35724041, -0.04871099, -0.63031308]), array([-0.02634738,  0.16283079, -0.06131725, -0.35335992]), array([-0.02309077, -0.03136809, -0.06838445, -0.08062502]), array([-0.02371813,  0.16466413, -0.06999695, -0.39407543]), array([-0.02042485, -0.02939837, -0.07787846, -0.1242568 ]), array([-0.02101281,  0.16674791, -0.08036359, -0.44045759]), array([-0.01767785,  0.36290975, -0.08917275, -0.75735217]), array([-0.01041966,  0.16912252, -0.10431979, -0.49400833]), array([-0.00703721,  0.3655491 , -0.11419996, -0.81766084]), array([ 2.73772805e-04,  1.72160154e-01, -1.30553173e-01, -5.62968811e-01]), array([ 0.00371698, -0.02091151, -0.14181255, -0.31409956]), array([ 0.00329875, -0.21375854, -0.14809454, -0.06928616]), array([-0.00097643, -0.01685798, -0.14948026, -0.4047849 ]), array([-0.00131358,  0.18003262, -0.15757596, -0.7406111

683 {'obs': [array([-0.00631616,  0.01887685,  0.0259909 , -0.04877411]), array([-0.00593862, -0.17660796,  0.02501542,  0.25199454]), array([-0.00947078, -0.37207802,  0.03005531,  0.55246166]), array([-0.01691234, -0.56760888,  0.04110455,  0.85446043]), array([-0.02826452, -0.37307053,  0.05819375,  0.57498072]), array([-0.03572593, -0.56895794,  0.06969337,  0.88541375]), array([-0.04710509, -0.76495336,  0.08740164,  1.19916644]), array([-0.06240415, -0.57106414,  0.11138497,  0.93510613]), array([-0.07382544, -0.3776065 ,  0.13008709,  0.67939744]), array([-0.08137757, -0.57427255,  0.14367504,  1.01004178]), array([-0.09286302, -0.77098918,  0.16387588,  1.34417169]), array([-0.1082828 , -0.96774873,  0.19075931,  1.68332101]), array([-0.12763777, -0.77527879,  0.22442573,  1.45559677])], 'act': [0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1], 'reward': 12.0}
---
684 {'obs': [array([ 0.03631258,  0.03929911, -0.03990244, -0.02315821]), array([ 0.03709857, -0.15522856, -0.04036561,  0.25667

703 {'obs': [array([-0.0471531 ,  0.01611485, -0.01453779, -0.01621193]), array([-0.04683081, -0.17879563, -0.01486203,  0.27184893]), array([-0.05040672,  0.01653521, -0.00942505, -0.02548424]), array([-0.05007601, -0.17845032, -0.00993474,  0.26421014]), array([-0.05364502,  0.01681201, -0.00465054, -0.03158967]), array([-0.05330878,  0.21200035, -0.00528233, -0.32573626]), array([-0.04906877,  0.016954  , -0.01179705, -0.03472382]), array([-0.04872969, -0.17799681, -0.01249153,  0.25421379]), array([-0.05228963,  0.01730126, -0.00740726, -0.04238286]), array([-0.0519436 , -0.1777137 , -0.00825491,  0.24795383]), array([-0.05549788,  0.01752517, -0.00329584, -0.0473214 ]), array([-0.05514737, -0.17754937, -0.00424226,  0.24431984]), array([-0.05869836,  0.01763292,  0.00064413, -0.04969818]), array([-0.0583457 ,  0.21274563, -0.00034983, -0.34217781]), array([-0.05409079,  0.40787255, -0.00719339, -0.63497103]), array([-0.04593334,  0.60309409, -0.01989281, -0.92991061]), array([-0.0

761 {'obs': [array([ 0.03580356,  0.03570364,  0.03671508, -0.01193267]), array([ 0.03651763,  0.23028034,  0.03647642, -0.29280921]), array([ 0.04112324,  0.42486375,  0.03062024, -0.57376848]), array([ 0.04962051,  0.22932617,  0.01914487, -0.27159866]), array([ 0.05420704,  0.42416978,  0.0137129 , -0.55818227]), array([ 0.06269043,  0.22885805,  0.00254925, -0.26121074]), array([ 0.06726759,  0.0336998 , -0.00267496,  0.03227516]), array([ 0.06794159,  0.22886001, -0.00202946, -0.26125055]), array([ 0.07251879,  0.03376708, -0.00725447,  0.03079157]), array([ 0.07319413,  0.22899231, -0.00663864, -0.26417136]), array([ 0.07777398,  0.42420839, -0.01192207, -0.55894077]), array([ 0.08625814,  0.22925579, -0.02310088, -0.27003766]), array([ 0.09084326,  0.42469965, -0.02850164, -0.56991629]), array([ 0.09933725,  0.22998878, -0.03989996, -0.28634694]), array([ 0.10393703,  0.42565637, -0.0456269 , -0.59134221]), array([ 0.11245016,  0.62138642, -0.05745374, -0.89804138]), array([ 0.1

779 {'obs': [array([-0.01876971,  0.04487113, -0.01474687,  0.00669617]), array([-0.01787229,  0.24020144, -0.01461295, -0.29060285]), array([-0.01306826,  0.04529087, -0.02042501, -0.00256426]), array([-0.01216244,  0.2406997 , -0.02047629, -0.30162091]), array([-0.00734845,  0.43610742, -0.02650871, -0.60069065]), array([ 0.0013737 ,  0.24136614, -0.03852252, -0.31647401]), array([ 0.00620102,  0.04681346, -0.044852  , -0.03618433]), array([ 0.00713729,  0.24254895, -0.04557569, -0.34267435]), array([ 0.01198827,  0.43828867, -0.05242918, -0.64937366]), array([ 0.02075404,  0.24393475, -0.06541665, -0.37365049]), array([ 0.02563274,  0.04980009, -0.07288966, -0.10229014]), array([ 0.02662874,  0.24588685, -0.07493546, -0.41704987]), array([ 0.03154648,  0.05190241, -0.08327646, -0.14890021]), array([ 0.03258453, -0.14193441, -0.08625447,  0.11639243]), array([ 0.02974584, -0.33572138, -0.08392662,  0.38066486]), array([ 0.02303141, -0.52955753, -0.07631332,  0.64575023]), array([ 0.0

854 {'obs': [array([ 0.00739466,  0.02714963, -0.02802944, -0.04518985]), array([ 0.00793766,  0.22266207, -0.02893323, -0.34658288]), array([ 0.0123909 ,  0.41818337, -0.03586489, -0.64824727]), array([ 0.02075457,  0.61378613, -0.04882984, -0.95200475]), array([ 0.03303029,  0.41935409, -0.06786993, -0.67505459]), array([ 0.04141737,  0.6153503 , -0.08137102, -0.98831049]), array([ 0.05372438,  0.42140649, -0.10113723, -0.72225403]), array([ 0.06215251,  0.22781813, -0.11558231, -0.46303904]), array([ 0.06670887,  0.42436758, -0.12484309, -0.78980168]), array([ 0.07519622,  0.23116093, -0.14063913, -0.53885584]), array([ 0.07981944,  0.03826709, -0.15141625, -0.29358358]), array([ 0.08058478,  0.23518701, -0.15728792, -0.62993033]), array([ 0.08528852,  0.4321135 , -0.16988652, -0.9677259 ]), array([ 0.09393079,  0.62905886, -0.18924104, -1.30859782]), array([ 0.10651197,  0.43677018, -0.215413  , -1.0806195 ])], 'act': [1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0], 'reward': 14.0}
---


868 {'obs': [array([ 0.01568612, -0.01912795, -0.0024498 , -0.04687277]), array([ 0.01530356, -0.21421469, -0.00338726,  0.24503621]), array([ 0.01101926, -0.01904452,  0.00151346, -0.04871321]), array([ 0.01063837,  0.1760557 ,  0.0005392 , -0.34091824]), array([ 0.01415949,  0.37116997, -0.00627916, -0.63343109]), array([ 0.02158289,  0.17613618, -0.01894779, -0.34273222]), array([ 0.02510561, -0.01871116, -0.02580243, -0.05608401]), array([ 0.02473139, -0.21345382, -0.02692411,  0.22834768]), array([ 0.02046231, -0.40818087, -0.02235716,  0.51241766]), array([ 0.01229869, -0.6029809 , -0.0121088 ,  0.79797222]), array([ 2.39074954e-04, -7.97934638e-01,  3.85064048e-03,  1.08682150e+00]), array([-0.01571962, -0.99310716,  0.02558707,  1.3807102 ]), array([-0.03558176, -0.79831375,  0.05320127,  1.09613751]), array([-0.05154804, -0.60393123,  0.07512402,  0.82061006]), array([-0.06362666, -0.40991328,  0.09153623,  0.55246966]), array([-0.07182493, -0.21618799,  0.10258562,  0.2899723

921 {'obs': [array([0.04351317, 0.03153022, 0.00797112, 0.00831776]), array([ 0.04414377, -0.16370513,  0.00813747,  0.30350497]), array([0.04086967, 0.03129991, 0.01420757, 0.01339949]), array([ 0.04149567, -0.16402289,  0.01447556,  0.31053099]), array([ 0.03821521, -0.35934805,  0.02068618,  0.60774371]), array([ 0.03102825, -0.16452133,  0.03284105,  0.32164746]), array([ 0.02773782, -0.36009519,  0.039274  ,  0.6245036 ]), array([ 0.02053592, -0.16554291,  0.05176408,  0.34444381]), array([0.01722506, 0.02880591, 0.05865295, 0.0685229 ]), array([ 0.01780118,  0.22304005,  0.06002341, -0.20509322]), array([ 0.02226198,  0.41725457,  0.05592154, -0.47825408]), array([ 0.03060707,  0.22138953,  0.04635646, -0.16848343]), array([ 0.03503486,  0.41581835,  0.04298679, -0.44618943]), array([ 0.04335123,  0.22011546,  0.03406301, -0.14027205]), array([0.04775354, 0.02452262, 0.03125756, 0.1629596 ]), array([ 0.04824399, -0.17103253,  0.03451676,  0.46533727]), array([ 0.04482334, -0.3666

960 {'obs': [array([ 0.04663933, -0.03232175,  0.02614503, -0.04625079]), array([ 0.04599289,  0.16241573,  0.02522002, -0.33057142]), array([ 0.04924121, -0.03305598,  0.01860859, -0.0300432 ]), array([ 0.04858009,  0.16179424,  0.01800773, -0.31679732]), array([ 0.05181597, -0.03357951,  0.01167178, -0.01849021]), array([ 0.05114438, -0.22886689,  0.01130197,  0.27785232]), array([ 0.04656704, -0.03390798,  0.01685902, -0.01124465]), array([ 0.04588888,  0.16096818,  0.01663413, -0.29856104]), array([ 0.04910825,  0.35584912,  0.01066291, -0.5859518 ]), array([ 0.05622523,  0.16057945, -0.00105613, -0.28992912]), array([ 0.05943682, -0.03452742, -0.00685471,  0.00242053]), array([ 0.05874627,  0.16069216, -0.0068063 , -0.29241723]), array([ 0.06196011,  0.35591049, -0.01265464, -0.58723898]), array([ 0.06907832,  0.16096804, -0.02439942, -0.29856908]), array([ 0.07229768,  0.35642914, -0.03037081, -0.59884618]), array([ 0.07942627,  0.55196256, -0.04234773, -0.90093871]), array([ 0.0

Let's create the REINFORCE agent. We assume that the policy is computed using an MLP with a softmax output.

In [7]:
'''
Input - Observations (1x4)
Output - Policy (1x2)
'''
class MLP(Chain):
    """Multilayer perceptron"""

    def __init__(self, n_output=1, n_hidden=5):
        super(MLP, self).__init__(l1=L.Linear(None, n_hidden), 
                                  l2=L.Linear(n_hidden, n_output))

    def __call__(self, x):
        return self.l2(F.relu(self.l1(x)))

1: A skeleton for the REINFORCEAgent is given. Implement the compute_loss and compute_score functions. 

# To Figure out

1. How the Agent and Model interact with each other. 
2. How state (s), action (a), policy (pi) work in this context with the NN
3. What is 'compute_loss'
4. What is 'compute_score'
5. What will be the loss function here - How it would maximize reward (refer slides)

Maximize expected return - using Gradient Ascent


# Questions

1. Do we use both Policy and Value gradient for 'compute_loss'?
2. What is that E function in slides 15 (expected return)
3. Is the (X) equn. in my book correctly correspond to V_t in 'compute_loss'
4. what is V_t?
5. How to maximize Reward?

In [8]:
class REINFORCEAgent(object):
    """Agent trained using REINFORCE"""

    def __init__(self, action_space, model, optimizer=Adam()):

        self.action_space = action_space

        self.model = model

        self.optimizer = optimizer
        self.optimizer.setup(self.model)

        # monitor score and reward
        self.rewards = []
        self.scores = []


    def act(self, observation, reward, done):

        # linear outputs reflecting the log action probabilities and the value
        policy = self.model(Variable(np.atleast_2d(np.asarray(observation, 'float32'))))
        print('---\nobs', Variable(np.atleast_2d(np.asarray(observation, 'float32'))))
        print('policy', policy)
        

        # generate action according to policy
        p = F.softmax(policy).data
#         print('p before',p, end=' ')
        # normalize p in case tiny floating precision problems occur
        row_sums = p.sum(axis=1)
        
        p /= row_sums[:, np.newaxis]
        print('choose policy', p.shape[1], 'prob', p[0])
                                             # a, size=None, replace=True, p=None
        action = np.asarray([np.random.choice(p.shape[1], None, True, p[0])])
        print('action', action)
#         print('policy', policy)
        return action, policy


    def compute_loss(self):
        """
        Return loss for this episode based on computed scores and accumulated rewards
        """
    
        return Variable(np.array([0]))

    def compute_score(self, action, policy):
        """
        Computes score
        
        # recompute score function: grad_theta log pi_theta (s_t, a_t) * v_t

        Args:
            action (int):
            policy [pi_theta]:

        Returns:
            score
        """

        pass

# compute_score

$$ \nabla_{\theta} log \pi_{\theta}(s_t, a_t)\cdot v_t $$

Now we run the REINFORCE agent on the CartPole environment. Note that we update the agent after each episode for simplicity.

In [9]:
env = gym.make(env_id)
env.seed(0)

print(env.action_space.n)

network = MLP(n_output=env.action_space.n, n_hidden=3)
agent = REINFORCEAgent(env.action_space, network, optimizer=Adam())

episode_count = 2
done = False
reward = 0
    
R = np.zeros(episode_count)

INFO: Making new env: CartPole-v0
2


In [13]:

for i in tqdm.trange(episode_count):

    ob = env.reset()

    loss = 0
    print('\n###################################### eps' ,i,' ################################\n')
    while True:

        action, policy = agent.act(ob, reward, done)

        ob, reward, done, _ = env.step(action[0])

        # get reward associated with taking the previous action in the previous state
        agent.rewards.append(reward)
        R[i] += reward

        # recompute score function: grad_theta log pi_theta (s_t, a_t) * v_t
        agent.scores.append(agent.compute_score(action, policy))

        # we learn at the end of each episode
        if done:
            
            loss += agent.compute_loss()
            
            agent.model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            agent.optimizer.update()

            break

100%|██████████| 2/2 [00:00<00:00, 14.95it/s]


###################################### eps 0  ################################

---
obs variable([[-0.00465297  0.04027819 -0.04591077  0.04221768]])
policy variable([[-0.00372359  0.00671818]])
choose policy 2 prob [0.49738958 0.50261045]
action [1]
---
obs variable([[-0.00384741  0.23602739 -0.04506642 -0.26458976]])
policy variable([[0. 0.]])
choose policy 2 prob [0.5 0.5]
action [0]
---
obs variable([[ 0.00087314  0.04157668 -0.05035821  0.01354511]])
policy variable([[-0.00185563  0.00334797]])
choose policy 2 prob [0.49869913 0.50130093]
action [0]
---
obs variable([[ 0.00170468 -0.15278825 -0.05008731  0.28992385]])
policy variable([[-0.04102376 -0.00535791]])
choose policy 2 prob [0.4910845  0.50891554]
action [0]
---
obs variable([[-0.00135109 -0.34716153 -0.04428884  0.5663986 ]])
policy variable([[-0.09027581 -0.02508044]])
choose policy 2 prob [0.48370692 0.51629305]
action [0]
---
obs variable([[-0.00829432 -0.5416351  -0.03296086  0.84480625]])
policy variable([[-0.14071




In [11]:
# You may want to run a video of the trained agent performing in the environment using the env.render() function.
#
# for i in range(3):
#
#     ob = env.reset()
#
#     while True:
#
#         action, policy = agent.act(ob, reward, done)
#
#         ob, reward, done, _ = env.step(action[0])
#
#         if done:
#             break
#       
#         env.render()

2: Plot the cumulative reward for both RandomAgent and REINFORCEAgent.