In [1]:
from typing import NamedTuple
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

from env import SuikaEnv

writer = SummaryWriter()

In [2]:
device = torch.device("cuda")

n_envs = 1
total_timesteps = 1_000_000

# PPO Parameter
learning_rate = 3e-4
n_rollout_steps = 1024
batch_size = 64
n_epochs = 10
gamma = 0.99
gae_lambda = 0.95
clip_range = 0.2
normalize_advantage = True
ent_coef = 0.0
vf_coef = 0.5
max_grad_norm = 0.5

buffer_size = n_envs * n_rollout_steps


In [3]:
env = SuikaEnv()

connecting to COM3


In [4]:
episode_frame_numbers = []
episode_rewards = []

In [5]:
class RolloutBufferSamples(NamedTuple):
    observations: torch.Tensor
    actions: torch.Tensor
    old_values: torch.Tensor
    old_log_prob: torch.Tensor
    advantages: torch.Tensor
    returns: torch.Tensor

In [6]:
class RolloutBuffer:
    def __init__(self, buffer_size, n_envs, obs_shape, action_dim, device):
        self.buffer_size = buffer_size
        self.n_envs = n_envs
        self.obs_shape = obs_shape
        self.action_dim = action_dim
        self.device = device

    def reset(self):
        self.observations = torch.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=torch.float32)
        self.actions = torch.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=torch.int64)
        self.rewards = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
        self.returns = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
        self.episode_starts = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
        self.values = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
        self.log_probs = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
        self.advantages = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
        self.pos = 0
        self.generator_ready = False

    def add(self, obs, action, reward, episode_start, value, log_prob):
        self.observations[self.pos] = obs
        self.actions[self.pos] = action
        self.rewards[self.pos] = reward
        self.episode_starts[self.pos] = episode_start
        self.values[self.pos] = value.cpu().flatten()
        self.log_probs[self.pos] = log_prob.cpu()
        self.pos += 1

    def compute_returns_and_advantage(self, last_values, dones):
        last_values = last_values.cpu().flatten()

        last_gae_lam = 0
        for step in reversed(range(self.buffer_size)):
            if step == self.buffer_size - 1:
                next_non_terminal = 1.0 - dones.to(torch.float32)
                next_values = last_values
            else:
                next_non_terminal = 1.0 - self.episode_starts[step + 1]
                next_values = self.values[step + 1]
            delta = self.rewards[step] + gamma * next_values * next_non_terminal - self.values[step]
            last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
            self.advantages[step] = last_gae_lam
        self.returns = self.advantages + self.values

    @staticmethod
    def swap_and_flatten(arr):
        shape = arr.shape
        if len(shape) < 3:
            shape = (*shape, 1)
        return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])

    def get(self, batch_size):
        indices = np.random.permutation(self.buffer_size * self.n_envs)

        if not self.generator_ready:
            self.observations = self.swap_and_flatten(self.observations)
            self.actions = self.swap_and_flatten(self.actions)
            self.values = self.swap_and_flatten(self.values)
            self.log_probs = self.swap_and_flatten(self.log_probs)
            self.advantages = self.swap_and_flatten(self.advantages)
            self.returns = self.swap_and_flatten(self.returns)
            self.generator_ready = True

        start_idx = 0
        while start_idx < self.buffer_size * self.n_envs:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size

    def to_torch(self, array):
        return torch.as_tensor(array, device=self.device)

    def _get_samples(
        self,
        batch_inds
    ):
        data = (
            self.observations[batch_inds],
            self.actions[batch_inds],
            self.values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
        )
        return RolloutBufferSamples(*tuple(map(self.to_torch, data)))

In [7]:
class PolicyValueNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc = nn.Linear(3136, 512)
        self.fc_p = nn.Linear(512, 4)
        self.fc_v = nn.Linear(512, 1)

    def extract_features(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc(x.flatten(1)))
        return x

    def forward(self, x):
        x = self.extract_features(x)
        policy = F.relu(self.fc_p(x))
        value = F.relu(self.fc_v(x))
        return policy, value
    
    def predict_values(self, x):
        x = self.extract_features(x)
        value = F.relu(self.fc_v(x))
        return value
    
    @staticmethod
    def log_prob(value, logits):
        value, log_pmf = torch.broadcast_tensors(value, logits)
        value = value[..., :1]
        log_prob = log_pmf.gather(-1, value).squeeze(-1)
        return log_prob

    @staticmethod
    def entropy(logits):
        min_real = torch.finfo(logits.dtype).min
        logits = torch.clamp(logits, min=min_real)
        probs = F.softmax(logits, dim=-1)
        p_log_p = logits * probs
        return -p_log_p.sum(-1)

    def sample(self, obs):
        logits, values = self.forward(obs)
        # Normalize
        logits = logits - logits.logsumexp(dim=-1, keepdim=True)
        probs = F.softmax(logits, dim=-1)
        actions = torch.multinomial(probs, 1, True)
        return actions, values, self.log_prob(actions, logits)
    
    def evaluate_actions(self, obs, actions):
        logits, values = self.forward(obs)
        # Normalize
        logits = logits - logits.logsumexp(dim=-1, keepdim=True)
        log_prob = self.log_prob(actions, logits)
        entropy = self.entropy(logits)
        return values, log_prob, entropy