# gazeNet

This jupyter notebook is based on [gazeNet](https://github.com/r-zemblys/gazeNet) by Raimondas Zemblys and uses his pretrained model to label eye-tracking-data. At this time the provided model can only be used to label data containing fixations, saccades and PSOs (post-saccadic oscillations). Follow the instructions below to label your data.

## Supported data formats
- npy
- tsv (created by Tobii devices)

**Todo**
- add progress bar

## Data upload
Copy the folder containing Your trials into ``/etdata`` (left panel).

In [None]:
import glob

root = 'etdata'         # The root folder of your trials.
dataset = 'trials'      # Folder containing example trials. Replace with your own folder name.
format = 'tsv'

FILES = glob.glob('%s/%s/*.%s' % (root, dataset, format))

## Input trial setup

Enter measurements for Your eye-tracking-setup. The values are needed for internal conversion of the data.

The default values for monitor width and height are for a 24" 16:9 monitor.

In [None]:
geom = {
    'screen_width': 53.1,       # Screen width in cm
    'screen_height': 29.8,      # Screen height in cm
    'display_width_pix': 1920,  # Screen width in pixels
    'display_height_pix': 1080, # Screen height in pixels
    'eye_distance': 60,         # Viewing distance in cm
}

## Labeling
Your annotated data and plots will be saved in ``/etdata/trials_gazeNet``.

In [None]:
import os
import numpy as np
import pandas as pd

## model configuration
from utils_lib import utils

dev = False
model_dir = 'model_final'
model_name = 'gazeNET_0004_00003750'
model_name = '%s.pth.tar'%model_name

logdir =  os.path.join('logdir', model_dir)
fname_config = os.path.join(logdir, 'config.json')
configuration = utils.Config(fname_config)
config = configuration.params

config['split_seqs']=False
config['augment']=False
config['batch_size']=1
cuda = False
num_classes = len(config['events'])

## load model
from model import gazeNET as gazeNET
import model as model_func

model = gazeNET(config, num_classes)
model_func.load(model, model_dir, config, model_name)
model.eval()

## load data
from utils_lib.etdata import ETData, tsv_to_npy
from utils_lib.data_loader import EMDataset, GazeDataLoader

for fpath in FILES:
    fname = os.path.basename(fpath)
    if format == 'tsv':
        x_px, y_px, X_test = tsv_to_npy(fpath, geom)
    else:
        X_test = np.load(fpath)

    _status = np.isnan(X_test['x']) | \
            np.isnan(X_test['y']) | \
            ~np.in1d(X_test['evt'], config['events'])
    X_test['status'] = ~_status
    test_dataset = EMDataset(config = config, gaze_data = [X_test])
    n_samples = len(test_dataset)
    test_loader = GazeDataLoader(test_dataset, batch_size=config['batch_size'],
                                num_workers=0,
                                shuffle=False)
    
    ## label data
    from utils_lib.ETeval import run_infer
    kwargs = {
        'cuda': False,
        'use_tqdm': False,
        'eval': False,
    }

    print ("Predicting %s" % fname)
    _gt, _pr, pr_raw = run_infer(model, n_samples, test_loader, **kwargs)

    ## postprocessing

    # revert to kartesian
    if format == 'tsv':
        X_test['x'] = x_px
        X_test['y'] = y_px

    #glue back the predictions
    import copy
    _data_pr = copy.deepcopy(test_dataset.data)
    for _d, _pred in zip(_data_pr, pr_raw):
        _d['evt'] = 0
        _d['evt'][1:] = np.argmax(_pred, axis=1)+1
    _data_pr = pd.concat([pd.DataFrame(_d) for _d in _data_pr]).reset_index(drop=True)
    _data = pd.DataFrame(X_test)
    _data = _data.merge(_data_pr, on='t', suffixes=('', '_pred'), how='left')
    _data['evt'] = _data['evt_pred'].replace({np.nan:0})

    etdata_pr = ETData()
    etdata_pr.load(_data[['t', 'x', 'y', 'status', 'evt']].values, **{'source':'np_array'})

    sdir = '%s/%s_gazeNet'%(root, dataset)
    if not os.path.exists(sdir):
        os.makedirs(sdir)
    spath = '%s/%s'%(sdir, fname.replace('.tsv', ''))
    if format == 'tsv':     # add predictions to the original data
        data = pd.read_csv(fpath, sep='\t')
        data = data.merge(_data['evt'], left_index=True, right_index=True)
        data.to_csv(spath, sep='\t', index=False)
    else:
        etdata_pr.save(spath)
    etdata_pr.plot_px(show=False, save=True, spath='%s'%spath)
    etdata_pr.plot_xy(show=False, save=True, spath='%s'%spath)