In [1]:
# --------------------------------------------------------
# Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
# NVIDIA Source Code License (1-Way Commercial)
# Code written by Seonwook Park, Shalini De Mello, Yufeng Zheng.
# --------------------------------------------------------
import numpy as np
from collections import OrderedDict
import gc
import json
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import logging
import losses
from tqdm import tqdm

from dataset import HDFDataset
from utils import save_images, worker_init_fn, send_data_dict_to_gpu, recover_images, def_test_list, RunningStatistics,\
    adjust_learning_rate, script_init_common, get_example_images, save_model, load_model
from core import DefaultConfig
from models import STED
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)


config = DefaultConfig()
script_init_common()

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
import warnings
warnings.filterwarnings('ignore')

if not config.skip_training:
    if config.semi_supervised: # Use semi-supervised.
        assert config.num_labeled_samples != 0
    if not os.path.exists(config.save_path):
        os.makedirs(config.save_path)
    # save configurations
    config.write_file_contents(config.save_path)

# Create the train and test datasets.
all_data = OrderedDict()

# Read GazeCapture train/val/test split
with open('./gazecapture_split.json', 'r') as f:
    all_gc_prefixes = json.load(f)

# [gc/val] full set size:              63518
# [gc/val] current set size:           1 270
# [gc/test] full set size:            191842
# [gc/test] current set size:           3836
# [mpi] full set size:                 34790
# [mpi] current set size:                695
# [gc/train] full set size:          1379083
# [gc/train] current set size:       1379083

2023-08-07 21:07:31,460 Written .../ST-ED/configs/combined.json
2023-08-07 21:07:31,461 Written .../ST-ED/configs/config_default.py
2023-08-07 21:11:01,138 Written source folder to .../ST-ED/src


In [3]:
"""
This part is to create the dataset.
"""

if not config.skip_training:
    # Define single training dataset
    train_prefixes = all_gc_prefixes['train']
    train_dataset = HDFDataset(hdf_file_path=config.gazecapture_file,
                               prefixes=train_prefixes,
                               is_bgr=False,
                               get_2nd_sample=True,
                               num_labeled_samples=config.num_labeled_samples if config.semi_supervised else None)
    
    # Define multiple val/test datasets for evaluation during training
    for tag, hdf_file, is_bgr, prefixes in [
        ('gc/val', config.gazecapture_file, False, all_gc_prefixes['val']),
        ('gc/test', config.gazecapture_file, False, all_gc_prefixes['test']),
        ('mpi', config.mpiigaze_file, False, None),
    ]:
        # Create evaluation dataset.
        dataset = HDFDataset(hdf_file_path=hdf_file,
                             prefixes=prefixes,
                             is_bgr=is_bgr,
                             get_2nd_sample=True,
                             pick_at_least_per_person=2)
        if tag == 'gc/test':
            # test pair visualization:
            test_list = def_test_list()
            test_visualize = get_example_images(dataset, test_list)
            test_visualize = send_data_dict_to_gpu(test_visualize, device)

        subsample = config.test_subsample
        # subsample test sets if requested
        if subsample < (1.0 - 1e-6):
            dataset = Subset(dataset, np.linspace(
                start=0, stop=len(dataset),
                num=int(subsample * len(dataset)),
                endpoint=False,
                dtype=np.uint32,
            ))

        all_data[tag] = {
            'dataset': dataset,
            'dataloader': DataLoader(dataset,
                                     batch_size=config.eval_batch_size,
                                     shuffle=False,
                                     num_workers=config.num_data_loaders,  # args.num_data_loaders,
                                     pin_memory=True,
                                     ),
        }

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=int(config.batch_size),
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=config.num_data_loaders,
                                  pin_memory=True,
                                  )
    all_data['gc/train'] = {'dataset': train_dataset, 'dataloader': train_dataloader}

    # Print some stats.
    logging.info('')
    for tag, val in all_data.items():
        tag = '[%s]' % tag
        dataset = val['dataset']
        original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset
        num_people = len(original_dataset.prefixes)
        num_original_entries = len(original_dataset)
        logging.info('%10s full set size:           %7d' % (tag, num_original_entries))
        logging.info('%10s current set size:        %7d' % (tag, len(dataset)))
        logging.info('')

    # Have dataloader re-open HDF to avoid multi-processing related errors.
    for tag, data_dict in all_data.items():
        dataset = data_dict['dataset']
        original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset
        original_dataset.close_hdf()

