# Proximal Policy Optimization

### Import the libraries

In [23]:
# Calculations
import numpy as np
from numpy import random, sqrt

# Network libraries
import torch
import torch.nn.functional as F, torch.nn as nn, torch.optim as optim
from torch.distributions import Categorical as Categorical

# Game library
from TicTacToe import initialState, randomState, printState, drawState, move, possibleMoves, possibleMovesMask, isOver, gameScore

### Generalized Advantage Estimate (GAE)

The GAE is calculated using the formula:

$ \begin{equation*}
\hat{A_t}^{GAE(\gamma, \lambda)} = (1-\lambda)\left( \hat{A}^{(1)}_t + \lambda\hat{A}^{(2)}_t + \lambda^2\hat{A}^{(3)}_t + ...\right) = \sum_{l=0}^{\infty}(\gamma\lambda)^l\delta^V_{t+l}
\end{equation*}$

where $\delta^V_{t}$ is the TD residual defined by:

$ \begin{equation*}
\delta^V_{t} = r_t + \gamma V^\pi(s_{t+1})-V^\pi(s_t)
\end{equation*}
$



Note that by the definition, we can easily evaulate $ \hat{A_t}^{GAE(\gamma, \lambda)} $ based on $ \hat{A_{t+1}}^{GAE(\gamma, \lambda)} $:

\begin{equation*}
\hat{A_t}^{GAE(\gamma, \lambda)} = \delta^V_t + \gamma\lambda \hat{A}_{t+1}^{GAE(\gamma, \lambda)} 
\end{equation*}


In [24]:
def gae(values, rewards, lmbda = 0.95, gamma = 0.99):
	# Initialize the advangate
	adv = 0
	# Initialize the results
	results = []

	# Starting with the last action, iterate over actions and calculate the advantages
	for t in reversed(range(len(rewards))):
		# Calculate the current value of delta
		delta = rewards[t] + gamma * values[t+1] - values[t]
		# Update the current advantage
		adv = delta + gamma * lmbda * adv
		# Add the advantage to the results
		results.append(adv)
	
	# Reverse the results to obtain the right order
	results.reverse()

	# Return the results
	return results

### Actor-Critic Network

In [25]:
class ActorCritic(nn.Module):

	# Initialize the actor and critic
	def __init__(self):
		
		super(ActorCritic, self).__init__()

		# Actor
		self.actor = nn.Sequential(
			nn.Conv2d(1, 8, 2),
			nn.ReLU(),
			nn.Conv2d(8, 32, 2),
			nn.ReLU(),
			nn.Flatten(),
			nn.Linear(32, 9)
		)

		# Critic
		self.critic = nn.Sequential(
			nn.Conv2d(1, 8, 2),
			nn.ReLU(),
			nn.Conv2d(8, 32, 2),
			nn.ReLU(),
			nn.Flatten(),
			nn.Linear(32, 1)
		)

		self.softmax = nn.Softmax(dim=-1)

	
	# Instead of using a forward function, we will split it into 
	# one action for the actor, and one for the critic
	def forward(self):
		raise NotImplementedError
	

	# For a given state, return an action according to a current policy and the log_prob
	# of performing this action
	def act(self, state):
		# Change the state into suitable format
		tensor_state = torch.from_numpy(state.reshape(-1, 1, 3,3)).float()

		# Calculate the initial output of the actor
		probs = self.actor(tensor_state)
		
		# Apply the invalid action masking
		probs = torch.where(possibleMovesMask(state, 1), probs, torch.tensor([-1e8]*9))

		# Apply the softmax function to obtain probabilities
		probs = self.softmax(probs)

		# Create a distribution
		dist = Categorical(probs)

		# Pick an action using the generated distribution
		action = dist.sample()

		# Evaluate the log_prob of the chosen action
		log_prob = dist.log_prob(action)

		# Return the action and the log_probabilties
		return action.detach(), log_prob.detach()


	# For a given state and action, return:
	# 	- 	the log_probabilities of the action, according to the current policy, 
	# 	-	the estimated value of the state, according to the critic,
	#	-	the entropy of the distribution
	def evaluate(self, state, action):
		# Change the state into suitable format
		tensor_state = torch.from_numpy(state.reshape(-1, 1, 3,3)).float()

		# Calculate the initial output of the actor
		probs = self.actor(tensor_state)
		
		# Apply the invalid action masking
		probs = torch.where(possibleMovesMask(state, 1), probs, torch.tensor([-1e8]*9))

		# Apply the softmax function to obtain probabilities
		probs = self.softmax(probs)

		# Create a distribution
		dist = Categorical(probs)

		# Evaluate the log_prob of the chosen action
		log_prob = dist.log_prob(action)

		# Calculate the entropy of the distribution
		dist_entropy = dist.entropy()

		# Calculate the evaluation of the position
		state_eval = self.critic(tensor_state)

		# Return the action and the log_probabilties
		return log_prob, state_eval, dist_entropy