# Imitation learning with GAIL (Generative Adversarial Imitation Learning)

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import gymnasium as gym

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
from datetime import datetime

In [21]:
from utils_preprocess import compute_frame_features, compute_foa_features
from utils_data import get_frames_by_type, get_feature_frames

from env_frames_flat import FlatFramesEnvironment
from gail_prime import GAILPrimePolicy

In [3]:
vid_filename = "012"
mat_filename = vid_filename + ".mat"
target_subject = 0

In [4]:
video_frames = get_frames_by_type("frames", vid_filename)
dynamics_frames = get_frames_by_type("dynamic", vid_filename)
patches_frames, _ = get_feature_frames(vid_filename)

In [5]:
# need to normalise frames before feeding them into the model
for i in range(len(video_frames)):
    # broadcasting will do the rest...
    video_frames[i] = video_frames[i] / 255.0
    dynamics_frames[i] = dynamics_frames[i] / 255.0
    patches_frames[i] = patches_frames[i] / 255.0

In [6]:
patch_bounding_boxes_per_frame, patch_centres_per_frame, speaker_info_per_frame = compute_frame_features(
    vid_filename
)

foa_centres_per_frame_per_subject, patch_weights_per_frame = compute_foa_features(
    mat_filename, patch_centres_per_frame
)
foa_centres_per_frame = [frame[target_subject] for frame in foa_centres_per_frame_per_subject]

In [9]:
flat_markov_env = FlatFramesEnvironment(
    1,
    video_frames,
    dynamics_frames,
    patches_frames,
    patch_bounding_boxes_per_frame,
    patch_centres_per_frame,
    speaker_info_per_frame,
    foa_centres_per_frame,
    patch_weights_per_frame,
)

In [10]:
num_train_envs = 5
num_test_envs = 3

train_envs = ts.env.DummyVectorEnv([lambda: flat_markov_env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: flat_markov_env for _ in range(num_test_envs)])

## Network creation

In [11]:
class ActorNet(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()

        # Tianshou needs this attribute for GAIL to work
        self.output_dim = np.prod(action_shape)

        self.net = nn.Sequential(
            nn.Linear(observation_shape, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, np.prod(action_shape), bias=False)
        )

    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float32)
        
        obs = obs.view(obs.size(0), -1)  # Ensure obs is flat

        logits = self.net(obs)

        return logits, state

In [12]:
class CriticNet(nn.Module):
    def __init__(self, observation_shape):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(observation_shape, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 1, bias=False)  # Output is the value function for the critic
        )

    def forward(self, obs, state=None, info={}):
        # Directly use the flat observation array
        # Ensure obs is a torch.Tensor
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float32)
        
        obs = obs.view(obs.size(0), -1)

        state_value = self.net(obs)

        return state_value

We'll need a further network for GAIL: the discriminator network.

