# Wrist decoder evaluation

In [None]:
import os
import glob
import torch
from pytorch_lightning import Trainer
from hydra import initialize, compose
from hydra.utils import instantiate
import omegaconf
from matplotlib import pyplot as plt
import seaborn as sns
import h5py

## Establish paths to data and model files

Before running this notebook you must make sure to download the data and model checkpoint as follows:
```
cd ~/generic-neuromotor-interface-data

./download_data.sh wrist small_subset <EMG_DATA_DIR>  # or full_data instead of small_subset

./download_models.sh wrist <MODELS_DIR>
```
where `<EMG_DATA_DIR>` and `<MODELS_DIR>` should match the directory specified by the `EMG_DATA_DIR` and `MODELS_DIR` variables defined in the next cell.

In [None]:
EMG_DATA_DIR = "~/emg_data/"  # path to EMG data
MODELS_DIR = "~/emg_models/"  # path to model files

## Load model checkpoint and config

In [None]:
"""Retrieve model checkpoint"""

model_ckpt_path = os.path.join(os.path.expanduser(MODELS_DIR), "wrist", "model_checkpoint.ckpt")
model_ckpt = torch.load(
    model_ckpt_path,
    map_location=torch.device("cpu")
)

In [None]:
"""Retrieve the config"""

config_path = os.path.relpath(os.path.join(os.path.expanduser(MODELS_DIR), "wrist"))
with initialize(config_path=config_path):
    cfg = compose(config_name="model_config")

## Instantiate model and data module

In [None]:
"""Load the model"""

# Instantiate model
model = instantiate(cfg.lightning_module)

# Load the checkpoint state_dict
model.load_state_dict(model_ckpt["state_dict"])

In [None]:
"""Assemble the data module"""

# Assemble DataModule config
datamodule_cfg = omegaconf.OmegaConf.to_container(cfg.data_module)
datamodule_cfg["data_location"] = EMG_DATA_DIR
if "from_csv" in datamodule_cfg["data_split"]["_target_"]:
    datamodule_cfg["data_split"]["csv_filename"] = os.path.join(EMG_DATA_DIR, "wrist_corpus.csv")

# Instantiate DataModule
datamodule = instantiate(datamodule_cfg)

## Run inference on one test stage

In [None]:
"""Grab one test stage"""

test_dataset = datamodule._make_dataset({"wrist_user_002_dataset_000": [(1713966045.6641605, 1713966207.9896057)]}, "test")  # from wrist_mini_split.yaml
sample = test_dataset[0]

In [None]:
"""Run inference"""

EMG_SAMPLING_RATE = 2000.

model_output_sampling_rate = EMG_SAMPLING_RATE / model.network.stride

model.eval()

# unpack sample
emg = sample["emg"]
wrist_angles = sample["wrist_angles"]

# compute model outputs
with torch.no_grad():
    predictions = model(emg.unsqueeze(0))

predictions = predictions[0]

# convert wrist angle units from radians to degrees
wrist_angles = torch.rad2deg(wrist_angles)
predictions = torch.rad2deg(predictions)

# convert model predictions from displacements (deg) to velocities (deg/sec)
predictions *= model_output_sampling_rate

# estimate wrist angle velocities with the displacement across
# adjacent wrist angle measurements
wrist_angles_sliced = wrist_angles[:, model.network.left_context :: model.network.stride]  # slice the wrist angles at the model output sampling rate
wirst_angle_displacements = torch.diff(wrist_angles_sliced, dim=1)
wrist_angle_velocities = wirst_angle_displacements * model_output_sampling_rate  # convert to deg/sec

# since we don't know the true displacement at the first timestep, remove the first prediction
predictions = predictions[:, 1:]

In [None]:
"""Evaluate Mean Absolute Error (MAE)"""

mae = torch.mean(torch.abs(predictions - wrist_angle_velocities))

print(f"MAE on this stage: {mae:.3f} (deg/sec)")

In [None]:
"""Plot predictions and targets"""

fig, axes = plt.subplots(3, 1, figsize=(12, 7.5), sharex=True, sharey=False)

# plot EMG
ax = axes[0]
spacing = 120
for channel_index, channel_data in enumerate(emg):
    ax.plot(
        torch.arange(len(channel_data)) / EMG_SAMPLING_RATE,
        channel_data + channel_index * spacing,
        linewidth=1,
        color="0.7",
    )
ax.set_ylim([-spacing, len(emg) * spacing])
ax.set_yticks([])

sns.despine(ax=ax, left=True)

# wrist angles
ax = axes[1]
assert wrist_angles.shape[0] == 1
ax.plot(
    torch.arange(wrist_angles.shape[1]) / EMG_SAMPLING_RATE,
    wrist_angles[0],
    linewidth=1,
    color="k",
)
sns.despine(ax=ax)

# wrist angle velocities
ax = axes[2]
assert wrist_angle_velocities.shape[0] == 1
ax.plot(
    torch.arange(wrist_angle_velocities.shape[1]) / model_output_sampling_rate,
    wrist_angle_velocities[0],
    linewidth=1,
    color="k",
    label="ground truth",
)
sns.despine(ax=ax)

# model predictions
assert predictions.shape[0] == 1
ax.plot(
    torch.arange(predictions.shape[1]) / model_output_sampling_rate,
    predictions[0],
    linewidth=2,
    color="r",
    label="model predictions",
)

ax.legend(
    loc="upper left",
    ncols=1,
    bbox_to_anchor=(1.0, 1.0),
    frameon=False
)
ax.set_ylim([-300, 300])

ax.set_xlim([145, 165])

axes[0].set_ylabel("EMG\n(normalized)")
axes[1].set_ylabel("wrist angle\n(deg)")
axes[2].set_ylabel("wrist angle velocity\n(deg / sec)")
axes[2].set_xlabel("time\n(sec)");

## Evaluate full test set

Note that this requires you to have downloaded the full dataset (`full_data` instead of `small_subset`) when invoking `./download_data.sh`.

In [None]:
trainer = Trainer()
test_results = trainer.test(model=model, datamodule=datamodule)