#### Notes
- Relative probabilities are interesting
    - Probability of a firetruck being classified as a bmw is small, but much larger than an image of a firetruck being classified as a toad.
- An ensamble of specialised learners
    - Create a group of specialists, which only learn to distinguish between easily confused classes and throws any other classes into a "OTHER" group-class
- High precision       => Many of the predicted class X are actually class X
- Increasing precision => Fewer non-X gestures are being predicted as gesture-X
- High recall          => Many of class X are actually classified as class X
- Increasing recall    => Fewer of gesture-X are predicted as anything but X

#### Things to try
- Recurrent/LSTM
- Convolutional NN
- Dropout
- Drop some g255 to make the classes equal
- Tweak the class weightings
- pass std as a feature
- Make sure the gestures are labelled correctly
- Why isn't the validation loss getting weighted like the training loss is?
- Get some dumb thresholds to spot the rising/falling edge and only then send it off to the model

#### Unanswered questions
- Why is recall so low for every gesture except g255
- Why don't the inverted frequencies work as class weights
- Why does precision.g255 decrease as all others increase?

In [None]:
# Imports
import tensorflow as tf
from tensorflow import keras
keras.utils.set_random_seed(42)
from tensorflow.keras import layers

from sklearn.model_selection import train_test_split

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import datetime

import wandb
from wandb.keras import WandbCallback
from ipywidgets import interact
import ipywidgets as widgets

In [None]:
config = {
    "window_size": 15,
    "window_skip": 1,
    "epochs": 50,
    "batch_size": 2048,
    "test_frac": 0.25,
    "use_class_weights": True,
    "n_hidden_units": {
        1: 128,
        2: 128,
        3: 128,
    },
    "lr": 1e-3,
    "loss_fn": keras.losses.SparseCategoricalCrossentropy(),
    'activation': "softmax",
    "omit_0255": False,
    "g255_vs_rest": False,
    "use_wandb": False,
}
config['optimiser'] = keras.optimizers.Adam(learning_rate=config['lr'])

In [None]:
# config['use_wandb'] = True
# wandb.init(
#     project="ergo",
#     entity="beyarkay",
#     config=config,
# )

In [None]:
# Function and constant definitions

FINGERS = [
    'left-5-x',
    'left-5-y',
    'left-5-z',
    'left-4-x',
    'left-4-y',
    'left-4-z',
    'left-3-x',
    'left-3-y',
    'left-3-z',
    'left-2-x',
    'left-2-y',
    'left-2-z',
    'left-1-x',
    'left-1-y',
    'left-1-z',
    'right-1-x',
    'right-1-y',
    'right-1-z',
    'right-2-x',
    'right-2-y',
    'right-2-z',
    'right-3-x',
    'right-3-y',
    'right-3-z',
    'right-4-x',
    'right-4-y',
    'right-4-z',
    'right-5-x',
    'right-5-y',
    'right-5-z',    
]

def make_batches(X, y, window_size=10, window_skip=1):
    assert window_skip == 1, 'window_skip is not supported for values other than 1'
    ends = np.array(range(window_size, len(y) - 1))
    starts = ends - window_size
    batched_X = np.empty((ends.shape[0], window_size, X.shape[1]))
    batched_y = np.empty((ends.shape[0],), dtype='object')
    for i in range(batched_y.shape[0]):
        batched_X[i] = X[starts[i]:ends[i]]
        batched_y[i] = y[ends[i]]
    return batched_X, batched_y

def gestures_and_indices(y):
    labels = sorted(np.unique(y))
    g2i_dict = {g:i for i, g in enumerate(labels)}
    i2g_dict = {i:g for i, g in enumerate(labels)}
    def g2i(g):
        not_list = type(g) not in [list, np.ndarray]
        if not_list: g = [g]
        result = np.array([g2i_dict[gi] for gi in g])
        return result[0] if not_list else result

    def i2g(i):
        not_list = type(i) not in [list, np.ndarray]
        if not_list: i = [i]
        result = np.array([i2g_dict[ii] for ii in i])
        return result[0] if not_list else result
    
    return g2i, i2g