In [13]:
class DiscriminatorNet(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()

        action_size = np.prod(action_shape)

        self.device = torch.device("cpu")

        self.net = nn.Sequential(
            nn.Linear(observation_shape + action_size, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 1, bias=False)  # the probability of being from expert
        )

    def forward(self, obs_action, state=None, info={}):
        if not isinstance(obs_action, torch.Tensor):
            obs_action = torch.tensor(obs_action, dtype=torch.float32, device=self.device)
        
        obs_action = obs_action.view(obs_action.size(0), -1)

        prob = self.net(obs_action)

        return prob

In [14]:
flat_markov_env.reset() # this doesn't distrub anything: Tianshou resets the environment before each run
original_shapes = flat_markov_env._get_info()["shapes"]
state_shape = sum(np.prod(shape) for shape in original_shapes.values())

action_shape = 1 # a Discrete value

actor_net = ActorNet(state_shape, action_shape)
critic_net = CriticNet(state_shape)
discriminator_net = DiscriminatorNet(state_shape, action_shape)

# using a single optimizer for actor and critic simplifies the training loop and is more computationally efficient (plus it guarantees some consistency across networks)
# BUT gradient updates in one network will influence the gradient updates in the other, and this might create unexpected problems...
combined_params = list(actor_net.parameters()) + list(critic_net.parameters())
optimizer = torch.optim.Adam(combined_params, lr=3e-4)

disc_params = list(discriminator_net.parameters())
disc_optimizer = torch.optim.Adam(disc_params, lr=1e-3)

In [15]:
actor_net

ActorNet(
  (net): Sequential(
    (0): Linear(in_features=403214, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=64, out_features=1, bias=False)
  )
)

In [16]:
critic_net

CriticNet(
  (net): Sequential(
    (0): Linear(in_features=403214, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=64, out_features=1, bias=False)
  )
)

In [17]:
discriminator_net

DiscriminatorNet(
  (net): Sequential(
    (0): Linear(in_features=403215, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=64, out_features=1, bias=False)
  )
)

## Setting up GAIL

In [22]:
def expert_policy(info):
    next_frame_idx = info["current_frame_idx"] + 1
    attended_centre = foa_centres_per_frame[next_frame_idx]

    # always picks the correct patch
    return np.array(patch_centres_per_frame[next_frame_idx].index(attended_centre))

We'll need a buffer of expert data, which is meant to be imitated by our agent.

In [23]:
buffer_size = 1000 
expert_buffer = ts.data.ReplayBuffer(size=buffer_size)

obs, info = flat_markov_env.reset()
for _ in range(buffer_size):
    action = expert_policy(info) 

    next_obs, _, terminated, truncated, next_info = flat_markov_env.step(action)
    
    # the reward is not used by GAIL, but is necessary, as per Tianshou API
    expert_buffer.add(ts.data.Batch(obs=obs, act=action, rew=0.0, obs_next=next_obs, terminated=terminated, truncated=truncated))
    
    if terminated:
        # reset the environment if we terminate before filling the buffer
        obs, info = flat_markov_env.reset()
    else:
        obs, info = next_obs, next_info

In [24]:
def dist_fn(logits: torch.Tensor):
    return torch.distributions.Categorical(logits=logits)

policy = GAILPrimePolicy(
    actor_net, 
    critic_net, 
    optimizer,
    dist_fn,
    expert_buffer,
    discriminator_net,
    disc_optimizer,
)

TypeError: '>' not supported between instances of 'DiscriminatorNet' and 'float'

In [None]:
train_collector = ts.data.Collector(policy, train_envs)

test_collector = ts.data.Collector(policy, test_envs)

In [None]:
num_epochs = 10
num_steps_per_epoch = 500
step_per_collect = 5
episode_per_test = 3
batch_size = 10

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("logs", "gail", "frames", f"video_{vid_filename}", f"subject_{target_subject}", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

In [None]:
result = ts.trainer.onpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    repeat_per_collect=1,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

Epoch #1:  17%|#7        | 85/500 [00:18<01:28,  4.68it/s, env_step=85, len=0, loss=9.934, loss/clip=0.000, loss/disc=4.656, loss/ent=0.000, loss/vf=19.867, n/ep=0, n/st=5, rew=0.00, stats/acc_exp=0.553, stats/acc_pi=0.494]  


KeyboardInterrupt: 

In [None]:
result

{'duration': '190.98s',
 'train_time/model': '142.87s',
 'test_step': 4356,
 'test_episode': 22,
 'test_time': '26.18s',
 'test_speed': '166.40 step/s',
 'best_reward': 89.12511129030099,
 'best_result': '89.13 ± 1.63',
 'train_step': 5000,
 'train_episode': 10,
 'train_time/collector': '21.93s',
 'train_speed': '30.34 step/s'}

In [None]:
policy.eval()

collector = ts.data.Collector(policy, train_envs)
collector.collect(n_episode=10)

{'n/ep': 10,
 'n/st': 3162,
 'rews': array([ 34.12632742,  27.83506823,  38.06989051,  45.2986403 ,
         42.55142094,  40.5135169 , 254.14732912, 233.32957745,
        362.57293248, 292.67269146]),
 'lens': array([ 79,  79,  79,  79,  79,  79, 573, 573, 771, 771]),
 'idxs': array([4, 4, 4, 4, 4, 4, 2, 3, 0, 1]),
 'rew': 137.111739480876,
 'len': 316.2,
 'rew_std': 125.31813530663396,
 'len_std': 297.1803492830574}