In [None]:
from collections import deque
import gym
import numpy as np
import torch
class PolicyNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 2, bias=False)
    def forward(self, x):
        x = self.fc(x)
        return torch.nn.functional.softmax(x, dim=-1)
pi = PolicyNetwork()
optimizer = torch.optim.SGD(pi.parameters(), lr=0.001)
env = gym.make('CartPole-v0')
max_steps = 200
discount = 0.995  # discount factor gamma
reward_history = deque(maxlen=100)


for episode in range(1, 10000 + 1):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32)
    history = []
    rewards = 0
    done = False
    for t in range(1, max_steps + 1):
        #env.render()
        probs = pi(state)
        action = torch.multinomial(probs, 1).item()
        state_next, reward, done, info = env.step(action)
        rewards += reward
        timeout = (t == max_steps)
        history.append([reward, probs[action]])
        if done:
            break
        state = torch.tensor(state_next, dtype=torch.float32)
    # compute average reward over 100 episodes
    reward_history.append(rewards)
    avg = np.mean(reward_history)
    print('episode: {}, reward: {}, avg: {}'.format(episode, rewards, avg))
    # update policy
    loss = 0
    G = 200 if timeout else 0
    for r, prob in reversed(history):
        G = r + discount*G
        loss -= G * prob.log()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
env.close()


episode: 1, reward: 14.0, avg: 14.0
episode: 2, reward: 16.0, avg: 15.0
episode: 3, reward: 10.0, avg: 13.333333333333334
episode: 4, reward: 16.0, avg: 14.0
episode: 5, reward: 14.0, avg: 14.0
episode: 6, reward: 32.0, avg: 17.0
episode: 7, reward: 16.0, avg: 16.857142857142858
episode: 8, reward: 13.0, avg: 16.375
episode: 9, reward: 53.0, avg: 20.444444444444443
episode: 10, reward: 15.0, avg: 19.9
episode: 11, reward: 25.0, avg: 20.363636363636363
episode: 12, reward: 12.0, avg: 19.666666666666668
episode: 13, reward: 48.0, avg: 21.846153846153847
episode: 14, reward: 36.0, avg: 22.857142857142858
episode: 15, reward: 22.0, avg: 22.8
episode: 16, reward: 26.0, avg: 23.0
episode: 17, reward: 10.0, avg: 22.235294117647058
episode: 18, reward: 13.0, avg: 21.72222222222222
episode: 19, reward: 14.0, avg: 21.31578947368421
episode: 20, reward: 14.0, avg: 20.95
episode: 21, reward: 13.0, avg: 20.571428571428573
episode: 22, reward: 34.0, avg: 21.181818181818183
episode: 23, reward: 17.0,

episode: 189, reward: 164.0, avg: 96.92
episode: 190, reward: 58.0, avg: 97.05
episode: 191, reward: 118.0, avg: 97.01
episode: 192, reward: 42.0, avg: 96.57
episode: 193, reward: 92.0, avg: 96.25
episode: 194, reward: 62.0, avg: 95.58
episode: 195, reward: 86.0, avg: 95.33
episode: 196, reward: 112.0, avg: 95.07
episode: 197, reward: 165.0, avg: 94.84
episode: 198, reward: 81.0, avg: 94.86
episode: 199, reward: 52.0, avg: 94.03
episode: 200, reward: 77.0, avg: 93.37
episode: 201, reward: 73.0, avg: 93.71
episode: 202, reward: 62.0, avg: 92.43
episode: 203, reward: 57.0, avg: 92.53
episode: 204, reward: 72.0, avg: 92.7
episode: 205, reward: 104.0, avg: 92.6
episode: 206, reward: 77.0, avg: 92.65
episode: 207, reward: 93.0, avg: 92.07
episode: 208, reward: 56.0, avg: 92.1
episode: 209, reward: 75.0, avg: 91.53
episode: 210, reward: 61.0, avg: 91.43
episode: 211, reward: 72.0, avg: 91.19
episode: 212, reward: 83.0, avg: 91.42
episode: 213, reward: 66.0, avg: 91.31
episode: 214, reward: 6

