## Backbone neural network for actor and critic networks

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

class FeedForwardNN(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(FeedForwardNN, self).__init__()

        self.layer1 = nn.Linear(in_dim, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, out_dim)
        
    def forward(self, states):
        if isinstance(states, np.ndarray):
            states = torch.tensor(states, dtype = torch.float)
        
        activation1 = F.relu(self.layer1(states))
        activation2 = F.relu(self.layer1(activation1))
        out         = self.layer3(activation2)
        
        return out

## PPO

### Steps:

1. Initialize actor and critic network
2. 

In [3]:
from torch.distributions import MultivariateNormal

class PPO:
    def __init__(self, env):
        self.env = env
        self.states_dim = env.observation_space.shape[0]
        self.act_dim    = env.action_space.shape[0]
        
        ## STEP 1
        #input is state for both actor and critic networks
        #output is a value for critic networks, and action distribution for actor networks 
        self.actor  = FeedForwardNN(self.states_dim, self.act_dim) 
        self.critic = FeedForwardNN(self.states_dim, 1)
        
        ##this is for sampling actions when collecting data
        self.cov_var = torch.full(size = (self.act_dim), fill_value=0.5)
        self.cov_mat = torch.diag(self.cov_var)  #basically every action has a probabiliy of 0.5
    
    def _init_params(self):
        self.timesteps_per_batch = 4800
        self.max_timesteps_per_episode = 1600
    
    def collect_data(self):
        batch_states    = [] #shape: (number of timesteps per batch, states_dim)
        batch_acts      = [] #shape: (number of timesteps per batch, act_dim)
        batch_log_probs = [] #(number of timesteps per batch, )
        batch_rewards   = [] #(number of episodes, number of timesteps per episode)
        batch_rewards_to_go = [] #(number of timesteps per batch, )
        batch_lens      = [] #(number of episodes, )
        
        #Number of timesteps run so far this batch
        t = 0
        
        while t < self.timesteps_per_batch:
            #Rewards this episode
            ep_rewards = []
            
            states = self.env.reset()
            done   = False
            
            for ep_t in range(self.max_timesteps_per_episode):
                t+=1
                
                #collect states
                batch_states.append(states)
                
                action, log_prob = self.get_action(states)
                states, rewards, done, _ = self.env.step(action)
                
                #collect reward, action, and log prob
                ep_rewards.append(rewards)
                batch_acts.append(action)
                batch_log_probs.append(log_prob)
                
                if done:
                    break
                
            batch_lens.append(ep_t + 1)
            batch_rewards.append(ep_rewards)
        
        # Reshape data as tensors in the shape specified before returning
        batch_states = torch.tensor(batch_states, dtype=torch.float)
        batch_acts = torch.tensor(batch_acts, dtype=torch.float)
        batch_log_probs = torch.tensor(batch_log_probs, dtype=torch.float)
        
        #
        batch_rewards_to_go = self.compute_rewards_to_go(batch_rewards)
        
        return batch_states, batch_acts, batch_log_probs, batch_rewards_to_go, batch_lens
                
    def learn(self, total_timesteps):
        t_so_far = 0 # Timesteps simulated until now
        
        while t_so_far < total_timesteps:
            batch_states, batch_acts, batch_log_probs, batch_rewards, batch_lens = self.collect_data()
        
    def get_action(self, states):
        mean = self.actor(states)
        dist = MultivariateNormal(mean, self.cov_mat)
        
        #sample action from this distribution
        action = dist.sample()
        log_prob = dist.log_prob(action)
        
        #detach from computational graph
        return action.detach().numpy(), log_prob.detach()
        