# train_dataset.__getitem__(0).keys()
# dict_keys(['key', 'image_a', 'gaze_a', 'head_a', 'image_b', 'gaze_b', 'head_b'])

2023-08-07 21:11:10,075 
2023-08-07 21:11:10,076   [gc/val] full set size:             63518
2023-08-07 21:11:10,077   [gc/val] current set size:           1270
2023-08-07 21:11:10,077 
2023-08-07 21:11:10,077  [gc/test] full set size:            191842
2023-08-07 21:11:10,078  [gc/test] current set size:           3836
2023-08-07 21:11:10,078 
2023-08-07 21:11:10,078      [mpi] full set size:             34790
2023-08-07 21:11:10,079      [mpi] current set size:            695
2023-08-07 21:11:10,079 
2023-08-07 21:11:10,080 [gc/train] full set size:           1379083
2023-08-07 21:11:10,080 [gc/train] current set size:        1379083
2023-08-07 21:11:10,082 


In [4]:
"""
This part is to create the network.
"""

# Create redirection network
network = STED().to(device)
# Load weights if available
from checkpoints_manager import CheckpointsManager

saver = CheckpointsManager(network.GazeHeadNet_eval, config.eval_gazenet_savepath)
_ = saver.load_last_checkpoint()
del saver

saver = CheckpointsManager(network.GazeHeadNet_train, config.gazenet_savepath)
_ = saver.load_last_checkpoint()
del saver

if config.load_step != 0:
    model_path = "/home/ethentsao/Desktop/STED-gaze/Our fully-supervised gaze redirector model.pt"
    # model_path = os.path.join(config.save_path, "checkpoints", str(config.load_step) + '.pt')
    print("Load model from", model_path)
    load_model(network, model_path)
    logging.info("Loaded checkpoints from step " + str(config.load_step))

# Transfer on the GPU before constructing and optimizer
if torch.cuda.device_count() > 1:
    logging.info('Using %d GPUs!' % torch.cuda.device_count())
    network.encoder = nn.DataParallel(network.encoder)
    network.decoder = nn.DataParallel(network.decoder)
    network.redirtrans_p = nn.DataParallel(network.redirtrans_p)
    network.redirtrans_dp = nn.DataParallel(network.redirtrans_dp)
    network.fusion = nn.DataParallel(network.fusion)
    network.discriminator = nn.DataParallel(network.discriminator)
    network.GazeHeadNet_eval = nn.DataParallel(network.GazeHeadNet_eval)
    network.GazeHeadNet_train = nn.DataParallel(network.GazeHeadNet_train)
    network.lpips = nn.DataParallel(network.lpips)
    network.pretrained_arcface = nn.DataParallel(network.pretrained_arcface)


Loading e4e over the pSp framework from checkpoint: pretrained_models/e4e_ffhq_encode.pt
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/ethentsao/Desktop/Ours/ours/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth


2023-08-07 21:11:14,763 Using 2 GPUs!


Finish model initialization.


In [5]:
"""
This part is to prepare for the training step.
"""

