In this assignment you will learn how to apply the REINFORCE algorithm within the OpenAI Gym environment. Make sure OpenAI gym is installed on your machine. Now let's import some relevant packages.

In [2]:
import gym
from gym import wrappers, logger
import matplotlib.pyplot as plt
import tqdm
import numpy as np
from chainer import Chain
import chainer.links as L
import chainer.functions as F
from chainer.optimizers import Adam
from chainer import Variable

%matplotlib inline

We will make use of the classic CartPole environment provided by OpenAI Gym. Figure out what the details of this environment are.

https://github.com/openai/gym/wiki/CartPole-v0

**Observation**

Type: Box(4)
Num 	Observation 	Min 	Max

0 	Cart Position 	-2.4 	2.4

1 	Cart Velocity 	-Inf 	Inf

2 	Pole Angle 	~ -41.8° 	~ 41.8°

3 	Pole Velocity At Tip 	-Inf 	Inf

[Cart Position, Cart Velocity, Pole Angle, Pole Velocity]

**Actions**

Type: Discrete(2)

Num 	Action

0 	Push cart to the left

1 	Push cart to the right

Note: The amount the velocity is reduced or increased is not fixed as it depends on the angle the pole is pointing. This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it.

**Reward**

Reward is 1 for every step taken, including the termination step

**Starting State**

All observations are assigned a uniform random value between ±0.05
Episode Termination

1. Pole Angle is more than ±12°
2. Cart Position is more than ±2.4 (center of the cart reaches the edge of the display)
3. Episode length is greater than 200

**Solved Requirements**

Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials.

In [3]:
env_id = 'CartPole-v0'

# You can set the level to logger.DEBUG or logger.WARN if you want to change the amount of output.
logger.set_level(logger.INFO)

In [4]:
### TEST
env = gym.make('CartPole-v0')
for i_episode in range(2):
    print('\n>> episode', i_episode)
    observation = env.reset()
    
    # actions inside each episode
    for t in range(20):
#         env.render()
        print('observation',observation, end=' ')
        action = env.action_space.sample()
        print('action', action, end=' ')
        observation, reward, done, info = env.step(action)
        print('reward', reward)
        if done:
            print("\nEpisode finished after {} timesteps".format(t+1))
            break

INFO: Making new env: CartPole-v0

