# Deep learning model validation 

This notebook can be used to validate a deep learning model. At the top the data and model can be loaded into memory, and in the following cells function for validation can be found.

In [129]:
import sys, os, fnmatch, csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.dirname(os.getcwd()))

from config import PATH_RAW_DATA, PATH_DATA_PROCESSED_DL, PATH_MODELS

# 1. Preparing data, model, and helper functions

## Configuration variables

In [130]:
N_AVERAGE = 30
MODEL_NAME = 'Fully_connected_regressor_02.hdf5'

## Load data

In [131]:
from sklearn.model_selection import train_test_split

# Step 1: Get all the files in the output folder
file_names = os.listdir(PATH_DATA_PROCESSED_DL)

# Step 2: Get the full paths of the files (without extensions)
files = [os.path.splitext(os.path.join(PATH_DATA_PROCESSED_DL, file_name))[0] for file_name in fnmatch.filter(file_names, "*.zarr")]

# Step 3: Load all the metadata
frames = []

for idx, feature_file in enumerate(files):
    df_metadata = pd.read_csv(feature_file.replace("processed_raw_", "processed_metadata_") + ".csv")
    frames.append(df_metadata)

df_metadata = pd.concat(frames) 

# Step 4: Add missing age information based on the age group the subject is in
df_metadata['age_months'].fillna(df_metadata['age_group'], inplace=True)
df_metadata['age_days'].fillna(df_metadata['age_group']*30, inplace=True)
df_metadata['age_years'].fillna(df_metadata['age_group']/12, inplace=True)

# Step 5: List all the unique subject IDs
subject_ids = list(set(df_metadata["code"].tolist()))

In [132]:
from sklearn.model_selection import train_test_split

IDs_train, IDs_temp = train_test_split(subject_ids, test_size=0.3, random_state=42)
IDs_test, IDs_val = train_test_split(IDs_temp, test_size=0.5, random_state=42)

In [133]:
from dataset_generator import DataGenerator

train_generator = DataGenerator(list_IDs = IDs_train,
                                BASE_PATH = PATH_DATA_PROCESSED_DL,
                                metadata = df_metadata,
                                n_average = N_AVERAGE,
                                batch_size = 10,
                                iter_per_epoch = 30,
                                n_timepoints = 501, 
                                n_channels=30, 
                                shuffle=True)

train_generator_noise = DataGenerator(list_IDs = IDs_train,
                                      BASE_PATH = PATH_DATA_PROCESSED_DL,
                                      metadata = df_metadata,
                                      n_average = N_AVERAGE,
                                      batch_size = 10,
                                      gaussian_noise=0.01,
                                      iter_per_epoch = 30,
                                      n_timepoints = 501, 
                                      n_channels=30, 
                                      shuffle=True)

val_generator = DataGenerator(list_IDs = IDs_val,
                              BASE_PATH = PATH_DATA_PROCESSED_DL,
                              metadata = df_metadata,
                              n_average = N_AVERAGE,
                              batch_size = 10,
                              iter_per_epoch = 100,
                              n_timepoints = 501,
                              n_channels=30,
                              shuffle=True)

test_generator = DataGenerator(list_IDs = IDs_test,
                               BASE_PATH = PATH_DATA_PROCESSED_DL,
                               metadata = df_metadata,
                               n_average = N_AVERAGE,
                               batch_size = 10,
                               iter_per_epoch = 100,
                               n_timepoints = 501,
                               n_channels=30,
                               shuffle=True)

In [134]:
df_metadata.head()

Unnamed: 0,code,cnt_path,cnt_file,age_group,age_days,age_months,age_years
0,23,/Volumes/Seagate Expansion Drive/ePodium/Data/...,023_35_mc_mmn36,35,1052.0,35.066667,2.922222
0,337,/Volumes/Seagate Expansion Drive/ePodium/Data/...,337_23_jc_mmn_36_wk,23,692.0,23.066667,1.922222
0,456,/Volumes/Seagate Expansion Drive/ePodium/Data/...,456_23_md_mmn36_wk,23,691.0,23.033333,1.919444
0,328,/Volumes/Seagate Expansion Drive/ePodium/Data/...,328_23_jc_mmn36_wk,23,699.0,23.3,1.941667
0,314,/Volumes/Seagate Expansion Drive/ePodium/Data/...,314_29_mmn_36_wk,29,877.0,29.233333,2.436111


## Load model

In [135]:
import tensorflow as tf

