# Imitation learning with GAIL (Generative Adversarial Imitation Learning)

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import gymnasium as gym

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
from datetime import datetime

In [2]:
from utils_preprocess import compute_frame_features, compute_foa_features
from utils_data import get_frames_by_type, get_feature_frames

from env_frames import FramesEnvironment
from env_frames_test import FramesTestEnvironment

from policy_gail_prime import GAILPrimePolicy

  from pkg_resources import resource_stream, resource_exists


In [3]:
vid_filename = "012"
mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
video_frames = get_frames_by_type("frames", vid_filename)
dynamics_frames = get_frames_by_type("dynamic", vid_filename)
patches_frames, _ = get_feature_frames(vid_filename)

In [5]:
# need to normalise frames before feeding them into the model
for i in range(len(video_frames)):
    # broadcasting will do the rest...
    video_frames[i] = video_frames[i] / 255.0
    dynamics_frames[i] = dynamics_frames[i] / 255.0
    patches_frames[i] = patches_frames[i] / 255.0

In [6]:
patch_bounding_boxes_per_frame, patch_centres_per_frame, speaker_info_per_frame = compute_frame_features(
    vid_filename
)

foa_centres_per_frame_per_subject, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres_per_frame
)
foa_centres_per_frame = [frame[target_subject] for frame in foa_centres_per_frame_per_subject]

In [7]:
markov_env = FramesEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [8]:
num_train_envs = 5
num_test_envs = 3

train_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: markov_env for _ in range(num_test_envs)])

## Network creation

In [9]:
class ActorNet(nn.Module):
    def __init__(self, observation_space, action_shape):
        super().__init__()

        self.num_patches = observation_space['patch_centres'].shape[0]
        self.output_dim = np.prod(action_shape)

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape), bias=False)
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        # pass through respective networks
        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        # combine outputs
        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)

        logits = self.combined_net(combined)

        return logits, state

In [10]:
class CriticNet(nn.Module):
    def __init__(self, observation_space):
        super().__init__()

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )
        
        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1, bias=False)
        )

    def forward(self, obs, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out], dim=1)
        state_value = self.combined_net(combined)

        return state_value

We'll need a further network for GAIL: the discriminator network.