def one_hot_and_back(y_all):
    return (
        lambda y: tf.one_hot(y, len(np.unique(y_all))),
        lambda onehot: tf.argmax(one_hot, axis=1)
    )

def conf_mat(model, X, y):
    y_pred = np.argmax(model.predict(X), axis=1)
    y_true = y

    confusion_mtx = tf.math.confusion_matrix(y_true, y_pred).numpy()
    plt.figure(figsize=(10, 8))
    labels = [i2g(i) for i in range(confusion_mtx.shape[0])]
    sns.heatmap(
        confusion_mtx, 
        annot=True, 
        fmt='g',
        xticklabels=labels, 
        yticklabels=labels,
        vmin=confusion_mtx.min(),
        vmax=confusion_mtx[:-1, :-1].max(),
    )
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    return confusion_mtx

def plot_timeseries(X, y, t=None, per='dimension'):
    # Make sure the given dataset is correctly formatted
    assert X.shape[0] == y.shape[0], 'There must be one y value for each X value'
    assert X.shape[1] == len(FINGERS), f'{X.shape[1]=} doesn\'t equal the number of finger labels ({len(FINGERS)})'
    assert not np.isnan(X_mean).any(), f'Input dataset has {np.isnan(X_mean).sum()} NaN values. Should have 0'
    
    # If we've got many many points, only show an abridged version of the plot
    abridged = X.shape[0] > 4000
    if per == 'dimension':
        nrows, ncols = (3, 1)
    elif per == 'finger':
        nrows, ncols = (5, 2)
        
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(13, 8))
    if len(axs.shape) > 1:
        axs = axs.T.flatten()

    ymin = float('inf')
    ymax = float('-inf')

    max_std = X.std(axis=0).max()
    for d in range(X.shape[1]):
        if per == 'dimension':
            ax_idx = d % 3
        elif per == 'finger':
            ax_idx = d // 3
            
        ax = axs[ax_idx]
        data_to_plot = X[:, d]
        ax.plot(
            data_to_plot, 
            alpha=np.clip(data_to_plot.std() / max_std, 0.05, 1.0),
            label=FINGERS[d],
            c=None if per == 'dimension' else ('tab:red', 'tab:green', 'tab:blue')[d%3]
        )
        
        # Set the title of each plot
        if per == 'dimension':
            ax.set_title(f'{FINGERS[d][-1]}')
        elif per == 'finger':
            ax.set_title(f'{FINGERS[d][:-2]}')
            
        ymax = max(ymax, X[:, d].max())
        ymin = min(ymin, X[:, d].min())

    # Plot the ticks and legend for each axis
    NUM_LABELS = 40 if per == 'dimension' else 20
    TICKS_PER_LABEL = max(1, X.shape[0] // NUM_LABELS)
    for i, ax in enumerate(axs):
        if abridged:
            ax.set_xticks([])
            ax.set_xticklabels([])
        else:
            ax.set_xticks(range(0, X.shape[0], TICKS_PER_LABEL))
            if (per == 'dimension' and i != len(axs)-1) or (per == 'finger' and i % 5 != 4):
                ax.set_xticklabels([])
            elif t is not None:
                ax.set_xticklabels(t[::TICKS_PER_LABEL], rotation=90)
        if per == 'dimension':
            handles, labels = ax.get_legend_handles_labels()
            ax.legend(
                handles, 
                labels, 
                loc='center left', 
                bbox_to_anchor=(1., 0.5)
            )
    # Plot the labels for each timestep and axis
    for dim_idx, ax in enumerate(axs):
        backtrack = 0
        for time in range(X.shape[0]):
            if abridged:
                continue
            if y[time] not in ['gesture0255', 'g255'] and time != X.shape[0]-1:
                backtrack += 1
                continue
            elif y[time] in ['gesture0255', 'g255'] and backtrack == 0:
                continue
            else:
                ax.fill_betweenx(
                    y=[ymin * 0.9, ymax * 1.1],
                    x1=[time - backtrack - .5, time - backtrack - .5],
                    x2=[time - 0.5, time - 0.5],
                    color='grey',
                    alpha=0.1
                )

                txt = y[time - backtrack].replace('gesture0', 'g')
                ax.text(
                    time - backtrack / 2 - .5,
                    (ymax - ymin)/2 + ymin,
                    txt,
                    va='baseline',
                    ha='center',
                    rotation=90
                )
                backtrack = 0
        ax.set_ylim((ymin, ymax))
        
    plt.tight_layout()
    return fig, axs


In [None]:
# Read in data and format to {X,y}_{train,test}

def parse_csvs(root='../gesture_data/train/'):
    print("TODO does this take into account how batches shouldn't straddle non-contiguous datasets?")
    dfs = []
    for path in os.listdir(root):
    #     print(f'reading data from {path}')
        dfs.append(pd.read_csv(
            root + path,
            names=['datetime', 'gesture'] + FINGERS,
            parse_dates=[1]
        ))
    df = pd.concat(dfs)

    df.datetime = df.datetime.apply(pd.Timestamp)
    return df
df = parse_csvs()
# df = pd.read_csv('../gesture_data/relabelled_gesture0001.csv')
X, y = make_batches(    
    df.drop(['datetime', 'gesture'], axis=1).to_numpy(), 
    df['gesture'].to_numpy(),
    window_size=config['window_size'],
    window_skip=config['window_skip'],
)

if config.get('omit_0255', False):
    X = X[y != 'gesture0255']
    y = y[y != 'gesture0255']
    
if config.get('g255_vs_rest', False):
    y = np.where(
        y == 'gesture0255',
        'gesture0255',
        'gesture0000'
    )

# Get functions to convert between gestures and indices
g2i, i2g = gestures_and_indices(y)
y = g2i(y)
# Get functions to convert between indices and one hot encodings
i2ohe, ohe2i = one_hot_and_back(y)

total = len(y)
n_unique = len(np.unique(y))
config['gestures'] = np.unique(y)
class_weight = {
    int(class_): (1/freq * total/n_unique) for class_, freq in zip(*np.unique(y, return_counts=True))
}

# class_weight[g2i('gesture0255')] *= 2

config['class_weight'] = class_weight if config['use_class_weights'] else None

X_train, X_valid, y_train, y_valid = train_test_split(
    X, 
    y, 
    test_size=config['test_frac'], 
    random_state=42
)

In [None]:
[(i2g(i), c) for i, c in zip(*np.unique(y, return_counts=True))]

In [None]:
# Compile the model
inputs = layers.Input(shape=X_train.shape[1:])

normalizer = layers.Normalization(axis=-1)
normalizer.adapt(X_train)
x = normalizer(inputs)

x = layers.Flatten()(x)
for layer_number, num_units in config.get("n_hidden_units").items():
    x = layers.Dense(
        units=num_units,
    )(x)


def init_biases(shape, dtype=None):
    assert shape == [len(class_weight)], f"Shape {shape} isn't (11,)"
    inv_freqs = np.array([1/v for v in class_weight.values()])
    return np.log(inv_freqs)

outputs = layers.Dense(
    len(np.unique(y)), 
    activation=config.get("activation"),
    bias_initializer=init_biases,
)(x)

model = keras.Model(inputs=inputs, outputs=outputs)

metrics = [
    keras.metrics.SparseCategoricalAccuracy(name='sca'),
    keras.metrics.SparseCategoricalCrossentropy(name='scce')
]

model.compile(
    optimizer=config['optimiser'],
    loss=config['loss_fn'],
    weighted_metrics=metrics,
)

In [None]:
# Define a custom callback for multi-class accuracy/recall/precision
import time
class MultiClassAccAndRecallCallback(keras.callbacks.Callback):
    def __init__(self, validation_data, training_data):
        super().__init__()
        self.validation_data = validation_data
        self.training_data = training_data

    def on_epoch_end(self, epoch, logs=None):
        start = time.time()
        # Calculate per-class {validation,training} {recall,precision}
        datas = [
            self.validation_data,
            self.training_data,
        ]
        keys = ['valid', 'train']
        for key, data in zip(keys, datas):
            X, y = data
            conf_mat = tf.math.confusion_matrix(
                np.argmax(self.model.predict(X, verbose=0), axis=1),
                y,
            ).numpy()
            precision = np.diag(conf_mat)  / conf_mat.sum(axis=0)
            recall = np.diag(conf_mat)  / conf_mat.sum(axis=1)

            ipr = list(zip(range(len(precision)), precision, recall))
            prec_and_recall = {i2g([i])[0]: {'precision': p, 'recall': r} for i, p, r in ipr}
            if config['use_wandb']:
                wandb.log({key: prec_and_recall}, commit=False)
        duration = time.time() - start
        if config['use_wandb']:
            wandb.log({'tt_calc_precision_recall': duration}, commit=False)
        
# class LogConfusionMatrixCallback(keras.callbacks.Callback):
#     def __init__(self, validation_data):
#         super().__init__()
#         self.validation_data = validation_data

#     def on_epoch_end(self, epoch, logs=None):
#         labels = np.unique(self.validation_data[1])        
#         cm = wandb.plot.confusion_matrix(
#             y_true=self.validation_data[1],
#             preds=np.argmax(self.model.predict(self.validation_data[0], verbose=0), axis=1),
#             class_names=labels
#         )
#         wandb.log({"conf_mat": cm}, commit=False)
        

In [None]:
%%time
# Fit the model
if config['use_wandb']:
    callbacks = [
        WandbCallback(),
        MultiClassAccAndRecallCallback((X_valid, y_valid), (X_train, y_train)),
    ]
else:
    callbacks = []

history = model.fit(
    X_train, 
    y_train,
    batch_size=config['batch_size'], 
    epochs=config['epochs'],
    validation_data=(X_valid, y_valid),
    callbacks=callbacks,
    class_weight=config['class_weight'],
)
if config['use_wandb']:
    wandb.finish()
model.save(f'models/{datetime.datetime.now()}')

In [None]:
# Plot metrics from history
items = list(history.history.keys())
items = items[:len(items)//2]
_fig, axs = plt.subplots(1, len(items), figsize=(5*len(items), 5))

for item, ax in zip(items, axs):
    ax.plot(
        history.history[f'{item}']
    )
    ax.plot(
        history.history[f'val_{item}']
    )
    ax.set_title(f'model {item}')

    if item == 'sca':
        ax.set_ylim((0, 1))
    else:
        ylim = ax.get_ylim()
        ax.set_ylim((0, ylim[1]))
    ax.set_ylabel(item)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'val'], loc='best')

