https://github.com/dennybritz/reinforcement-learning/blob/master/PolicyGradient/Continuous%20MountainCar%20Actor%20Critic%20Solution.ipynb

In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from IPython.display import clear_output

### Actor-critic for MountainCarContinuous-v0

In [2]:
env = gym.make("MountainCarContinuous-v0")
s = env.reset()

In [3]:
obs_shape = env.observation_space.shape
action_shape = env.action_space.shape

In [4]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.dense1 = nn.Linear(obs_shape[0], 200) 
        self.dense2 = nn.Linear(200, 1)
        self.dense3 = nn.Linear(200, 1)
        
        
    def forward(self, x):
        x = torch.tensor(x, dtype=torch.float32)
        x = F.relu(self.dense1(x))
        mu = self.dense2(x)
        sigma = self.dense3(x)
        
        mu = torch.squeeze(mu)
        sigma = F.softplus(torch.squeeze(sigma)) + 1e-5
        self.normal = torch.distributions.normal.Normal(mu, sigma)
        actions = self.normal.sample(sample_shape=torch.Size(action_shape))
        actions = torch.clamp(actions, env.action_space.low[0], env.action_space.high[0])
        return actions
    
    def get_entropy(self):
        return self.normal.entropy()
    
    def get_log_prob(self, action):
        return self.normal.log_prob(action)

In [5]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super(ValueNetwork, self).__init__()
        self.dense1 = nn.Linear(obs_shape[0], 200)
        self.dense2 = nn.Linear(200, 1)
        
    def forward(self, x):
        x = torch.tensor(x, dtype=torch.float32)
        x = F.relu(self.dense1(x))
        v_s = self.dense2(x)
        
        return v_s

In [6]:
policy_estimator = PolicyNetwork()
policy_estimator([s]).shape

torch.Size([1])

In [7]:
env.action_space.sample().shape

(1,)

In [8]:
value_estimator = ValueNetwork()

In [9]:
def generate_session(tmax=1000):
    states, actions, rewards, dones, next_states = [], [], [], [], []
    s = env.reset()
    for i in range(tmax):
        action = policy_estimator([s])
        new_s, reward, done, info = env.step(action.data.numpy())
        
        if done:
            break
        
        states.append(s)
        actions.append(action)
        rewards.append(reward)
        dones.append(done)
        next_states.append(new_s)
        s = new_s
        
    return states, actions, rewards, dones, next_states

In [10]:
states, actions, rewards, dones, next_states = generate_session()

In [11]:
def compute_td_loss(states, rewards, next_states, gamma):
    next_v_s = value_estimator(next_states).detach()
    td_target = rewards + gamma*next_v_s
    td_error = (td_target - value_estimator(states))**2
    loss = torch.mean(td_error)
    
    return loss

In [12]:
def compute_J_hat(states, actions, rewards, next_states, gamma):
    H = -policy_estimator.get_entropy()
    td_target = rewards + gamma*value_estimator(next_states)
    J_hat = -policy_estimator.get_log_prob(actions)*(td_target - value_estimator(states)).detach()
    loss = H + J_hat
    
    return loss

In [13]:
def train(gamma=0.98, episodes=10, tmax=1000):
    Ltd_opt = torch.optim.Adam(policy_estimator.parameters(), lr=0.001)
    J_opt = torch.optim.Adam(value_estimator.parameters(), lr=0.001)
    rewards = []
    loss_Ltd = []
    loss_J = []
    for i_episode in range(episodes):
        state = env.reset()
        episode_rewards = []
        
        for t in range(tmax):
            
            action = policy_estimator(state)
            next_state, reward, done, _ = env.step(action)
            
            episode_rewards.append(reward)
            
            Ltd = compute_td_loss(state, reward, next_state, gamma)
            if t % 100 == 0:
                loss_Ltd.append(Ltd)
            Ltd.backward()
            Ltd_opt.step()
            Ltd_opt.zero_grad()
            
            J = compute_J_hat(state, action, reward, next_state, gamma)
            if t % 100 == 0:
                loss_J.append(J)
            J.backward()
            J_opt.step()
            J_opt.zero_grad()
            
            if done:
                break
        
        rewards.append(np.mean(np.array(episode_rewards)))
        print("episode: {}, reward: {}".format(i_episode + 1, np.mean(np.array(episode_rewards))))
    plt.subplot(121)
    plt.plot(loss_Ltd)
    plt.title("Ltd loss")
    plt.subplot(122)
    plt.plot(loss_J)
    plt.title("J_hat loss")
    plt.show()
        
    return np.array(rewards)