model_path = os.path.join(PATH_MODELS, MODEL_NAME)
loaded_model = tf.keras.models.load_model(model_path)

## Helper functions for validation

In [136]:
def evaluate_model(model):
    """ Evaluates the model """
    model.evaluate(train_generator)
    model.evaluate(val_generator)
    model.evaluate(test_generator)
    
def print_few_predictions(model):
    """ Prints a few predictions, as a sanity check """
    x_test, y_test = test_generator.__getitem__(0)
    predictions = model.predict(x_test)

    for idx in range(len(y_test)): print(f"{y_test[idx]} -> {predictions[idx]}")

Definition of error stability (Vandenbosch et al., 2018): 

_"Stability was assessed as the correlation between the prediction errors (estimated minus actual age) of subjects at baseline with their own prediction error at follow-up."_


I think this means: Take the prediction error (of a subject) at time 1 and compare it to the prediction error at time 2, take the prediction error at time 2 and compare it to time 3, take the prediction error at time 3 and compare it to time 4. You can then make two lists of errors: TIME_N and TIME_N+1 and look at the correlation between those two. 

A stable error then would mean a positive correlation, because if error is low at time n, you expect it to be low at time n+1 as well. 

In [138]:
import zarr
from scipy.stats import pearsonr

def error_stability(model, IDs_test):
    """Takes in the IDs of the test subjects, calculates the error stability per subject
    and returns this as a dictionary"""
    
    errors_time_N = []
    errors_time_N1 = []
    
    # Step 1: Iterate over subjects
    for ID in IDs_test:
        print(f"Predicting for ID: {ID}..")
        
        # Step 2: Find all files of a subject
        df_temp = df_metadata[df_metadata['code'] == ID]
    
        # Step 3: Find all the age groups the subject was found in
        ages_subject = sorted(list(set(df_temp['age_group'].tolist())))
        
        if len(ages_subject) == 1:
            continue
                    
        # Step 4: Loop over all the ages of a subject
        prev_prediction_error = None
        curr_prediction_error = None
                       
        for age_group in ages_subject:
            prev_prediction_error = curr_prediction_error
            
            X_data = np.zeros((0, 30, 501))
            
            # Step 5: Concatenate data of files in the same age group before averaging all epochs
            for i, metadata_file in df_temp[df_temp['age_group'] == age_group].iterrows():
                filename = os.path.join(PATH_DATA_PROCESSED_DL, 'processed_raw_' + metadata_file['cnt_file'] + '.zarr')
                data_signal = zarr.open(os.path.join(filename), mode='r')
                                
                X_data = np.concatenate((X_data, data_signal), axis=0) 
            
            X_data_mean = np.mean(X_data[:,:,:], axis=0) # Average all epochs of a subject at a specific age
            X = np.expand_dims(X_data_mean, axis=0)            
            X = np.swapaxes(X, 1, 2)
                                    
            actual_age = df_temp[df_temp['age_group'] == age_group]['age_months'].values[0]
            curr_prediction_error = model.predict(X).flatten()[0] - actual_age
#             curr_prediction_error = model.predict(X)
#             print(curr_prediction_error)
            
#             curr_prediction_error = curr_prediction_error.flatten()[0] - actual_age
            
                        
            # Step 6: If there are two values to compare, i.e. time N and time N+1, add them to the lists
            if prev_prediction_error and curr_prediction_error:
                
                errors_time_N.append(prev_prediction_error)
                errors_time_N1.append(curr_prediction_error)
                
    # Step 7: Look at correlation between prediction error and follow-up
    corr, _ = pearsonr(errors_time_N, errors_time_N1)
    print(f"Pearsons correlation: {corr:.3f}")        
    
error_stability(loaded_model, IDs_test)

Predicting for ID: 712..
Predicting for ID: 420..
Predicting for ID: 758..
Predicting for ID: 28..
Predicting for ID: 732..
Predicting for ID: 613..
Predicting for ID: 164..
Predicting for ID: 709..
Predicting for ID: 121..
Predicting for ID: 711..
Predicting for ID: 329..
Predicting for ID: 169..
Predicting for ID: 474..
Predicting for ID: 154..
Predicting for ID: 428..
Predicting for ID: 159..
Predicting for ID: 472..
Predicting for ID: 632..
Predicting for ID: 451..
Predicting for ID: 426..
Predicting for ID: 158..
Predicting for ID: 122..
Predicting for ID: 496..
Predicting for ID: 485..
Predicting for ID: 425..
Predicting for ID: 149..
Predicting for ID: 317..
Predicting for ID: 105..
Predicting for ID: 301..
Predicting for ID: 304..
Predicting for ID: 310..
Predicting for ID: 135..
Predicting for ID: 641..
Predicting for ID: 719..
Predicting for ID: 108..
Predicting for ID: 466..
Predicting for ID: 156..
Predicting for ID: 29..
Predicting for ID: 733..
Predicting for ID: 455..
Pr