plt.suptitle('Without "fixed" gesture0001')
plt.tight_layout()
plt.show()

In [None]:
# Confusion matrices
subtitle = 'Without fixed gesture0001'
conf_mat(model, X_valid, y_valid)
plt.title(f'Validation set\n{subtitle}')
plt.show()

conf_mat(model, X_train, y_train)
plt.title(f'Training set\n{subtitle}')
plt.show()


In [None]:
# Calculate precision and recall for each gesture
confusion_mtx = tf.math.confusion_matrix(
    np.argmax(model.predict(X_valid, verbose=0), axis=1), 
    y_valid
).numpy()
precision = np.diag(confusion_mtx)  / confusion_mtx.sum(axis=0)
recall = np.diag(confusion_mtx)  / confusion_mtx.sum(axis=1)
ipr = list(zip(range(len(precision)), precision, recall))
prec_and_recall = {i2g(i): {'precision': p, 'recall': r} for i, p, r in ipr}

print('\n'.join([f'{i2g(i)}:    precision:{p:.3f}, recall: {r:.3f}' for i, p, r in ipr]))


In [None]:
# Ratio of true predictions over g255
fig, ax = plt.subplots()
ax.bar(
    x=range(confusion_mtx.shape[0]),
    height=np.diag(confusion_mtx) / confusion_mtx[:, -1]
)
ax.set_xticks(range(confusion_mtx.shape[0]))
labels = i2g(list(range(confusion_mtx.shape[0])))
labels = [l.replace('gesture0', 'g') for l in labels]
ax.set_xticklabels(labels)
ax.set_title('Ratio of True Predictions over classified-as-g255')
plt.show()

