# Imitation learning with GAIL

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import gymnasium as gym

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
from datetime import datetime

In [2]:
from utils import compute_frame_features, compute_foa_features, get_frames_by_type, get_feature_frames

from envs import FramesEnvironment, FramesTestEnvironment

from policy_gail_prime import GAILPrimePolicy

  from pkg_resources import resource_stream, resource_exists


# Data and environment initialisation

In [3]:
vid_filename = "038"
mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
video_frames = get_frames_by_type("frames", vid_filename)
dynamics_frames = get_frames_by_type("dynamic", vid_filename)
patches_frames, _ = get_feature_frames(vid_filename)

In [5]:
# need to normalise frames before feeding them into the model
for i in range(len(video_frames)):
    # broadcasting will do the rest...
    video_frames[i] = video_frames[i] / 255.0
    dynamics_frames[i] = dynamics_frames[i] / 255.0
    patches_frames[i] = patches_frames[i] / 255.0

In [6]:
patch_bounding_boxes_per_frame, patch_centres_per_frame, speaker_info_per_frame = compute_frame_features(
    vid_filename
)

foa_centres_per_frame_per_subject, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres_per_frame
)
foa_centres_per_frame = [frame[target_subject] for frame in foa_centres_per_frame_per_subject]

In [7]:
markov_env = FramesEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [8]:
num_train_envs = 5
num_test_envs = 5

train_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_test_envs)])

## Networks construction

We will proceed as in the PPO case, but we'll need to add a further network to the mix: the discriminator network.

In [9]:
class ActorNet(nn.Module):
    def __init__(self, observation_space, action_shape):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]
        self.output_dim = np.prod(action_shape)

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape), bias=False)
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        # pass through respective networks
        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        # combine outputs
        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)

        logits = self.combined_net(combined)

        return logits, state

In [10]:
class CriticNet(nn.Module):
    def __init__(self, observation_space):
        super().__init__()

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )
        
        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1, bias=False)
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)
        state_value = self.combined_net(combined)

        return state_value

In [11]:
class DiscriminatorNet(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        #! everything in this notebook runs on CPU, so if you want to change the device here, remember to do that for EVERY tensor and EVERY network
        self.device = "cpu"

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )
        
        # combining the outputs of all networks
        self.action_net = nn.Sequential(
            nn.Linear(np.prod(action_space), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 16)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64 + 16, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1, bias=False)
        )

    def forward(self, obs, oh_action, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        action_out = self.action_net(oh_action)

        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out, action_out], dim=1)
        combined_out = self.combined_net(combined)

        return combined_out

In [12]:
state_shape = markov_env.observation_space
action_shape = markov_env.action_space.n

actor_net = ActorNet(state_shape, action_shape)
critic_net = CriticNet(state_shape)
discriminator_net = DiscriminatorNet(state_shape, action_shape)

# using a single optimizer for actor and critic simplifies the training loop and is more computationally efficient (plus it guarantees some consistency across networks)
# BUT gradient updates in one network will influence the gradient updates in the other, and this might create unexpected problems...
combined_params = list(actor_net.parameters()) + list(critic_net.parameters())
optimizer = torch.optim.Adam(combined_params, lr=3e-4)

disc_params = list(discriminator_net.parameters())
disc_optimizer = torch.optim.Adam(disc_params, lr=1e-3)

In [13]:
actor_net

ActorNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kern

In [14]:
critic_net

CriticNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(ker

In [15]:
discriminator_net

DiscriminatorNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPoo

## GAIL

We'll need a buffer of expert data, which is meant to be imitated by our agent.

In [16]:
def expert_policy(info):
    next_frame_idx = info["current_frame_idx"] + 1
    attended_centre = foa_centres_per_frame[next_frame_idx]

    # always picks the correct patch
    return np.array(patch_centres_per_frame[next_frame_idx].index(attended_centre))

In [17]:
buffer_size = 2000 
expert_buffer = ts.data.ReplayBuffer(size=buffer_size)

obs, info = markov_env.reset()
for _ in range(buffer_size):
    action = expert_policy(info) 

    next_obs, _, terminated, truncated, next_info = markov_env.step(action)
    
    # the reward is not used by GAIL, but is necessary, as per Tianshou API
    expert_buffer.add(ts.data.Batch(obs=obs, act=action, rew=0.0, obs_next=next_obs, terminated=terminated, truncated=truncated))
    
    if terminated:
        # reset the environment if we terminate before filling the buffer
        obs, info = markov_env.reset()
    else:
        obs, info = next_obs, next_info

In [18]:
def dist_fn(logits: torch.Tensor):
    return torch.distributions.Categorical(logits=logits)

policy = GAILPrimePolicy(
    actor_net, 
    critic_net, 
    optimizer,
    dist_fn,
    expert_buffer,
    discriminator_net,
    disc_optimizer,
)

In [19]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(2000, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

In [20]:
num_epochs = 20
num_steps_per_epoch = 1000

step_per_collect = 10
episode_per_test = 2
batch_size = 10

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("logs", "gail", "frames", f"video_{vid_filename}", f"subject_{target_subject}", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

## Training

In [21]:
result = ts.trainer.onpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    repeat_per_collect=1,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

Epoch #1: 1001it [01:03, 15.72it/s, env_step=1000, len=105, loss=6.238, loss/clip=-0.000, loss/disc=1.336, loss/ent=2.972, loss/vf=12.535, n/ep=0, n/st=10, rew=6.63, stats/acc_exp=0.680, stats/acc_pi=0.515]                          


Epoch #1: test_reward: 48.982790 ± 32.024125, best_reward: 48.982790 ± 32.024125 in #1


Epoch #2: 1001it [01:02, 16.11it/s, env_step=2000, len=200, loss=5.097, loss/clip=-0.000, loss/disc=1.368, loss/ent=1.881, loss/vf=10.231, n/ep=0, n/st=10, rew=16.22, stats/acc_exp=0.545, stats/acc_pi=0.680]                          


Epoch #2: test_reward: 41.108053 ± 14.976943, best_reward: 48.982790 ± 32.024125 in #1


Epoch #3: 1001it [01:02, 16.08it/s, env_step=3000, len=200, loss=4.566, loss/clip=0.000, loss/disc=1.263, loss/ent=2.056, loss/vf=9.174, n/ep=0, n/st=10, rew=17.59, stats/acc_exp=0.550, stats/acc_pi=0.745]                           


Epoch #3: test_reward: 86.307233 ± 34.927462, best_reward: 86.307233 ± 34.927462 in #3


Epoch #4: 1001it [01:03, 15.79it/s, env_step=4000, len=200, loss=4.375, loss/clip=-0.000, loss/disc=1.284, loss/ent=2.764, loss/vf=8.805, n/ep=0, n/st=10, rew=23.27, stats/acc_exp=0.665, stats/acc_pi=0.545]                          


Epoch #4: test_reward: 41.253185 ± 16.721661, best_reward: 86.307233 ± 34.927462 in #3


Epoch #5: 1001it [01:03, 15.65it/s, env_step=5000, len=200, loss=4.560, loss/clip=-0.000, loss/disc=1.340, loss/ent=2.492, loss/vf=9.169, n/ep=0, n/st=10, rew=14.87, stats/acc_exp=0.560, stats/acc_pi=0.625]                          


Epoch #5: test_reward: 38.053725 ± 15.970516, best_reward: 86.307233 ± 34.927462 in #3


Epoch #6: 1001it [01:03, 15.78it/s, env_step=6000, len=200, loss=4.261, loss/clip=0.000, loss/disc=1.311, loss/ent=3.015, loss/vf=8.583, n/ep=0, n/st=10, rew=21.73, stats/acc_exp=0.800, stats/acc_pi=0.435]                          


Epoch #6: test_reward: 45.112522 ± 25.832289, best_reward: 86.307233 ± 34.927462 in #3


Epoch #7: 1001it [01:02, 15.91it/s, env_step=7000, len=200, loss=4.444, loss/clip=-0.000, loss/disc=1.309, loss/ent=2.452, loss/vf=8.937, n/ep=0, n/st=10, rew=14.10, stats/acc_exp=0.505, stats/acc_pi=0.715]                          


Epoch #7: test_reward: 35.831503 ± 13.843098, best_reward: 86.307233 ± 34.927462 in #3


Epoch #8: 1001it [01:04, 15.58it/s, env_step=8000, len=200, loss=3.353, loss/clip=-0.000, loss/disc=1.351, loss/ent=2.625, loss/vf=6.759, n/ep=0, n/st=10, rew=19.06, stats/acc_exp=0.610, stats/acc_pi=0.615]                          


Epoch #8: test_reward: 47.234467 ± 28.962157, best_reward: 86.307233 ± 34.927462 in #3


Epoch #9: 1001it [01:03, 15.76it/s, env_step=9000, len=200, loss=4.437, loss/clip=-0.000, loss/disc=1.284, loss/ent=2.336, loss/vf=8.922, n/ep=0, n/st=10, rew=22.12, stats/acc_exp=0.700, stats/acc_pi=0.580]                          


Epoch #9: test_reward: 45.031834 ± 22.188875, best_reward: 86.307233 ± 34.927462 in #3


Epoch #10: 1001it [01:02, 15.99it/s, env_step=10000, len=200, loss=4.524, loss/clip=0.000, loss/disc=1.261, loss/ent=2.537, loss/vf=9.098, n/ep=0, n/st=10, rew=15.00, stats/acc_exp=0.640, stats/acc_pi=0.570]                          


Epoch #10: test_reward: 51.224614 ± 26.468755, best_reward: 86.307233 ± 34.927462 in #3


Epoch #11: 1001it [01:04, 15.48it/s, env_step=11000, len=200, loss=4.010, loss/clip=0.000, loss/disc=1.250, loss/ent=2.629, loss/vf=8.073, n/ep=0, n/st=10, rew=19.41, stats/acc_exp=0.605, stats/acc_pi=0.655]                          


Epoch #11: test_reward: 57.513918 ± 34.053382, best_reward: 86.307233 ± 34.927462 in #3


Epoch #12: 1001it [01:04, 15.51it/s, env_step=12000, len=200, loss=4.899, loss/clip=-0.000, loss/disc=1.261, loss/ent=2.876, loss/vf=9.856, n/ep=0, n/st=10, rew=14.46, stats/acc_exp=0.625, stats/acc_pi=0.650]                          


Epoch #12: test_reward: 55.666877 ± 34.122730, best_reward: 86.307233 ± 34.927462 in #3


Epoch #13: 1001it [01:03, 15.81it/s, env_step=13000, len=200, loss=4.874, loss/clip=0.000, loss/disc=1.297, loss/ent=2.379, loss/vf=9.796, n/ep=0, n/st=10, rew=16.62, stats/acc_exp=0.675, stats/acc_pi=0.595]                          


Epoch #13: test_reward: 51.228343 ± 24.386174, best_reward: 86.307233 ± 34.927462 in #3


Epoch #14: 1001it [01:04, 15.52it/s, env_step=14000, len=200, loss=4.980, loss/clip=0.000, loss/disc=1.303, loss/ent=2.549, loss/vf=10.011, n/ep=0, n/st=10, rew=13.46, stats/acc_exp=0.705, stats/acc_pi=0.550]                          


Epoch #14: test_reward: 52.485027 ± 32.790120, best_reward: 86.307233 ± 34.927462 in #3


Epoch #15: 1001it [01:02, 16.09it/s, env_step=15000, len=200, loss=4.370, loss/clip=0.000, loss/disc=1.318, loss/ent=2.345, loss/vf=8.787, n/ep=0, n/st=10, rew=16.60, stats/acc_exp=0.680, stats/acc_pi=0.460]                          


Epoch #15: test_reward: 58.454852 ± 34.335089, best_reward: 86.307233 ± 34.927462 in #3


Epoch #16: 1001it [01:02, 16.01it/s, env_step=16000, len=200, loss=4.350, loss/clip=-0.000, loss/disc=1.294, loss/ent=2.263, loss/vf=8.746, n/ep=0, n/st=10, rew=17.95, stats/acc_exp=0.710, stats/acc_pi=0.595]                          


Epoch #16: test_reward: 64.042181 ± 41.377560, best_reward: 86.307233 ± 34.927462 in #3


Epoch #17: 1001it [01:04, 15.43it/s, env_step=17000, len=200, loss=4.763, loss/clip=-0.000, loss/disc=1.227, loss/ent=2.070, loss/vf=9.568, n/ep=0, n/st=10, rew=14.03, stats/acc_exp=0.695, stats/acc_pi=0.640]                          


Epoch #17: test_reward: 74.365170 ± 41.113986, best_reward: 86.307233 ± 34.927462 in #3


Epoch #18: 1001it [01:04, 15.62it/s, env_step=18000, len=200, loss=4.247, loss/clip=0.000, loss/disc=1.322, loss/ent=1.818, loss/vf=8.531, n/ep=0, n/st=10, rew=24.50, stats/acc_exp=0.650, stats/acc_pi=0.505]                          


Epoch #18: test_reward: 67.435053 ± 38.916334, best_reward: 86.307233 ± 34.927462 in #3


Epoch #19: 1001it [01:04, 15.63it/s, env_step=19000, len=200, loss=4.345, loss/clip=-0.000, loss/disc=1.235, loss/ent=2.022, loss/vf=8.731, n/ep=0, n/st=10, rew=21.82, stats/acc_exp=0.720, stats/acc_pi=0.550]                          


Epoch #19: test_reward: 77.715454 ± 41.328612, best_reward: 86.307233 ± 34.927462 in #3


Epoch #20: 1001it [01:04, 15.60it/s, env_step=20000, len=200, loss=4.410, loss/clip=0.000, loss/disc=1.257, loss/ent=2.206, loss/vf=8.864, n/ep=0, n/st=10, rew=21.29, stats/acc_exp=0.720, stats/acc_pi=0.530]                          


Epoch #20: test_reward: 56.836347 ± 40.389282, best_reward: 86.307233 ± 34.927462 in #3


In [22]:
policy_path = os.path.join("weights", "gail", "frames", f"video_{vid_filename}", f"subject_{target_subject}")
try:
    os.makedirs(policy_path)
except OSError:
    # the directory structure already exists => just move on
    # (this case is useful for easily re-running things without annoyances)
    pass

torch.save(policy.state_dict(), os.path.join(policy_path, "state_dict"))

In [23]:
result

{'duration': '1410.76s',
 'train_time/model': '1171.28s',
 'test_step': 22008,
 'test_episode': 42,
 'test_time': '139.40s',
 'test_speed': '157.88 step/s',
 'best_reward': 86.30723254102486,
 'best_result': '86.31 ± 34.93',
 'train_step': 20000,
 'train_episode': 40,
 'train_time/collector': '100.08s',
 'train_speed': '15.73 step/s'}

## Testing

In [24]:
# #! make sure you run all the code up to the instantiation of the models and optimizers before this cell
policy = GAILPrimePolicy(
    actor_net, 
    critic_net, 
    optimizer,
    dist_fn,
    expert_buffer,
    discriminator_net,
    disc_optimizer,
)

policy.load_state_dict(torch.load(os.path.join(policy_path, "state_dict")))

<All keys matched successfully>

In [25]:
test_markov_env = FramesTestEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [26]:
num_test_envs = 1

testing_envs = ts.env.DummyVectorEnv([lambda: test_markov_env for _ in range(num_test_envs)])

In [32]:
policy.eval()

collector = ts.data.Collector(policy, testing_envs)
# why does it return three different values, if all conditions are the same? 
# => GAIL is policy gradient + GAN, so there's more stochasticity invovled
# (also, more training could've helped quite a bit in this case)
collector.collect(n_episode=3)

{'n/ep': 3,
 'n/st': 1572,
 'rews': array([61., 56., 45.]),
 'lens': array([524, 524, 524]),
 'idxs': array([0, 0, 0]),
 'rew': 54.0,
 'len': 524.0,
 'rew_std': 6.683312551921141,
 'len_std': 0.0}

In [28]:
np.mean(collector.collect(n_episode=100, random=True)["rews"])

19.54

### TensorBoard visualisation

In [29]:
%load_ext tensorboard

In [30]:
%tensorboard --logdir logs/gail/frames