In [1]:
# act with actor critic on moon lander 

In [2]:
import gym
import numpy as np
import matplotlib.pyplot as plt
import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

import sys
import time
from IPython.display import display, clear_output

from utils import ReturnTracker, Agent
from tqdm.auto import tqdm

In [3]:
class PolicyNet(nn.Module):

    def __init__(self, n_observations, n_actions, nodes=256, noise=1e-6):
        super(PolicyNet, self).__init__()
        self.nodes = nodes
        self.fc1 = nn.Linear(n_observations, nodes)
        self.fc2 = nn.Linear(nodes, nodes)
        self.mu = nn.Linear(nodes, n_actions)
        self.log_std = nn.Linear(nodes, n_actions)
        self.noise = noise

    def forward(self, observations):
        x = F.relu(self.fc1(observations))
        x = F.relu(self.fc2(x))
        mu = self.mu(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, min=-10, max=2)
        return mu, log_std
    
    def sample(self, observations, greedy=False):
        mu, log_std = self.forward(observations)
        sigma = log_std.exp()

        probs = Normal(mu, sigma)
        sample = probs.rsample() if not greedy else mu
        actions = torch.tanh(sample)
        
        log_probs = probs.log_prob(sample)
        log_probs -= torch.log(1 - actions.pow(2) - self.noise)
        log_probs = log_probs.sum(dim=1, keepdim=True)

        return actions, log_probs

In [4]:
class ValueNet(nn.Module):

    def __init__(self, n_observations, nodes=256):
        super(ValueNet, self).__init__()
        self.nodes = nodes
        self.fc1 = nn.Linear(n_observations, nodes)
        self.fc2 = nn.Linear(nodes, nodes)
        self.fc3 = nn.Linear(nodes, 1)

    def forward(self, observations):
        x = F.relu(self.fc1(observations))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
class ActorCritic(Agent):
    
    def __init__(self, env, lr_policy=2e-5, lr_critic=1e-3, gamma=.99):
        self.env = env
        self.lr_policy = lr_policy
        self.lr_critic = lr_critic
        self.gamma = gamma
        
        self.policy = PolicyNet(env.observation_space.shape[0], env.action_space.shape[0])
        self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=lr_policy)
        
        self.critic = ValueNet(env.observation_space.shape[0])
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=lr_critic)
        
        self.total_num_steps = 0
    
    def update(self, state, action, log_prob, reward, new_state, steps):
        
        # train critic
        value = self.critic(state)
        new_value = self.critic(new_state)
        critic_loss = F.mse_loss(reward + self.gamma*new_value.detach(), value)
        self.critic_optim.zero_grad()
        critic_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_value_(self.critic.parameters(), 0.5)
        self.critic_optim.step()
        
        # train policy
        td_error = reward + self.gamma * new_value.detach() - value.detach()
        policy_loss = - td_error * log_prob  # * self.gamma ** steps 
        self.policy_optim.zero_grad()
        policy_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_value_(self.policy.parameters(), 0.5)
        self.policy_optim.step()
        
    
    def learn(self, episodes, max_steps=None, tracker=None, verb=True):
        
        self.train()
        
        iterator = range(episodes)
        if verb:
            iterator = tqdm(iterator, leave=True)
        
        for episode in iterator:

            if tracker is not None:
                tracker.new_episode()
            
            state, info = self.env.reset()
            state = torch.tensor(state).float().reshape(1,-1)
            
            returns = 0
            steps = 0
            terminated = False
            truncated = False
            
            while not terminated and not truncated:
                
                # pdb.set_trace()
                action, log_prob = self.policy.sample(state)
                
                new_state, reward, terminated, truncated, info = self.env.step(action.flatten().detach().numpy())
                new_state = torch.tensor(new_state).float().reshape(1,-1)

                #update
                self.update(state, action, log_prob, reward, new_state, steps)
                
                state = new_state
                self.total_num_steps += 1
                steps += 1
                returns += reward
                
                if tracker is not None:
                    tracker.add(reward)
                
                if max_steps is not None:
                    if steps > max_steps:
                        break
            
            if verb:
                iterator.set_description(f"total steps: {self.total_num_steps}, episode: {episode}, return: {returns:.4f}")
    
    def predict(self, state):
        self.policy.eval()
        
        state = torch.tensor(state).float()
            
        if len(state.shape) != 2:
            state = state.unsqueeze(0)

        action, log_probs = self.policy.sample(state, greedy=True)
        
        action = action.detach().numpy()
        log_probs = log_probs.detach().numpy()
            
        return action, log_probs
    
    def train(self):
        self.policy.train()
        self.critic.train()

    def eval(self):
        self.policy.eval()
        self.critic.eval()

In [12]:
env = gym.make("LunarLander-v2", continuous=True, render_mode='rgb_array').env  # "LunarLander-v2", continuous=True; "Pendulum-v1", 

In [13]:
env.action_space.shape

(2,)

In [14]:
observation, info = env.reset()

In [15]:
agent = ActorCritic(env, lr_policy=2e-2, lr_critic=2e-2, gamma=.99)



In [16]:
tracker = ReturnTracker()



In [None]:
agent.learn(episodes=2000,
            max_steps=500, 
            tracker=tracker)



  0%|          | 0/2000 [00:00<?, ?it/s]

In [None]:
tracker.plot(smooth=1)


In [None]:
# test

while True:
    
    terminated, truncated = False, False
    
    state, info = env.reset()
    
    steps = 0
    returns = 0
    
    while not terminated and not truncated:

        time.sleep(.05)
        clear_output(wait=True)

        action, log_prob = agent.predict(state)
        new_state, reward, terminated, truncated, info = env.step(action.flatten())

        steps += 1
        state = torch.tensor(new_state)
        returns += reward
        
        print('steps: {}, returns: {}'.format(steps, returns))

        plt.imshow(env.render())
        plt.axis('off')
        plt.show()
    
        sys.stdout.flush()
        
        if steps > 100:
            break
    
    time.sleep(1.)