# A first stab: DQN

DQN is a classical RL algorithm which should provide a nice baseline for further work.
Classical RL techniques woul probably not work very well without further feature engineering, because the current state space is quite large.

In [1]:
import tianshou as ts 
import torch
from torch import nn
import numpy as np

In [2]:
from markov_gaze_env import MarkovGazeEnv
from preprocess_utils import compute_frame_features, compute_foa_features

  from pkg_resources import resource_stream, resource_exists


## Data and environment initialisation

In [3]:
vid_filename = "012"
mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
patch_bounding_boxes, patch_centres, speaker_info = compute_frame_features(
    vid_filename
)

foa_centres, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres
)
foa_centres_single_subject = [frame[target_subject] for frame in foa_centres]

In [5]:
env = MarkovGazeEnv(
    patch_bounding_boxes,
    patch_centres,
    speaker_info,
    foa_centres_single_subject,
    patch_weights_per_frame,
)

## DQN

First, let's construct the network.

The biggest headache comes from the observations: they're quite complex. So, we build multiple networks, each processing a part of an observation and combining their outputs in the end!

In [7]:
class Net(nn.Module):
    def __init__(self, observation_space, action_shape):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_bounding_boxes
        self.patch_bboxes_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_bounding_boxes'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(observation_space['speaker_info'].n, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 64 + 32, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        patch_bboxes = torch.tensor(obs['patch_bounding_boxes'], dtype=torch.float32)
        # speaker_info = torch.from_numpy(obs['speaker_info']).float()
        # speaker_info creates problems, so let's make sure everything works first
        speaker_info = torch.zeros(1, 4)


        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        patch_bboxes = patch_bboxes.view(patch_bboxes.size(0), -1)
        speaker_info = torch.tensor(speaker_info, dtype=torch.float32)

        # pass through respective networks
        patch_centres_out = self.patch_centres_net(patch_centres)
        patch_bboxes_out = self.patch_bboxes_net(patch_bboxes)
        speaker_info_out = self.speaker_info_net(speaker_info)

        # combine outputs
        combined = torch.cat([patch_centres_out, patch_bboxes_out, speaker_info_out], dim=1)

        logits = self.combined_net(combined)

        return logits, state

In [8]:
state_shape = env.observation_space
action_shape = env.action_space.shape or env.action_space.n

net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [9]:
net

Net(
  (patch_centres_net): Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU(inplace=True)
  )
  (patch_bboxes_net): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU(inplace=True)
  )
  (speaker_info_net): Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU(inplace=True)
  )
  (combined_net): Sequential(
    (0): Linear(in_features=160, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=128, out_features=4, bias=True)
  )
)

## Setting up DQN

In [10]:
# set up the policy
policy = ts.policy.DQNPolicy(
    model=net, 
    optim=optim, 
    discount_factor=0.99,
    estimation_step=3,
    target_update_freq=50
)

In [11]:
# set up the collectors
train_collector = ts.data.Collector(policy, env, ts.data.ReplayBuffer(size=100))

test_collector = ts.data.Collector(policy, env, ts.data.ReplayBuffer(size=100))



## Training

In [12]:
num_epochs = 100
num_steps_per_epoch = 10
num_steps_before_train = 100
num_test_steps = 10
# the number of frames
num_steps_per_epoch = len(patch_centres)  
num_epochs = 1 # one epoch => once we finish the frames of a video, we're done
num_steps_before_train = 0  # start training immediately
step_per_collect=10

def stop_fn(result):
    return result["n/ep"] >= num_steps_per_epoch

In [13]:
result = ts.trainer.offpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=num_test_steps,
    batch_size=64,
    stop_fn = stop_fn,
    test_in_train=False
)

  speaker_info = torch.tensor(speaker_info, dtype=torch.float32)


Type of patch_centres_out: <class 'torch.Tensor'>
Type of patch_bboxes_out: <class 'torch.Tensor'>
Type of speaker_info_out: <class 'torch.Tensor'>
Type of patch_centres_out: <class 'torch.Tensor'>
Type of patch_bboxes_out: <class 'torch.Tensor'>
Type of speaker_info_out: <class 'torch.Tensor'>
Type of patch_centres_out: <class 'torch.Tensor'>
Type of patch_bboxes_out: <class 'torch.Tensor'>
Type of speaker_info_out: <class 'torch.Tensor'>
Type of patch_centres_out: <class 'torch.Tensor'>
Type of patch_bboxes_out: <class 'torch.Tensor'>
Type of speaker_info_out: <class 'torch.Tensor'>
Type of patch_centres_out: <class 'torch.Tensor'>
Type of patch_bboxes_out: <class 'torch.Tensor'>
Type of speaker_info_out: <class 'torch.Tensor'>
Type of patch_centres_out: <class 'torch.Tensor'>
Type of patch_bboxes_out: <class 'torch.Tensor'>
Type of speaker_info_out: <class 'torch.Tensor'>
Type of patch_centres_out: <class 'torch.Tensor'>
Type of patch_bboxes_out: <class 'torch.Tensor'>
Type of speak

IndexError: list index out of range