In [2]:
# --------------------------------------------------------
# Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
# NVIDIA Source Code License (1-Way Commercial)
# Code written by Seonwook Park, Shalini De Mello, Yufeng Zheng.
# --------------------------------------------------------
import numpy as np
from collections import OrderedDict
import gc
import json
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import logging
import losses
from tqdm import tqdm

from dataset import HDFDataset
from utils import save_images, worker_init_fn, send_data_dict_to_gpu, recover_images, def_test_list, RunningStatistics,\
    adjust_learning_rate, script_init_common, get_example_images, save_model, load_model
from core import DefaultConfig
from models import STED
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)


config = DefaultConfig()
script_init_common()

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
import warnings
warnings.filterwarnings('ignore')

if not config.skip_training:
    if config.semi_supervised: # Use semi-supervised.
        assert config.num_labeled_samples != 0
    if not os.path.exists(config.save_path):
        os.makedirs(config.save_path)
    # save configurations
    config.write_file_contents(config.save_path)

# Create the train and test datasets.
all_data = OrderedDict()

# Read GazeCapture train/val/test split
with open('./gazecapture_split.json', 'r') as f:
    all_gc_prefixes = json.load(f)

# [gc/val] full set size:              63518
# [gc/val] current set size:           1 270
# [gc/test] full set size:            191842
# [gc/test] current set size:           3836
# [mpi] full set size:                 34790
# [mpi] current set size:                695
# [gc/train] full set size:          1379083
# [gc/train] current set size:       1379083


2023-08-24 23:28:40,915 Written output/ST-ED/save_2/configs/combined.json
2023-08-24 23:28:40,916 Written output/ST-ED/save_2/configs/config_default.py
2023-08-24 23:33:23,502 Written source folder to output/ST-ED/save_2/src


In [7]:
"""
This part is to create the dataset.
"""

if not config.skip_training:
    # Define single training dataset
    train_prefixes = all_gc_prefixes['train']
    train_dataset = HDFDataset(hdf_file_path=config.gazecapture_file,
                               prefixes=train_prefixes,
                               is_bgr=False,
                               get_2nd_sample=True,
                               num_labeled_samples=config.num_labeled_samples if config.semi_supervised else None)
    
    # Define multiple val/test datasets for evaluation during training
    for tag, hdf_file, is_bgr, prefixes in [
        ('gc/val', config.gazecapture_file, False, all_gc_prefixes['val']),
        ('gc/test', config.gazecapture_file, False, all_gc_prefixes['test']),
        ('mpi', config.mpiigaze_file, False, None),
    ]:
        # Create evaluation dataset.
        dataset = HDFDataset(hdf_file_path=hdf_file,
                             prefixes=prefixes,
                             is_bgr=is_bgr,
                             get_2nd_sample=True,
                             pick_at_least_per_person=2)
        if tag == 'gc/test':
            # test pair visualization:
            test_list = def_test_list()
            test_visualize = get_example_images(dataset, test_list)
            test_visualize = send_data_dict_to_gpu(test_visualize, device)

        subsample = config.test_subsample
        # subsample test sets if requested
        if subsample < (1.0 - 1e-6):
            dataset = Subset(dataset, np.linspace(
                start=0, stop=len(dataset),
                num=int(subsample * len(dataset)),
                endpoint=False,
                dtype=np.uint32,
            ))

        all_data[tag] = {
            'dataset': dataset,
            'dataloader': DataLoader(dataset,
                                     batch_size=config.eval_batch_size,
                                     shuffle=False,
                                     num_workers=config.num_data_loaders,  # args.num_data_loaders,
                                     pin_memory=True,
                                     ),
        }

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=int(config.batch_size),
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=config.num_data_loaders,
                                  pin_memory=True,
                                  )
    all_data['gc/train'] = {'dataset': train_dataset, 'dataloader': train_dataloader}

    # Print some stats.
    logging.info('')
    for tag, val in all_data.items():
        tag = '[%s]' % tag
        dataset = val['dataset']
        original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset
        num_people = len(original_dataset.prefixes)
        num_original_entries = len(original_dataset)
        logging.info('%10s full set size:           %7d' % (tag, num_original_entries))
        logging.info('%10s current set size:        %7d' % (tag, len(dataset)))
        logging.info('')

    # Have dataloader re-open HDF to avoid multi-processing related errors.
    for tag, data_dict in all_data.items():
        dataset = data_dict['dataset']
        original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset
        original_dataset.close_hdf()

# train_dataset.__getitem__(0).keys()
# dict_keys(['key', 'image_a', 'gaze_a', 'head_a', 'image_b', 'gaze_b', 'head_b'])


