In [84]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import scipy
import glob
from tqdm import tqdm
from collections import defaultdict
import eye_metrics_utils
import data_utils
import gaze_entropy


from sklearn.model_selection import train_test_split


import warnings
# warnings.filterwarnings(action='once')
warnings.filterwarnings('ignore')

In [2]:
csv_files = glob.glob("data/*.csv")

In [3]:
csv_files_one = [v for v in csv_files if "One Gaze-Vergence" in v]
csv_files_two = [v for v in csv_files if "Two Gaze-Vergence" in v]
csv_files_three = [v for v in csv_files if "Three Go-Around Gaze-Vergence" in v]

In [4]:
df_par = pd.read_csv("participant.csv")
par_id_arr = [v[-3:] for v in df_par['ID'].tolist()]
flight_exp_arr = df_par['Flight_hour'].tolist()

for i,e in zip(par_id_arr, flight_exp_arr):
    print(i, e)

032 10
027 18
031 26
028 28
004 35
008 35
010 35
029 40
003 42
007 49
023 50
021 63
006 90
019 98
022 100
015 133
016 200
014 155
005 180
025 200
002 220
001 230
020 230
011 300
017 420
013 23
024 28
018 116
026 150
012 175
009 220
033 1300


In [12]:
flight_exp_arr[par_id_arr.index("010")]

35

In [235]:
X = []
y = []

def norm(df_x):
    train_stats = df_x.describe().transpose()
    return (df_x - train_stats['mean']) / train_stats['std']

for csv_files in [csv_files_two]:
#     part = defaultdict(dict)
    for csv in csv_files:
        par_id = csv[14:17]
        
        if par_id not in par_id_arr:
            continue
#         ret = defaultdict(list)
        df_data = pd.read_csv(csv)
        print(csv, len(df_data))
        exp = flight_exp_arr[par_id_arr.index(par_id)]
        for df_slice in data_utils.data_slicing(df_data, window_length = 1200, stride = 1200, min_length=1200):
            if (data_utils.check_percentage_null(df_slice) < 0.3): # if missing value > 50%, remove
                continue
                
            df_slice.fillna(0.0, inplace=True)
            v = df_slice[['X Pos', 'Y Pos', 'Pupil Diameter']]
            v = norm(v)
            
            
            X.append(v.values)
            y.append(exp)
            

data\PISSS_ID_001_Approach Two Gaze-Vergence.csv 9554
data\PISSS_ID_002_Approach Two Gaze-Vergence.csv 9430
data\PISSS_ID_003_Approach Two Gaze-Vergence.csv 9368
data\PISSS_ID_004_Approach Two Gaze-Vergence.csv 9862
data\PISSS_ID_005_Approach Two Gaze-Vergence.csv 9245
data\PISSS_ID_006_Approach Two Gaze-Vergence.csv 9739
data\PISSS_ID_007_Approach Two Gaze-Vergence.csv 9677
data\PISSS_ID_008_Approach Two Gaze-Vergence.csv 9923
data\PISSS_ID_009_Approach Two Gaze-Vergence.csv 9243
data\PISSS_ID_010_Approach Two Gaze-Vergence.csv 9923
data\PISSS_ID_011_Approach Two Gaze-Vergence.csv 9492
data\PISSS_ID_012_Approach Two Gaze-Vergence.csv 9431
data\PISSS_ID_013_Approach Two Gaze-Vergence.csv 8691
data\PISSS_ID_014_Approach Two Gaze-Vergence.csv 9307
data\PISSS_ID_015_Approach Two Gaze-Vergence.csv 8812
data\PISSS_ID_016_Approach Two Gaze-Vergence.csv 8259
data\PISSS_ID_017_Approach Two Gaze-Vergence.csv 9184
data\PISSS_ID_018_Approach Two Gaze-Vergence.csv 8937
data\PISSS_ID_019_Approach T

In [236]:
np.array(X).shape

(221, 1200, 3)

In [237]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from sklearn.preprocessing import MinMaxScaler, StandardScaler, Normalizer
from sklearn.metrics import mean_squared_error

In [238]:
import tensorflow as tf

In [239]:
def sqrt_loss_function(y_true, y_pred):
    return tf.sqrt(tf.reduce_mean((y_true - y_pred)**2))

In [240]:
X = np.array(X)
y = np.array(y)

In [261]:
scaler = StandardScaler()
y_scale = scaler.fit_transform(y.reshape(-1,1))

In [262]:
X_train, X_test, y_train, y_test = train_test_split(X, y_scale, test_size=0.2, random_state=5)

In [263]:
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=1)

In [265]:
# create and fit the LSTM network
model = Sequential()
model.add(LSTM(128, input_shape=(1200, 3), return_sequences=True))
model.add(LSTM(64))

model.add(Dense(10))
model.add(Dense(1))

model.compile(loss=sqrt_loss_function, optimizer = tf.keras.optimizers.RMSprop(0.001))

model.summary()

Model: "sequential_21"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_33 (LSTM)               (None, 1200, 128)         67584     
_________________________________________________________________
lstm_34 (LSTM)               (None, 64)                49408     
_________________________________________________________________
dense_24 (Dense)             (None, 10)                650       
_________________________________________________________________
dense_25 (Dense)             (None, 1)                 11        
Total params: 117,653
Trainable params: 117,653
Non-trainable params: 0
_________________________________________________________________


In [266]:
model.fit(X_train, y_train, epochs=20, validation_data=(X_val, y_val), batch_size=8, verbose=1)

Train on 132 samples, validate on 44 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
  8/132 [>.............................] - ETA: 52s - loss: 3.1413

KeyboardInterrupt: 

In [267]:
pred = model.predict(X_test)

In [268]:
p = scaler.inverse_transform(pred)

In [269]:
t = scaler.inverse_transform(y_test)

In [271]:
for u,v in zip(p,t):
    print(u,v)

[191.27005] [35.]
[165.07442] [35.]
[91.85651] [10.]
[159.3215] [35.]
[165.81815] [35.]
[76.1859] [26.]
[153.78723] [50.]
[167.87288] [35.]
[149.74536] [230.]
[100.490295] [26.]
[111.88648] [230.]
[165.74924] [35.]
[67.9944] [155.]
[102.338356] [28.]
[105.44446] [28.]
[124.25095] [100.]
[150.41997] [116.]
[106.63169] [90.]
[167.45726] [155.]
[138.85191] [26.]
[147.91446] [35.]
[101.16445] [40.]
[143.58311] [133.]
[144.41733] [90.]
[133.576] [200.]
[157.8999] [49.]
[164.49812] [300.]
[120.60966] [50.]
[119.57893] [133.]
[161.90514] [230.]
[157.96245] [23.]
[95.20293] [23.]
[180.8774] [49.]
[95.65531] [40.]
[147.74167] [180.]
[116.164276] [28.]
[164.21478] [49.]
[89.30718] [18.]
[160.13904] [98.]
[101.57508] [50.]
[119.11367] [28.]
[115.04357] [40.]
[145.19247] [42.]
[159.13416] [1300.]
[138.99033] [90.]
