# Handwriting decoder evaluation

In [None]:
import os
import torch
from pytorch_lightning import Trainer
from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf

from generic_neuromotor_interface.handwriting_utils import CharacterErrorRates

TASK_NAME = "handwriting"

## Establish paths to data and model files

Before running this notebook you must make sure to download the data and model checkpoint as follows:
```
cd ~/generic-neuromotor-interface-data

./download_data.sh handriting small_subset <EMG_DATA_DIR>  # or full_data instead of small_subset

./download_models.sh handriting <MODELS_DIR>
```
where `<EMG_DATA_DIR>` and `<MODELS_DIR>` should match the directory specified by the `EMG_DATA_DIR` and `MODELS_DIR` variables defined in the next cell.

In [None]:
EMG_DATA_DIR = "~/emg_data/"  # path to EMG data
MODELS_DIR = "~/emg_models/"  # path to model files

In [None]:
if not os.path.exists(os.path.expanduser(EMG_DATA_DIR)):
    raise FileNotFoundError(f"The EMG data path does not exist: {EMG_DATA_DIR}")

if not os.path.exists(os.path.expanduser(MODELS_DIR)):
    raise FileNotFoundError(f"The models path does not exist: {MODELS_DIR}")

## Load model config

In [None]:
"""Load model config"""

config_path = os.path.join(os.path.expanduser(MODELS_DIR), TASK_NAME, "model_config.yaml")
config = OmegaConf.load(config_path)

## Load model checkpoint

In [None]:
"""Load model checkpoint"""

model_ckpt_path = os.path.join(
    os.path.expanduser(MODELS_DIR),
    TASK_NAME,
    "model_checkpoint.ckpt"
)
model = instantiate(config.lightning_module)
model = model.load_from_checkpoint(
    model_ckpt_path,
    map_location=torch.device("cpu"),
)

## Instantiate data module

In [None]:
"""Assemble the data module"""

# Update DataModule config with data path
config["data_module"]["data_location"] = os.path.expanduser(EMG_DATA_DIR)
if "from_csv" in config["data_module"]["data_split"]["_target_"]:
    config["data_module"]["data_split"]["csv_filename"] = os.path.join(
        os.path.expanduser(EMG_DATA_DIR),
        f"{TASK_NAME}_corpus.csv"
    )

datamodule = instantiate(config["data_module"])

## Run inference on one prompt

In [None]:
"""Grab one test prompt"""

test_dataset = datamodule._make_dataset({"handwriting_user_001_dataset_000": None}, "test")  # from handwriting_mini_split.yaml
sample = test_dataset[55]  # an arbitrary prompt from this dataset

In [None]:
"""Run inference"""

model.eval()

# unpack sample
emg = sample["emg"]
labels = sample["prompts"]

# compute model outputs
with torch.no_grad():
    emissions, _slice = model(emg.T.unsqueeze(0))

    # compute greedy decode outputs
    predictions = model.decoder.decode_batch(
        emissions=emissions.movedim(0, 1).numpy(),
        emission_lengths=model.network.compute_time_downsampling(
            emg_lengths=torch.as_tensor([len(emg)]), slc=_slice
        )
    )

predictions = torch.as_tensor(predictions[0])

# convert predictions and labels to characters
predictions = model.decoder._charset.labels_to_str(predictions)
labels = model.decoder._charset.labels_to_str(labels)

In [None]:
"""Evaluate CER on this prompt"""

metric = CharacterErrorRates()
metric.update(
    prediction=predictions,
    target=labels,
)
aggregate_metrics = metric.compute()

print("CER of above prompt decode:", aggregate_metrics["CER"])

In [None]:
"""Print predictions and target"""

print(
    f"Prediction: \t {predictions} \n"
    f"Target: \t {labels}"
)

## Evaluate full test set

Note that this requires you to have downloaded the full dataset (`full_data` instead of `small_subset`) when invoking `./download_data.sh`.

In [None]:
trainer = Trainer(accelerator="cpu")
test_results = trainer.test(model=model, datamodule=datamodule)