### Work in progress

In [None]:
START_IDX = 0
FINSH_IDX = 150

df = pd.read_csv('../gesture_data/relabelled.csv')
if 'Unnamed: 0' in df.columns:
    df = df.drop(['Unnamed: 0'], axis=1)
df['datetime'] = pd.to_datetime(df['datetime'])
data = df.drop(['datetime', 'gesture'], axis=1).to_numpy()
normalizer = layers.Normalization(axis=-1)
normalizer.adapt(data)
data = normalizer(data).numpy()
# data = data[START_IDX:FINSH_IDX, :]

In [None]:
# Explore the dataset between a given range
# @interact(rng=widgets.IntRangeSlider(
#     value=[0, 150],
#     min=0,
#     max=(df.shape[0] // 10),
#     step=5,
#     continuous_update=False
# ))
STEP = 25
@interact(rng=widgets.IntSlider(
    value=0,
    min=0,
    max=df.shape[0],
    step=STEP,
    continuous_update=False
))
def interact_plot_timeseries(rng):
#     start, finsh = (rng[0]*10, rng[1]*10)
    start, finsh = ((rng), (rng+2*STEP))
    data = df.drop(['datetime', 'gesture'], axis=1).to_numpy()
    data = normalizer(data).numpy()[start:finsh, :]
    plot_timeseries(
        data, 
        df['gesture'].to_numpy()[start:finsh],
        df['datetime'].dt.strftime('%h-%d %H:%m:%S').to_numpy()[start:finsh],
        per='finger'
    )