2023-08-24 23:36:47,595 
2023-08-24 23:36:47,596   [gc/val] full set size:             63518
2023-08-24 23:36:47,596   [gc/val] current set size:           1270
2023-08-24 23:36:47,597 
2023-08-24 23:36:47,598  [gc/test] full set size:            191842
2023-08-24 23:36:47,598  [gc/test] current set size:           3836
2023-08-24 23:36:47,599 
2023-08-24 23:36:47,599      [mpi] full set size:             34790
2023-08-24 23:36:47,600      [mpi] current set size:            695
2023-08-24 23:36:47,600 
2023-08-24 23:36:47,601 [gc/train] full set size:           1379083
2023-08-24 23:36:47,601 [gc/train] current set size:        1379083
2023-08-24 23:36:47,602 


In [61]:
import matplotlib.pyplot as plt
from tqdm import tqdm

dataset_explore = 'dataset_explore'
name = 'gc/train'

test = all_data[name]['dataset']

# sample = test[111]

# # GazeCapture training
# print(sample['gaze_a'])
# print(sample['head_a'])
# plt.imshow(sample['image_a'].permute(1, 2, 0))
# plt.savefig(f'{dataset_explore}/foo_a.png')

# print(sample['gaze_b'])
# print(sample['head_b'])
# plt.imshow(sample['image_b'].permute(1, 2, 0))
# plt.savefig(f'{dataset_explore}/foo_b.png')

# plt.show()

gaze_yaw_list_a = []
gaze_pitch_list_a = []
head_yaw_list_a = []
head_pitch_list_a = []

gaze_yaw_list_b = []
gaze_pitch_list_b = []
head_yaw_list_b = []
head_pitch_list_b = []

for i in tqdm(range(len(test))):
    tmp = test[i]

    gaze_yaw_list_a.append(tmp['gaze_a'][1])
    gaze_pitch_list_a.append(tmp['gaze_a'][0])
    head_yaw_list_a.append(tmp['head_a'][1])
    head_pitch_list_a.append(tmp['head_a'][0])

    gaze_yaw_list_b.append(tmp['gaze_a'][1])
    gaze_pitch_list_b.append(tmp['gaze_a'][0])
    head_yaw_list_b.append(tmp['head_b'][1])
    head_pitch_list_b.append(tmp['head_b'][0])

gaze_yaw_list_a.sort()
gaze_pitch_list_a.sort()
head_yaw_list_a.sort()
head_pitch_list_a.sort()
gaze_yaw_list_b.sort()
gaze_pitch_list_b.sort()
head_yaw_list_b.sort()
head_pitch_list_b.sort()

# Get min and max values
min_max_values = {
    'Gaze Yaw A': (gaze_yaw_list_a[0], gaze_yaw_list_a[-1]),
    'Gaze Pitch A': (gaze_pitch_list_a[0], gaze_pitch_list_a[-1]),
    'Head Yaw A': (head_yaw_list_a[0], head_yaw_list_a[-1]),
    'Head Pitch A': (head_pitch_list_a[0], head_pitch_list_a[-1]),
    'Gaze Yaw B': (gaze_yaw_list_b[0], gaze_yaw_list_b[-1]),
    'Gaze Pitch B': (gaze_pitch_list_b[0], gaze_pitch_list_b[-1]),
    'Head Yaw B': (head_yaw_list_b[0], head_yaw_list_b[-1]),
    'Head Pitch B': (head_pitch_list_b[0], head_pitch_list_b[-1]),
}

# Print min and max values
for key, values in min_max_values.items():
    print(f"{key}: Min={values[0]}, Max={values[1]}")

# Plot the distributions
plt.figure(figsize=(10, 6))

plt.subplot(2, 4, 1)
plt.hist(gaze_yaw_list_a, bins=20, color='blue', alpha=0.7)
plt.title('Gaze Yaw List A')

plt.subplot(2, 4, 2)
plt.hist(gaze_pitch_list_a, bins=20, color='green', alpha=0.7)
plt.title('Gaze Pitch List A')

plt.subplot(2, 4, 3)
plt.hist(head_yaw_list_a, bins=20, color='red', alpha=0.7)
plt.title('Head Yaw List A')

plt.subplot(2, 4, 4)
plt.hist(head_pitch_list_a, bins=20, color='purple', alpha=0.7)
plt.title('Head Pitch List A')

plt.subplot(2, 4, 5)
plt.hist(gaze_yaw_list_b, bins=20, color='blue', alpha=0.7)
plt.title('Gaze Yaw List B')

plt.subplot(2, 4, 6)
plt.hist(gaze_pitch_list_b, bins=20, color='green', alpha=0.7)
plt.title('Gaze Pitch List B')

plt.subplot(2, 4, 7)
plt.hist(head_yaw_list_b, bins=20, color='red', alpha=0.7)
plt.title('Head Yaw List B')

plt.subplot(2, 4, 8)
plt.hist(head_pitch_list_b, bins=20, color='purple', alpha=0.7)
plt.title('Head Pitch List B')

plt.tight_layout()

if 'gc' in name:
    new_name = name.split('/')[0] + '_' + name.split('/')[1]
else:
    new_name = name

plt.savefig(f'{dataset_explore}/{new_name}.png')




[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A