# DQN again, but with more information

Before abandoning DQN for a more full-featured approach, I want to try to apply it to a more informative environment: one that feeds the dynamic and the video frames, along with the patch data.

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
from datetime import datetime

In [2]:
from utils_preprocess import compute_frame_features, compute_foa_features
from utils_data import get_frames_by_type, get_feature_frames

from env_frames import FramesEnvironment
from env_frames_test import FramesTestEnvironment

  from pkg_resources import resource_stream, resource_exists


In [3]:
vid_filename = "012"
mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
video_frames = get_frames_by_type("frames", vid_filename)
dynamics_frames = get_frames_by_type("dynamic", vid_filename)
patches_frames, _ = get_feature_frames(vid_filename)

In [5]:
# need to normalise frames before feeding them into the model
for i in range(len(video_frames)):
    # broadcasting will do the rest...
    video_frames[i] = video_frames[i] / 255.0
    dynamics_frames[i] = dynamics_frames[i] / 255.0
    patches_frames[i] = patches_frames[i] / 255.0

In [6]:
patch_bounding_boxes_per_frame, patch_centres_per_frame, speaker_info_per_frame = compute_frame_features(
    vid_filename
)

foa_centres_per_frame_per_subject, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres_per_frame
)
foa_centres_per_frame = [frame[target_subject] for frame in foa_centres_per_frame_per_subject]

In [7]:
markov_env = FramesEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [8]:
# env.observation_space.sample(), env.action_space.sample()

In [9]:
num_train_envs = 5
num_test_envs = 10

train_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_test_envs)])

## DQN

In [10]:
class Net(nn.Module):
    def __init__(self, observation_space, action_shape):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        # pass through respective networks
        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        # combine outputs
        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)

        logits = self.combined_net(combined)

        return logits, state

In [11]:
state_shape = markov_env.observation_space
action_shape = markov_env.action_space.n

net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [12]:
net

Net(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_si

## Setting up DQN

Unlike in the simpler DQN case, I need to significanly reduce the memory footprint of my code to avoid crashing the jupyter kernel and running out of RAM.

Naturally, if you're on a more powerful machine, feel free to increase the various memory-intensive parameters.

(I could also consider a streaming approach to loading the various frames, seeing as they're the main culprits behind the kernel crashes, but that would require quite the engineering effort and would be quite out of scope.)

In [13]:
policy = ts.policy.DQNPolicy(
    model=net, 
    optim=optim, 
    discount_factor=0.99,
    estimation_step=1,
    target_update_freq=10
)

In [14]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(100, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

In [15]:
num_epochs = 10
num_steps_per_epoch = 50
step_per_collect = 5 
episode_per_test = 2
batch_size = 10 

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("logs", "dqn", "frames", f"video_{vid_filename}", f"subject_{target_subject}", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

## Training

In [16]:
result = ts.trainer.offpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

Epoch #1: 51it [00:06,  7.69it/s, env_step=50, len=0, loss=0.592, n/ep=0, n/st=5, rew=0.00]                        


Epoch #1: test_reward: 55.022509 ± 1.961869, best_reward: 55.022509 ± 3.044663 in #0


Epoch #2: 51it [00:06,  7.78it/s, env_step=100, len=0, loss=0.529, n/ep=0, n/st=5, rew=0.00]                        


Epoch #2: test_reward: 55.022509 ± 5.037302, best_reward: 55.022509 ± 3.044663 in #0


Epoch #3: 51it [00:06,  7.77it/s, env_step=150, len=0, loss=0.449, n/ep=0, n/st=5, rew=0.00]                        


Epoch #3: test_reward: 55.022509 ± 8.124387, best_reward: 55.022509 ± 3.044663 in #0


Epoch #4: 51it [00:06,  7.76it/s, env_step=200, len=0, loss=0.427, n/ep=0, n/st=5, rew=0.00]                        


Epoch #4: test_reward: 55.022509 ± 3.746921, best_reward: 55.022509 ± 3.044663 in #0


Epoch #5: 51it [00:06,  7.79it/s, env_step=250, len=0, loss=0.352, n/ep=0, n/st=5, rew=0.00]                        


Epoch #5: test_reward: 55.022509 ± 6.822969, best_reward: 55.022509 ± 3.044663 in #0


Epoch #6: 51it [00:06,  7.60it/s, env_step=300, len=0, loss=0.334, n/ep=0, n/st=5, rew=0.00]                        


Epoch #6: test_reward: 55.022509 ± 5.836692, best_reward: 55.022509 ± 3.044663 in #0


Epoch #7: 51it [00:06,  7.41it/s, env_step=350, len=0, loss=0.389, n/ep=0, n/st=5, rew=0.00]                        


Epoch #7: test_reward: 55.022509 ± 0.484511, best_reward: 55.022509 ± 3.044663 in #0


Epoch #8: 51it [00:07,  7.28it/s, env_step=400, len=0, loss=0.397, n/ep=0, n/st=5, rew=0.00]                        


Epoch #8: test_reward: 55.022509 ± 0.274288, best_reward: 55.022509 ± 3.044663 in #0


Epoch #9: 51it [00:06,  7.33it/s, env_step=450, len=0, loss=0.340, n/ep=0, n/st=5, rew=0.00]                        


Epoch #9: test_reward: 55.022509 ± 2.407658, best_reward: 55.022509 ± 3.044663 in #0


Epoch #10: 51it [00:07,  7.07it/s, env_step=500, len=0, loss=0.280, n/ep=0, n/st=5, rew=0.00]                        


Epoch #10: test_reward: 55.022509 ± 9.313679, best_reward: 55.022509 ± 3.044663 in #0


In [17]:
result

{'duration': '93.75s',
 'train_time/model': '65.78s',
 'test_step': 4356,
 'test_episode': 22,
 'test_time': '25.95s',
 'test_speed': '167.89 step/s',
 'best_reward': 55.022508984555316,
 'best_result': '55.02 ± 3.04',
 'train_step': 500,
 'train_episode': 0,
 'train_time/collector': '2.02s',
 'train_speed': '7.37 step/s'}

## Testing

In [18]:
test_markov_env = FramesTestEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [21]:
num_test_envs = 1

testing_envs = ts.env.DummyVectorEnv([lambda: test_markov_env for _ in range(num_test_envs)])

In [22]:
policy.eval()
policy.set_eps(0.05)

collector = ts.data.Collector(policy, testing_envs)
# should be the same values 3 times (if not, there's a problem)
collector.collect(n_episode=3)

KeyboardInterrupt: 

## TensorBoard visualisation

In [None]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
%tensorboard --logdir logs/dqn

Reusing TensorBoard on port 6006 (pid 4585), started 0:02:03 ago. (Use '!kill 4585' to kill it.)