In [11]:
class DiscriminatorNet(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        #! everything in this notebook runs on CPU, so if you want to change the device here, remember to do that for EVERY tensor and EVERY network
        self.device = "cpu"

        self.rgb_frame_net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [3, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        self.bitmap_frame_net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # [1, 180, 320]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 45, 80]
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [32, 11, 20]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),  # [64, 6, 10]
            nn.Flatten(),
            nn.Linear(64 * 6 * 10, 64),
            nn.ReLU(inplace=True)
        )

        # network for patch_centres
        self.patch_centres_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['patch_centres'].shape), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )

        # network for speaker_info
        self.speaker_info_net = nn.Sequential(
            nn.Linear(np.prod(observation_space['speaker_info'].shape), 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True)
        )
        
        # combining the outputs of all networks
        self.action_net = nn.Sequential(
            nn.Linear(np.prod(action_space), 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 16)
        )

        # combining the outputs of all networks
        self.combined_net = nn.Sequential(
            nn.Linear(64 + 32 + 64 + 64 + 64 + 16, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1, bias=False)
        )

    def forward(self, obs, oh_action, state=None, info={}):
        patch_centres = torch.tensor(obs['patch_centres'], dtype=torch.float32)
        speaker_info = torch.tensor(obs['speaker_info'], dtype=torch.float32)
        
        patch_centres = patch_centres.view(patch_centres.size(0), -1)
        speaker_info = speaker_info.view(speaker_info.size(0), -1)

        video_frame = torch.tensor(obs['video_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        dynamics_frame = torch.tensor(obs['dynamics_frames'], dtype=torch.float32).permute(0, 3, 1, 2)
        patches_frame = torch.tensor(obs['patches_frames'], dtype=torch.float32).unsqueeze(1)

        patch_centres_out = self.patch_centres_net(patch_centres)
        speaker_info_out = self.speaker_info_net(speaker_info)

        video_frame_out = self.rgb_frame_net(video_frame)
        dynamics_frame_out = self.rgb_frame_net(dynamics_frame)
        patches_frame_out = self.bitmap_frame_net(patches_frame)

        action_out = self.action_net(oh_action)

        combined = torch.cat([patch_centres_out, speaker_info_out, video_frame_out, dynamics_frame_out, patches_frame_out, action_out], dim=1)
        combined_out = self.combined_net(combined)

        return combined_out

In [12]:
state_shape = markov_env.observation_space
action_shape = markov_env.action_space.n

actor_net = ActorNet(state_shape, action_shape)
critic_net = CriticNet(state_shape)
discriminator_net = DiscriminatorNet(state_shape, action_shape)

# using a single optimizer for actor and critic simplifies the training loop and is more computationally efficient (plus it guarantees some consistency across networks)
# BUT gradient updates in one network will influence the gradient updates in the other, and this might create unexpected problems...
combined_params = list(actor_net.parameters()) + list(critic_net.parameters())
optimizer = torch.optim.Adam(combined_params, lr=3e-4)

disc_params = list(discriminator_net.parameters())
disc_optimizer = torch.optim.Adam(disc_params, lr=1e-3)

In [13]:
actor_net

ActorNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kern

In [14]:
critic_net

CriticNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(ker

In [15]:
discriminator_net

DiscriminatorNet(
  (rgb_frame_net): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3840, out_features=64, bias=True)
    (10): ReLU(inplace=True)
  )
  (bitmap_frame_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPoo

## Setting up GAIL

In [16]:
def expert_policy(info):
    next_frame_idx = info["current_frame_idx"] + 1
    attended_centre = foa_centres_per_frame[next_frame_idx]

    # always picks the correct patch
    return np.array(patch_centres_per_frame[next_frame_idx].index(attended_centre))

We'll need a buffer of expert data, which is meant to be imitated by our agent.

In [17]:
buffer_size = 1000 
expert_buffer = ts.data.ReplayBuffer(size=buffer_size)

obs, info = markov_env.reset()
for _ in range(buffer_size):
    action = expert_policy(info) 

    next_obs, _, terminated, truncated, next_info = markov_env.step(action)
    
    # the reward is not used by GAIL, but is necessary, as per Tianshou API
    expert_buffer.add(ts.data.Batch(obs=obs, act=action, rew=0.0, obs_next=next_obs, terminated=terminated, truncated=truncated))
    
    if terminated:
        # reset the environment if we terminate before filling the buffer
        obs, info = markov_env.reset()
    else:
        obs, info = next_obs, next_info

In [18]:
def dist_fn(logits: torch.Tensor):
    return torch.distributions.Categorical(logits=logits)

policy = GAILPrimePolicy(
    actor_net, 
    critic_net, 
    optimizer,
    dist_fn,
    expert_buffer,
    discriminator_net,
    disc_optimizer,
)

In [19]:
train_collector = ts.data.Collector(policy, train_envs)

test_collector = ts.data.Collector(policy, test_envs)

In [23]:
num_epochs = 10
num_steps_per_epoch = 500
step_per_collect = 5
episode_per_test = 3
batch_size = 10

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("logs", "gail", "frames", f"video_{vid_filename}", f"subject_{target_subject}", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

In [24]:
result = ts.trainer.onpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    repeat_per_collect=1,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

Epoch #1: 501it [00:35, 14.30it/s, env_step=500, len=100, loss=5.013, loss/clip=-0.000, loss/disc=1.364, loss/ent=1.252, loss/vf=10.052, n/ep=0, n/st=5, rew=29.07, stats/acc_exp=0.680, stats/acc_pi=0.470]                         


Epoch #1: test_reward: 101.768463 ± 75.373435, best_reward: 117.034717 ± 83.713906 in #0


Epoch #2: 501it [00:33, 14.89it/s, env_step=1000, len=100, loss=4.777, loss/clip=0.000, loss/disc=1.292, loss/ent=1.138, loss/vf=9.577, n/ep=0, n/st=5, rew=33.22, stats/acc_exp=0.630, stats/acc_pi=0.690]                          


Epoch #2: test_reward: 94.542585 ± 78.930204, best_reward: 117.034717 ± 83.713906 in #0


Epoch #3: 501it [00:33, 14.98it/s, env_step=1500, len=100, loss=4.465, loss/clip=0.000, loss/disc=1.351, loss/ent=1.067, loss/vf=8.952, n/ep=0, n/st=5, rew=44.56, stats/acc_exp=0.700, stats/acc_pi=0.440]                          


Epoch #3: test_reward: 91.719925 ± 69.954405, best_reward: 117.034717 ± 83.713906 in #0


Epoch #4: 501it [00:33, 14.94it/s, env_step=2000, len=100, loss=4.105, loss/clip=0.000, loss/disc=1.332, loss/ent=1.219, loss/vf=8.235, n/ep=0, n/st=5, rew=29.38, stats/acc_exp=0.680, stats/acc_pi=0.510]                          


Epoch #4: test_reward: 86.856749 ± 54.105594, best_reward: 117.034717 ± 83.713906 in #0


Epoch #5: 501it [00:34, 14.43it/s, env_step=2500, len=100, loss=4.399, loss/clip=0.000, loss/disc=1.312, loss/ent=1.130, loss/vf=8.821, n/ep=0, n/st=5, rew=35.90, stats/acc_exp=0.710, stats/acc_pi=0.500]                         


Epoch #5: test_reward: 109.905463 ± 90.921338, best_reward: 117.034717 ± 83.713906 in #0


Epoch #6: 501it [00:38, 13.07it/s, env_step=3000, len=100, loss=4.051, loss/clip=0.000, loss/disc=1.306, loss/ent=1.103, loss/vf=8.124, n/ep=0, n/st=5, rew=35.64, stats/acc_exp=0.680, stats/acc_pi=0.610]                          


Epoch #6: test_reward: 114.317953 ± 95.193676, best_reward: 117.034717 ± 83.713906 in #0


Epoch #7: 501it [00:34, 14.68it/s, env_step=3500, len=100, loss=3.970, loss/clip=0.000, loss/disc=1.332, loss/ent=0.978, loss/vf=7.960, n/ep=0, n/st=5, rew=35.34, stats/acc_exp=0.350, stats/acc_pi=0.720]                          


Epoch #7: test_reward: 113.634080 ± 100.634117, best_reward: 117.034717 ± 83.713906 in #0


Epoch #8: 501it [00:35, 14.05it/s, env_step=4000, len=100, loss=3.823, loss/clip=0.000, loss/disc=1.375, loss/ent=0.874, loss/vf=7.664, n/ep=0, n/st=5, rew=43.80, stats/acc_exp=0.310, stats/acc_pi=0.760]                         


Epoch #8: test_reward: 127.086236 ± 111.766529, best_reward: 127.086236 ± 111.766529 in #8


Epoch #9: 501it [00:35, 14.29it/s, env_step=4500, len=100, loss=3.915, loss/clip=0.000, loss/disc=1.386, loss/ent=0.879, loss/vf=7.847, n/ep=0, n/st=5, rew=36.48, stats/acc_exp=0.390, stats/acc_pi=0.690]                         


Epoch #9: test_reward: 116.353462 ± 102.959098, best_reward: 127.086236 ± 111.766529 in #8


Epoch #10: 501it [00:33, 15.03it/s, env_step=5000, len=100, loss=4.067, loss/clip=-0.000, loss/disc=1.294, loss/ent=0.995, loss/vf=8.153, n/ep=0, n/st=5, rew=32.77, stats/acc_exp=0.730, stats/acc_pi=0.480]                         


Epoch #10: test_reward: 104.232144 ± 87.981237, best_reward: 127.086236 ± 111.766529 in #8


In [25]:
result

{'duration': '397.15s',
 'train_time/model': '323.99s',
 'test_step': 8701,
 'test_episode': 33,
 'test_time': '50.21s',
 'test_speed': '173.30 step/s',
 'best_reward': 127.08623618715016,
 'best_result': '127.09 ± 111.77',
 'train_step': 5000,
 'train_episode': 10,
 'train_time/collector': '22.96s',
 'train_speed': '14.41 step/s'}

## Testing

In [26]:
test_markov_env = FramesTestEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

{'n/ep': 10,
 'n/st': 3162,
 'rews': array([ 22.22135627,  28.51254907,  25.16103925,  27.08127978,
         30.67060244,  23.82440383, 203.39745404, 218.22130575,
        255.43788278, 283.68417918]),
 'lens': array([ 79,  79,  79,  79,  79,  79, 573, 573, 771, 771]),
 'idxs': array([4, 4, 4, 4, 4, 4, 2, 3, 0, 1]),
 'rew': 111.82120523819722,
 'len': 316.2,
 'rew_std': 106.70425692970272,
 'len_std': 297.1803492830574}

In [None]:
num_test_envs = 1

testing_envs = ts.env.DummyVectorEnv([lambda: test_markov_env for _ in range(num_test_envs)])

In [None]:
policy.eval()
policy.set_eps(0.05)

collector = ts.data.Collector(policy, testing_envs)
# should be the same values 3 times (if not, there's a problem)
collector.collect(n_episode=3)