# Getting Started

This notebook showcases basic functionality of the code base.

Here, we load the metadata, an example dataset, and run inference using a pre-trained model. 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from tempfile import gettempdir

# DATA_DOWNLOAD_DIR = Path.home()

DATA_DOWNLOAD_DIR = gettempdir()

In [None]:
import os

repo_path = Path("/src/platform/generic_neuromotor_interface/")
assert os.path.exists(repo_path)

print("Installing package...")
!pip install -e {repo_path} -qqq

print("Please restart kernel if first install!")

In [None]:
import generic_neuromotor_interface

## Download Dataset Metadata

In [None]:
!cd {DATA_DOWNLOAD_DIR} && wget https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_metadata.csv -O emg2pose_metadata.csv

In [None]:
import pandas as pd

metadata_df = pd.read_csv(Path(DATA_DOWNLOAD_DIR) / "emg2pose_metadata.csv")
metadata_df.head(5)

## Download a Smaller (~600 MiB) Version of the Dataset

In [None]:
!cd {DATA_DOWNLOAD_DIR} && wget "https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_dataset_mini.tar" -O emg2pose_dataset_mini.tar

# Unpack the tar to ~/emg2pose_dataset_mini
!cd {DATA_DOWNLOAD_DIR} && tar -xvf emg2pose_dataset_mini.tar

In [None]:
import glob
import os

sessions = sorted(glob.glob(os.path.join(DATA_DOWNLOAD_DIR, "emg2pose_dataset_mini/*.hdf5")))
sessions

## Let's Look at a Dataset

In [None]:
from generic_neuromotor_interface.data import Emg2PoseSessionData

session = sessions[15]
data = Emg2PoseSessionData(hdf5_path=session)

In [None]:
print(data.fields)
print()

print(f"{'emg shape: ':<20} {data['emg'].shape}")
print(f"{'joint_angles shape: ':<20} {data['joint_angles'].shape}")
print(f"{'time shape: ':<20} {data['time'].shape}")

In [None]:
metadata_df[metadata_df["filename"] == data.metadata["filename"]]

## Let's Load a Checkpoint and Generate some Predictions

In [None]:
!cd {DATA_DOWNLOAD_DIR} \
&& wget "https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_model_checkpoints.tar.gz" -O emg2pose_model_checkpoints.tar.gz && \
tar -xvzf emg2pose_model_checkpoints.tar.gz

In [None]:
from generic_neuromotor_interface.utils import generate_hydra_config_from_overrides

config = generate_hydra_config_from_overrides(
    overrides=[
        "experiment=tracking_vemg2pose",
        f"checkpoint={DATA_DOWNLOAD_DIR}/emg2pose_model_checkpoints/regression_vemg2pose.ckpt"
    ]
)

In [None]:
from generic_neuromotor_interface.lightning import Emg2PoseModule

module = Emg2PoseModule.load_from_checkpoint(
    config.checkpoint,
    network=config.network,
    optimizer=config.optimizer,
    lr_scheduler=config.lr_scheduler,
)

In [None]:
session = data
start_idx = 0
stop_idx = 10_000

In [None]:
import torch

session_window = session[start_idx:stop_idx]

# no_ik_failure is not a field so we slice separately
no_ik_failure_window = session.no_ik_failure[start_idx:stop_idx]

batch = {
    "emg": torch.Tensor([session_window["emg"].T]),  # BCT
    "joint_angles": torch.Tensor([session_window["joint_angles"].T]),  # BCT
    "no_ik_failure": torch.Tensor([no_ik_failure_window]),  # BT
}

preds, joint_angles, no_ik_failure = module.forward(batch)

# Algorithms that use the initial state for ground truth will do poorly
# when the first joint angles are missing!
if (joint_angles[:, 0] == 0).all():
    print(
        "Warning! Ground truth not available at first time step!"
    )

# BCT --> TC (as numpy)
preds = preds[0].T.detach().numpy()
joint_angles = joint_angles[0].T.detach().numpy()

In [None]:
preds.shape

In [None]:
joint_angles.shape

### Compare the Ground Truth and Predictions Side-by-Side

In [None]:
import matplotlib.pyplot as plt

N_COLS = 2
N_ROWS = 10

fig, axs = plt.subplots(N_ROWS, N_COLS, figsize=(4*N_COLS, 2*N_ROWS))

axs_flattened = axs.flatten()
for i, ax in enumerate(axs_flattened):
    ax.set_title(f"Joint Angle {i}")
    ax.plot(joint_angles[:, i], label="gt")
    ax.plot(preds[:, i], label="pred")

    ax.legend()

fig.suptitle("Predicted vs. Ground Truth Joint Angles")

plt.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()