# Data Analysis

In this notebook, we load the framewise displacement data and perform some exploratory data analysis by generating some visualizations of the data.

In [None]:
# add modules to path
import sys
sys.path.insert(1, '../src')

# library imports
import torch
import random
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report

# project imports
import data_processing as dp
import train_eval as te
import models_nn as mnn

# autoreload all modules
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
# Load the data from the .npz file
data_dict, pd_keys, control_keys = dp.load_data()

Loaded 320/364 PD subjects and 44/88 Control subjects


In [29]:
# ----------------------- Hyper-parameters ----------------------- #
MAX_LEN = 200  # truncate / zero-pad sequences to this length
BATCH_SIZE = 32
EPOCHS = 100
LR = 2e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RNG_SEED = 42

random.seed(RNG_SEED)
np.random.seed(RNG_SEED)
torch.manual_seed(RNG_SEED)

keys = pd_keys + control_keys
random.shuffle(keys)

dataset = dp.MotionDataset(data_dict, keys, MAX_LEN)
train_loader, val_loader = te.train_val_split(
        dataset, val_size=0.2, random_state=RNG_SEED, batch_size=BATCH_SIZE
)


In [30]:
models = {
    "GRU": mnn.RNNClassifier(cell="gru"),
    "LSTM": mnn.RNNClassifier(cell="lstm"),
    "Transformer": mnn.TransformerClassifier()
}

for name, model in models.items():
    print(f"\nTraining {name}…")
    te.train(
        model, train_loader, val_loader,
        device=DEVICE, epochs=EPOCHS, lr=LR,
    )
    acc, bal_acc, auc, _, _ = te.evaluate(model, val_loader, device=DEVICE)
    print(
        f"{name} final  ACC={acc:.3f}  BAL_ACC={bal_acc:.3f}  AUC={auc:.3f}"
    )

    preds, targets = te.get_predictions(model, val_loader, device=DEVICE)
    cm = confusion_matrix(targets, preds)
    print("Confusion matrix:\n", cm)
    print(classification_report(targets, preds, digits=3))



Training GRU…
Epoch 05/100  val_acc=0.556  val_bal_acc=0.641  val_auc=0.721
Epoch 10/100  val_acc=0.778  val_bal_acc=0.602  val_auc=0.686
Epoch 15/100  val_acc=0.708  val_bal_acc=0.508  val_auc=0.555
Epoch 20/100  val_acc=0.583  val_bal_acc=0.547  val_auc=0.527
Epoch 25/100  val_acc=0.778  val_bal_acc=0.602  val_auc=0.537
Epoch 30/100  val_acc=0.681  val_bal_acc=0.547  val_auc=0.564
Epoch 35/100  val_acc=0.722  val_bal_acc=0.461  val_auc=0.467
Epoch 40/100  val_acc=0.681  val_bal_acc=0.438  val_auc=0.436
Epoch 45/100  val_acc=0.736  val_bal_acc=0.469  val_auc=0.467
Epoch 50/100  val_acc=0.694  val_bal_acc=0.555  val_auc=0.574
Epoch 55/100  val_acc=0.722  val_bal_acc=0.570  val_auc=0.525
Epoch 60/100  val_acc=0.750  val_bal_acc=0.586  val_auc=0.504
Epoch 65/100  val_acc=0.708  val_bal_acc=0.508  val_auc=0.510
Epoch 70/100  val_acc=0.778  val_bal_acc=0.547  val_auc=0.559
Epoch 75/100  val_acc=0.778  val_bal_acc=0.547  val_auc=0.547
Epoch 80/100  val_acc=0.764  val_bal_acc=0.484  val_auc