In [22]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Flatten, Dense, concatenate, Input
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv

In [23]:
df = np.load("datasets/df.npz")

X = df["x"]
Y = df["y"]
Z = df["z"]

df = None
print(X.shape)

(3012, 224, 224)


In [28]:
def get_model():
    input1 = Input(shape=(224, 224, 1,))
    input2 = Input(shape=(2,))
    x = input1
    for i in range(0, 5):
        act = "relu"
        if i == 0:
            act = "selu"
        x = Conv2D(32 + 2 ** (4 + i), (3, 3), strides=(2, 2), activation = act)(x)
    x = Flatten()(x)
    y = Dense(8, activation="relu")(input2)

    img_model = Model(inputs=input1, outputs=x)
    pos_model = Model(inputs=input2, outputs=y)

    z = concatenate([img_model.output, pos_model.output])
    z = Dense(2, activation="linear")(z)

    model = Model(inputs=[input1, input2], outputs=z)

    model.compile(optimizer="adam", loss="mse", metrics="mse")
    return model

model = get_model()
model.summary()

Model: "model_9"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_7 (InputLayer)           [(None, 224, 224, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_49 (Conv2D)             (None, 111, 111, 48  480         ['input_7[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_50 (Conv2D)             (None, 55, 55, 64)   27712       ['conv2d_49[0][0]']              
                                                                                            

In [29]:
checkpoint = callbacks.ModelCheckpoint("weights/tmp.h5", monitor="val_loss", save_best_only=True)
history = model.fit([X, Z], Y, epochs = 20, batch_size = 16, validation_split = 0.7, callbacks=[checkpoint])
model.load_weights("weights/tmp.h5")

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [30]:
model.save_weights("weights/eye.h5")

In [32]:
preds = model.predict([X, Z], verbose=0)

In [33]:
for i, pred in enumerate(preds):
    print(pred, Y[i])


[0.50948894 0.6230823 ] [0.51367188 0.47106481]
[0.530283  0.5917458] [0.51367188 0.47106481]
[0.53984857 0.6038803 ] [0.51367188 0.47106481]
[0.53732175 0.5747205 ] [0.51367188 0.47106481]
[0.53732175 0.5747205 ] [0.51367188 0.47106481]
[0.5394434  0.57248384] [0.51367188 0.47106481]
[0.5092219 0.6026573] [0.51367188 0.47106481]
[0.5092219 0.6026573] [0.51367188 0.47106481]
[0.51144224 0.58815444] [0.51367188 0.47106481]
[0.50478923 0.5576386 ] [0.51432292 0.46643519]
[0.50478923 0.5576386 ] [0.51497396 0.46296296]
[0.50478923 0.5576386 ] [0.51497396 0.45833333]
[0.519693   0.34362286] [0.51627604 0.45138889]
[0.519693   0.34362286] [0.51627604 0.44328704]
[0.54318064 0.38264865] [0.51757812 0.42824074]
[0.52480394 0.40288937] [0.52148438 0.40046296]
[0.52480394 0.40288937] [0.5234375  0.37731481]
[0.52480394 0.40288937] [0.52408854 0.36226852]
[0.5077057 0.4382816] [0.52539062 0.34722222]
[0.5077057 0.4382816] [0.52669271 0.32986111]
[0.5106501  0.36924666] [0.52734375 0.31712963]
[0

[0.87547654 0.9433735 ] [0.86848958 0.92708333]
[0.87547654 0.9433735 ] [0.87369792 0.93055556]
[0.8983733 0.9521365] [0.87630208 0.93287037]
[0.8983733 0.9521365] [0.8828125 0.9375   ]
[0.8529511  0.82148206] [0.88932292 0.94212963]
[0.8425428 0.7989834] [0.88932292 0.94212963]
[0.8425428 0.7989834] [0.88932292 0.94212963]
[0.8425428 0.7989834] [0.88932292 0.94212963]
[0.82952636 0.7789737 ] [0.88932292 0.94212963]
[0.82952636 0.7789737 ] [0.89778646 0.94560185]
[0.7774013  0.73370695] [0.92317708 0.95717593]
[0.7774013  0.73370695] [0.92447917 0.95833333]
[0.7789993  0.72580737] [0.92838542 0.95949074]
[0.84077823 0.7768039 ] [0.93684896 0.96296296]
[0.84077823 0.7768039 ] [0.94921875 0.96412037]
[0.83561283 0.7474956 ] [0.95377604 0.96527778]
[0.8146219  0.74613166] [0.96223958 0.96759259]
[0.91232204 0.8805789 ] [0.96679688 0.96990741]
[0.91232204 0.8805789 ] [0.97005208 0.96990741]
[0.9224159 0.8815497] [0.97526042 0.97106481]
[0.9224159 0.8815497] [0.98046875 0.97453704]
[0.92459