def execute_training_step(current_step):
    global train_data_iterator
    try:
        input = next(train_data_iterator)
    except StopIteration:
        np.random.seed()  # Ensure randomness
        # Some cleanup
        train_data_iterator = None
        torch.cuda.empty_cache()
        gc.collect() # 显式地触发垃圾回收。当你想要立即释放内存或者优化内存使用时，这个函数很有用。
        # Restart!
        global train_dataloader
        train_data_iterator = iter(train_dataloader)
        input = next(train_data_iterator)
    input = send_data_dict_to_gpu(input, device)

    network.train()
    network.encoder.eval()
    network.decoder.eval()
    network.GazeHeadNet_eval.eval()
    network.GazeHeadNet_train.eval()
    network.lpips.eval()
    network.pretrained_arcface.eval()

    # forward + backward + optimize
    loss_dict, generated = network.optimize(input, current_step)

    # save training samples in tensorboard
    if config.use_tensorboard and current_step % config.save_freq_images == 0 and current_step != 0:
        for image_index in range(5):
            tensorboard.add_image('train/input_image',
                                  torch.clamp((input['image_a'][image_index] + 1) * (255.0 / 2.0), 0, 255).type(
                                      torch.cuda.ByteTensor), current_step)
            tensorboard.add_image('train/target_image',
                                  torch.clamp((input['image_b'][image_index] + 1) * (255.0 / 2.0), 0, 255).type(
                                      torch.cuda.ByteTensor), current_step)
            tensorboard.add_image('train/generated_image',
                                  torch.clamp((generated[image_index] + 1) * (255.0 / 2.0), 0, 255).type(
                                      torch.cuda.ByteTensor), current_step)
    # If doing multi-GPU training, just take an average
    for key, value in loss_dict.items():
        if value.dim() > 0:
            value = torch.mean(value)
            loss_dict[key] = value
    # Store values for logging later
    for key, value in loss_dict.items():
        loss_dict[key] = value.detach().cpu()
    for key, value in loss_dict.items():
        running_losses.add(key, value.numpy())

"""
This part is to prepare for the testing step.
"""

def execute_test(tag, data_dict):
    test_losses = RunningStatistics()
    with torch.no_grad():
        network.eval()
        for input_dict in tqdm(data_dict['dataloader']):
            input_dict = send_data_dict_to_gpu(input_dict, device)
            output_dict, loss_dict = network(input_dict)
            for key, value in loss_dict.items():
                test_losses.add(key, value.detach().cpu().numpy())
    test_loss_means = test_losses.means()
    logging.info('Test Losses at [%7d] for %10s: %s' %
                 (current_step, '[' + tag + ']',
                  ', '.join(['%s: %.6f' % v for v in test_loss_means.items()])))
    if config.use_tensorboard:
        for k, v in test_loss_means.items():
            tensorboard.add_scalar('test/%s/%s' % (tag, k), v, current_step)

"""
This part is to prepare for the visualization step.
"""

def execute_visualize(data):
    output_dict, losses_dict = network(test_visualize)
    keys = data['key'].cpu().numpy()
    for i in tqdm(range(len(keys))):
        path = os.path.join(config.save_path, 'samples', str(keys[i]))
        if not os.path.exists(path):
            os.makedirs(path)
        cv2.imwrite(os.path.join(path, 'redirect_' + str(current_step) + '.png'),
                    recover_images(output_dict['image_b_hat'][i]))
        cv2.imwrite(os.path.join(path, 'redirect_all_' + str(current_step) + '.png'),
                    recover_images(output_dict['image_b_hat_all'][i]))
    # walks = network.latent_walk(test_visualize)
    # save_images(os.path.join(config.save_path, 'samples'), walks, keys, cycle=True)


if config.use_tensorboard and ((not config.skip_training) or config.compute_full_result):
    from tensorboardX import SummaryWriter
    if not os.path.exists(config.log_path):
        os.mkdir(config.log_path)
        print(f"Make {config.log_path}")

    tensorboard = SummaryWriter(logdir=config.log_path)
current_step = config.load_step

In [6]:

"""
This part is to start thr training step.
"""

