# Train a model to recogize static gestures

In [1]:
import warnings
warnings.filterwarnings('ignore')

## Prepare the data

In [2]:
import os
import numpy as np
import pandas as pd
import pickle

In [3]:
data = pickle.load(open('../out_data/rsl_hand_keypoints_150421_05_07.pkl', 'rb'))
data

[{'hand_type': 'right',
  'gesture_id': 0,
  'gesture_name': 'а',
  'pred_gesture_name': 'FIST',
  'keypoints': array([[ 2.67089456e-01,  6.60982251e-01, -2.08790916e-05],
         [ 3.30584913e-01,  6.34182751e-01, -4.25098836e-02],
         [ 3.82477403e-01,  5.80635548e-01, -7.67343044e-02],
         [ 3.96937668e-01,  5.13537645e-01, -1.06835015e-01],
         [ 3.69887769e-01,  4.76100266e-01, -1.18069492e-01],
         [ 3.34840119e-01,  5.07150948e-01, -1.71379764e-02],
         [ 3.37242931e-01,  4.45898712e-01, -8.58453214e-02],
         [ 3.37169409e-01,  5.10042250e-01, -8.25941339e-02],
         [ 3.40645224e-01,  5.32609284e-01, -4.55603376e-02],
         [ 2.96800733e-01,  5.02685487e-01, -1.93976499e-02],
         [ 2.97395349e-01,  4.43975806e-01, -1.20612361e-01],
         [ 3.00099164e-01,  5.27399480e-01, -1.20525457e-01],
         [ 3.04577440e-01,  5.35932124e-01, -8.16024840e-02],
         [ 2.59712875e-01,  5.06670117e-01, -3.23568396e-02],
         [ 2.58578002e

> **Note:** The data is collected for the right hand only.

In [5]:
gesture_ids = np.array([x['gesture_id'] for x in data]).reshape(-1, 1)
gesture_names = np.array([x['gesture_name'] for x in data]).reshape(-1, 1)
keypoints = np.array([x['keypoints'] for x in data]).reshape(-1, 63)
scaled_keypoints = np.array([x['scaled_keypoints'] for x in data]).reshape(-1, 63)

In [6]:
print(gesture_ids.shape, gesture_names.shape, keypoints.shape, scaled_keypoints.shape)

(3324, 1) (3324, 1) (3324, 63) (3324, 63)


In [7]:
columns = ['gesture_id', 'gesture_name']
for i in range(scaled_keypoints.shape[1] // 3):
    columns += [f'kp{i}_x', f'kp{i}_y', f'kp{i}_z']
df = pd.DataFrame(data=np.hstack([gesture_ids, gesture_names, scaled_keypoints]), 
                  columns=columns)
df

Unnamed: 0,gesture_id,gesture_name,kp0_x,kp0_y,kp0_z,kp1_x,kp1_y,kp1_z,kp2_x,kp2_y,...,kp17_z,kp18_x,kp18_y,kp18_z,kp19_x,kp19_y,kp19_z,kp20_x,kp20_y,kp20_z
0,0,а,0.4705820679664612,0.796766459941864,-5.952839637757279e-05,0.6320144534111023,0.6967858076095581,-0.12119996547698975,0.747340977191925,0.5494773983955383,...,-0.13368837535381317,0.25215715169906616,0.42273616790771484,-0.23778514564037323,0.3088401257991791,0.5225798487663269,-0.23095780611038208,0.32632893323898315,0.5603058934211731,-0.16367658972740173
1,0,а,0.4762500822544098,0.7909183502197266,-8.796168549451977e-05,0.6325406432151794,0.6932122111320496,-0.1254713237285614,0.7475526928901672,0.554513156414032,...,-0.13856087625026703,0.25062090158462524,0.4330465495586395,-0.23548181354999542,0.30781179666519165,0.5319973826408386,-0.2320072203874588,0.32554391026496887,0.5679558515548706,-0.16866394877433777
2,0,а,0.47361570596694946,0.8141859173774719,-2.2398537112167105e-05,0.6298989057540894,0.7206665873527527,-0.11891934275627136,0.7410799264907837,0.5806694626808167,...,-0.09826485812664032,0.26282739639282227,0.4407682716846466,-0.20334528386592865,0.31513866782188416,0.5448867678642273,-0.200063094496727,0.32395967841148376,0.5724753737449646,-0.1379202902317047
3,0,а,0.4597257971763611,0.8010740280151367,-3.998601096100174e-05,0.6129380464553833,0.6996431350708008,-0.11300965398550034,0.724448025226593,0.5550662279129028,...,-0.13049684464931488,0.25256723165512085,0.44187694787979126,-0.2183557152748108,0.3047855794429779,0.5375128388404846,-0.21153496205806732,0.3142794072628021,0.5576078295707703,-0.1528605818748474
4,0,а,0.46551647782325745,0.797554075717926,-2.4780123567325063e-05,0.6154492497444153,0.6900844573974609,-0.1179519072175026,0.7195952534675598,0.5500237345695496,...,-0.12548692524433136,0.26072946190834045,0.4423027038574219,-0.21988382935523987,0.31294265389442444,0.542239785194397,-0.21235616505146027,0.3187944293022156,0.558204710483551,-0.15105491876602173
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3319,11,no_gesture,0.41811972856521606,0.7935647964477539,-0.00026196689577773213,0.3369643986225128,0.6581730246543884,0.1398848593235016,0.3340877592563629,0.5164526700973511,...,-0.19694684445858002,0.6625296473503113,0.4708195626735687,-0.17526967823505402,0.7101671099662781,0.4238339066505432,-0.12373901903629303,0.7403766512870789,0.3978142738342285,-0.07982686161994934
3320,11,no_gesture,0.40528959035873413,0.7980723977088928,-0.00029728253139182925,0.31517839431762695,0.6708207726478577,0.1520850658416748,0.3056996464729309,0.5439714193344116,...,-0.10434993356466293,0.6577094197273254,0.45000284910202026,-0.07622505724430084,0.714701771736145,0.40628668665885925,-0.042804475873708725,0.7595261931419373,0.38513484597206116,-0.025336606428027153
3321,11,no_gesture,0.40624600648880005,0.8087884783744812,-0.0003180171479471028,0.3267473578453064,0.6900452971458435,0.18300282955169678,0.33200207352638245,0.5550119280815125,...,-0.16702590882778168,0.6883755922317505,0.4810240864753723,-0.11715154349803925,0.7536380887031555,0.45250770449638367,-0.0570104718208313,0.8005122542381287,0.4411378800868988,-0.019529180601239204
3322,11,no_gesture,0.4067753255367279,0.7988735437393188,-0.0003317581140436232,0.3170301616191864,0.6799787282943726,0.1768741011619568,0.3100930154323578,0.5549732446670532,...,-0.06930389255285263,0.6736761927604675,0.4573438763618469,-0.02326342649757862,0.7356796860694885,0.42045775055885315,0.026450440287590027,0.7826856970787048,0.4078677296638489,0.05795770883560181


In [8]:
df['gesture_name'].value_counts()

no_gesture    1004
б              241
а              231
я              221
е              211
ж              205
л              203
м              202
н              202
г              202
в              202
и              200
Name: gesture_name, dtype: int64

## Train sklearn model

In [9]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import make_scorer, accuracy_score, f1_score, precision_score, recall_score
from sklearn.model_selection import cross_val_score

In [10]:
X, y = df.iloc[:, 2:].values.astype('float'), df.iloc[:, 0].values.astype('int32')
print(X.shape, y.shape)
print(X[0], y[0], np.unique(y))

(3324, 63) (3324,)
[ 4.70582068e-01  7.96766460e-01 -5.95283964e-05  6.32014453e-01
  6.96785808e-01 -1.21199965e-01  7.47340977e-01  5.49477398e-01
 -2.18777239e-01  7.52097189e-01  4.00270224e-01 -3.04597408e-01
  6.57687724e-01  3.41591448e-01 -3.36628020e-01  5.77113450e-01
  4.30479914e-01 -4.88620996e-02  5.51611483e-01  3.01826924e-01
 -2.44753674e-01  5.85068703e-01  4.34843689e-01 -2.35484198e-01
  6.06517971e-01  4.79191691e-01 -1.29897118e-01  4.69632626e-01
  4.47839081e-01 -5.53046577e-02  4.40467149e-01  3.25721741e-01
 -3.43878001e-01  4.91718233e-01  4.96761739e-01 -3.43630224e-01
  5.08573472e-01  5.11315882e-01 -2.32656911e-01  3.69215965e-01
  4.82048899e-01 -9.22526121e-02  3.38155985e-01  3.72540414e-01
 -3.28931540e-01  4.00546640e-01  5.25510371e-01 -3.13929737e-01
  4.15574819e-01  5.43543398e-01 -2.16082990e-01  2.71617562e-01
  5.26124120e-01 -1.33688375e-01  2.52157152e-01  4.22736168e-01
 -2.37785146e-01  3.08840126e-01  5.22579849e-01 -2.30957806e-01
  3.26

In [11]:
idx = np.arange(X.shape[0])
np.random.shuffle(idx)
X, y = X[idx], y[idx]

In [12]:
# class_amounts = {0: 400, 1: 200, 2: 200, 3: 200}

# X_new, y_new = None, None
# for class_id in np.unique(y):
#     class_indices = np.argwhere(class_id == y)
#     print(class_indices.shape)
#     subsample_indices = np.random.choice(class_indices[:,0], size=class_amounts[class_id])
#     print(subsample_indices.shape)
#     if X_new is None:
#         X_new = X[subsample_indices]
#     else:
#         X_new = np.vstack((X_new, X[subsample_indices]))
#     if y_new is None:
#         y_new = y[subsample_indices].reshape(-1,1)
#     else:
#         y_new = np.vstack((y_new, y[subsample_indices].reshape(-1,1)))

# print(X_new.shape, y_new.shape)
# X = X_new
# y = y_new

In [13]:
logreg = LogisticRegression()
cross_val_score(logreg, X=X, y=y, scoring=make_scorer(accuracy_score), cv=5)

array([0.87630402, 0.88905547, 0.88687783, 0.85800604, 0.86535552])

In [14]:
rfclf = RandomForestClassifier()
cross_val_score(rfclf, X=X, y=y, scoring=make_scorer(accuracy_score), cv=5)

array([0.93740686, 0.95202399, 0.94419306, 0.94864048, 0.94856278])

In [15]:
gbclf = GradientBoostingClassifier()
cross_val_score(gbclf, X=X, y=y, scoring=make_scorer(accuracy_score), cv=5)

array([0.93889717, 0.95352324, 0.94570136, 0.94561934, 0.95007564])

In [16]:
models_root = './rsl_models'
os.makedirs(models_root, exist_ok=True)
pickle.dump(gbclf, open(f'{models_root}/rsl_gb_default_scaled.pkl', 'wb'))

## Train Keras model

In [17]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [18]:
X, y = df.iloc[:, 2:].values.astype('float'), df.iloc[:, 0].values.astype('int32')
print(X.shape, y.shape)
print(X[0], y[0], np.unique(y))

(3324, 63) (3324,)
[ 4.70582068e-01  7.96766460e-01 -5.95283964e-05  6.32014453e-01
  6.96785808e-01 -1.21199965e-01  7.47340977e-01  5.49477398e-01
 -2.18777239e-01  7.52097189e-01  4.00270224e-01 -3.04597408e-01
  6.57687724e-01  3.41591448e-01 -3.36628020e-01  5.77113450e-01
  4.30479914e-01 -4.88620996e-02  5.51611483e-01  3.01826924e-01
 -2.44753674e-01  5.85068703e-01  4.34843689e-01 -2.35484198e-01
  6.06517971e-01  4.79191691e-01 -1.29897118e-01  4.69632626e-01
  4.47839081e-01 -5.53046577e-02  4.40467149e-01  3.25721741e-01
 -3.43878001e-01  4.91718233e-01  4.96761739e-01 -3.43630224e-01
  5.08573472e-01  5.11315882e-01 -2.32656911e-01  3.69215965e-01
  4.82048899e-01 -9.22526121e-02  3.38155985e-01  3.72540414e-01
 -3.28931540e-01  4.00546640e-01  5.25510371e-01 -3.13929737e-01
  4.15574819e-01  5.43543398e-01 -2.16082990e-01  2.71617562e-01
  5.26124120e-01 -1.33688375e-01  2.52157152e-01  4.22736168e-01
 -2.37785146e-01  3.08840126e-01  5.22579849e-01 -2.30957806e-01
  3.26

In [19]:
idx = np.arange(X.shape[0])
np.random.shuffle(idx)
X, y = X[idx], y[idx]

In [20]:
n_classes = len(np.unique(y))
n_feats = X.shape[1]

In [21]:
# class_amounts = {0: 400, 1: 200, 2: 200, 3: 200}

# X_new, y_new = None, None
# for class_id in np.unique(y):
#     class_indices = np.argwhere(class_id == y)
#     print(class_indices.shape)
#     subsample_indices = np.random.choice(class_indices[:,0], size=class_amounts[class_id])
#     print(subsample_indices.shape)
#     if X_new is None:
#         X_new = X[subsample_indices]
#     else:
#         X_new = np.vstack((X_new, X[subsample_indices]))
#     if y_new is None:
#         y_new = y[subsample_indices].reshape(-1,1)
#     else:
#         y_new = np.vstack((y_new, y[subsample_indices].reshape(-1,1)))

# print(X_new.shape, y_new.shape)
# X = X_new
# y = y_new

In [22]:
y_ohe = np.zeros((y.shape[0], n_classes))
y_ohe[np.arange(len(y)).reshape(-1,1), y.reshape(-1, 1)] = 1
print(y_ohe.sum(axis=0))

[ 231.  241.  202.  202.  211.  205.  200.  203.  202.  202.  221. 1004.]


In [23]:
train_ratio = 0.75
X_train, y_train = X[:int(len(X)*train_ratio)], y_ohe[:int(len(X)*train_ratio)]
X_val, y_val = X[int(len(X)*train_ratio):], y[int(len(X)*train_ratio):]

* One neuron (linear / logistic regression):

In [24]:
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)

lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

model = keras.Sequential([keras.layers.Dense(units=n_classes, input_shape=[n_feats])])
model.compile(optimizer=optimizer, loss=loss_func, metrics=['accuracy'])

In [25]:
model.training = True
model.fit(X_train, y_train, epochs=500)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/

Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 207/500
Epoch 208/500
Epoch 209/500
Epoch 210/500
Epoch 211/500
Epoch 212/500
Epoch 213/500
Epoch 214/500
Epoch 215/500
Epoch 216/500
Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 228/500
Epoch 229/500
Epoch 230/500
Epoch 231/500
Epoch 

Epoch 240/500
Epoch 241/500
Epoch 242/500
Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 256/500
Epoch 257/500
Epoch 258/500
Epoch 259/500
Epoch 260/500
Epoch 261/500
Epoch 262/500
Epoch 263/500
Epoch 264/500
Epoch 265/500
Epoch 266/500
Epoch 267/500
Epoch 268/500
Epoch 269/500
Epoch 270/500
Epoch 271/500
Epoch 272/500
Epoch 273/500
Epoch 274/500
Epoch 275/500
Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 288/500
Epoch 289/500
Epoch 290/500
Epoch 291/500
Epoch 292/500
Epoch 293/500
Epoch 294/500
Epoch 295/500
Epoch 296/500
Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 305/500
Epoch 306/500
Epoch 307/500
Epoch 308/500
Epoch 309/500
Epoch 310/500
Epoch 

Epoch 319/500
Epoch 320/500
Epoch 321/500
Epoch 322/500
Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 347/500
Epoch 348/500
Epoch 349/500
Epoch 350/500
Epoch 351/500
Epoch 352/500
Epoch 353/500
Epoch 354/500
Epoch 355/500
Epoch 356/500
Epoch 357/500
Epoch 358/500
Epoch 359/500
Epoch 360/500
Epoch 361/500
Epoch 362/500
Epoch 363/500
Epoch 364/500
Epoch 365/500
Epoch 366/500
Epoch 367/500
Epoch 368/500
Epoch 369/500
Epoch 370/500
Epoch 371/500
Epoch 372/500
Epoch 373/500
Epoch 374/500
Epoch 375/500
Epoch 376/500
Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 382/500
Epoch 383/500
Epoch 384/500
Epoch 385/500
Epoch 386/500
Epoch 387/500
Epoch 388/500
Epoch 389/500
Epoch 

Epoch 398/500
Epoch 399/500
Epoch 400/500
Epoch 401/500
Epoch 402/500
Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 438/500
Epoch 439/500
Epoch 440/500
Epoch 441/500
Epoch 442/500
Epoch 443/500
Epoch 444/500
Epoch 445/500
Epoch 446/500
Epoch 447/500
Epoch 448/500
Epoch 449/500
Epoch 450/500
Epoch 451/500
Epoch 452/500
Epoch 453/500
Epoch 454/500
Epoch 455/500
Epoch 456/500
Epoch 457/500
Epoch 458/500
Epoch 459/500
Epoch 460/500
Epoch 461/500
Epoch 462/500
Epoch 463/500
Epoch 464/500
Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 

Epoch 477/500
Epoch 478/500
Epoch 479/500
Epoch 480/500
Epoch 481/500
Epoch 482/500
Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500


<tensorflow.python.keras.callbacks.History at 0x7f8a846ad490>

In [26]:
model.training = False
print('Validation accuracy:', 
      accuracy_score(y_val, model(X_val).numpy().argmax(axis=1)))

Validation accuracy: 0.9205776173285198


* Few FC layers:

In [27]:
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)

lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

model = keras.Sequential([
    keras.layers.Dense(units=256, input_shape=[n_feats]),
    keras.layers.ReLU(),
    keras.layers.Dense(units=n_classes)
])
model.compile(optimizer=optimizer, loss=loss_func, metrics=['accuracy'])

In [28]:
model.training = True
model.fit(X_train, y_train, epochs=500)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/

Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 207/500
Epoch 208/500
Epoch 209/500
Epoch 210/500
Epoch 211/500
Epoch 212/500
Epoch 213/500
Epoch 214/500
Epoch 215/500
Epoch 216/500
Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 228/500
Epoch 229/500
Epoch 230/500
Epoch 231/500
Epoch 

Epoch 240/500
Epoch 241/500
Epoch 242/500
Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 256/500
Epoch 257/500
Epoch 258/500
Epoch 259/500
Epoch 260/500
Epoch 261/500
Epoch 262/500
Epoch 263/500
Epoch 264/500
Epoch 265/500
Epoch 266/500
Epoch 267/500
Epoch 268/500
Epoch 269/500
Epoch 270/500
Epoch 271/500
Epoch 272/500
Epoch 273/500
Epoch 274/500
Epoch 275/500
Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 288/500
Epoch 289/500
Epoch 290/500
Epoch 291/500
Epoch 292/500
Epoch 293/500
Epoch 294/500
Epoch 295/500
Epoch 296/500
Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 305/500
Epoch 306/500
Epoch 307/500
Epoch 308/500
Epoch 309/500
Epoch 310/500
Epoch 

Epoch 319/500
Epoch 320/500
Epoch 321/500
Epoch 322/500
Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 347/500
Epoch 348/500
Epoch 349/500
Epoch 350/500
Epoch 351/500
Epoch 352/500
Epoch 353/500
Epoch 354/500
Epoch 355/500
Epoch 356/500
Epoch 357/500
Epoch 358/500
Epoch 359/500
Epoch 360/500
Epoch 361/500
Epoch 362/500
Epoch 363/500
Epoch 364/500
Epoch 365/500
Epoch 366/500
Epoch 367/500
Epoch 368/500
Epoch 369/500
Epoch 370/500
Epoch 371/500
Epoch 372/500
Epoch 373/500
Epoch 374/500
Epoch 375/500
Epoch 376/500
Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 382/500
Epoch 383/500
Epoch 384/500
Epoch 385/500
Epoch 386/500
Epoch 387/500
Epoch 388/500
Epoch 389/500
Epoch 

Epoch 398/500
Epoch 399/500
Epoch 400/500
Epoch 401/500
Epoch 402/500
Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 438/500
Epoch 439/500
Epoch 440/500
Epoch 441/500
Epoch 442/500
Epoch 443/500
Epoch 444/500
Epoch 445/500
Epoch 446/500
Epoch 447/500
Epoch 448/500
Epoch 449/500
Epoch 450/500
Epoch 451/500
Epoch 452/500
Epoch 453/500
Epoch 454/500
Epoch 455/500
Epoch 456/500
Epoch 457/500
Epoch 458/500
Epoch 459/500
Epoch 460/500
Epoch 461/500
Epoch 462/500
Epoch 463/500
Epoch 464/500
Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 

Epoch 477/500
Epoch 478/500
Epoch 479/500
Epoch 480/500
Epoch 481/500
Epoch 482/500
Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500


<tensorflow.python.keras.callbacks.History at 0x7f8a84ae4bd0>

In [29]:
model.training = False
print('Validation accuracy:', 
      accuracy_score(y_val, model(X_val).numpy().argmax(axis=1)))

Validation accuracy: 0.9638989169675091


* Export to tflite:

In [30]:
models_root = './rsl_models'
os.makedirs(models_root, exist_ok=True)

name = 'rsl_fc256_scaled'

keras.models.save_model(model, f'{models_root}/{name}.h5')

# !!! AttributeError: module 'tensorflow._api.v2.compat.v1' has no attribute 'TFLiteConverter'
# converter = tf.compat.v1.TFLiteConverter.from_keras_model_file(keras_file)
# tflite_model = converter.convert()
# open('linear.tflite', 'wb').write(tflite_model)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open(f'{models_root}/{name}.tflite', 'wb') as file:
    file.write(tflite_model)

* Check:

In [31]:
import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path=f'{models_root}/{name}.tflite')
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

1 input(s):
[ 1 63] <class 'numpy.float32'>

1 output(s):
[ 1 12] <class 'numpy.float32'>
