# Let's try PPO!

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
from datetime import datetime

In [2]:
from utils_preprocess import compute_frame_features, compute_foa_features
from utils_data import get_frames_by_type, get_feature_frames

from env_frames import FramesEnvironment
from env_frames_test import FramesTestEnvironment

  from pkg_resources import resource_stream, resource_exists


## Data and environment initialisation

In [3]:
vid_filename = "038"

mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
video_frames = get_frames_by_type("frames", vid_filename)
dynamics_frames = get_frames_by_type("dynamic", vid_filename)
patches_frames, _ = get_feature_frames(vid_filename)

In [5]:
# need to normalise frames before feeding them into the model
for i in range(len(video_frames)):
    # broadcasting will do the rest...
    video_frames[i] = video_frames[i] / 255.0
    dynamics_frames[i] = dynamics_frames[i] / 255.0
    patches_frames[i] = patches_frames[i] / 255.0

In [6]:
patch_bounding_boxes_per_frame, patch_centres_per_frame, speaker_info_per_frame = compute_frame_features(
    vid_filename
)

foa_centres_per_frame_per_subject, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres_per_frame
)

foa_centres_per_frame = [frame[target_subject] for frame in foa_centres_per_frame_per_subject]

In [7]:
markov_env = FramesEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [8]:
# env.observation_space.sample(), env.action_space.sample()

In [9]:
num_train_envs = 5
num_test_envs = 5

train_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_test_envs)])

## Networks construction

Unlike for the DQN case, we'll need to networks: an actor network and a critic network.

In [10]:
class ActorNet(nn.Module):
    def __init__(self, observation_space, action_shape):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        # pass through respective networks
        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        # combine outputs
        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)

        logits = self.combined_net(combined)

        return logits, state

In [11]:
class CriticNet(nn.Module):
    def __init__(self, observation_space):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )
        
        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)
        state_value = self.combined_net(combined)

        return state_value

In [12]:
state_shape = markov_env.observation_space
action_shape = markov_env.action_space.n

actor_net = ActorNet(state_shape, action_shape)
critic_net = CriticNet(state_shape)

# using a single optimizer for actor and critic simplifies the training loop and is more computationally efficient (plus it guarantees some consistency across networks)
# BUT gradient updates in one network will influence the gradient updates in the other, and this might create unexpected problems...
combined_params = list(actor_net.parameters()) + list(critic_net.parameters())
optimizer = torch.optim.Adam(combined_params, lr=3e-4)

In [13]:
actor_net

ActorNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kern

In [14]:
critic_net

CriticNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(ker

## PPO

In [15]:
def dist_fn(logits: torch.Tensor):
    return torch.distributions.Categorical(logits=logits)

#! the instantiation of the policy might seem simpler, but there are a lot of hyperparameters to play around with in PPO, i just decided to leave them to their default values
policy = ts.policy.PPOPolicy(
    actor_net, 
    critic_net, 
    optimizer,
    dist_fn=dist_fn,
)

In [16]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(2000, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

In [17]:
num_epochs = 20
num_steps_per_epoch = 1000

step_per_collect = 10
episode_per_test = 2
batch_size = 10

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("logs", "ppo", "frames", f"video_{vid_filename}", f"subject_{target_subject}", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

## Training

In [18]:
# PPO is an on-policy method
result = ts.trainer.onpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    repeat_per_collect=1,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

Epoch #1: 1001it [00:28, 34.94it/s, env_step=1000, len=105, loss=0.077, loss/clip=0.000, loss/ent=2.827, loss/vf=0.211, n/ep=0, n/st=10, rew=4.31]                          


Epoch #1: test_reward: 63.664297 ± 31.780220, best_reward: 63.664297 ± 31.780220 in #1


Epoch #2: 1001it [00:27, 36.37it/s, env_step=2000, len=200, loss=0.596, loss/clip=0.000, loss/ent=0.324, loss/vf=1.198, n/ep=0, n/st=10, rew=27.91]                          


Epoch #2: test_reward: 94.934364 ± 59.950585, best_reward: 94.934364 ± 59.950585 in #2


Epoch #3: 1001it [00:27, 36.96it/s, env_step=3000, len=200, loss=0.447, loss/clip=-0.000, loss/ent=0.015, loss/vf=0.894, n/ep=0, n/st=10, rew=33.69]                          


Epoch #3: test_reward: 94.934364 ± 39.869754, best_reward: 94.934364 ± 59.950585 in #2


Epoch #4: 1001it [00:29, 34.43it/s, env_step=4000, len=200, loss=0.430, loss/clip=-0.000, loss/ent=0.011, loss/vf=0.860, n/ep=0, n/st=10, rew=32.04]                          


Epoch #4: test_reward: 94.934364 ± 51.892358, best_reward: 94.934364 ± 59.950585 in #2


Epoch #5: 1001it [00:26, 37.44it/s, env_step=5000, len=200, loss=0.515, loss/clip=0.000, loss/ent=0.005, loss/vf=1.029, n/ep=0, n/st=10, rew=37.91]                          


Epoch #5: test_reward: 94.059359 ± 38.283670, best_reward: 94.934364 ± 59.950585 in #2


Epoch #6: 1001it [00:26, 37.79it/s, env_step=6000, len=200, loss=0.509, loss/clip=0.000, loss/ent=0.006, loss/vf=1.018, n/ep=0, n/st=10, rew=34.15]                          


Epoch #6: test_reward: 94.934364 ± 41.570922, best_reward: 94.934364 ± 59.950585 in #2


Epoch #7: 1001it [00:26, 37.94it/s, env_step=7000, len=200, loss=0.438, loss/clip=0.000, loss/ent=0.019, loss/vf=0.876, n/ep=0, n/st=10, rew=33.36]                           


Epoch #7: test_reward: 94.934364 ± 45.306064, best_reward: 94.934364 ± 59.950585 in #2


Epoch #8: 1001it [00:27, 36.88it/s, env_step=8000, len=200, loss=0.450, loss/clip=-0.000, loss/ent=0.012, loss/vf=0.899, n/ep=0, n/st=10, rew=38.93]                          


Epoch #8: test_reward: 93.965495 ± 48.604088, best_reward: 94.934364 ± 59.950585 in #2


Epoch #9: 1001it [01:13, 13.62it/s, env_step=9000, len=200, loss=0.516, loss/clip=0.000, loss/ent=0.005, loss/vf=1.031, n/ep=0, n/st=10, rew=37.65]                          


Epoch #9: test_reward: 94.934364 ± 50.165667, best_reward: 94.934364 ± 59.950585 in #2


Epoch #10: 1001it [00:26, 37.37it/s, env_step=10000, len=200, loss=0.560, loss/clip=-0.000, loss/ent=0.006, loss/vf=1.120, n/ep=0, n/st=10, rew=35.61]                          


Epoch #10: test_reward: 94.934364 ± 43.964500, best_reward: 94.934364 ± 59.950585 in #2


Epoch #11: 1001it [00:28, 35.65it/s, env_step=11000, len=200, loss=0.529, loss/clip=0.000, loss/ent=0.006, loss/vf=1.058, n/ep=0, n/st=10, rew=37.59]                          


Epoch #11: test_reward: 94.934364 ± 56.831889, best_reward: 94.934364 ± 59.950585 in #2


Epoch #12: 1001it [00:27, 35.81it/s, env_step=12000, len=200, loss=0.487, loss/clip=-0.000, loss/ent=0.007, loss/vf=0.975, n/ep=0, n/st=10, rew=33.66]                          


Epoch #12: test_reward: 94.934364 ± 35.765854, best_reward: 94.934364 ± 59.950585 in #2


Epoch #13: 1001it [00:27, 36.57it/s, env_step=13000, len=200, loss=0.526, loss/clip=0.000, loss/ent=0.007, loss/vf=1.052, n/ep=0, n/st=10, rew=36.26]                          


Epoch #13: test_reward: 94.934364 ± 45.340467, best_reward: 94.934364 ± 59.950585 in #2


Epoch #14: 1001it [00:28, 34.96it/s, env_step=14000, len=200, loss=0.478, loss/clip=0.000, loss/ent=0.001, loss/vf=0.957, n/ep=0, n/st=10, rew=40.69]                          


Epoch #14: test_reward: 94.934364 ± 43.665366, best_reward: 94.934364 ± 59.950585 in #2


Epoch #15: 1001it [00:27, 36.67it/s, env_step=15000, len=200, loss=0.507, loss/clip=0.000, loss/ent=0.002, loss/vf=1.015, n/ep=0, n/st=10, rew=37.59]                          


Epoch #15: test_reward: 94.934364 ± 39.266137, best_reward: 94.934364 ± 59.950585 in #2


Epoch #16: 1001it [00:27, 36.51it/s, env_step=16000, len=200, loss=0.453, loss/clip=0.000, loss/ent=0.002, loss/vf=0.906, n/ep=0, n/st=10, rew=37.30]                          


Epoch #16: test_reward: 94.934364 ± 50.176386, best_reward: 94.934364 ± 59.950585 in #2


Epoch #17: 1001it [00:28, 35.32it/s, env_step=17000, len=200, loss=0.464, loss/clip=-0.000, loss/ent=0.002, loss/vf=0.928, n/ep=0, n/st=10, rew=42.48]                          


Epoch #17: test_reward: 94.934364 ± 53.537204, best_reward: 94.934364 ± 59.950585 in #2


Epoch #18: 1001it [00:28, 34.95it/s, env_step=18000, len=200, loss=0.575, loss/clip=-0.000, loss/ent=0.002, loss/vf=1.150, n/ep=0, n/st=10, rew=34.52]                          


Epoch #18: test_reward: 94.934364 ± 48.543899, best_reward: 94.934364 ± 59.950585 in #2


Epoch #19: 1001it [00:27, 35.77it/s, env_step=19000, len=200, loss=0.532, loss/clip=0.000, loss/ent=0.000, loss/vf=1.064, n/ep=0, n/st=10, rew=35.88]                          


Epoch #19: test_reward: 94.934364 ± 50.417487, best_reward: 94.934364 ± 59.950585 in #2


Epoch #20: 1001it [00:27, 36.00it/s, env_step=20000, len=200, loss=0.554, loss/clip=-0.000, loss/ent=0.000, loss/vf=1.108, n/ep=0, n/st=10, rew=29.74]                          


Epoch #20: test_reward: 94.934364 ± 40.884668, best_reward: 94.934364 ± 59.950585 in #2


In [19]:
policy_path = os.path.join("weights", "ppo", "frames", f"video_{vid_filename}", f"subject_{target_subject}")

try:
    os.makedirs(policy_path)
except OSError:
    # the directory structure already exists => just move on
    # (this case is useful for easily re-running things without annoyances)
    pass

torch.save(policy.state_dict(), os.path.join(policy_path, "state_dict"))

In [20]:
result

{'duration': '756.48s',
 'train_time/model': '512.13s',
 'test_step': 22008,
 'test_episode': 42,
 'test_time': '157.38s',
 'test_speed': '139.84 step/s',
 'best_reward': 94.93436350932097,
 'best_result': '94.93 ± 59.95',
 'train_step': 20000,
 'train_episode': 40,
 'train_time/collector': '86.97s',
 'train_speed': '33.38 step/s'}

## Testing

In [21]:
# #! make sure you run all the code up to the instantiation of the models and optimizers before this cell
policy = ts.policy.PPOPolicy(
    actor_net, 
    critic_net, 
    optimizer,
    dist_fn=dist_fn,
)

policy.load_state_dict(torch.load(os.path.join(policy_path, "state_dict")))

<All keys matched successfully>

In [22]:
test_markov_env = FramesTestEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [23]:
num_test_envs = 1

testing_envs = ts.env.DummyVectorEnv([lambda: test_markov_env for _ in range(num_test_envs)])

In [24]:
policy.eval()

collector = ts.data.Collector(policy, testing_envs)
collector.collect(n_episode=3)

{'n/ep': 3,
 'n/st': 1572,
 'rews': array([69., 69., 69.]),
 'lens': array([524, 524, 524]),
 'idxs': array([0, 0, 0]),
 'rew': 69.0,
 'len': 524.0,
 'rew_std': 0.0,
 'len_std': 0.0}

In [27]:
np.mean(collector.collect(n_episode=100, random=True)["rews"])

19.94

### TensorBoard visualisation

In [25]:
%load_ext tensorboard

In [26]:
%tensorboard --logdir logs/ppo/frames