if not config.skip_training:
    logging.info('Training')
    running_losses = RunningStatistics()
    train_data_iterator = iter(train_dataloader)
    # main training loop
    for current_step in tqdm(range(config.load_step, config.num_training_steps)): 
        # Save model
        if current_step % config.save_interval == 0 and current_step != config.load_step:
            save_model(network, current_step)
            print("Finish save model.")

        # lr decay
        if (current_step % config.decay_steps == 0) or current_step == config.load_step:
            lr = adjust_learning_rate(network.optimizers, config.decay, int(current_step /config.decay_steps), config.lr)
            if config.use_tensorboard:
                tensorboard.add_scalar('train/lr', lr, current_step)
            print("Finish lr decay.")

        # Testing loop: every specified iterations compute the test statistics
        if current_step % config.print_freq_test == 0 and current_step != 0:
            network.eval()
            network.clean_up()
            torch.cuda.empty_cache()
            for tag, data_dict in list(all_data.items())[:-1]:
                execute_test(tag, data_dict)
                # This might help with memory leaks
                torch.cuda.empty_cache()
            print("Finish test model.")

        # Visualization loop
        if (current_step != 0 and current_step % config.save_freq_images == 0) or current_step == config.num_training_steps - 1:
            network.eval()
            torch.cuda.empty_cache()
            with torch.no_grad():
                # save redirected, style modified samples
                execute_visualize(test_visualize)
            torch.cuda.empty_cache()
            print("Finish visualization.")

        # Training step
        execute_training_step(current_step)

        # Print training loss
        if current_step != 0 and (current_step % config.print_freq_train == 0):
            running_loss_means = running_losses.means()
            logging.info('Losses at [%7d]: %s' %
                         (current_step,
                          ', '.join(['%s: %.5f' % v
                                     for v in running_loss_means.items()])))
            if config.use_tensorboard:
                for k, v in running_loss_means.items():
                    tensorboard.add_scalar('train/' + k, v, current_step)
            running_losses.reset

    logging.info('Finished Training')
    # Save model parameters
    save_model(network, config.num_training_steps) # Save final model.
    del all_data


2023-08-07 21:11:14,798 Training
  0%|          | 0/206865 [00:00<?, ?it/s]

Finish lr decay.


  0%|          | 200/206865 [01:12<19:58:41,  2.87it/s]2023-08-07 21:12:28,344 Losses at [    200]: l1: 1.65070, perceptual: 0.85853, redirection_feature_loss: 18.62209, gaze_redirection: 17.95755, head_redirection: 19.62083, id: 0.96279, gaze_a: 93.59990, head_a: 108.67809, embedding_consistency: 0.01612
  0%|          | 400/206865 [02:24<20:23:48,  2.81it/s]2023-08-07 21:13:39,816 Losses at [    400]: l1: 1.68195, perceptual: 0.86340, redirection_feature_loss: 17.95756, gaze_redirection: 17.42968, head_redirection: 19.27347, id: 0.95396, gaze_a: 93.63182, head_a: 108.85184, embedding_consistency: 0.01248
  0%|          | 600/206865 [03:35<19:48:43,  2.89it/s]2023-08-07 21:14:50,930 Losses at [    600]: l1: 1.68273, perceptual: 0.86430, redirection_feature_loss: 17.63455, gaze_redirection: 17.95383, head_redirection: 18.55249, id: 0.95384, gaze_a: 93.63341, head_a: 108.74879, embedding_consistency: 0.01041
  0%|          | 800/206865 [04:46<19:55:09,  2.87it/s]2023-08-07 21:16:02,291 

Finish test model.


100%|██████████| 5/5 [00:00<00:00, 131.45it/s]


Finish visualization.


IndexError: index 2 is out of bounds for dimension 0 with size 2

In [None]:

"""
This part is to evaluate.
"""
    
# Compute evaluation results on complete test sets
if config.compute_full_result:
    logging.info('Computing complete test results for final model...')
    all_data = OrderedDict()
    for tag, hdf_file, is_bgr, prefixes in [
        ('gc/val', config.gazecapture_file, False, all_gc_prefixes['val']),
        ('gc/test', config.gazecapture_file, False, all_gc_prefixes['test']),
        ('mpi', config.mpiigaze_file, False, None),
    ]:
        # Define dataset structure based on selected prefixes
        dataset = HDFDataset(hdf_file_path=hdf_file,
                             prefixes=prefixes,
                             is_bgr=is_bgr,
                             get_2nd_sample=True,
                             pick_at_least_per_person=2)
        if tag == 'gc/test':
            # test pair visualization:
            test_list = def_test_list()
            test_visualize = get_example_images(dataset, test_list)
            test_visualize = send_data_dict_to_gpu(test_visualize, device)
            with torch.no_grad():
                # save redirected, style modified samples
                execute_visualize(test_visualize)
        all_data[tag] = {
            'dataset': dataset,
            'dataloader': DataLoader(dataset,
                                     batch_size=config.eval_batch_size,
                                     shuffle=False,
                                     num_workers=config.num_data_loaders,
                                     pin_memory=True,
                                     worker_init_fn=worker_init_fn),
        }
    logging.info('')

    for tag, val in all_data.items():
        tag = '[%s]' % tag
        dataset = val['dataset']
        original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset
        num_entries = len(original_dataset)
        num_people = len(original_dataset.prefixes)
        logging.info('%10s set size:                %7d' % (tag, num_entries))
        logging.info('%10s num people:              %7d' % (tag, num_people))
        logging.info('')

    for tag, data_dict in all_data.items():
        dataset = data_dict['dataset']
        # Have dataloader re-open HDF to avoid multi-processing related errors.
        original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset
        original_dataset.close_hdf()

    network.eval()
    torch.cuda.empty_cache()
    for tag, data_dict in list(all_data.items()):
        execute_test(tag, data_dict)
    if config.use_tensorboard:
        tensorboard.close()
        del tensorboard
    # network.clean_up()
    torch.cuda.empty_cache()