>> episode 0
observation [-0.00515196  0.04703844  0.03652002  0.03551209] action 0 reward 1.0
observation [-0.00421119 -0.14858767  0.03723026  0.3394901 ] action 1 reward 1.0
observation [-0.00718294  0.0459853   0.04402007  0.05877606] action 1 reward 1.0
observation [-0.00626324  0.24044935  0.04519559 -0.21969972] action 0 reward 1.0
observation [-0.00145425  0.04471147  0.04080159  0.08689026] action 1 reward 1.0
observation [-0.00056002  0.23922554  0.0425394  -0.19264566] action 1 reward 1.0
observation [ 0.00422449  0.43371395  0.03868648 -0.47161155] action 1 reward 1.0
observation [ 0.01289877  0.62826873  0.02925425 -0.75185438] action 1 reward 1.0
observation [ 0.02546415  0.82297532  0.01421717 -1.03518976] action 1 reward 1.0
observation [ 0.04192365  1.01790538 -0.00648663 -1.3233756 ] action 1 reward 1.0
observation [ 0.06228176  1.21310867 -0.03295414 -1.61808144] action 1 reward 1.0
observation [ 0.08654393  1.4086032  -0.06531577 -

Let's define a baseline agent which just emits random actions.

In [5]:
class RandomAgent(object):
    """The world's simplest agent!"""

    def __init__(self, action_space):
        self.action_space = action_space

    def act(self, observation, reward, done):
        return self.action_space.sample()


Let's run the agent on the environment.

In [None]:
env = gym.make(env_id)
env.seed(0)
agent = RandomAgent(env.action_space)

episode_count = 1000
done = False
reward = 0
    
R0 = np.zeros(episode_count)

random_history = {}

for i in tqdm.trange(episode_count):

    ob = env.reset()
    temp_history = {'obs':[], 'act':[], 'reward':0}
    temp_history['obs'].append(ob)
    while True:

        action = agent.act(ob, reward, done)
        temp_history['act'].append(action)
        ob, reward, done, _ = env.step(action)
        temp_history['obs'].append(ob)
        R0[i] += reward
        
        if done:
            temp_history['reward'] = R0[i]
            random_history[i] = temp_history
            break

# Close the env and write monitor result info to disk
env.close()
print(max(R0))


In [9]:
for eps in random_history:
    print(eps, random_history[eps])
    print('---')

0 {'obs': [array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]), array([-0.04363321,  0.24146826,  0.01284913, -0.30946528]), array([-0.03880385,  0.43640481,  0.00665982, -0.59806842]), array([-0.03007575,  0.63143294, -0.00530154, -0.88864615]), array([-0.01744709,  0.82662643, -0.02307447, -1.18299093]), array([-0.00091456,  0.63181137, -0.04673429, -0.89762942]), array([ 0.01172166,  0.43735302, -0.06468687, -0.61999524]), array([ 0.02046872,  0.24319131, -0.07708678, -0.34836648]), array([ 0.02533255,  0.43932019, -0.08405411, -0.66432722]), array([ 0.03411895,  0.6355046 , -0.09734065, -0.98224772]), array([ 0.04682905,  0.83178663, -0.11698561, -1.30384998]), array([ 0.06346478,  1.0281817 , -0.14306261, -1.63074441]), array([ 0.08402841,  1.22466534, -0.1756775 , -1.96437604]), array([ 0.10852172,  1.03178449, -0.21496502, -1.7308918 ])], 'act': [1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0], 'reward': 13.0}
---
1 {'obs': [array([-0.03975157,  0.01730763, -0.01074233,  0.016984

80 {'obs': [array([-0.0374842 , -0.04376657, -0.02952087,  0.03913226]), array([-0.03835953,  0.151766  , -0.02873823, -0.26271667]), array([-0.03532421, -0.0429342 , -0.03399256,  0.02076516]), array([-0.03618289, -0.23755259, -0.03357726,  0.30253221]), array([-0.04093394, -0.4321803 , -0.02752661,  0.58443937]), array([-0.04957755, -0.62690606, -0.01583782,  0.86832555]), array([-0.06211567, -0.43157224,  0.00152869,  0.57070545]), array([-0.07074712, -0.23647176,  0.0129428 ,  0.2785045 ]), array([-0.07547655, -0.04153682,  0.01851289, -0.01006834]), array([-0.07630729, -0.23691931,  0.01831152,  0.28839765]), array([-0.08104567, -0.43229754,  0.02407947,  0.58679914]), array([-0.08969162, -0.62774831,  0.03581546,  0.88696899]), array([-0.10224659, -0.82333764,  0.05355483,  1.19069233]), array([-0.11871334, -0.62894899,  0.07736868,  0.915265  ]), array([-0.13129232, -0.82502722,  0.09567398,  1.23122621]), array([-0.14779287, -0.63125712,  0.12029851,  0.96998556]), array([-0.16

158 {'obs': [array([ 0.01225246,  0.04483665, -0.0074833 , -0.01806378]), array([ 0.01314919, -0.15017719, -0.00784457,  0.27224871]), array([ 0.01014565,  0.04505581, -0.0023996 , -0.02289806]), array([ 0.01104676, -0.15003164, -0.00285756,  0.2690268 ]), array([ 0.00804613, -0.3451127 ,  0.00252298,  0.56080706]), array([ 0.00114388, -0.15002625,  0.01373912,  0.26892007]), array([-0.00185665, -0.34534155,  0.01911752,  0.56590454]), array([-0.00876348, -0.54072641,  0.03043561,  0.86454849]), array([-0.01957801, -0.73624914,  0.04772658,  1.16664345]), array([-0.03430299, -0.93195858,  0.07105945,  1.47389961]), array([-0.05294216, -0.73777345,  0.10053744,  1.2042309 ]), array([-0.06769763, -0.93404088,  0.12462206,  1.52665368]), array([-0.08637845, -0.74062377,  0.15515513,  1.27532214]), array([-0.10119093, -0.54778319,  0.18066157,  1.03496879]), array([-0.11214659, -0.74478692,  0.20136095,  1.3784911 ]), array([-0.12704233, -0.94177191,  0.22893077,  1.72679942])], 'act': [0,

242 {'obs': [array([-0.0469342 , -0.00097274, -0.02154066, -0.00047239]), array([-0.04695365, -0.19577924, -0.0215501 ,  0.28533712]), array([-0.05086924, -0.39058732, -0.01584336,  0.57114615]), array([-0.05868098, -0.58548356, -0.00442044,  0.85879606]), array([-0.07039065, -0.39030167,  0.01275548,  0.56472648]), array([-0.07819669, -0.19536099,  0.02405001,  0.27608925]), array([-0.08210391, -0.39081766,  0.0295718 ,  0.5762595 ]), array([-0.08992026, -0.19612244,  0.04109699,  0.29303719]), array([-0.09384271, -0.0016098 ,  0.04695773,  0.0135937 ]), array([-0.09387491, -0.19737263,  0.04722961,  0.32071477]), array([-0.09782236, -0.39313425,  0.0536439 ,  0.6279101 ]), array([-0.10568504, -0.19880041,  0.0662021 ,  0.3525924 ]), array([-0.10966105, -0.39479824,  0.07325395,  0.66539433]), array([-0.11755702, -0.59085855,  0.08656184,  0.98021349]), array([-0.12937419, -0.78702741,  0.10616611,  1.29878246]), array([-0.14511474, -0.59340132,  0.13214176,  1.04113283]), array([-0.1

336 {'obs': [array([ 0.03303116, -0.04102206,  0.01634288,  0.01112251]), array([ 0.03221072, -0.23637454,  0.01656533,  0.30891668]), array([ 0.02748323, -0.43172855,  0.02274367,  0.60677749]), array([ 0.01884866, -0.23693187,  0.03487922,  0.32134409]), array([ 0.01411002, -0.04232353,  0.0413061 ,  0.03986155]), array([ 0.01326355, -0.23801272,  0.04210333,  0.34528549]), array([ 0.0085033 , -0.43370752,  0.04900904,  0.65094227]), array([-1.70852401e-04, -2.39301188e-01,  6.20278853e-02,  3.74085802e-01]), array([-0.00495688, -0.04511266,  0.0695096 ,  0.10158769]), array([-0.00585913,  0.14894783,  0.07154136, -0.16838075]), array([-0.00288017,  0.34297671,  0.06817374, -0.43766436]), array([ 0.00397936,  0.14695936,  0.05942045, -0.12429378]), array([ 0.00691855, -0.04896131,  0.05693458,  0.1865279 ]), array([ 0.00593932,  0.14530178,  0.06066514, -0.08766486]), array([ 0.00884536, -0.05063491,  0.05891184,  0.22352413]), array([ 0.00783266,  0.14359772,  0.06338232, -0.0500093

421 {'obs': [array([ 0.02280422, -0.04393853, -0.04104338,  0.00842486]), array([ 0.02192545, -0.23844856, -0.04087488,  0.28788083]), array([ 0.01715648, -0.04276827, -0.03511726, -0.01740831]), array([ 0.01630111,  0.15283925, -0.03546543, -0.32096103]), array([ 0.0193579 ,  0.34844784, -0.04188465, -0.62461407]), array([ 0.02632686,  0.15393489, -0.05437693, -0.34541102]), array([ 0.02940555,  0.34978649, -0.06128515, -0.65473314]), array([ 0.03640128,  0.15556895, -0.07437981, -0.38196031]), array([ 0.03951266,  0.35166388, -0.08201902, -0.69713693]), array([ 0.04654594,  0.54782165, -0.09596176, -1.01447103]), array([ 0.05750237,  0.74408336, -0.11625118, -1.33567863]), array([ 0.07238404,  0.55060222, -0.14296475, -1.08151671]), array([ 0.08339608,  0.74729193, -0.16459509, -1.41542991]), array([ 0.09834192,  0.94402527, -0.19290368, -1.75471433]), array([ 0.11722243,  0.75154346, -0.22799797, -1.52770953])], 'act': [0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0], 'reward': 14.0}
---


498 {'obs': [array([ 0.01514087, -0.03326202, -0.0203848 , -0.00656762]), array([ 0.01447562,  0.16214626, -0.02051615, -0.30561186]), array([ 0.01771855,  0.35755447, -0.02662839, -0.60469382]), array([ 0.02486964,  0.55303849, -0.03872227, -0.90564372]), array([ 0.03593041,  0.35846168, -0.05683514, -0.62537879]), array([ 0.04309964,  0.55432906, -0.06934272, -0.935406  ]), array([ 0.05418622,  0.36020744, -0.08805084, -0.66529378]), array([ 0.06139037,  0.55643665, -0.10135671, -0.98435082]), array([ 0.07251911,  0.75275955, -0.12104373, -1.30707169]), array([ 0.0875743 ,  0.94918972, -0.14718516, -1.63506109]), array([ 0.10655809,  0.75606967, -0.17988638, -1.39162947]), array([ 0.12167948,  0.56358507, -0.20771897, -1.16016329]), array([ 0.13295119,  0.76071679, -0.23092224, -1.51013567])], 'act': [1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1], 'reward': 12.0}
---
499 {'obs': [array([ 0.00158175, -0.03030606, -0.02697292,  0.04866514]), array([ 0.00097563,  0.16519207, -0.02599962, -0.25240

594 {'obs': [array([-0.01558697,  0.00986731,  0.01459264,  0.01806602]), array([-0.01538963,  0.20477698,  0.01495396, -0.26997727]), array([-0.01129409,  0.39968238,  0.00955442, -0.55790637]), array([-0.00330044,  0.20442761, -0.00160371, -0.26222864]), array([ 0.00078811,  0.39957242, -0.00684828, -0.55541697]), array([ 0.00877956,  0.20454729, -0.01795662, -0.26489951]), array([ 0.01287051,  0.00968617, -0.02325461,  0.02206608]), array([ 0.01306423,  0.20513377, -0.02281329, -0.27786228]), array([ 0.01716691,  0.01034458, -0.02837054,  0.00753895]), array([ 0.0173738 ,  0.20586167, -0.02821976, -0.29395849]), array([ 0.02149103,  0.40137435, -0.03409893, -0.59540626]), array([ 0.02951852,  0.20674581, -0.04600705, -0.3136563 ]), array([ 0.03365343,  0.40249193, -0.05228018, -0.62048609]), array([ 0.04170327,  0.59830354, -0.0646899 , -0.92916557]), array([ 0.05366934,  0.79423629, -0.08327321, -1.24145552]), array([ 0.06955407,  0.60027621, -0.10810232, -0.97597761]), array([ 0.0

697 {'obs': [array([-0.01693744, -0.00503525,  0.01164486,  0.03722985]), array([-0.01703815,  0.18991779,  0.01238945, -0.25175637]), array([-0.01323979,  0.38486065,  0.00735433, -0.54050582]), array([-0.00554258,  0.1896361 , -0.00345579, -0.24551476]), array([-0.00174986,  0.38480724, -0.00836609, -0.53928572]), array([ 0.00594629,  0.58004579, -0.0191518 , -0.8345929 ]), array([ 0.0175472 ,  0.38519066, -0.03584366, -0.54799408]), array([ 0.02525102,  0.1905901 , -0.04680354, -0.26681652]), array([ 0.02906282, -0.00383371, -0.05213987,  0.01074461]), array([ 0.02898614,  0.19199573, -0.05192498, -0.29792253]), array([ 0.03282606, -0.00234907, -0.05788343, -0.02205732]), array([ 0.03277908, -0.1965952 , -0.05832457,  0.25181567]), array([ 0.02884717, -0.000691  , -0.05328826, -0.05867867]), array([ 0.02883335,  0.19515291, -0.05446183, -0.36768687]), array([ 0.03273641,  0.39100473, -0.06181557, -0.67703289]), array([ 0.04055651,  0.19679366, -0.07535623, -0.40443489]), array([ 0.0

773 {'obs': [array([-0.01492393, -0.00893212,  0.04533678,  0.00411195]), array([-0.01510257,  0.1855113 ,  0.04541901, -0.27392905]), array([-0.01139234,  0.37995677,  0.03994043, -0.55194786]), array([-0.00379321,  0.57449569,  0.02890148, -0.83178416]), array([ 0.00769671,  0.37899094,  0.01226579, -0.53015364]), array([ 0.01527653,  0.18369861,  0.00166272, -0.23363113]), array([ 0.0189505 , -0.01144706, -0.0030099 ,  0.05957581]), array([ 0.01872156,  0.18371791, -0.00181839, -0.23405524]), array([ 0.02239591, -0.01137801, -0.00649949,  0.05805355]), array([ 0.02216835,  0.18383653, -0.00533842, -0.2366729 ]), array([ 0.02584509,  0.37903434, -0.01007188, -0.53103494]), array([ 0.03342577,  0.57429651, -0.02069258, -0.82687441]), array([ 0.0449117 ,  0.76969521, -0.03723006, -1.12599291]), array([ 0.06030561,  0.57508039, -0.05974992, -0.84521607]), array([ 0.07180721,  0.77096448, -0.07665424, -1.1560739 ]), array([ 0.0872265 ,  0.96699749, -0.09977572, -1.47177423]), array([ 0.1

848 {'obs': [array([-0.0376424 , -0.03108196, -0.04397774,  0.02952036]), array([-0.03826404, -0.22554656, -0.04338733,  0.30800993]), array([-0.04277497, -0.4200243 , -0.03722713,  0.58670013]), array([-0.05117545, -0.22440127, -0.02549313,  0.28252665]), array([-0.05566348, -0.02892515, -0.0198426 , -0.01808642]), array([-0.05624198, -0.223757  , -0.02020433,  0.26827048]), array([-0.06071712, -0.41858487, -0.01483892,  0.55451297]), array([-0.06908882, -0.22325773, -0.00374866,  0.25719204]), array([-0.07355397, -0.02808247,  0.00139518, -0.03667091]), array([-0.07411562, -0.2232244 ,  0.00066177,  0.25645189]), array([-0.07858011, -0.41835579,  0.0057908 ,  0.54934347]), array([-0.08694723, -0.22331566,  0.01677767,  0.25849067]), array([-0.09141354, -0.0284372 ,  0.02194749, -0.02885347]), array([-0.09198228, -0.2238669 ,  0.02137042,  0.27067252]), array([-0.09645962, -0.02905632,  0.02678387, -0.0151942 ]), array([-0.09704075,  0.16567148,  0.02647998, -0.29930762]), array([-0.0

922 {'obs': [array([-0.00165004,  0.0242935 , -0.02040191,  0.01108511]), array([-0.00116417,  0.21970201, -0.02018021, -0.28796441]), array([ 0.00322987,  0.41510584, -0.0259395 , -0.58694305]), array([ 0.01153198,  0.2203566 , -0.03767836, -0.30254283]), array([ 0.01593912,  0.02579134, -0.04372922, -0.02197689]), array([ 0.01645494, -0.1686771 , -0.04416875,  0.25659462]), array([ 0.0130814 , -0.36314153, -0.03903686,  0.53502538]), array([ 0.00581857, -0.167493  , -0.02833635,  0.23030225]), array([ 0.00246871,  0.02802217, -0.02373031, -0.07118255]), array([ 0.00302915, -0.16675168, -0.02515396,  0.21391991]), array([-3.05879928e-04, -3.61505144e-01, -2.08755620e-02,  4.98563210e-01]), array([-0.00753598, -0.55632665, -0.0109043 ,  0.78459482]), array([-0.01866252, -0.36105657,  0.0047876 ,  0.48850132]), array([-0.02588365, -0.16600249,  0.01455763,  0.19733111]), array([-0.0292037 ,  0.02890824,  0.01850425, -0.09072422]), array([-0.02862553, -0.16647398,  0.01668976,  0.2077388

995 {'obs': [array([-0.02184312,  0.00821809,  0.04346495,  0.02218891]), array([-0.02167876,  0.20269063,  0.04390873, -0.25646972]), array([-0.01762495,  0.00697018,  0.03877934,  0.04973301]), array([-0.01748555, -0.18868575,  0.039774  ,  0.35439453]), array([-0.02125926, -0.38435001,  0.04686189,  0.65934927]), array([-0.02894626, -0.1899105 ,  0.06004887,  0.38178254]), array([-0.03274447,  0.00430969,  0.06768453,  0.10862111]), array([-0.03265828,  0.19839972,  0.06985695, -0.16196326]), array([-0.02869028,  0.00235085,  0.06661768,  0.15191404]), array([-0.02864327, -0.19365862,  0.06965596,  0.46484719]), array([-0.03251644, -0.38969215,  0.07895291,  0.77864672]), array([-0.04031028, -0.19573943,  0.09452584,  0.51181198]), array([-0.04422507, -0.00206718,  0.10476208,  0.25034973]), array([-0.04426641, -0.19851703,  0.10976908,  0.57415268]), array([-0.04823675, -0.00509133,  0.12125213,  0.31796769]), array([-0.04833858,  0.18811379,  0.12761148,  0.06584864]), array([-0.0

Let's create the REINFORCE agent. We assume that the policy is computed using an MLP with a softmax output.

In [11]:
class MLP(Chain):
    """Multilayer perceptron"""

    def __init__(self, n_output=1, n_hidden=5):
        super(MLP, self).__init__(l1=L.Linear(None, n_hidden), 
                                  l2=L.Linear(n_hidden, n_output))

    def __call__(self, x):
        return self.l2(F.relu(self.l1(x)))

1: A skeleton for the REINFORCEAgent is given. Implement the compute_loss and compute_score functions. 

In [33]:
class REINFORCEAgent(object):
    """Agent trained using REINFORCE"""

    def __init__(self, action_space, model, optimizer=Adam()):

        self.action_space = action_space

        self.model = model

        self.optimizer = optimizer
        self.optimizer.setup(self.model)

        # monitor score and reward
        self.rewards = []
        self.scores = []


    def act(self, observation, reward, done):

        # linear outputs reflecting the log action probabilities and the value
        policy = self.model(Variable(np.atleast_2d(np.asarray(observation, 'float32'))))
        print('obs', Variable(np.atleast_2d(np.asarray(observation, 'float32'))))
        print('policy', policy)
        

        # generate action according to policy
        p = F.softmax(policy).data
#         print('p before',p, end=' ')
        # normalize p in case tiny floating precision problems occur
        row_sums = p.sum(axis=1)
        
        p /= row_sums[:, np.newaxis]
                                             # a, size=None, replace=True, p=None
        action = np.asarray([np.random.choice(p.shape[1], None, True, p[0])])
        print('action', action)
        print('policy', policy)
        return action, policy


    def compute_loss(self):
        """
        Return loss for this episode based on computed scores and accumulated rewards
        """
    
        return Variable(np.array([0]))

    def compute_score(self, action, policy):
        """
        Computes score

        Args:
            action (int):
            policy:

        Returns:
            score
        """

        pass

Now we run the REINFORCE agent on the CartPole environment. Note that we update the agent after each episode for simplicity.

In [35]:
env = gym.make(env_id)
env.seed(0)

print(env.action_space.n)

network = MLP(n_output=env.action_space.n, n_hidden=3)
agent = REINFORCEAgent(env.action_space, network, optimizer=Adam())

episode_count = 2
done = False
reward = 0
    
R = np.zeros(episode_count)

INFO: Making new env: CartPole-v0
2


In [36]:

for i in tqdm.trange(episode_count):

    ob = env.reset()

    loss = 0
    while True:

        action, policy = agent.act(ob, reward, done)

        ob, reward, done, _ = env.step(action[0])

        # get reward associated with taking the previous action in the previous state
        agent.rewards.append(reward)
        R[i] += reward

        # recompute score function: grad_theta log pi_theta (s_t, a_t) * v_t
        agent.scores.append(agent.compute_score(action, policy))

        # we learn at the end of each episode
        if done:
            
            loss += agent.compute_loss()
            
            agent.model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            agent.optimizer.update()

            break

100%|██████████| 2/2 [00:00<00:00, 30.98it/s]

obs variable([[-0.04456399  0.04653909  0.01326909 -0.02099827]])
policy variable([[0.0016998 0.009274 ]])
action [0]
policy variable([[0.0016998 0.009274 ]])
obs variable([[-0.04363321 -0.14877062  0.01284913  0.2758415 ]])
policy variable([[0.04039792 0.1036236 ]])
action [0]
policy variable([[0.04039792 0.1036236 ]])
obs variable([[-0.04660862 -0.3440735   0.01836596  0.5725492 ]])
policy variable([[0.08420952 0.18547148]])
action [1]
policy variable([[0.08420952 0.18547148]])
obs variable([[-0.05349009 -0.14921382  0.02981694  0.28570825]])
policy variable([[0.04226245 0.1276412 ]])
action [0]
policy variable([[0.04226245 0.1276412 ]])
obs variable([[-0.05647437 -0.34474805  0.03553111  0.5876441 ]])
policy variable([[0.08690345 0.21220927]])
action [1]
policy variable([[0.08690345 0.21220927]])
obs variable([[-0.06336933 -0.15014124  0.04728399  0.3063621 ]])
policy variable([[0.04584122 0.15736428]])
action [0]
policy variable([[0.04584122 0.15736428]])
obs variable([[-0.06637216




In [0]:
# You may want to run a video of the trained agent performing in the environment using the env.render() function.
#
# for i in range(3):
#
#     ob = env.reset()
#
#     while True:
#
#         action, policy = agent.act(ob, reward, done)
#
#         ob, reward, done, _ = env.step(action[0])
#
#         if done:
#             break
#       
#         env.render()

2: Plot the cumulative reward for both RandomAgent and REINFORCEAgent.