In [None]:
# Imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

import wandb
from wandb.keras import WandbCallback


In [None]:
config = {
    "window_size": 30,
    "window_skip": 1,
    "epochs": 30,
    "batch_size": 32,
    "test_frac": 0.25,
    "use_class_weights": True,
    "n_hidden_units": {
        1: 128,
    },
    "lr": 1e-4,
}
wandb.init(
    project="ergo", 
    entity="beyarkay",
    config=config
)

In [None]:
# Function and constant definitions

FINGERS = [
    'left-5-x',
    'left-5-y',
    'left-5-z',
    'left-4-x',
    'left-4-y',
    'left-4-z',
    'left-3-x',
    'left-3-y',
    'left-3-z',
    'left-2-x',
    'left-2-y',
    'left-2-z',
    'left-1-x',
    'left-1-y',
    'left-1-z',
    'right-1-x',
    'right-1-y',
    'right-1-z',
    'right-2-x',
    'right-2-y',
    'right-2-z',
    'right-3-x',
    'right-3-y',
    'right-3-z',
    'right-4-x',
    'right-4-y',
    'right-4-z',
    'right-5-x',
    'right-5-y',
    'right-5-z',    
]

def make_batches(X, y, window_size=10, window_skip=1):
    assert window_skip == 1, 'window_skip is not supported for values other than 1'
    ends = np.array(range(window_size, len(y) - 1))
    starts = ends - window_size
    batched_X = np.empty((ends.shape[0], window_size, X.shape[1]))
    batched_y = np.empty((ends.shape[0],), dtype='object')
    for i in range(batched_y.shape[0]):
        batched_X[i] = X[starts[i]:ends[i]]
        batched_y[i] = y[ends[i]]
    return batched_X, batched_y

def gestures_and_indices(y):
    labels = sorted(np.unique(y))
    g2i_dict = {g:i for i, g in enumerate(labels)}
    i2g_dict = {i:g for i, g in enumerate(labels)}
    def g2i(g):
        return np.array([g2i_dict[gi] for gi in g])
    def i2g(i):
        return np.array([i2g_dict[ii] for ii in i])
    return g2i, i2g

def one_hot_and_back(y_all):
    return (
        lambda y: tf.one_hot(y, len(np.unique(y_all))),
        lambda onehot: tf.argmax(one_hot, axis=1)
    )

def conf_mat(model, X, y):
    y_pred = np.argmax(model.predict(X), axis=1)
    y_true = y

    confusion_mtx = tf.math.confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    labels = [i2g([i])[0] for i in range(confusion_mtx.shape[0])]
    sns.heatmap(
        confusion_mtx, 
        annot=True, 
        fmt='g',
        xticklabels=labels, 
        yticklabels=labels,
    )
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    return confusion_mtx


In [None]:
# Read in data and format to {X,y}_{train,test}
root = '../gesture_data/train/'
dfs = []
for path in os.listdir(root):
#     print(f'reading data from {path}')
    dfs.append(pd.read_csv(
        root + path,
        names=['datetime', 'gesture'] + FINGERS,
        parse_dates=[1]
    ))
df = pd.concat(dfs)

df.datetime = df.datetime.apply(pd.Timestamp)
X, y = make_batches(    
    df.drop(['datetime', 'gesture'], axis=1).to_numpy(), 
    df['gesture'].to_numpy(),
    window_size=wandb.config['window_size'],
    window_skip=wandb.config['window_skip'],
)
# Get functions to convert between gestures and indices
g2i, i2g = gestures_and_indices(y)
y = g2i(y)
# Get functions to convert between indices and one hot encodings
i2ohe, ohe2i = one_hot_and_back(y)

total = len(y)
n_unique = len(np.unique(y))
wandb.config['gestures'] = np.unique(y)
class_weight = {
    int(class_): (1/weight * total/n_unique) for class_, weight in zip(*np.unique(y, return_counts=True))
}

wandb.config['class_weight'] = class_weight if wandb.config.use_class_weights else None

X_train, X_valid, y_train, y_valid = train_test_split(
    X, 
    y, 
    test_size=wandb.config['test_frac'], 
    random_state=42
)


In [None]:
[(i2g([i])[0], c) for i, c in zip(*np.unique(y, return_counts=True))]

In [None]:
# Compile the model
inputs = layers.Input(shape=X_train.shape[1:])

normalizer = layers.Normalization(axis=-1)
normalizer.adapt(X_train)
x = normalizer(inputs)

x = layers.Flatten()(x)
x = layers.Dense(units=wandb.config.n_hidden_units['1'])(x)

outputs = layers.Dense(len(np.unique(y)), activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=wandb.config.lr),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[keras.metrics.SparseCategoricalAccuracy('acc')],
)

In [None]:
class MultiClassAccAndRecallCallback(keras.callbacks.Callback):
    def __init__(self, validation_data, training_data):
        super().__init__()
        self.validation_data = validation_data
        self.training_data = training_data
    
    def on_epoch_end(self, epoch, logs=None):
        # Calculate per-class {validation,training} {recall,precision}
        datas = [
            self.validation_data,
            self.training_data,
        ]
        keys = ['valid', 'train']
        for key, data in zip(keys, datas):
            X, y = data
            conf_mat = tf.math.confusion_matrix(
                np.argmax(self.model.predict(X, verbose=0), axis=1), 
                y,
            ).numpy()
            precision = np.diag(conf_mat)  / conf_mat.sum(axis=0)
            recall = np.diag(conf_mat)  / conf_mat.sum(axis=1)
            
            ipr = list(zip(range(len(precision)), precision, recall))
            prec_and_recall = {i2g([i])[0]: {'precision': p, 'recall': r} for i, p, r in ipr}
            wandb.log({key: prec_and_recall}, commit=False)

In [None]:
# Fit the model
history = model.fit(
    X_train, 
    y_train,
    batch_size=wandb.config['batch_size'], 
    epochs=wandb.config['epochs'],
    validation_data=(X_valid, y_valid),
    callbacks=[WandbCallback(), MultiClassAccAndRecallCallback((X_valid, y_valid), (X_train, y_train))],
    class_weight=wandb.config['class_weight'],
)
wandb.finish()

In [None]:
no_gesture0255 = (y_valid != g2i(['gesture0255'])[0])
conf_mat(model, X_valid[no_gesture0255], y_valid[no_gesture0255])
plt.title('Validation set\n(`gesture0255` removed)')
plt.show()

# no_gesture0255 = (y_train != g2i(['gesture0255'])[0])
# conf_mat(model, X_train[no_gesture0255], y_train[no_gesture0255])
# plt.title('Training set\n(`gesture0255` removed)')
# plt.show

In [None]:
confusion_mtx = tf.math.confusion_matrix(
    np.argmax(model.predict(X_valid, verbose=0), axis=1), 
    y_valid
).numpy()
precision = np.diag(confusion_mtx)  / confusion_mtx.sum(axis=0)
recall = np.diag(confusion_mtx)  / confusion_mtx.sum(axis=1)
ipr = list(zip(range(len(precision)), precision, recall))
prec_and_recall = {i2g([i])[0]: {'precision': p, 'recall': r} for i, p, r in ipr}

print('\n'.join([f'{i2g([i])[0]}:    precision:{p:.3f}, recall: {r:.3f}' for i, p, r in ipr]))