2023-08-07 21:04:14,429 Computing complete test results for final model...


Start execute_visualize
test_visualize['image_b'].shape torch.Size([5, 3, 128, 128])


100%|██████████| 5/5 [00:00<00:00, 129.22it/s]
2023-08-07 21:04:15,905 
2023-08-07 21:04:15,905   [gc/val] set size:                  63518
2023-08-07 21:04:15,906   [gc/val] num people:                   48
2023-08-07 21:04:15,906 
2023-08-07 21:04:15,907  [gc/test] set size:                 191842
2023-08-07 21:04:15,907  [gc/test] num people:                  139
2023-08-07 21:04:15,908 
2023-08-07 21:04:15,908      [mpi] set size:                  34790
2023-08-07 21:04:15,908      [mpi] num people:                   15
2023-08-07 21:04:15,909 
  0%|          | 48/31759 [00:47<8:37:02,  1.02it/s]

KeyboardInterrupt: 

In [None]:

# """
# This part is to evaluate.
# """

# # Use Redirector to create new training data
# if config.store_redirect_dataset:
#     train_tag = 'gc/train'
#     train_prefixes = all_gc_prefixes['train']
#     train_dataset = HDFDataset(hdf_file_path=config.gazecapture_file,
#                                prefixes=train_prefixes,
#                                num_labeled_samples=config.num_labeled_samples,
#                                sample_target_label=True
#                                )
#     train_dataset.close_hdf()
#     train_dataloader = DataLoader(train_dataset,
#                                   batch_size=config.eval_batch_size,
#                                   shuffle=False,
#                                   num_workers=config.num_data_loaders,
#                                   pin_memory=True,
#                                   )
#     current_person_id = None
#     current_person_data = {}
#     ofpath = os.path.join(config.save_path, 'Redirected_samples.h5')
#     ofdir = os.path.dirname(ofpath)
#     if not os.path.isdir(ofdir):
#         os.makedirs(ofdir)
#     import h5py

#     h5f = h5py.File(ofpath, 'w')

#     def store_person_predictions():
#         global current_person_data
#         if len(current_person_data) > 0:
#             g = h5f.create_group(current_person_id)
#             for key, data in current_person_data.items():
#                 g.create_dataset(key, data=data, chunks=tuple([1] + list(np.asarray(data).shape[1:])),
#                                  compression='lzf', dtype=
#                                  np.float32)
#         current_person_data = {}

#     with torch.no_grad():
#         np.random.seed()
#         num_batches = int(np.ceil(len(train_dataset) / config.eval_batch_size))
#         for i, input_dict in enumerate(train_dataloader):
#             batch_size = input_dict['image_a'].shape[0]
#             input_dict = send_data_dict_to_gpu(input_dict, device)
#             output_dict = network.redirect(input_dict)
#             zipped_data = zip(
#                 input_dict['key'],
#                 input_dict['image_a'].cpu().numpy().astype(np.float32),
#                 input_dict['gaze_a'].cpu().numpy().astype(np.float32),
#                 input_dict['head_a'].cpu().numpy().astype(np.float32),
#                 output_dict['image_b_hat_r'].cpu().numpy().astype(np.float32),
#                 input_dict['gaze_b_r'].cpu().numpy().astype(np.float32),
#                 input_dict['head_b_r'].cpu().numpy().astype(np.float32)
#             )

