# Train a model to recogise static gestures

In [1]:
import warnings
warnings.filterwarnings('ignore')

## Prepare the data

In [2]:
import numpy as np
import pandas as pd
import pickle
import os

In [3]:
data = pickle.load(open('../out_data/robot_hand_keypoints_120421_03_26.pkl', 'rb'))
data

[{'hand_type': 'right',
  'gesture_id': 0,
  'gesture_name': 'no_gesture',
  'pred_gesture_name': 'FIVE',
  'keypoints': array([[ 2.61616558e-01,  8.62391472e-01, -1.18615622e-04],
         [ 3.40141207e-01,  8.32768440e-01, -3.68560776e-02],
         [ 3.95898253e-01,  7.64145434e-01, -5.05490154e-02],
         [ 4.23794150e-01,  6.91991389e-01, -6.58366755e-02],
         [ 4.54131007e-01,  6.34845853e-01, -8.54506791e-02],
         [ 3.52733880e-01,  6.40689015e-01, -3.90581018e-03],
         [ 3.70479733e-01,  5.49144268e-01, -1.90667845e-02],
         [ 3.82825851e-01,  4.93496358e-01, -3.28761302e-02],
         [ 3.93023401e-01,  4.43832815e-01, -4.70148660e-02],
         [ 3.07039738e-01,  6.26214385e-01, -2.82135443e-03],
         [ 3.11871171e-01,  5.25952458e-01, -1.28231077e-02],
         [ 3.18490505e-01,  4.61563885e-01, -2.94405650e-02],
         [ 3.22950721e-01,  4.09772158e-01, -4.30353209e-02],
         [ 2.64687926e-01,  6.30137026e-01, -1.12086385e-02],
         [ 2.

> **Note:** The data is collected for the right hand only.

In [4]:
gesture_ids = np.array([x['gesture_id'] for x in data]).reshape(-1, 1)
gesture_names = np.array([x['gesture_name'] for x in data]).reshape(-1, 1)
keypoints = np.array([x['keypoints'] for x in data]).reshape(-1, 63)

In [5]:
print(gesture_ids.shape, gesture_names.shape, keypoints.shape)

(1901, 1) (1901, 1) (1901, 63)


In [6]:
columns = ['gesture_id', 'gesture_name']
for i in range(keypoints.shape[1] // 3):
    columns += [f'kp{i}_x', f'kp{i}_y', f'kp{i}_z']
df = pd.DataFrame(data=np.hstack([gesture_ids, gesture_names, keypoints]), 
                  columns=columns)
df

Unnamed: 0,gesture_id,gesture_name,kp0_x,kp0_y,kp0_z,kp1_x,kp1_y,kp1_z,kp2_x,kp2_y,...,kp17_z,kp18_x,kp18_y,kp18_z,kp19_x,kp19_y,kp19_z,kp20_x,kp20_y,kp20_z
0,0,no_gesture,0.2616165578365326,0.862391471862793,-0.00011861562234116718,0.3401412069797516,0.832768440246582,-0.03685607761144638,0.3958982527256012,0.7641454339027405,...,-0.024227840825915337,0.21081572771072388,0.5820988416671753,-0.036938998848199844,0.2077687680721283,0.5350749492645264,-0.05315976217389107,0.20487046241760254,0.4915812313556671,-0.06644965708255768
1,0,no_gesture,0.2488212138414383,0.8327491879463196,-8.602064190199599e-05,0.3300294280052185,0.8062025308609009,-0.05029881373047829,0.3931019902229309,0.7401115298271179,...,-0.05144594609737396,0.18782782554626465,0.5499157309532166,-0.07304853200912476,0.178327739238739,0.5024149417877197,-0.09637845307588577,0.16922423243522644,0.4561491012573242,-0.11695774644613266
2,0,no_gesture,0.2098255455493927,0.806172788143158,-7.33856504666619e-05,0.293041467666626,0.8084140419960022,-0.035692423582077026,0.36657994985580444,0.7661011219024658,...,-0.03616072237491608,0.23157352209091187,0.5067298412322998,-0.056251414120197296,0.2460532784461975,0.4627677798271179,-0.08015701919794083,0.26056623458862305,0.42247650027275085,-0.09676968306303024
3,0,no_gesture,0.1882975697517395,0.796644926071167,-4.8352027079090476e-05,0.27230411767959595,0.8144708871841431,-0.03367951884865761,0.3501909077167511,0.7865704894065857,...,-0.022155076265335083,0.25076210498809814,0.5014303922653198,-0.03654671460390091,0.27215859293937683,0.46206140518188477,-0.05788211151957512,0.29213279485702515,0.42708343267440796,-0.07161156088113785
4,0,no_gesture,0.1724901795387268,0.7742319703102112,-6.687969289487228e-05,0.2561293840408325,0.804787814617157,-0.018885411322116852,0.3348514437675476,0.7911533117294312,...,-0.015632126480340958,0.27180129289627075,0.4929482340812683,-0.02364089898765087,0.298082560300827,0.4571201205253601,-0.041698750108480453,0.3216680884361267,0.4257204532623291,-0.053536172956228256
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1896,0,no_gesture,0.32307901978492737,0.6458818912506104,-2.1450407075462863e-05,0.365364134311676,0.7069713473320007,-0.01946178823709488,0.4299249053001404,0.737430214881897,...,-0.029674027115106583,0.4608999192714691,0.4699644148349762,-0.06489071995019913,0.4925169050693512,0.44605037569999695,-0.08293934166431427,0.5246965885162354,0.41828328371047974,-0.09062712639570236
1897,0,no_gesture,0.3506574034690857,0.7074780464172363,-6.25709944870323e-05,0.4089662432670593,0.7137359380722046,-0.028144683688879013,0.47107112407684326,0.6655962467193604,...,-0.056962061673402786,0.3344624638557434,0.4835580587387085,-0.09510712325572968,0.3327365517616272,0.4387051463127136,-0.11736176162958145,0.33124879002571106,0.38620853424072266,-0.13155074417591095
1898,0,no_gesture,0.3773188591003418,0.7767844200134277,-1.5660458302590996e-05,0.42634764313697815,0.7357620000839233,-0.046103738248348236,0.4525631368160248,0.6489238739013672,...,-0.046092595905065536,0.270196795463562,0.6188383102416992,-0.0734802708029747,0.24736957252025604,0.5900310277938843,-0.08739117532968521,0.2180003523826599,0.5517174601554871,-0.09553233534097672
1899,0,no_gesture,0.3849989175796509,0.8331529498100281,-4.887472096015699e-05,0.423456609249115,0.7704190611839294,-0.02924411930143833,0.42183586955070496,0.6738439202308655,...,-0.05820860713720322,0.23991179466247559,0.7386565208435059,-0.09852148592472076,0.21134266257286072,0.7226414680480957,-0.11997214704751968,0.17695461213588715,0.7038993239402771,-0.1325664520263672


In [7]:
df['gesture_name'].value_counts()

no_gesture    995
angle         306
move          300
grab          300
Name: gesture_name, dtype: int64

## Train sklearn model

In [8]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import make_scorer, accuracy_score, f1_score, precision_score, recall_score
from sklearn.model_selection import cross_val_score

In [9]:
X, y = df.iloc[:, 2:].values.astype('float'), df.iloc[:, 0].values.astype('int32')
print(X.shape, y.shape)
print(X[0], y[0], np.unique(y))

(1901, 63) (1901,)
[ 2.61616558e-01  8.62391472e-01 -1.18615622e-04  3.40141207e-01
  8.32768440e-01 -3.68560776e-02  3.95898253e-01  7.64145434e-01
 -5.05490154e-02  4.23794150e-01  6.91991389e-01 -6.58366755e-02
  4.54131007e-01  6.34845853e-01 -8.54506791e-02  3.52733880e-01
  6.40689015e-01 -3.90581018e-03  3.70479733e-01  5.49144268e-01
 -1.90667845e-02  3.82825851e-01  4.93496358e-01 -3.28761302e-02
  3.93023401e-01  4.43832815e-01 -4.70148660e-02  3.07039738e-01
  6.26214385e-01 -2.82135443e-03  3.11871171e-01  5.25952458e-01
 -1.28231077e-02  3.18490505e-01  4.61563885e-01 -2.94405650e-02
  3.22950721e-01  4.09772158e-01 -4.30353209e-02  2.64687926e-01
  6.30137026e-01 -1.12086385e-02  2.63029665e-01  5.37684083e-01
 -2.10358277e-02  2.67169356e-01  4.76384193e-01 -3.83302718e-02
  2.70671397e-01  4.25496340e-01 -5.44468910e-02  2.22248927e-01
  6.51061773e-01 -2.42278408e-02  2.10815728e-01  5.82098842e-01
 -3.69389988e-02  2.07768768e-01  5.35074949e-01 -5.31597622e-02
  2.04

In [10]:
idx = np.arange(X.shape[0])
np.random.shuffle(idx)
X, y = X[idx], y[idx]

In [11]:
class_amounts = {0: 400, 1: 200, 2: 200, 3: 200}

X_new, y_new = None, None
for class_id in np.unique(y):
    class_indices = np.argwhere(class_id == y)
    print(class_indices.shape)
    subsample_indices = np.random.choice(class_indices[:,0], size=class_amounts[class_id])
    print(subsample_indices.shape)
    if X_new is None:
        X_new = X[subsample_indices]
    else:
        X_new = np.vstack((X_new, X[subsample_indices]))
    if y_new is None:
        y_new = y[subsample_indices].reshape(-1,1)
    else:
        y_new = np.vstack((y_new, y[subsample_indices].reshape(-1,1)))

print(X_new.shape, y_new.shape)
X = X_new
y = y_new

(995, 1)
(400,)
(300, 1)
(200,)
(306, 1)
(200,)
(300, 1)
(200,)
(1000, 63) (1000, 1)


In [12]:
logreg = LogisticRegression()
cross_val_score(logreg, X=X, y=y, scoring=make_scorer(accuracy_score), cv=5)

array([0.86, 0.86, 0.86, 0.86, 0.89])

In [13]:
rfclf = RandomForestClassifier()
cross_val_score(rfclf, X=X, y=y, scoring=make_scorer(accuracy_score), cv=5)

array([0.9  , 0.915, 0.965, 0.945, 0.955])

In [14]:
gbclf = GradientBoostingClassifier()
cross_val_score(gbclf, X=X, y=y, scoring=make_scorer(accuracy_score), cv=5)

array([0.925, 0.91 , 0.975, 0.95 , 0.97 ])

In [15]:
models_root = './robot_models'
os.makedirs(models_root, exist_ok=True)
pickle.dump(gbclf, open(f'{models_root}/robot_gb_default.pkl', 'wb'))

## Train Keras model

In [35]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [36]:
X, y = df.iloc[:, 2:].values.astype('float'), df.iloc[:, 0].values.astype('int32')
print(X.shape, y.shape)
print(X[0], y[0], np.unique(y))

(1901, 63) (1901,)
[ 2.61616558e-01  8.62391472e-01 -1.18615622e-04  3.40141207e-01
  8.32768440e-01 -3.68560776e-02  3.95898253e-01  7.64145434e-01
 -5.05490154e-02  4.23794150e-01  6.91991389e-01 -6.58366755e-02
  4.54131007e-01  6.34845853e-01 -8.54506791e-02  3.52733880e-01
  6.40689015e-01 -3.90581018e-03  3.70479733e-01  5.49144268e-01
 -1.90667845e-02  3.82825851e-01  4.93496358e-01 -3.28761302e-02
  3.93023401e-01  4.43832815e-01 -4.70148660e-02  3.07039738e-01
  6.26214385e-01 -2.82135443e-03  3.11871171e-01  5.25952458e-01
 -1.28231077e-02  3.18490505e-01  4.61563885e-01 -2.94405650e-02
  3.22950721e-01  4.09772158e-01 -4.30353209e-02  2.64687926e-01
  6.30137026e-01 -1.12086385e-02  2.63029665e-01  5.37684083e-01
 -2.10358277e-02  2.67169356e-01  4.76384193e-01 -3.83302718e-02
  2.70671397e-01  4.25496340e-01 -5.44468910e-02  2.22248927e-01
  6.51061773e-01 -2.42278408e-02  2.10815728e-01  5.82098842e-01
 -3.69389988e-02  2.07768768e-01  5.35074949e-01 -5.31597622e-02
  2.04

In [37]:
idx = np.arange(X.shape[0])
np.random.shuffle(idx)
X, y = X[idx], y[idx]

In [38]:
n_classes = len(np.unique(y))
n_feats = X.shape[1]

In [39]:
# class_amounts = {0: 400, 1: 200, 2: 200, 3: 200}

# X_new, y_new = None, None
# for class_id in np.unique(y):
#     class_indices = np.argwhere(class_id == y)
#     print(class_indices.shape)
#     subsample_indices = np.random.choice(class_indices[:,0], size=class_amounts[class_id])
#     print(subsample_indices.shape)
#     if X_new is None:
#         X_new = X[subsample_indices]
#     else:
#         X_new = np.vstack((X_new, X[subsample_indices]))
#     if y_new is None:
#         y_new = y[subsample_indices].reshape(-1,1)
#     else:
#         y_new = np.vstack((y_new, y[subsample_indices].reshape(-1,1)))

# print(X_new.shape, y_new.shape)
# X = X_new
# y = y_new

In [40]:
y_ohe = np.zeros((y.shape[0], n_classes))
y_ohe[np.arange(len(y)).reshape(-1,1), y.reshape(-1, 1)] = 1
print(y_ohe.sum(axis=0))

[995. 300. 306. 300.]


In [41]:
train_ratio = 0.75
X_train, y_train = X[:int(len(X)*train_ratio)], y_ohe[:int(len(X)*train_ratio)]
X_val, y_val = X[int(len(X)*train_ratio):], y[int(len(X)*train_ratio):]

* One neuron (linear / logistic regression):

In [42]:
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)

lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

model = keras.Sequential([keras.layers.Dense(units=n_classes, input_shape=[n_feats])])
model.compile(optimizer=optimizer, loss=loss_func, metrics=['accuracy'])

In [43]:
model.fit(X_train, y_train, epochs=500)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/

Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 207/500
Epoch 208/500
Epoch 209/500
Epoch 210/500
Epoch 211/500
Epoch 212/500
Epoch 213/500
Epoch 214/500
Epoch 215/500
Epoch 216/500
Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 228/500
Epoch 229/500
Epoch 230/500
Epoch 231/500
Epoch 

Epoch 240/500
Epoch 241/500
Epoch 242/500
Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 256/500
Epoch 257/500
Epoch 258/500
Epoch 259/500
Epoch 260/500
Epoch 261/500
Epoch 262/500
Epoch 263/500
Epoch 264/500
Epoch 265/500
Epoch 266/500
Epoch 267/500
Epoch 268/500
Epoch 269/500
Epoch 270/500
Epoch 271/500
Epoch 272/500
Epoch 273/500
Epoch 274/500
Epoch 275/500
Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 288/500
Epoch 289/500
Epoch 290/500
Epoch 291/500
Epoch 292/500
Epoch 293/500
Epoch 294/500
Epoch 295/500
Epoch 296/500
Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 305/500
Epoch 306/500
Epoch 307/500
Epoch 308/500
Epoch 309/500
Epoch 310/500
Epoch 

Epoch 319/500
Epoch 320/500
Epoch 321/500
Epoch 322/500
Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 347/500
Epoch 348/500
Epoch 349/500
Epoch 350/500
Epoch 351/500
Epoch 352/500
Epoch 353/500
Epoch 354/500
Epoch 355/500
Epoch 356/500
Epoch 357/500
Epoch 358/500
Epoch 359/500
Epoch 360/500
Epoch 361/500
Epoch 362/500
Epoch 363/500
Epoch 364/500
Epoch 365/500
Epoch 366/500
Epoch 367/500
Epoch 368/500
Epoch 369/500
Epoch 370/500
Epoch 371/500
Epoch 372/500
Epoch 373/500
Epoch 374/500
Epoch 375/500
Epoch 376/500
Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 382/500
Epoch 383/500
Epoch 384/500
Epoch 385/500
Epoch 386/500
Epoch 387/500
Epoch 388/500
Epoch 389/500
Epoch 

Epoch 398/500
Epoch 399/500
Epoch 400/500
Epoch 401/500
Epoch 402/500
Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 438/500
Epoch 439/500
Epoch 440/500
Epoch 441/500
Epoch 442/500
Epoch 443/500
Epoch 444/500
Epoch 445/500
Epoch 446/500
Epoch 447/500
Epoch 448/500
Epoch 449/500
Epoch 450/500
Epoch 451/500
Epoch 452/500
Epoch 453/500
Epoch 454/500
Epoch 455/500
Epoch 456/500
Epoch 457/500
Epoch 458/500
Epoch 459/500
Epoch 460/500
Epoch 461/500
Epoch 462/500
Epoch 463/500
Epoch 464/500
Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 

Epoch 477/500
Epoch 478/500
Epoch 479/500
Epoch 480/500
Epoch 481/500
Epoch 482/500
Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500


<tensorflow.python.keras.callbacks.History at 0x7fb118344cd0>

In [44]:
print('Validation accuracy:', 
      accuracy_score(y_val, model(X_val).numpy().argmax(axis=1)))

Validation accuracy: 0.5987394957983193


* Few FC layers:

In [47]:
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)

lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

model = keras.Sequential([
    keras.layers.Dense(units=256, input_shape=[n_feats]),
    keras.layers.ReLU(),
#     keras.layers.Dropout(0.2),
    keras.layers.Dense(units=n_classes)
])
model.compile(optimizer=optimizer, loss=loss_func, metrics=['accuracy'])

In [48]:
model.training = True
model.fit(X_train, y_train, epochs=500)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/

Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 207/500
Epoch 208/500
Epoch 209/500
Epoch 210/500
Epoch 211/500
Epoch 212/500
Epoch 213/500
Epoch 214/500
Epoch 215/500
Epoch 216/500
Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 228/500
Epoch 229/500
Epoch 230/500
Epoch 231/500
Epoch 

Epoch 240/500
Epoch 241/500
Epoch 242/500
Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 256/500
Epoch 257/500
Epoch 258/500
Epoch 259/500
Epoch 260/500
Epoch 261/500
Epoch 262/500
Epoch 263/500
Epoch 264/500
Epoch 265/500
Epoch 266/500
Epoch 267/500
Epoch 268/500
Epoch 269/500
Epoch 270/500
Epoch 271/500
Epoch 272/500
Epoch 273/500
Epoch 274/500
Epoch 275/500
Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 288/500
Epoch 289/500
Epoch 290/500
Epoch 291/500
Epoch 292/500
Epoch 293/500
Epoch 294/500
Epoch 295/500
Epoch 296/500
Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 305/500
Epoch 306/500
Epoch 307/500
Epoch 308/500
Epoch 309/500
Epoch 310/500
Epoch 

Epoch 319/500
Epoch 320/500
Epoch 321/500
Epoch 322/500
Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 347/500
Epoch 348/500
Epoch 349/500
Epoch 350/500
Epoch 351/500
Epoch 352/500
Epoch 353/500
Epoch 354/500
Epoch 355/500
Epoch 356/500
Epoch 357/500
Epoch 358/500
Epoch 359/500
Epoch 360/500
Epoch 361/500
Epoch 362/500
Epoch 363/500
Epoch 364/500
Epoch 365/500
Epoch 366/500
Epoch 367/500
Epoch 368/500
Epoch 369/500
Epoch 370/500
Epoch 371/500
Epoch 372/500
Epoch 373/500
Epoch 374/500
Epoch 375/500
Epoch 376/500
Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 382/500
Epoch 383/500
Epoch 384/500
Epoch 385/500
Epoch 386/500
Epoch 387/500
Epoch 388/500
Epoch 389/500
Epoch 

Epoch 398/500
Epoch 399/500
Epoch 400/500
Epoch 401/500
Epoch 402/500
Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 438/500
Epoch 439/500
Epoch 440/500
Epoch 441/500
Epoch 442/500
Epoch 443/500
Epoch 444/500
Epoch 445/500
Epoch 446/500
Epoch 447/500
Epoch 448/500
Epoch 449/500
Epoch 450/500
Epoch 451/500
Epoch 452/500
Epoch 453/500
Epoch 454/500
Epoch 455/500
Epoch 456/500
Epoch 457/500
Epoch 458/500
Epoch 459/500
Epoch 460/500
Epoch 461/500
Epoch 462/500
Epoch 463/500
Epoch 464/500
Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 

Epoch 477/500
Epoch 478/500
Epoch 479/500
Epoch 480/500
Epoch 481/500
Epoch 482/500
Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500


<tensorflow.python.keras.callbacks.History at 0x7fb1181ac8d0>

In [49]:
model.training = False
print('Validation accuracy:', 
      accuracy_score(y_val, model(X_val).numpy().argmax(axis=1)))

Validation accuracy: 0.9600840336134454


* Export to tflite:

In [51]:
models_root = './robot_models'
os.makedirs(models_root, exist_ok=True)

name = 'robot_fc256'

keras.models.save_model(model, f'{models_root}/{name}.h5')

# !!! AttributeError: module 'tensorflow._api.v2.compat.v1' has no attribute 'TFLiteConverter'
# converter = tf.compat.v1.TFLiteConverter.from_keras_model_file(keras_file)
# tflite_model = converter.convert()
# open('linear.tflite', 'wb').write(tflite_model)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open(f'{models_root}/{name}.tflite', 'wb') as file:
    file.write(tflite_model)