# Proximal Policy Optimization
Link to [arXiv paper](https://arxiv.org/pdf/1707.06347.pdf)

In [3]:
import functools
from itertools import chain
import gym
import numpy as np
import time
import torch
from torch import optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

## Network Topology, Datasets

In [4]:
# Policy Network
# Chooses the next move.
class PolicyNetwork(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_in, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, n_out),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        return self.net(x)

# Value Network
# Estimates the value of a state.
class ValueNetwork(nn.Module):
    def __init__(self, n_in):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_in, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )
        
    def forward(self, x):
        return self.net(x)
    
    

# These datasets sample *observations*, not rollouts.
# Policy dataset returns (prev_state, action, reward, ret) tuples.
class PolicyDataset(Dataset):
    def __init__(self, experience):
        super(PolicyDataset, self).__init__()
        self._exp = []
        for x in experience:
            self._exp.extend(x)
        self._length = len(self._exp)

    def __getitem__(self, index):
        return self._exp[index]

    def __len__(self):
        return self._length

## Calculate Returns, Advantages
## Rollout and Train

In [12]:
def _calculate_returns(rollouts, gamma, value, device):
    """
    Modifies `rollouts` in-place from (state, action, reward) to (state, action, reward, return).
    """
    for i, rollout in enumerate(rollouts):
        current_return = 0
        for j in reversed(range(len(rollout))):
            state, action_dist, action, reward = rollout[j]
            ret = reward + gamma*current_return
            adv = ret - value(_prepare_numpy(state, device)).to('cpu')
            rollouts[i][j] = (state, action_dist, action, reward, ret, adv)
            current_return = ret

def get_ratio(current_action_dist, old_action_dist, action):
    """
    Parameters:
        current_action_dist ((n,4) ndarray) - the batched action distributions.
        old_action_dist ((n,4) ndarray): the batched action distributions, frozen.
        action ((n,) ndarray): the actions taken at each observation.
    Returns:
        ratio ((n,) ndarray): (new likelihood)/(old likelihood)
    """
    n = current_action_dist.shape[0]
    current_likelihood = current_action_dist[range(n), action].unsqueeze(1)
    old_likelihood = old_action_dist[range(n), action].unsqueeze(1)
    ratio = current_likelihood/old_likelihood
    return ratio

def likelihood_fn(action_dist, action):
    n = action_dist.shape[0]
    likelihood = action_dist[range(n), action].unsqueeze(1)
    return likelihood

def _prepare_numpy(ndarray, device):
    return torch.Tensor(ndarray).to(device)

def _prepare_tensor_batch(tensor, device):
    return tensor.detach().float().to(device)

