# A first stab: DQN

[DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) is a classical RL algorithm which should provide a nice baseline for further work.

Classical RL techniques woul probably not work very well without further feature engineering, because the current state space is quite large.

In [1]:
import tianshou as ts 
import torch
from torch import nn
import numpy as np

In [2]:
from markov_gaze_env import MarkovGazeEnv
from preprocess_utils import compute_frame_features, compute_foa_features

  from pkg_resources import resource_stream, resource_exists


## Data and environment initialisation

In [3]:
vid_filename = "012"
mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
patch_bounding_boxes, patch_centres, speaker_info = compute_frame_features(
    vid_filename
)

foa_centres, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres
)
foa_centres_single_subject = [frame[target_subject] for frame in foa_centres]

In [6]:
env = MarkovGazeEnv(
    patch_bounding_boxes,
    patch_centres,
    speaker_info,
    foa_centres_single_subject,
    patch_weights_per_frame,
)

In [8]:
# env.observation_space.sample(), env.action_space.sample()

For efficiency, it's a good idea to set up some vectorized environments.

In [9]:
num_train_envs = 5
num_test_envs = 10

train_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_test_envs)])

## DQN

First, let's construct the network.

The biggest headache comes from the observations: they're quite complex. So, we build multiple networks, each processing a part of an observation and combining their outputs in the end!

In [10]:
class Net(nn.Module):
    def __init__(self, observation_space, action_shape):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_bounding_boxes
        self.patch_bboxes_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_bounding_boxes'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 64 + 32, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        patch_bboxes = torch.tensor(obs['patch_bounding_boxes'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)

        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        patch_bboxes = patch_bboxes.view(patch_bboxes.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        # pass through respective networks
        patch_centres_out = self.patch_centres_net(patch_centres)
        patch_bboxes_out = self.patch_bboxes_net(patch_bboxes)
        speaker_info_out = self.speaker_info_net(speaker_info)

        # combine outputs
        combined = torch.cat([patch_centres_out, patch_bboxes_out, speaker_info_out], dim=1)

        logits = self.combined_net(combined)

        return logits, state

In [11]:
state_shape = env.observation_space
action_shape = env.action_space.n

net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [12]:
net

Net(
  (patch_centres_net): Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU(inplace=True)
  )
  (patch_bboxes_net): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU(inplace=True)
  )
  (speaker_info_net): Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU(inplace=True)
  )
  (combined_net): Sequential(
    (0): Linear(in_features=160, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=128, out_features=4, bias=True)
  )
)

## Setting up DQN

First, we need to set up the policy, which is readily done in Tianshou.

In [13]:
policy = ts.policy.DQNPolicy(
    model=net, 
    optim=optim, 
    discount_factor=0.99,
    estimation_step=3,
    target_update_freq=50
)

Then, we need to set up the collectors, i.e., the objects that will be interacting with the environment according to the above policy and collect the generated data.

In classical DQN fashion, we store the data in a replay buffer.

In [14]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(1000, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

## Training

In [15]:
num_epochs = 10
num_steps_per_epoch = 1000
step_per_collect = 10 # update the Q-values at each step
episode_per_test = 5 # use 5 episodes to test the policy
batch_size = 1

In [16]:
result = ts.trainer.offpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    test_in_train=False
)

Epoch #1: 1001it [00:02, 450.86it/s, env_step=1000, len=79, loss=0.319, n/ep=0, n/st=10, rew=7.00]                          


Epoch #1: test_reward: 26.600000 ± 12.419340, best_reward: 26.600000 ± 14.650597 in #0


Epoch #2: 1001it [00:02, 491.01it/s, env_step=2000, len=79, loss=0.567, n/ep=0, n/st=10, rew=4.00]                           


Epoch #2: test_reward: 26.400000 ± 13.908271, best_reward: 26.600000 ± 14.650597 in #0


Epoch #3: 1001it [00:02, 492.37it/s, env_step=3000, len=79, loss=0.349, n/ep=0, n/st=10, rew=9.00]                           


Epoch #3: test_reward: 26.400000 ± 14.934524, best_reward: 26.600000 ± 14.650597 in #0


Epoch #4: 1001it [00:02, 497.65it/s, env_step=4000, len=79, loss=0.701, n/ep=0, n/st=10, rew=7.00]                           


Epoch #4: test_reward: 26.400000 ± 15.213152, best_reward: 26.600000 ± 14.650597 in #0


Epoch #5: 1001it [00:02, 488.10it/s, env_step=5000, len=79, loss=0.403, n/ep=0, n/st=10, rew=8.00]                           


Epoch #5: test_reward: 26.400000 ± 14.079773, best_reward: 26.600000 ± 14.650597 in #0


Epoch #6: 1001it [00:02, 494.19it/s, env_step=6000, len=79, loss=1.308, n/ep=0, n/st=10, rew=9.00]                           


Epoch #6: test_reward: 26.400000 ± 15.409088, best_reward: 26.600000 ± 14.650597 in #0


Epoch #7: 1001it [00:02, 495.55it/s, env_step=7000, len=79, loss=1.071, n/ep=0, n/st=10, rew=7.00]                           


Epoch #7: test_reward: 26.400000 ± 13.734628, best_reward: 26.600000 ± 14.650597 in #0


Epoch #8: 1001it [00:02, 494.18it/s, env_step=8000, len=79, loss=2.441, n/ep=0, n/st=10, rew=8.00]                           


Epoch #8: test_reward: 26.400000 ± 14.827002, best_reward: 26.600000 ± 14.650597 in #0


Epoch #9: 1001it [00:02, 489.81it/s, env_step=9000, len=79, loss=1.202, n/ep=0, n/st=10, rew=5.00]                           


Epoch #9: test_reward: 26.400000 ± 14.813507, best_reward: 26.600000 ± 14.650597 in #0


Epoch #10: 1001it [00:02, 499.02it/s, env_step=10000, len=79, loss=2.410, n/ep=0, n/st=10, rew=9.00]                          


Epoch #10: test_reward: 26.400000 ± 11.377170, best_reward: 26.600000 ± 14.650597 in #0


In [17]:
result

{'duration': '22.12s',
 'train_time/model': '18.96s',
 'test_step': 13057,
 'test_episode': 55,
 'test_time': '1.60s',
 'test_speed': '8174.61 step/s',
 'best_reward': 26.6,
 'best_result': '26.60 ± 14.65',
 'train_step': 10000,
 'train_episode': 20,
 'train_time/collector': '1.56s',
 'train_speed': '487.24 step/s'}