# Discrete gesture decoder evaluation

In [None]:
import os
import torch
from pytorch_lightning import Trainer
from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from matplotlib import pyplot as plt
import seaborn as sns

from generic_neuromotor_interface.cler import GestureType, compute_cler
from generic_neuromotor_interface.data import make_dataset

TASK_NAME = "discrete_gestures"

## Establish paths to data and model files

Before running this notebook you must make sure to download the data and model checkpoint as follows:
```
cd ~/generic-neuromotor-interface-data

./download_data.sh discrete_gestures small_subset <EMG_DATA_DIR>  # or full_data instead of small_subset

./download_models.sh discrete_gestures <MODELS_DIR>
```
where `<EMG_DATA_DIR>` and `<MODELS_DIR>` should match the directory specified by the `EMG_DATA_DIR` and `MODELS_DIR` variables defined in the next cell.

In [None]:
EMG_DATA_DIR = "~/emg_data/"  # path to EMG data
MODELS_DIR = "~/emg_models/"  # path to model files

In [None]:
if not os.path.exists(os.path.expanduser(EMG_DATA_DIR)):
    raise FileNotFoundError(f"The EMG data path does not exist: {EMG_DATA_DIR}")

if not os.path.exists(os.path.expanduser(MODELS_DIR)):
    raise FileNotFoundError(f"The models path does not exist: {MODELS_DIR}")

## Load model config

In [None]:
"""Load model config"""

config_path = os.path.join(os.path.expanduser(MODELS_DIR), TASK_NAME, "model_config.yaml")
config = OmegaConf.load(config_path)

## Load model checkpoint

In [None]:
"""Load model checkpoint"""

model_ckpt_path = os.path.join(
    os.path.expanduser(MODELS_DIR),
    TASK_NAME,
    "model_checkpoint.ckpt"
)
model = instantiate(config.lightning_module)
model = model.load_from_checkpoint(
    model_ckpt_path,
    map_location=torch.device("cpu"),
)

## Instantiate data module

In [None]:
"""Assemble the data module"""

# Update DataModule config with data path
config["data_module"]["data_location"] = os.path.expanduser(EMG_DATA_DIR)
if "from_csv" in config["data_module"]["data_split"]["_target_"]:
    config["data_module"]["data_split"]["csv_filename"] = os.path.join(
        os.path.expanduser(EMG_DATA_DIR),
        f"{TASK_NAME}_corpus.csv"
    )

datamodule = instantiate(config["data_module"])

## Run inference on one test dataset

In [None]:
"""Grab one test dataset"""

test_dataset = make_dataset(
    datamodule.data_location,
    partition_dict={"discrete_gestures_user_002_dataset_000": None},  # from discrete_gestures_mini_split.yaml
    transform=datamodule.transform,
    emg_augmentation=None,
    window_length=None,
    stride=None,
    jitter=False,
)

sample = test_dataset[0]

In [None]:
"""Run inference"""

model.eval()

# unpack sample
emg = sample["emg"]
emg_times = sample["timestamps"]
labels = sample["prompts"]

# compute model outputs
with torch.no_grad():
    logits = model(emg.unsqueeze(0))

logits = logits[0]
    
# convert logits to probabilities
probs = torch.nn.Sigmoid()(logits)

# get timestamps associated with each predicted probability
prob_times = emg_times[model.network.left_context::model.network.stride]

In [None]:
"""Evaluate CLER"""

cler = compute_cler(probs, prob_times, labels)

print("CLER on this dataset:", cler)

In [None]:
"""Plot predictions and targets"""

fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True, sharey=False)

t0 = emg_times[0]

# plot EMG
ax = axes[0]
spacing = 200
for channel_index, channel_data in enumerate(emg):
    ax.plot(
        emg_times - t0,
        channel_data + channel_index * spacing,
        linewidth=1,
        color="0.7",
    )

ax.set_ylim([-spacing, len(emg) * spacing])
ax.set_yticks([])

sns.despine(ax=ax, left=True)
    
# labels
ax = axes[1]
for gesture in GestureType:
    prob_index = gesture.value
    ax.plot(
        prob_times - t0,
        probs[prob_index] + prob_index,
        linewidth=1,
        label=gesture.name
    )

ax.set_yticks([])

sns.despine(ax=ax, left=True)

legends, handles = ax.get_legend_handles_labels()
ax.legend(
    legends[::-1],
    handles[::-1],
    loc="upper left",
    ncols=1,
    bbox_to_anchor=(1.0, 1.0),
    frameon=False
)

ax.set_xlim([352, 357])

axes[0].set_ylabel("EMG\n(normalized)")
axes[1].set_ylabel("predicted gesture\nprobability")
axes[1].set_xlabel("time\n(sec)")


tmin, tmax = ax.get_xlim()
_, ymax = axes[0].get_ylim()

labels_in_window = False

for label in labels.to_dict(orient="records"):
    gesture_name = label["name"]
    t = label["time"] - t0
    if (t > tmin) and (t < tmax):
        labels_in_window = True
        lines = axes[0].axvline(t, color="k")
        axes[0].text(
            t - 0.075,
            ymax + 200,
            gesture_name,
            rotation="vertical",
            va="top",
            ha="left"
        )

if labels_in_window:
    axes[0].legend(
        [lines],
        ["ground truth labels"],
        loc="upper left",
        ncols=1,
        bbox_to_anchor=(1.0, 1.0),
        frameon=False,
    )

## Evaluate full test set

Note that this requires you to have downloaded the full dataset (`full_data` instead of `small_subset`) when invoking `./download_data.sh`.

In [None]:
trainer = Trainer(accelerator="cpu")
test_results = trainer.test(model=model, datamodule=datamodule)