# Discrete gesture decoder evaluation

In [None]:
import os
import torch
from pytorch_lightning import Trainer
from hydra import initialize, compose
from hydra.utils import instantiate
import omegaconf
from matplotlib import pyplot as plt
import seaborn as sns

from generic_neuromotor_interface.cler import GestureType
from generic_neuromotor_interface.cler import compute_cler

from generic_neuromotor_interface.scripts.download_data import download_data
from generic_neuromotor_interface.scripts.download_models import download_models

## Establish paths to data and model files

Before running this notebook you must make sure to download the data and model checkpoint as follows:

In [None]:
EMG_DATA_DIR = "~/emg_data/"  # path to EMG data
MODELS_DIR = "~/emg_models/"  # path to model files

TASK = "discrete_gestures"
DATASET_TYPE = "small_subset"  # 'small_subset' or 'full_data'

In [None]:
## uncomment to download if you haven't already

# download_data(TASK, DATASET_TYPE, EMG_DATA_DIR)
# download_models(TASK, MODELS_DIR)

## Load model checkpoint and config

In [None]:
"""Retrieve model checkpoint"""

model_ckpt_path = os.path.join(os.path.expanduser(MODELS_DIR), "discrete_gestures", "model_checkpoint.ckpt")

if not os.path.exists(model_ckpt_path):
    raise FileNotFoundError(f"The model checkpoint path does not exist: {model_ckpt_path}")

model_ckpt = torch.load(
    model_ckpt_path,
    map_location=torch.device("cpu"),
    weights_only=False
)

In [None]:
"""Retrieve the config"""

config_path = os.path.join(os.path.expanduser(MODELS_DIR), "discrete_gestures")

if not os.path.exists(config_path):
    raise FileNotFoundError(f"The config path does not exist: {config_path}")

with initialize(config_path=os.path.relpath(config_path), version_base="1.1"):
    cfg = compose(config_name="model_config")

## Instantiate model and data module

In [None]:
"""Load the model"""

# Instantiate model
model = instantiate(cfg.lightning_module)

# Load the checkpoint state_dict
model.load_state_dict(model_ckpt["state_dict"])

In [None]:
"""Assemble the data module"""

data_path = os.path.expanduser(EMG_DATA_DIR)

if not os.path.exists(data_path):
    raise FileNotFoundError(f"The EMG data path does not exist: {data_path}")

# Assemble DataModule config
datamodule_cfg = omegaconf.OmegaConf.to_container(cfg.data_module)
datamodule_cfg["data_location"] = EMG_DATA_DIR
if "from_csv" in datamodule_cfg["data_split"]["_target_"]:
    datamodule_cfg["data_split"]["csv_filename"] = os.path.join(data_path, "discrete_gestures_corpus.csv")

# Instantiate DataModule
datamodule = instantiate(datamodule_cfg)

## Run inference on one test dataset

In [None]:
"""Grab one test dataset"""

test_dataset = datamodule._make_dataset({"discrete_gestures_user_002_dataset_000": None}, "test")  # from discrete_gestures_mini_split.yaml
sample = test_dataset[0]

In [None]:
"""Run inference"""

model.eval()

# unpack sample
emg = sample["emg"]
emg_times = sample["timestamps"]
labels = sample["prompts"]

# compute model outputs
with torch.no_grad():
    logits = model(emg.unsqueeze(0))

logits = logits[0]

# convert logits to probabilities
probs = torch.nn.Sigmoid()(logits)

# get timestamps associated with each predicted probability
prob_times = emg_times[model.network.left_context::model.network.stride]

In [None]:
"""Evaluate CLER"""

cler = compute_cler(probs, prob_times, labels)

print("CLER on this dataset:", cler)

In [None]:
"""Plot predictions and targets"""

fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True, sharey=False)

t0 = emg_times[0]

# plot EMG
ax = axes[0]
spacing = 200
for channel_index, channel_data in enumerate(emg):
    ax.plot(
        emg_times - t0,
        channel_data + channel_index * spacing,
        linewidth=1,
        color="0.7",
    )

ax.set_ylim([-spacing, len(emg) * spacing])
ax.set_yticks([])

sns.despine(ax=ax, left=True)

# labels
ax = axes[1]
for gesture in GestureType:
    prob_index = gesture.value
    ax.plot(
        prob_times - t0,
        probs[prob_index] + prob_index,
        linewidth=1,
        label=gesture.name
    )

ax.set_yticks([])

sns.despine(ax=ax, left=True)

legends, handles = ax.get_legend_handles_labels()
ax.legend(
    legends[::-1],
    handles[::-1],
    loc="upper left",
    ncols=1,
    bbox_to_anchor=(1.0, 1.0),
    frameon=False
)

ax.set_xlim([352, 357])

axes[0].set_ylabel("EMG\n(normalized)")
axes[1].set_ylabel("predicted gesture\nprobability")
axes[1].set_xlabel("time\n(sec)")


tmin, tmax = ax.get_xlim()
_, ymax = axes[0].get_ylim()

labels_in_window = False

for label in labels.to_dict(orient="records"):
    gesture_name = label["name"]
    t = label["time"] - t0
    if (t > tmin) and (t < tmax):
        labels_in_window = True
        lines = axes[0].axvline(t, color="k")
        axes[0].text(
            t - 0.075,
            ymax + 200,
            gesture_name,
            rotation="vertical",
            va="top",
            ha="left"
        )

if labels_in_window:
    axes[0].legend(
        [lines],
        ["ground truth labels"],
        loc="upper left",
        ncols=1,
        bbox_to_anchor=(1.0, 1.0),
        frameon=False,
    )

## Evaluate full test set

Note that this requires you to have downloaded the full dataset (`full_data` instead of `small_subset`).

In [None]:
## uncomment to download if you haven't already

# download_data(TASK, "full_data", EMG_DATA_DIR)

In [None]:
trainer = Trainer()
test_results = trainer.test(model=model, datamodule=datamodule)