# DQN again, but with visual information

Before abandoning DQN for a more full-featured approach, I want to try to apply it to a more informative environment: one that feeds the dynamic and the video frames, along with the patch data.

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
from datetime import datetime

In [2]:
from utils_preprocess import compute_frame_features, compute_foa_features
from utils_data import get_frames_by_type, get_feature_frames

from env_frames import FramesEnvironment
from env_frames_test import FramesTestEnvironment

  from pkg_resources import resource_stream, resource_exists


## Data and environment initialisation

In [3]:
vid_filename = "038"

mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
video_frames = get_frames_by_type("frames", vid_filename)
dynamics_frames = get_frames_by_type("dynamic", vid_filename)
patches_frames, _ = get_feature_frames(vid_filename)

In [5]:
# need to normalise frames before feeding them into the model
for i in range(len(video_frames)):
    # broadcasting will do the rest...
    video_frames[i] = video_frames[i] / 255.0
    dynamics_frames[i] = dynamics_frames[i] / 255.0
    patches_frames[i] = patches_frames[i] / 255.0

In [6]:
patch_bounding_boxes_per_frame, patch_centres_per_frame, speaker_info_per_frame = compute_frame_features(
    vid_filename
)

foa_centres_per_frame_per_subject, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres_per_frame
)

foa_centres_per_frame = [frame[target_subject] for frame in foa_centres_per_frame_per_subject]

In [7]:
markov_env = FramesEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [8]:
# env.observation_space.sample(), env.action_space.sample()

In [9]:
num_train_envs = 5
num_test_envs = 5

train_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_test_envs)])

## Network construction

In [10]:
class Net(nn.Module):
    def __init__(self, observation_space, action_shape):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        # pass through respective networks
        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        # combine outputs
        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)

        logits = self.combined_net(combined)

        return logits, state

In [11]:
state_shape = markov_env.observation_space
action_shape = markov_env.action_space.n

net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [17]:
net