Rule for re-labelling measurements
1. Measurements are <250ms before or after an existing label
2. Labels are all contiguous
3. Windows are no longer than 60ms long (3 measurements)
4. Measurements have std > 3 in more than 1 dimenison


Also delete all non-0255 gestures that occur in the first 250ms of a recording

Does the batching process take into account that the df is non-contiguous?

In [None]:
# Relabel the dataset
gesture = 'gesture0006'
window_size = 20
y_orig = df['gesture'].to_numpy()

df = pd.read_csv('../gesture_data/relabelled.csv')
if 'Unnamed: 0' in df.columns:
    df = df.drop(['Unnamed: 0'], axis=1)
df['datetime'] = pd.to_datetime(df['datetime'])

@interact(
    val=widgets.IntSlider(
        value=0,
        min=-window_size,
        max=window_size,
        step=1,
        continuous_update=False
#     rng=widgets.IntRangeSlider(
#         value=[0, 1],
#         min=-window_size,
#         max=window_size,
#         step=1,
#         continuous_update=False
), INDEX=widgets.BoundedIntText(
    value=0,
    min=0,
    max=len(np.nonzero(y_orig == gesture)[0]),
    step=1,
    description='INDEX:'
))
def interact_fix_labels(val=0, INDEX=0):
    rng = (val, val+1)
    y_orig = df['gesture'].to_numpy()
    X_orig = df.drop(['datetime', 'gesture'], axis=1).to_numpy()
    t_orig = df['datetime'].to_numpy()
    # Get a series which is y_orig, but shifted backwards by one
    y_offset = np.concatenate((['gesture0255'], y_orig[:-1]))
    # Get all the indices where the gesture goes [..., !=gesture, ==gesture, ...]
    indices = np.nonzero((y_orig == gesture) & (y_offset != gesture))[0]

    start, finsh = rng
    idx = indices[INDEX]
    print(f'{INDEX / len(indices) * 100:.0f}%')
    window_start = idx - window_size
    window_finsh = idx + window_size+1
    X = X_orig[window_start : window_finsh]
    t = t_orig[window_start : window_finsh]
    y_true = y_orig[window_start : window_finsh]
    y_new = y_true.copy()
    y_true = np.array([yi.replace('gesture0255', 'g255') for yi in y_true])
    
    # Remove the old label
    y_new[window_size - 5: window_size + 5] = 'gesture0255'
    # Get indices for the new label
    s = window_size + start
    f = window_size + finsh
    # Set the new label
    y_new[s:f] = gesture
    
    # Plot the new labels and the old labels
    plot_timeseries(
        X, 
        y_new,
        y_true,
        per='finger'
    )
