In [1]:
import argparse
import gym
import numpy as np
from itertools import count
from random import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
env = gym.make('CartPole-v1')
print(env._max_episode_steps)

500


In [3]:
def select_action_random(state):
    if random() < 0.5:
        return 0
    else:
        return 1

def goodness_score(select_action, num_episodes=100):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, _, done, _ = env.step(action)
            #env.render()
            if done:
                break
        ts.append(t)
    score = sum(ts) / (len(ts)*num_steps)
    return score

print(goodness_score(select_action_random))

0.04226


In [4]:
def select_action_simple(state):
    if state[2] < 0:
        return 0
    else:
        return 1

goodness_score(select_action_simple)

0.08632

In [5]:
def select_action_good(state):
    if state[2]+state[3] < 0:
        return 0
    else:
        return 1

goodness_score(select_action_good)

0.95548

In [6]:
class PolicyNN(nn.Module):
    def __init__(self):
        super(PolicyNN, self).__init__()
        self.fc = nn.Linear(4, 2)

    def forward(self, x):
        x = self.fc(x)
        return F.softmax(x, dim=1)

def select_action_from_policy(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)

def select_action_from_policy_best(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    if probs[0][0] > probs[0][1]:
        return 0
    else:
        return 1

In [7]:
model_untrained = PolicyNN()

print(
    goodness_score(lambda state: select_action_from_policy(model_untrained, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model_untrained, state))
)

0.0313 0.01876


In [8]:
model = PolicyNN()
optimizer = optim.Adam(model.parameters(), lr=0.01)

def train_wont_work(num_episodes=100):
    num_steps = 500
    for episode in range(num_episodes):
        state = env.reset()
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, _, done, _ = env.step(action)
            if done:
                break
        loss = 1.0 - t / num_steps
        # this doesn't actually work, because
        # the loss function is not an explicit
        # function of the model's output; it's
        # a function of book keeping variables
        optimizer.zero_grad()
        loss.backward() # AttributeError: 'float' object has no attribute 'backward'
        optimizer.step()

def train_simple(num_episodes=10*1000):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()
        probs = []
        for t in range(1, num_steps+1):
            action, prob = select_action_from_policy(model, state)
            probs.append(prob)
            state, _, done, _ = env.step(action)
            if done:
                break
        loss = 0
        for i, prob in enumerate(probs):
            loss += -1 * (t - i) * prob
        print(episode, t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ts.append(t)
        if len(ts) > 10 and sum(ts[-10:])/10.0 >= num_steps * 0.95:
            print('Stopping training, looks good...')
            return

train_simple()

0 15 73.55551147460938
1 12 43.6472282409668
2 19 122.97470092773438
3 11 36.83574676513672
4 39 540.05810546875
5 18 121.04826354980469
6 20 139.02340698242188
7 26 251.57608032226562
8 19 121.61566925048828
9 11 36.701202392578125
10 18 110.67570495605469
11 12 43.91122055053711
12 21 150.94691467285156
13 20 138.30221557617188
14 47 783.0886840820312
15 21 155.34681701660156
16 55 1063.359130859375
17 22 171.0105438232422
18 53 1004.13525390625
19 18 106.62135314941406
20 58 1155.60400390625
21 29 297.94146728515625
22 67 1612.26171875
23 11 41.799346923828125
24 20 134.62718200683594
25 21 148.46450805664062
26 37 472.6285400390625
27 11 42.9749641418457
28 31 348.74072265625
29 75 1902.5970458984375
30 38 483.49725341796875
31 28 259.1026611328125
32 17 96.3718032836914
33 28 260.8665771484375
34 26 231.17686462402344
35 24 193.5537567138672
36 39 528.433349609375
37 25 208.82374572753906
38 30 327.7794494628906
39 50 868.1632690429688
40 36 439.43695068359375
41 30 296.5043029785

338 352 35842.8359375
339 454 60004.265625
340 150 6966.24462890625
341 428 53134.45703125
342 418 51235.73046875
343 195 11596.97265625
344 101 3029.0712890625
345 500 75268.0234375
346 241 17025.296875
347 133 5361.77880859375
348 107 3406.32275390625
349 145 6306.47509765625
350 123 4705.58056640625
351 112 4032.36962890625
352 358 36724.33984375
353 148 6520.658203125
354 500 74122.640625
355 135 5545.19482421875
356 111 3754.09130859375
357 124 4789.892578125
358 500 73928.1953125
359 140 5701.27978515625
360 182 9680.3349609375
361 157 7105.388671875
362 375 42349.98828125
363 191 11081.9677734375
364 282 22960.91796875
365 375 41256.1015625
366 500 74403.1171875
367 163 7718.86767578125
368 150 6743.13037109375
369 409 49875.34765625
370 333 33138.30859375
371 196 11386.5673828125
372 146 6203.3173828125
373 500 74518.6640625
374 167 8187.93896484375
375 288 25828.783203125
376 426 55482.4453125
377 395 45578.3125
378 386 44929.40625
379 161 7640.96142578125
380 500 74476.742187

In [9]:
print(
    goodness_score(lambda state: select_action_from_policy(model, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model, state))
)

0.90666 1.0
