# Data Analysis

In this notebook, we load the framewise displacement data and perform some exploratory data analysis by generating some visualizations of the data.

In [2]:
import numpy as np
import random
import torch
from sklearn.metrics import confusion_matrix, classification_report
import sys
sys.path.insert(1, '../src')

import data_processing as dp
import train_eval as te
import models_nn as mnn


In [10]:
from importlib import reload
reload(dp)
reload(te)
reload(mnn)
reload(mstats)


<module 'models_stats' from '/Users/kevin/Repos/head-motion-analysis/notebooks/../src/models_stats.py'>

In [7]:
# Load the data from the .npz file
data_file = '../framewise_displacement_data.npz'
data = np.load(data_file, allow_pickle=True)

# Convert the loaded data to a dictionary
data_dict = {key: data[key].item() for key in data}

pd_keys = dp.filter_valid_subjects(data_dict, 'PD')
control_keys = dp.filter_valid_subjects(data_dict, 'Control')


PD: Removed 283 out of 603
Control: Removed 38 out of 82


In [8]:
# ----------------------- Hyper-parameters ----------------------- #
MAX_LEN = 100  # truncate / zero-pad sequences to this length
BATCH_SIZE = 32
EPOCHS = 100
LR = 2e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RNG_SEED = 42

random.seed(RNG_SEED)
np.random.seed(RNG_SEED)
torch.manual_seed(RNG_SEED)

keys = pd_keys + control_keys
random.shuffle(keys)

dataset = dp.MotionDataset(data_dict, keys)
train_loader, val_loader = te.train_val_split(
        dataset, val_size=0.2, random_state=RNG_SEED, batch_size=BATCH_SIZE
)


In [11]:
models = {
    "GRU": mnn.RNNClassifier(cell="gru"),
    "LSTM": mnn.RNNClassifier(cell="lstm"),
    "Transformer": mnn.TransformerClassifier()
}

for name, model in models.items():
    print(f"\nTraining {name}…")
    te.train(
        model, train_loader, val_loader,
        device=DEVICE, epochs=EPOCHS, lr=LR,
    )
    acc, bal_acc, auc, _, _ = te.evaluate(model, val_loader, device=DEVICE)
    print(
        f"{name} final  ACC={acc:.3f}  BAL_ACC={bal_acc:.3f}  AUC={auc:.3f}"
    )

    preds, targets = te.get_predictions(model, val_loader, device=DEVICE)
    cm = confusion_matrix(targets, preds)
    print("Confusion matrix:\n", cm)
    print(classification_report(targets, preds, digits=3))



Training GRU…
Epoch 05/100  val_acc=0.486  val_bal_acc=0.602  val_auc=0.699
Epoch 10/100  val_acc=0.847  val_bal_acc=0.586  val_auc=0.635
Epoch 15/100  val_acc=0.792  val_bal_acc=0.609  val_auc=0.525
Epoch 20/100  val_acc=0.681  val_bal_acc=0.602  val_auc=0.557
Epoch 25/100  val_acc=0.694  val_bal_acc=0.500  val_auc=0.412
Epoch 30/100  val_acc=0.764  val_bal_acc=0.484  val_auc=0.582
Epoch 35/100  val_acc=0.722  val_bal_acc=0.570  val_auc=0.539
Epoch 40/100  val_acc=0.833  val_bal_acc=0.688  val_auc=0.598
Epoch 45/100  val_acc=0.792  val_bal_acc=0.500  val_auc=0.590
Epoch 50/100  val_acc=0.736  val_bal_acc=0.469  val_auc=0.580
Epoch 55/100  val_acc=0.750  val_bal_acc=0.477  val_auc=0.582
Epoch 60/100  val_acc=0.778  val_bal_acc=0.547  val_auc=0.598
Epoch 65/100  val_acc=0.750  val_bal_acc=0.586  val_auc=0.586
Epoch 70/100  val_acc=0.764  val_bal_acc=0.484  val_auc=0.646
Epoch 75/100  val_acc=0.764  val_bal_acc=0.430  val_auc=0.600
Epoch 80/100  val_acc=0.778  val_bal_acc=0.438  val_auc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