In [80]:
import zarr
from scipy.stats import pearsonr

def error_stability(model, IDs_test):
    """Takes in the IDs of the test subjects, calculates the error stability per subject
    and returns this as a dictionary"""
    
    errors_time_N = []
    errors_time_N1 = []
    
    # Step 1: Iterate over subjects
    for ID in IDs_test:
        print(f"Predicting for ID: {ID}..")
        
        # Step 2: Find all files of a subject
        df_temp = df_metadata[df_metadata['code'] == ID]
    
        # Step 3: Find all the age groups the subject was found in
        ages_subject = sorted(list(set(df_temp['age_group'].tolist())))
        
        if len(ages_subject) == 1:
            continue
                    
        # Step 4: Loop over all the ages of a subject
        prev_prediction_error = None
        curr_prediction_error = None
                       
        for age_group in ages_subject:
            prev_prediction_error = curr_prediction_error
            
            X_data = np.zeros((0, 30, 501))
            
            # Step 5: Concatenate data of files in the same age group before averaging all epochs
            for i, metadata_file in df_temp[df_temp['age_group'] == age_group].iterrows():
                filename = os.path.join(PATH_DATA_PROCESSED_DL, 'processed_raw_' + metadata_file['cnt_file'] + '.zarr')
                data_signal = zarr.open(os.path.join(filename), mode='r')
                                
                X_data = np.concatenate((X_data, data_signal), axis=0) 
            
            X_data_mean = np.mean(X_data[:,:,:], axis=0) # Average all epochs of a subject at a specific age
            X = np.expand_dims(X_data_mean, axis=0)            
            X = np.swapaxes(X, 1, 2)
                                    
            actual_age = df_temp[df_temp['age_group'] == age_group]['age_months'].values[0]
            curr_prediction_error = model.predict(X).flatten()[0] - actual_age
                        
            # Step 6: If there are two values to compare, i.e. time N and time N+1, add them to the lists
            if prev_prediction_error and curr_prediction_error:
                
                errors_time_N.append(prev_prediction_error)
                errors_time_N1.append(curr_prediction_error)
                
    # Step 7: Look at correlation between prediction error and follow-up
    corr, _ = pearsonr(errors_time_N, errors_time_N1)
    print(f"Pearsons correlation: {corr:.3f}")        
    
error_stability(loaded_model, IDs_test)

Predicting for ID: 712..
Predicting for ID: 420..
Predicting for ID: 758..
Predicting for ID: 28..
Predicting for ID: 732..
Predicting for ID: 613..
Predicting for ID: 164..
Predicting for ID: 709..
Predicting for ID: 121..
Predicting for ID: 711..
Predicting for ID: 329..
Predicting for ID: 169..
Predicting for ID: 474..
Predicting for ID: 154..
Predicting for ID: 428..
Predicting for ID: 159..
Predicting for ID: 472..
Predicting for ID: 632..
Predicting for ID: 451..
Predicting for ID: 426..
Predicting for ID: 158..
Predicting for ID: 122..
Predicting for ID: 496..
Predicting for ID: 485..
Predicting for ID: 425..
Predicting for ID: 149..
Predicting for ID: 317..
Predicting for ID: 105..
Predicting for ID: 301..
Predicting for ID: 304..
Predicting for ID: 310..
Predicting for ID: 135..
Predicting for ID: 641..
Predicting for ID: 719..
Predicting for ID: 108..
Predicting for ID: 466..
Predicting for ID: 156..
Predicting for ID: 29..
Predicting for ID: 733..
Predicting for ID: 455..
Pr

# 2. Model validation

In [139]:
evaluate_model(loaded_model)



In [140]:
print_few_predictions(loaded_model)

[22.9] -> [27.630173]
[17.03333333] -> [23.78022]
[10.96666667] -> [18.644392]
[11.33333333] -> [10.05482]
[47.4] -> [30.856066]
[40.9] -> [29.112087]
[35.5] -> [15.527671]
[23.06666667] -> [15.542363]
[34.96666667] -> [28.063675]
[35.03333333] -> [24.761223]


In [None]:
error_stability(loaded_model, IDs_test)