#             for (person_id, image_a, gaze_a, head_a, image_b_r, gaze_b_r, head_b_r) in zipped_data:
#                 # Store predictions if moved on to next person
#                 if person_id != current_person_id:
#                     store_person_predictions()
#                     current_person_id = person_id
#                 # Now write it
#                 to_write = {
#                     'real': True,
#                     'gaze': gaze_a,
#                     'head': head_a,
#                     'image': image_a,
#                 }
#                 for k, v in to_write.items():
#                     if k not in current_person_data:
#                         current_person_data[k] = []
#                     current_person_data[k].append(v)

#                 to_write = {
#                     'real': False,
#                     'gaze': gaze_b_r,
#                     'head': head_b_r,
#                     'image': image_b_r,
#                 }
#                 for k, v in to_write.items():
#                     current_person_data[k].append(v)

#             logging.info('processed batch [%04d/%04d] with %d entries.' %
#                          (i + 1, num_batches, len(next(iter(input_dict.values())))))
#         store_person_predictions()
#     logging.info('Completed processing')
#     logging.info('Done')
#     del train_dataset, train_dataloader

In [None]:
"""
This part is to evaluate.
"""

# Use Redirector to create new training data
if config.store_redirect_dataset:
    train_tag = 'gc/train'
    train_prefixes = all_gc_prefixes['train']
    train_dataset = HDFDataset(hdf_file_path=config.gazecapture_file,
                               prefixes=train_prefixes,
                               num_labeled_samples=config.num_labeled_samples,
                               sample_target_label=True
                               )
    train_dataset.close_hdf()
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.eval_batch_size,
                                  shuffle=False,
                                  num_workers=config.num_data_loaders,
                                  pin_memory=True,
                                  )
    current_person_id = None
    current_person_data = {}
    ofpath = os.path.join(config.save_path, 'Redirected_samples.h5')
    ofdir = os.path.dirname(ofpath)
    if not os.path.isdir(ofdir):
        os.makedirs(ofdir)
    import h5py

    h5f = h5py.File(ofpath, 'w')

    def store_person_predictions():
        global current_person_data
        if len(current_person_data) > 0:
            g = h5f.create_group(current_person_id)
            for key, data in current_person_data.items():
                g.create_dataset(key, data=data, chunks=tuple([1] + list(np.asarray(data).shape[1:])),
                                 compression='lzf', dtype=
                                 np.float32)
        current_person_data = {}

    with torch.no_grad():
        np.random.seed()
        num_batches = int(np.ceil(len(train_dataset) / config.eval_batch_size))
        for i, input_dict in enumerate(train_dataloader):
            batch_size = input_dict['image_a'].shape[0]
            input_dict = send_data_dict_to_gpu(input_dict, device)
            output_dict = network.redirect(input_dict)
            zipped_data = zip(
                input_dict['key'],
                input_dict['image_a'].cpu().numpy().astype(np.float32),
                input_dict['gaze_a'].cpu().numpy().astype(np.float32),
                input_dict['head_a'].cpu().numpy().astype(np.float32),
                output_dict['image_b_hat_r'].cpu().numpy().astype(np.float32),
                input_dict['gaze_b_r'].cpu().numpy().astype(np.float32),
                input_dict['head_b_r'].cpu().numpy().astype(np.float32)
            )

            for (person_id, image_a, gaze_a, head_a, image_b_r, gaze_b_r, head_b_r) in zipped_data:
                # Store predictions if moved on to next person
                if person_id != current_person_id:
                    store_person_predictions()
                    current_person_id = person_id
                # Now write it
                to_write = {
                    'real': True,
                    'gaze': gaze_a,
                    'head': head_a,
                    'image': image_a,
                }
                for k, v in to_write.items():
                    if k not in current_person_data:
                        current_person_data[k] = []
                    current_person_data[k].append(v)

                to_write = {
                    'real': False,
                    'gaze': gaze_b_r,
                    'head': head_b_r,
                    'image': image_b_r,
                }
                for k, v in to_write.items():
                    current_person_data[k].append(v)

            logging.info('processed batch [%04d/%04d] with %d entries.' %
                         (i + 1, num_batches, len(next(iter(input_dict.values())))))
        store_person_predictions()
    logging.info('Completed processing')
    logging.info('Done')
    del train_dataset, train_dataloader