In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym
import matplotlib.pyplot as plt

In [14]:
LR_A = 0.001
LR_C = 0.01     # learning rate for critic
MAX_EPISODE = 3000
DISPLAY_REWARD_THRESHOLD = 200  # renders environment if total episode reward is greater then this threshold
MAX_EP_STEPS = 1000   # maximum time step in one episode
RENDER = True  # rendering wastes time
GAMMA = 0.9     # reward discount in TD error
env = gym.make('CartPole-v0')
env.seed(1)  # reproducible
env = env.unwrapped

N_F = env.observation_space.shape[0]
N_A = env.action_space.n

In [16]:
class Actor_Net(nn.Module):
    def __init__(self):
        super(Actor_Net, self).__init__()
        self.fc1 = nn.Linear(N_F, 20)
        self.fc1.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(20, N_A)
        self.out.weight.data.normal_(0, 0.1)
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.out(x)
        x = x - x.max()
        prob_actions = F.softmax(x, dim=1)
        return prob_actions
    
class Critic_Net(nn.Module):
    def __init__(self):
        super(Critic_Net, self).__init__()
        self.fc1 = nn.Linear(N_F, 20)
        self.fc1.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(20, N_A)
        self.out.weight.data.normal_(0, 0.1)
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.out(x)
        x = x - x.max()
        prob_actions = F.softmax(x, dim=1)
        return prob_actions
        

In [17]:
class Actor(object):
    def __init__(self):
        self.actor_net = Actor_Net()
        self.optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=LR_A)
        #self.loss_func = nn.CrossEntropyLoss()
       
    def learn(self, s, a, td):
        s = Variable(torch.unsqueeze(torch.FloatTensor(s),0))
        prob_all_actions = self.actor_net.forward(s)
        neg_log_prob = -torch.log(prob_all_actions)
        loss = torch.sum(neg_log_prob * td).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss
    def choose_action(self, s):
        s = Variable(torch.unsqueeze(torch.FloatTensor(s),0))
        prob_weights = self.actor_net.forward(s)
        action = np.random.choice(range(prob_weights.shape[1]), p=prob_weights.view(-1).detach().numpy())
        return action
        
        

In [18]:
class Critic(object):
    def __init__(self):
        self.critic_net = Critic_Net()
        self.optimizer = torch.optim.Adam(self.critic_net.parameters(), lr=LR_C)
    def learn(self, s, r, s_):
        s = Variable(torch.unsqueeze(torch.FloatTensor(s),0))
        s_ = Variable(torch.unsqueeze(torch.FloatTensor(s_),0))
        #print(r)
        #r = Variable(torch.FloatTensor(r))
        v_next = self.critic_net(s_)
        v = self.critic_net(s)
        td_error = r + GAMMA * v_next - v
        #print(td_error)
        loss = torch.mul(td_error, td_error).mean()
        print(loss)
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer.step()
        return td_error
        
        

In [19]:
actor = Actor()
critic = Critic()

In [20]:
for i_episode in range(MAX_EPISODE):
    s = env.reset()
    t = 0
    track_r = []
    while True:
        if RENDER: env.render()

        a = actor.choose_action(s)

        s_, r, done, info = env.step(a)

        if done: r = -20

        track_r.append(r)

        td_error = critic.learn(s, r, s_)  # gradient = grad[r + gamma * V(s_) - V(s)]
        actor.learn(s, a, td_error)     # true_gradient = grad[logPi(s,a) * td_error]

        s = s_
        t += 1

        if done or t >= MAX_EP_STEPS:
            ep_rs_sum = sum(track_r)

            if 'running_reward' not in globals():
                running_reward = ep_rs_sum
            else:
                running_reward = running_reward * 0.95 + ep_rs_sum * 0.05
            if running_reward > DISPLAY_REWARD_THRESHOLD: RENDER = True  # rendering
            print("episode:", i_episode, "  reward:", int(running_reward))
            break

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9026, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 0   reward: 53
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025,

episode: 10   reward: 31
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 11   reward: 29
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 12   reward: 26
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
t

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 22   reward: 15
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025

tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 34   reward: 5
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 35   reward: 5
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<Me

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 52   reward: -3
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 53   reward: -4
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<Me

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 71   reward: -8
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 72   reward: -8
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<Me

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 91   reward: -10
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 92   reward: -10
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 93   reward: 

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 112   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 113   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 133   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 134   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 154   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 155   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0024, grad_

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 174   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 175   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 176   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackw

tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 195   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(402.0025, grad_fn=<MeanBackward0>)
episode: 196   reward: -11
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn=<MeanBackward0>)
tensor(0.9025, grad_fn

ValueError: probabilities contain NaN