Net(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_si

## DQN

In [13]:
policy = ts.policy.DQNPolicy(
    model=net, 
    optim=optim, 
    discount_factor=0.99,
    target_update_freq=50
)

In [14]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(2000, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

In [15]:
num_epochs = 20
num_steps_per_epoch = 1000

step_per_collect = 10
episode_per_test = 2
batch_size = 10 

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("logs", "dqn", "frames", f"video_{vid_filename}", f"subject_{target_subject}", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

## Training

In [16]:
result = ts.trainer.offpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

Epoch #1: 1001it [02:25,  6.89it/s, env_step=1000, len=105, loss=0.052, n/ep=0, n/st=10, rew=2.00]                          


Epoch #1: test_reward: 10.500732 ± 2.499268, best_reward: 10.500732 ± 2.499268 in #1


Epoch #2: 1001it [02:39,  6.27it/s, env_step=2000, len=200, loss=0.064, n/ep=0, n/st=10, rew=8.50]                          


Epoch #2: test_reward: 24.505371 ± 11.499268, best_reward: 24.505371 ± 11.499268 in #2


Epoch #3: 1001it [03:16,  5.09it/s, env_step=3000, len=200, loss=0.045, n/ep=0, n/st=10, rew=8.00]                          


Epoch #3: test_reward: 28.007568 ± 10.002930, best_reward: 28.007568 ± 10.002930 in #3


Epoch #4: 1001it [03:10,  5.24it/s, env_step=4000, len=200, loss=0.078, n/ep=0, n/st=10, rew=7.00]                          


Epoch #4: test_reward: 28.007568 ± 18.002197, best_reward: 28.007568 ± 10.002930 in #3


Epoch #5: 1001it [03:25,  4.88it/s, env_step=5000, len=200, loss=0.059, n/ep=0, n/st=10, rew=10.50]                          


Epoch #5: test_reward: 28.007568 ± 20.002930, best_reward: 28.007568 ± 10.002930 in #3


Epoch #6: 1001it [03:12,  5.20it/s, env_step=6000, len=200, loss=0.065, n/ep=0, n/st=10, rew=8.00]                          


Epoch #6: test_reward: 28.007568 ± 13.002930, best_reward: 28.007568 ± 10.002930 in #3


Epoch #7: 1001it [03:02,  5.48it/s, env_step=7000, len=200, loss=0.056, n/ep=0, n/st=10, rew=9.50]                          


Epoch #7: test_reward: 26.507202 ± 14.506470, best_reward: 28.007568 ± 10.002930 in #3


Epoch #8: 1001it [03:20,  5.00it/s, env_step=8000, len=200, loss=0.061, n/ep=0, n/st=10, rew=12.00]                          


Epoch #8: test_reward: 28.007568 ± 14.001465, best_reward: 28.007568 ± 10.002930 in #3


Epoch #9: 1001it [03:20,  4.99it/s, env_step=9000, len=200, loss=0.043, n/ep=0, n/st=10, rew=11.50]                          


Epoch #9: test_reward: 28.007568 ± 18.006104, best_reward: 28.007568 ± 10.002930 in #3


Epoch #10: 1001it [03:10,  5.26it/s, env_step=10000, len=200, loss=0.111, n/ep=0, n/st=10, rew=10.50]                          


Epoch #10: test_reward: 28.007568 ± 14.005371, best_reward: 28.007568 ± 10.002930 in #3


Epoch #11: 1001it [03:10,  5.27it/s, env_step=11000, len=200, loss=0.054, n/ep=0, n/st=10, rew=7.00]                          


Epoch #11: test_reward: 28.007568 ± 16.002930, best_reward: 28.007568 ± 10.002930 in #3


Epoch #12: 1001it [03:02,  5.47it/s, env_step=12000, len=200, loss=0.074, n/ep=0, n/st=10, rew=9.51]                          


Epoch #12: test_reward: 1.000000 ± 0.000000, best_reward: 28.007568 ± 10.002930 in #3


Epoch #13: 1001it [03:01,  5.52it/s, env_step=13000, len=200, loss=0.036, n/ep=0, n/st=10, rew=5.50]                          


Epoch #13: test_reward: 28.007568 ± 14.005371, best_reward: 28.007568 ± 10.002930 in #3


Epoch #14: 1001it [02:55,  5.70it/s, env_step=14000, len=200, loss=0.076, n/ep=0, n/st=10, rew=8.50]                          


Epoch #14: test_reward: 28.007568 ± 12.000732, best_reward: 28.007568 ± 10.002930 in #3


Epoch #15: 1001it [02:53,  5.76it/s, env_step=15000, len=200, loss=0.088, n/ep=0, n/st=10, rew=6.50]                          


Epoch #15: test_reward: 28.007568 ± 14.006104, best_reward: 28.007568 ± 10.002930 in #3


Epoch #16: 1001it [03:04,  5.44it/s, env_step=16000, len=200, loss=0.104, n/ep=0, n/st=10, rew=9.50]                          


Epoch #16: test_reward: 1.000000 ± 0.000000, best_reward: 28.007568 ± 10.002930 in #3


Epoch #17: 1001it [03:04,  5.43it/s, env_step=17000, len=200, loss=0.090, n/ep=0, n/st=10, rew=6.50]                          


Epoch #17: test_reward: 28.007568 ± 16.001465, best_reward: 28.007568 ± 10.002930 in #3


Epoch #18: 1001it [03:05,  5.39it/s, env_step=18000, len=200, loss=0.072, n/ep=0, n/st=10, rew=9.51]                          


Epoch #18: test_reward: 28.007568 ± 16.001465, best_reward: 28.007568 ± 10.002930 in #3


Epoch #19: 1001it [03:03,  5.46it/s, env_step=19000, len=200, loss=0.085, n/ep=0, n/st=10, rew=7.00]                          


Epoch #19: test_reward: 28.007568 ± 11.002197, best_reward: 28.007568 ± 10.002930 in #3


Epoch #20: 1001it [03:06,  5.38it/s, env_step=20000, len=200, loss=0.099, n/ep=0, n/st=10, rew=7.50]                          


Epoch #20: test_reward: 28.007568 ± 9.001465, best_reward: 28.007568 ± 10.002930 in #3


In [12]:
policy_path = os.path.join("weights", "dqn", "frames", f"video_{vid_filename}", f"subject_{target_subject}")

try:
    os.makedirs(policy_path)
except OSError:
    # the directory structure already exists => just move on
    # (this case is useful for easily re-running things without annoyances)
    pass

torch.save(policy.state_dict(), os.path.join(policy_path, "state_dict"))

In [18]:
result

{'duration': '3869.17s',
 'train_time/model': '3457.95s',
 'test_step': 22008,
 'test_episode': 42,
 'test_time': '177.48s',
 'test_speed': '124.00 step/s',
 'best_reward': 28.007568359476863,
 'best_result': '28.01 ± 10.00',
 'train_step': 20000,
 'train_episode': 40,
 'train_time/collector': '233.75s',
 'train_speed': '5.42 step/s'}

## Testing

In [13]:
# #! make sure you run all the code up to the instantiation of the models and optimizers before this cell
policy = ts.policy.DQNPolicy(
    model=net, 
    optim=optim, 
    discount_factor=0.99,
    target_update_freq=50
)

policy.load_state_dict(torch.load(os.path.join(policy_path, "state_dict")))

<All keys matched successfully>

In [14]:
test_markov_env = FramesTestEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [15]:
num_test_envs = 1

testing_envs = ts.env.DummyVectorEnv([lambda: test_markov_env for _ in range(num_test_envs)])

In [16]:
policy.eval()

collector = ts.data.Collector(policy, testing_envs)
collector.collect(n_episode=3)

{'n/ep': 3,
 'n/st': 1572,
 'rews': array([28., 28., 28.]),
 'lens': array([524, 524, 524]),
 'idxs': array([0, 0, 0]),
 'rew': 28.0,
 'len': 524.0,
 'rew_std': 0.0,
 'len_std': 0.0}

In [21]:
np.mean(collector.collect(n_episode=100, random=True)["rews"])

19.23

## TensorBoard visualisation

In [None]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
%tensorboard --logdir logs/dqn/frames

Reusing TensorBoard on port 6006 (pid 4585), started 0:02:03 ago. (Use '!kill 4585' to kill it.)