#     print(df.loc[df['datetime'].isin(t[s:f]), ['datetime', 'gesture']])
    time_mask = df['datetime'].isin(t_orig[window_start : window_finsh])
    def change_df(_):
        df.loc[time_mask, 'gesture'] = y_new
        print(f'Modified {sum(time_mask)} measurements.')
        
    button = widgets.Button(
        description='Save',
        button_style='danger',
    )
    button.on_click(change_df)
    display(button)

In [None]:
# df.to_csv('../gesture_data/relabelled.csv')

In [None]:
# Visualise confidence intervals per finger
gesture = 'gesture0006'
y_orig = df['gesture'].to_numpy()
X_orig = df.drop(['datetime', 'gesture'], axis=1).to_numpy()
t_orig = df['datetime'].to_numpy()
# Get a series which is y_orig, but shifted backwards by one
y_offset = np.concatenate((['gesture0255'], y_orig[:-1]))
# Get all the indices where the gesture goes [..., !=gesture, ==gesture, ...]
indices = np.nonzero((y_orig == gesture) & (y_offset != gesture))[0]


Xs = np.empty((
    len(indices), 
    window_size * 2 + 1,
    X_orig.shape[-1]
))

for i, idx in enumerate(indices):
    window_start = idx - window_size
    window_finsh = idx + window_size + 1
    Xs[i] = X_orig[window_start : window_finsh]

X_mean = Xs.mean(axis=0)
X_std = Xs.std(axis=0)

blank_labels = np.array(['g255'] * X.shape[0])

PER = 'finger'
_fig, axs = plot_timeseries(
    X_mean,
    blank_labels,
    per=PER
)

ymin = float('inf')
ymax = float('-inf')
max_std = X_mean.std(axis=0).max()

for d in range(X_mean.shape[1]):
    if PER == 'dimension':
        ax_idx = d % 3
    elif PER == 'finger':
        ax_idx = d // 3

    ax = axs[ax_idx]

    high = X_mean[:, d] + X_std[:, d]
    low = X_mean[:, d] - X_std[:, d]
    ymin = min(ymin, min(low))
    ymax = max(ymax, max(high))
    
    kwargs = {} if PER == 'dimension' else {'color': ('tab:red', 'tab:green', 'tab:blue')[d%3]}
    ax.fill_between(
        range(len(X_mean[:, d])),
        low,
        high,
        alpha=np.clip(X_mean[:, d].std() / (4*max_std), 0.05, 1.0),
        **kwargs
    )

for ax in axs:
    ax.set_ylim((ymin * 0.9, ymax * 1.1))
    
title = f'Mean measurements for {gesture}\n(per {PER})'
plt.suptitle(title)
plt.tight_layout()
plt.savefig('imgs/' + title.lower().replace(' ', '-') + '.pdf')


# TODO: Fix the dataset
- Try train on a small, but very good, dataset