episode: 393, reward: 200.0, avg: 155.13
episode: 394, reward: 200.0, avg: 156.14
episode: 395, reward: 200.0, avg: 156.84
episode: 396, reward: 200.0, avg: 158.04
episode: 397, reward: 200.0, avg: 158.81
episode: 398, reward: 200.0, avg: 159.42
episode: 399, reward: 200.0, avg: 160.68
episode: 400, reward: 200.0, avg: 160.94
episode: 401, reward: 200.0, avg: 162.15
episode: 402, reward: 200.0, avg: 162.64
episode: 403, reward: 200.0, avg: 163.56
episode: 404, reward: 200.0, avg: 163.71
episode: 405, reward: 200.0, avg: 164.53
episode: 406, reward: 200.0, avg: 165.25
episode: 407, reward: 200.0, avg: 165.35
episode: 408, reward: 200.0, avg: 165.77
episode: 409, reward: 200.0, avg: 165.77
episode: 410, reward: 200.0, avg: 166.97
episode: 411, reward: 200.0, avg: 167.83
episode: 412, reward: 200.0, avg: 167.97
episode: 413, reward: 200.0, avg: 168.6
episode: 414, reward: 200.0, avg: 168.86
episode: 415, reward: 200.0, avg: 169.62
episode: 416, reward: 200.0, avg: 170.77
episode: 417, rew

episode: 601, reward: 200.0, avg: 162.53
episode: 602, reward: 166.0, avg: 162.19
episode: 603, reward: 169.0, avg: 161.88
episode: 604, reward: 200.0, avg: 161.88
episode: 605, reward: 200.0, avg: 161.88
episode: 606, reward: 200.0, avg: 161.88
episode: 607, reward: 200.0, avg: 161.88
episode: 608, reward: 200.0, avg: 161.88
episode: 609, reward: 200.0, avg: 161.88
episode: 610, reward: 200.0, avg: 161.88
episode: 611, reward: 200.0, avg: 161.88
episode: 612, reward: 200.0, avg: 161.88
episode: 613, reward: 200.0, avg: 161.88
episode: 614, reward: 200.0, avg: 161.88
episode: 615, reward: 200.0, avg: 161.88
episode: 616, reward: 200.0, avg: 161.88
episode: 617, reward: 200.0, avg: 161.88
episode: 618, reward: 200.0, avg: 161.88
episode: 619, reward: 200.0, avg: 161.88
episode: 620, reward: 200.0, avg: 161.88
episode: 621, reward: 200.0, avg: 161.88
episode: 622, reward: 200.0, avg: 161.88
episode: 623, reward: 200.0, avg: 161.88
episode: 624, reward: 200.0, avg: 161.88
episode: 625, re

episode: 802, reward: 200.0, avg: 187.43
episode: 803, reward: 200.0, avg: 187.43
episode: 804, reward: 200.0, avg: 187.43
episode: 805, reward: 200.0, avg: 187.43
episode: 806, reward: 200.0, avg: 187.43
episode: 807, reward: 200.0, avg: 187.43
episode: 808, reward: 200.0, avg: 187.43
episode: 809, reward: 200.0, avg: 187.43
episode: 810, reward: 188.0, avg: 187.31
episode: 811, reward: 200.0, avg: 187.31
episode: 812, reward: 200.0, avg: 187.31
episode: 813, reward: 200.0, avg: 187.31
episode: 814, reward: 200.0, avg: 187.31
episode: 815, reward: 92.0, avg: 186.23
episode: 816, reward: 137.0, avg: 185.6
episode: 817, reward: 117.0, avg: 184.77
episode: 818, reward: 102.0, avg: 183.79
episode: 819, reward: 200.0, avg: 183.79
episode: 820, reward: 200.0, avg: 183.79
episode: 821, reward: 200.0, avg: 183.79
episode: 822, reward: 151.0, avg: 183.3
episode: 823, reward: 200.0, avg: 183.3
episode: 824, reward: 200.0, avg: 183.3
episode: 825, reward: 200.0, avg: 183.3
episode: 826, reward: 

episode: 1006, reward: 200.0, avg: 198.13
episode: 1007, reward: 200.0, avg: 198.13
episode: 1008, reward: 200.0, avg: 198.13
episode: 1009, reward: 200.0, avg: 198.13
episode: 1010, reward: 200.0, avg: 198.13
episode: 1011, reward: 200.0, avg: 198.13
episode: 1012, reward: 200.0, avg: 198.13
episode: 1013, reward: 200.0, avg: 198.13
episode: 1014, reward: 200.0, avg: 198.13
episode: 1015, reward: 200.0, avg: 198.13
episode: 1016, reward: 200.0, avg: 198.13
episode: 1017, reward: 200.0, avg: 198.13
episode: 1018, reward: 200.0, avg: 198.13
episode: 1019, reward: 200.0, avg: 198.13
episode: 1020, reward: 200.0, avg: 198.13
episode: 1021, reward: 200.0, avg: 198.13
episode: 1022, reward: 200.0, avg: 198.13
episode: 1023, reward: 200.0, avg: 198.13
episode: 1024, reward: 200.0, avg: 198.13
episode: 1025, reward: 200.0, avg: 198.13
episode: 1026, reward: 200.0, avg: 198.13
episode: 1027, reward: 200.0, avg: 198.13
episode: 1028, reward: 180.0, avg: 197.93
episode: 1029, reward: 186.0, avg: