# PPO

In [51]:
from typing import Optional
from collections import namedtuple, deque
import random
import numpy as np

import gymnasium as gym
from gym import Env

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.distributions import Beta
from torch.optim import Adam

In [16]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
max_episode_steps = 100
replay_buffer_capacity = 100

## Pendulum Environment

In [17]:
env = gym.make("Pendulum-v1", max_episode_steps=max_episode_steps)

In [18]:
class StateActionSpec:
    
    def __init__(self, env: Env) -> None:
        
        self._env = env
    
    @property
    def state_dim(self) -> Optional[int]:
        
        try:
            return int(self._env.observation_space.shape[0])
        except:
            return None
    
    @property
    def action_dim(self) -> Optional[int]:
        
        try:
            return int(self._env.action_space.shape[0])
        except:
            return None

    @property
    def n_actions(self) -> Optional[int]:
        
        try:
            return self._env.action_space.n
        except:
            return None
    
    @property
    def action_min(self) -> Optional[np.ndarray]:
        
        try:
            return env.action_space.__dict__["low"]
        except:
            return None
    
    @property
    def action_max(self) -> Optional[np.ndarray]:
        
        try:
            return env.action_space.__dict__["high"]
        except:
            return None


In [19]:
state_action_spec = StateActionSpec(env)

In [20]:
Transition = namedtuple(
    "Transition",
    (
        "state",
        "action",
        "next_state",
        "reward"
    )
)

In [21]:
class ReplayBuffer:
    
    def __init__(self, capacity: int) -> None:
        
        self._capacity = capacity
        self._transitions = deque([], maxlen=self._capacity)
        
    def __len__(self) -> int:
        return len(self._transitions)
    
    @property
    def capacity(self) -> int:
        """Maximum number of transitions in memory.
        """
        return self._capacity
    
    def add(self, *args) -> None:
        
        if len(args) == 1:
            transition = args[0]
            assert isinstance(transition, Transition)
            
        elif len(args) == len(Transition._fields):
            transition = Transition(*args)
            
        else:
            raise ValueError
        
        self._transitions.append(transition)
        
    def sample(self, batch_size: int) -> list[Transition]:
        
        return random.sample(
            population=self._transitions, 
            k=batch_size
        )


In [22]:
replay_buffer = ReplayBuffer(replay_buffer_capacity)

In [23]:
class FeatureNetwork(nn.Module):
    
    def __init__(
            self, 
            state_action_spec: StateActionSpec
        ) -> None:
        
        super().__init__()
        
        self._state_action_spec = state_action_spec
        self.fc = nn.Linear(
            self._state_action_spec.state_dim, 
            64
        )
        
        self.lstm = nn.LSTM(
            input_size=64, 
            hidden_size=128,
            batch_first=True
        )
    
    @property
    def state_action_spec(self) -> StateActionSpec:
        return self._state_action_spec
    
    @property
    def n_features(self) -> int:
        return self.lstm.hidden_size
        
    def forward(self, states: Tensor) -> Tensor:
        """Calculate the features of a sequence of states.

        Parameters
        ----------
        states : Tensor
            Batches of obserbed states.
            Shape: (N, state_dim)

        Returns
        -------
        Tensor
            Batches of sequences of output tensors from LSTM layer.
            Shape: (N, 1, H)
        """
        
        n_batches = states.shape[0]
        
        x = self.fc(states)
        x = F.relu(x)
        
        x = x.view(n_batches, 1, -1)
        x, _ = self.lstm(x)
        x = x.squeeze(dim=1)
        
        return x


In [109]:
class PolicyNetwork(nn.Module):
    
    def __init__(
            self, 
            feature_network: FeatureNetwork
        ) -> None:
        
        super().__init__()
        
        self.feature_network = feature_network
        
        self.fc = nn.Linear(
            self.feature_network.n_features, 
            2 * self.feature_network.state_action_spec.action_dim
        )
        
    def forward(self, states: Tensor) -> Tensor:
        """Determine the actions (as in parameters of normal distributions) 
        to take based on a sequence of states.

        Parameters
        ----------
        states : Tensor
            Batches of obserbed states.
            Shape: (N, state_dim)

        Returns
        -------
        Tensor
            Batches of actions to take.
            Shape: (N, 2 * action_dim)
        """
        
        batch_size = states.shape[0]
        x = self.feature_network(states)
        x = self.fc(x)
        x = x.view(batch_size, -1, 2)

        x = F.softmax(x, dim=-1) * 10
        
        alpha = x[..., 0]
        beta = x[..., 1]
        sample = Beta(alpha, beta).sample()
        sample.requires_grad_()
        action_min = torch.tensor(self.feature_network.state_action_spec.action_min)
        action_max = torch.tensor(self.feature_network.state_action_spec.action_max)
        action = action_min + sample * (action_max - action_min)
        
        return action


In [110]:
class TargetNetwork(nn.Module):
    
    def __init__(
            self, 
            feature_network: FeatureNetwork
        ) -> None:
        
        super().__init__()
        
        self.feature_network = feature_network
        
        self.fc = nn.Linear(self.feature_network.n_features, 1)
        
    def forward(self, states: Tensor) -> Tensor:
        """Calculate state values.

        Parameters
        ----------
        states : Tensor
            Batches of obserbed states.
            Shape: (N, state_dim)

        Returns
        -------
        Tensor
            Batches of state values.
            Shape: (N,)
        """
        
        x = self.feature_network(states)
        
        x = self.fc(x)
        x = x.squeeze()
        
        return x


In [118]:
feature_network = FeatureNetwork(state_action_spec)
policy_network = PolicyNetwork(feature_network)
target_network = TargetNetwork(feature_network)

In [119]:
x = torch.randn(4, 3)
x

tensor([[-0.7457,  1.0544,  0.3091],
        [ 0.6257,  0.4877, -1.6655],
        [-0.0422, -0.5164, -0.7973],
        [-0.4540, -0.2207,  1.4826]])

In [121]:
policy_network(x)

tensor([[ 0.1885],
        [ 0.1156],
        [-0.0206],
        [ 0.1475]], grad_fn=<AddBackward0>)

In [114]:
target_network(x)

tensor([-0.0417, -0.0764, -0.0544, -0.0449], grad_fn=<SqueezeBackward0>)

In [None]:
def play_one_episode(env: Env, replay_buffer: ReplayBuffer):
    
    state, _ = env.reset()
    is_done = False
    while not is_done:
        states = torch.tensor(state).unsqueeze(dim=0)
        actions: Tensor = policy_network(states)
        action = actions.detach().squeeze(dim=0).numpy()
        
        print(action)
        
        # Interact with the environment
        next_state, reward, is_terminated, is_truncated, _ = env.step(action)
        is_done = is_terminated or is_truncated
        
        # Add this transition to memory
        replay_buffer.add(state, action, next_state, reward)
        
        state = next_state


In [None]:
play_one_episode(env, replay_buffer)

In [None]:
replay_buffer.sample(10)

In [None]:
state, _ = env.reset()

In [None]:
env.step((1,))

In [None]:
torch.tensor(state).unsqueeze(dim=0)

In [None]:
policy_network(torch.tensor(state).unsqueeze(dim=0)).item()