def main(policy, value):
    # Initialize environment and networks
    env = gym.make('CartPole-v0')
    device = torch.device('cuda')
    policy = policy.to(device)
    value = value.to(device)
    
    # Hyperparameters
    epochs = 100
    n_episodes = 100
    max_episode_length = 200
    gamma = 0.99
    policy_epochs = 5
    batch_size = 256
    epsilon = 0.2
    lr = 1e-3
    weight_decay = 1e-2
    
    # Optimizers
    params = chain(value.parameters(), policy.parameters())
    optimizer = optim.Adam(params, lr=lr) #, weight_decay=weight_decay)
    
    # Training
    for _ in range(epochs):
        # Generate Rollouts
        rollouts = []
        rewards_per_rollout = []
        for _ in range(n_episodes):
            state = env.reset() # Reset environment each episode

            rollout = []
            reward_total = 0
            for _ in range(max_episode_length):
                # Compute policy probabilities
                action_dist = policy(_prepare_numpy(state, device).unsqueeze(0))[0,:].to('cpu')
                action_one_hot = np.random.multinomial(1, action_dist.cpu().detach().numpy())
                action = np.argmax(action_one_hot)

                # Take Action
                state, reward, done, _ = env.step(action)
                state = state.astype(np.float32)
                rollout.append((state, action_dist, action, reward))
                reward_total += reward

                if done:
                    break

            rollouts.append(rollout)
            rewards_per_rollout.append(reward_total)
        # End Rollouts
        
        # Prepare Data
        _calculate_returns(rollouts, gamma, value, device) # modifies rollouts inplace
        avg_reward = sum(rewards_per_rollout) / len(rewards_per_rollout)
        experience_dataset = PolicyDataset(rollouts)
        data_loader = DataLoader(experience_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
        # End Prepare Data
        
        # Train
        value_criteria = nn.MSELoss()              
        for _ in range(policy_epochs):
            avg_value_loss = 0
            avg_policy_loss = 0
            for state, action_dist, action, reward, ret, adv in data_loader:
                # state is an (n_timesteps,4) tensor; action/reward/ret are all (n_timestaps,) tensors
                optimizer.zero_grad()
                
                # Convert float64 to float32, detach tensors
                state = _prepare_tensor_batch(state, device)
                action_dist = _prepare_tensor_batch(action_dist, device)
                action = _prepare_tensor_batch(action, device).long()
                reward = _prepare_tensor_batch(reward, device)
                ret = _prepare_tensor_batch(ret, device)
                adv = _prepare_tensor_batch(adv, device)
                
                # Calculate the ratio term
                current_action_dist = policy(state)
                ratio = get_ratio(current_action_dist, action_dist, action)
                
                # Value Loss
                expected_returns = value(state).view(-1)
                value_loss = value_criteria(expected_returns, ret)
        
                # Proximal Policy Optimization
#                 advantage = ret - expected_returns.detach()
                advantage = adv
                lhs = ratio * advantage
                rhs = torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantage
                policy_loss = -1*torch.mean(torch.min(lhs, rhs))
        
                # Policy gradient
                likelihood = likelihood_fn(current_action_dist, action)
                policy_loss = torch.mean(torch.log(likelihood)*advantage)
                
                # Logging
                avg_value_loss += value_loss.item()
                avg_policy_loss += policy_loss.item()
                
                # Backpropagate
                loss = policy_loss + value_loss
                loss.backward()
                optimizer.step()
              
            avg_value_loss /= len(data_loader)
            avg_policy_loss /= len(data_loader)
         
        ratio_print = ratio.mean().item()
        adv_print = advantage.mean().item()
        ret_print = ret[0].item()
        exp_ret_print = expected_returns.mean().item()
#         print("ret: {:.3f}, exp_ret: {:.3f}".format(ret_print, exp_ret_print))
        print("avg reward: {:.3f}, value loss: {:.3f}, policy loss: {:.3f}, ratio: {:.3f}, adv: {:.3f}".format(avg_reward, 
                                    avg_value_loss, avg_policy_loss, ratio_print, adv_print))

n_in = 4
n_out = 2  
device = 'cuda'
policy = PolicyNetwork(n_in, n_out).to(device)
value = ValueNetwork(n_in).to(device)
main(policy, value)
print("done")

avg reward: 21.430, value loss: 240.249, policy loss: -9.423, ratio: 0.961, adv: 13.355
avg reward: 18.870, value loss: 169.881, policy loss: -9.116, ratio: 0.876, adv: 10.703
avg reward: 14.210, value loss: 72.340, policy loss: -5.641, ratio: 0.871, adv: 7.302
avg reward: 13.470, value loss: 59.011, policy loss: -5.217, ratio: 0.869, adv: 6.503
avg reward: 14.730, value loss: 68.046, policy loss: -7.001, ratio: 0.761, adv: 7.069
avg reward: 13.310, value loss: 43.797, policy loss: -5.058, ratio: 0.747, adv: 6.157
avg reward: 13.190, value loss: 34.896, policy loss: -3.851, ratio: 0.747, adv: 4.793
avg reward: 12.640, value loss: 25.403, policy loss: -2.198, ratio: 0.606, adv: 2.566
avg reward: 12.240, value loss: 21.614, policy loss: -0.765, ratio: 0.585, adv: 1.660
avg reward: 11.730, value loss: 19.407, policy loss: -0.216, ratio: 0.702, adv: 0.861
avg reward: 12.090, value loss: 19.544, policy loss: -0.269, ratio: 0.668, adv: 1.106
avg reward: 11.830, value loss: 17.592, policy los

In [59]:
type(rollouts[0][:,0])

numpy.ndarray