https://bytepawn.com/solving-the-cartpole-reinforcement-learning-problem-with-pytorch.html

In [23]:
import gym
from random import random
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
env = gym.make('CartPole-v1')

In [3]:
def goodness_score(select_action, num_episodes=100):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, _, done, _ = env.step(action)
            if done:
                break
        ts.append(t)
    score = sum(ts) / (len(ts)*num_steps)
    return score

In [4]:
def select_action_random(state):
    if random() < 0.5:
        return 0
    else:
        return 1

In [5]:
def select_action_simple(state):
    if state[2] < 0:
        return 0
    else:
        return 1

In [10]:
print(goodness_score(select_action_simple))
print(goodness_score(select_action_random))

0.08398
0.04204


In [11]:
def select_action_good(state):
    if state[2]+state[3] < 0:
        return 0
    else:
        return 1

goodness_score(select_action_good)

0.98592

In [19]:
class PolicyNN(nn.Module):
    def __init__(self):
        super(PolicyNN, self).__init__()
        self.fc = nn.Linear(4, 2)

    def forward(self, x):
        x = self.fc(x)
        return F.softmax(x, dim=1)

In [13]:
def select_action_from_policy(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)

def select_action_from_policy_best(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    if probs[0][0] > probs[0][1]:
        return 0
    else:
        return 1

In [25]:
model = PolicyNN()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [26]:
def train_simple(num_episodes=10*1000):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()
        probs = []
        for t in range(1, num_steps+1):
            action, prob = select_action_from_policy(model, state)
            probs.append(prob)
            state, _, done, _ = env.step(action)
            if done:
                break
        loss = 0
        for i, prob in enumerate(probs):
            loss += -1 * (t - i) * prob
        print(episode, t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ts.append(t)
        # check stopping condition:
        if len(ts) > 10 and sum(ts[-10:])/10.0 >= num_steps * 0.95:
            print('Stopping training, looks good...')
            return

train_simple()

0 53 949.7922973632812
1 24 199.6031036376953
2 15 79.35723876953125
3 27 261.7921142578125
4 23 195.65345764160156
5 16 91.99872589111328
6 23 180.47142028808594
7 14 70.39389038085938
8 35 409.5982971191406
9 33 372.71356201171875
10 18 118.32518768310547
11 21 155.21478271484375
12 18 119.1595687866211
13 14 80.6568832397461
14 31 320.3471374511719
15 21 155.1772003173828
16 27 251.7634735107422
17 15 82.27651977539062
18 50 835.3651123046875
19 70 1628.2569580078125
20 21 164.67274475097656
21 18 112.89387512207031
22 24 196.91929626464844
23 25 229.04403686523438
24 26 226.33184814453125
25 28 274.42425537109375
26 54 966.0028686523438
27 15 82.3408432006836
28 28 267.94500732421875
29 21 155.88262939453125
30 38 481.6607666015625
31 39 500.85980224609375
32 45 659.1751098632812
33 85 2367.63232421875
34 160 8301.09375
35 20 141.44638061523438
36 68 1513.7879638671875
37 33 364.7498474121094
38 57 1053.075927734375
39 61 1253.705810546875
40 72 1703.3406982421875
41 21 154.7506256

339 220 14773.037109375
340 161 7586.578125
341 84 2200.77685546875
342 376 42140.58984375
343 500 75833.2265625
344 119 4369.19384765625
345 164 8319.4736328125
346 134 5543.5087890625
347 351 35859.7421875
348 251 19710.201171875
349 154 6978.95947265625
350 275 22167.794921875
351 50 812.6509399414062
352 339 34185.4296875
353 500 74906.734375
354 293 25663.90625
355 451 61126.4375
356 209 12871.0732421875
357 186 10861.2138671875
358 497 72408.4296875
359 264 21322.978515625
360 500 74573.2265625
361 500 73963.4921875
362 285 24332.91796875
363 458 62735.58203125
364 414 52500.13671875
365 410 49860.4609375
366 256 19263.79296875
367 486 68741.28125
368 308 28593.279296875
369 390 46279.2890625
370 500 75469.1171875
371 458 61753.62890625
372 165 8296.8642578125
373 395 45625.375
374 500 74355.890625
375 460 62589.640625
376 96 2981.64990234375
377 500 74602.0390625
378 156 7440.8994140625
379 423 52917.203125
380 500 74030.4609375
381 171 9189.1591796875
382 500 76181.3671875
383 

In [27]:


print(
    goodness_score(lambda state: select_action_from_policy(model, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model, state))
)



0.82902 1.0


In [28]:
num_episodes = 10
num_steps = 500
for episode in range(num_episodes):
    state = env.reset()
    for t in range(1, num_steps+1):
        action = select_action_good(state)
        state, _, done, _ = env.step(action)
        env.render()
        if done:
            break