# Discrete gesture decoder evaluation

In [None]:
import os
import torch
from pytorch_lightning import Trainer
from hydra import initialize, compose
from hydra.utils import instantiate
import omegaconf
from matplotlib import pyplot as plt
import seaborn as sns

from generic_neuromotor_interface.cler import GestureType
from generic_neuromotor_interface.cler import compute_cler

## Establish paths to data and model files

Before running this notebook you must make sure to download the data and model checkpoint as follows:
```
cd ~/generic-neuromotor-interface-data
./download_data.sh discrete_gestures full_data <EMG_DATA_DIR>
./download_models.sh discrete_gestures <MODELS_DIR>
```
where `<EMG_DATA_DIR>` and `<MODELS_DIR>` should match the directory specified by the `EMG_DATA_DIR` and `MODELS_DIR` variables defined in the next cell.

In [None]:
EMG_DATA_DIR = "~/emg_data/"  # path to EMG data
MODELS_DIR = "~/emg_models/"  # path to model files

## Load model checkpoint and config

In [None]:
"""Retrive model checkpoint"""

model_ckpt = torch.load(
    os.path.join(MODELS_DIR, "discrete_gestures", "model_checkpoint.ckpt"),
    map_location=torch.device("cpu")
)

In [None]:
"""Retrieve the config"""

with initialize(config_path=os.path.join(MODELS_DIR, "discrete_gestures")):
    cfg = compose(config_name="model_config")

## Instantiate model and data module

In [None]:
"""Load the model"""

# Instantiate model
model = instantiate(cfg.lightning_module)

# Load the checkpoint state_dict
model.load_state_dict(model_ckpt["state_dict"])

In [None]:
"""Assemble the data module"""

# Assemble DataModule config
datamodule_cfg = omegaconf.OmegaConf.to_container(cfg.data_module)
datamodule_cfg["data_location"] = EMG_DATA_DIR
datamodule_cfg["data_split"]["csv_filename"] = os.path.join(EMG_DATA_DIR, "discrete_gestures_corpus.csv")

# Instantiate DataModule
datamodule = instantiate(datamodule_cfg)

## Run inference on one test dataset

In [None]:
"""Grab one test dataset"""

test_dataset = datamodule._make_dataset(datamodule.data_split.test, "test")
sample = test_dataset[0]

In [None]:
"""Run inference"""

model.eval()

# unpack sample
emg = sample["emg"]
emg_times = sample["timestamps"]
labels = sample["prompts"]

# compute model outputs
with torch.no_grad():
    logits = model(emg.unsqueeze(0))

logits = logits[0]
    
# convert logits to probabilities
probs = torch.nn.Sigmoid()(logits)

# get timestamps associated with each predicted probability
prob_times = emg_times[model.network.left_context::model.network.stride]

In [None]:
"""Evaluate CLER"""

cler = compute_cler(probs, prob_times, labels)

print("CLER on this dataset:", cler)

In [None]:
"""Plot predictions and targets"""

fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True, sharey=False)

t0 = emg_times[0]

# plot EMG
ax = axes[0]
spacing = 200
for channel_index, channel_data in enumerate(emg):
    ax.plot(
        emg_times - t0,
        channel_data + channel_index * spacing,
        linewidth=1,
        color="0.7",
    )

ax.set_ylim([-spacing, len(emg) * spacing])
ax.set_yticks([])

sns.despine(ax=ax, left=True)
    
# labels
ax = axes[1]
for gesture in GestureType:
    prob_index = gesture.value
    ax.plot(
        prob_times - t0,
        probs[prob_index] + prob_index,
        linewidth=1,
        label=gesture.name
    )

ax.set_yticks([])

sns.despine(ax=ax, left=True)

ax.legend(
    loc="upper left",
    ncols=1,
    bbox_to_anchor=(1.0, 1.0),
    frameon=False
)

ax.set_xlim([10, 17])

axes[0].set_ylabel("EMG\n(normalized)")
axes[1].set_ylabel("predicted gesture\nprobability")
axes[1].set_xlabel("time\n(sec)")


tmin, tmax = ax.get_xlim()
_, ymax = axes[0].get_ylim()

for label in labels.to_dict(orient="records"):
    gesture_name = label["name"]
    t = label["time"] - t0
    if (t > tmin) and (t < tmax):
        lines = axes[0].axvline(t, color="k")
        axes[0].text(
            t - 0.1,
            ymax - 3,
            gesture_name,
            rotation="vertical",
            va="top",
            ha="left"
        )

axes[0].legend(
    [lines],
    ["ground truth labels"],
    loc="upper left",
    ncols=1,
    bbox_to_anchor=(1.0, 1.0),
    frameon=False,
)

## Evaluate full test set

In [None]:
trainer = Trainer()
test_results = trainer.test(model=model, datamodule=datamodule)