# DL4H Final Project (Draft)

## Graph Representation Learning for Familial Relationships

[Link to original paper](https://arxiv.org/pdf/2304.05010.pdf)

[Link to Github](https://github.com/dsgelab/family-EHR-graphs)

[Link to replicated Github](https://github.com/cosmocat27/ehr_graph_replica/tree/main)

(This notebook should be run in a python colab environment so that the setup works as expected)

---

# Introduction

## Background of the problem

The general problem is that of using family medical history to predict disease. Solving this problem would mean more accurate prediction and understanding of certain heritable diseases, leading to more useful medical interventions. Family history is a well-established indicator of health risks, but its assessment is often
complicated by the interactions of genetic, environmental, and lifestyle factors. The wide
availability of EHR presents opportunities for deep learning to learn complex representations of
patient data that would be useful in clinical prediction.

Various methods have been used to model family history and heriditary disease, including polygenic risk scoring methods and BLUP, but these approaches are limited by data availability such as genetic data. Alternatively, there are clinical baselines that use rule-based and MLP approaches that do not consider graphical information, but do not perform as well as graph based approaches.

## Paper explanation

This paper formulates disease risk prediction from family history as a graph modeling problem, and uses graph-based deep learning and LSTMs to learn supervised representations of the family history. It’s shown that the approach can predict 10-year disease risk better than the baseline approaches, based on AUC-ROC/PRC. Furthermore, graph explainability techniques can be used to identify specific features of the family history that are useful for disease prediction.

The paper did a good job of contributing a novel graph-based method to the research regime and improving on existing performance, and it advances the state of the art on using deep learning to model heritable disease risk.

# Scope of Reproducibility:

## Hypotheses and corresponding experiments


1.   Hypothesis 1: The graph neural network based model provides better performance in predicting 10 year disease onset of certain diseases (adult asthma, colorectal cancer, coronary heart disease, depression and suicide, and type two diabetes), compared to the baseline (rule-based or static MLP). To test this, we will run the full GNN model and compare the AUC-ROC/PRC results with the baseline models.
2.   Hypothesis 2: Features for family history provide incremental predictive value when encoded into a graph neural network and used to classify a patient’s development of disease within 10 years. To test this, we will run the ablation studies proposed in the paper, which incrementally add family history and graph connectivity features to a baseline model to understand their incremental value.

# Setup

Sets up the packages, repos, etc that will be used for this colab.

In [None]:
# importing files from public directory instead
#from google.colab import drive
#drive.mount('/content/drive')

!git clone https://github.com/dsgelab/family-EHR-graphs.git
!git clone https://github.com/cosmocat27/ehr_graph_replica.git
!pip install torch_geometric
!mkdir results

import sys
sys.path.append('/content/family-EHR-graphs/src')

### The following section is for data generation only. The result of the following code has been uploaded via Drive.

#import os
#os.environ['LD_LIBRARY_PATH'] = os.environ.get("LD_LIBRARY_PATH") + ":/content/gsl-2.7/.libs"

#!wget "https://ftp.gnu.org/gnu/gsl/gsl-2.7.tar.gz" && tar -xvzf gsl-2.7.tar.gz
#!cd gsl-2.7 && ./configure && make && make install
#!gcc "/content/drive/My Drive/project/SimPedPheno_V1.1.c" -o PhenoPedSim -Lgsl_lib_directory -Igsl_include_directory -lm -lgsl -fPIC -lcblas -lblas
#!./PhenoPedSim "/content/drive/My Drive/project/syn_data_params_1.txt" 1000

In [None]:
# import packages you need
import argparse
import torch
import numpy as np
import time
import pandas as pd
import matplotlib.pyplot as plt
from data import DataFetch, Data, GraphData
from torch_geometric.loader import DataLoader
from model import Baseline, BaselineLongitudinal, GNN, GNNLongitudinal, GNNExplainabilityLSTM
from sklearn import metrics
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from utils import EarlyStopping, get_classification_threshold_auc, get_classification_threshold_precision_recall, WeightedBCELoss
import explainability
import json

from main import *

# set the seed
seed = 1000
np.random.seed(seed)
torch.manual_seed(seed)

# Methodology

In [None]:
# setup the params to be used for the first experiment

sqlpath = 'long.db'
params = {'model_type':'graph',
        'gnn_layer':'graphconv',
        'pooling_method':'target',
        'outpath':'results',
        'outname':'G2_TestDisease',
        'obs_window_start':1990,
        'obs_window_end':2010,
        'batchsize':250,
        'num_workers':6,
        'max_epochs':100,
        'patience':8,
        'learning_rate':0.001,
        'main_hidden_dim':20,
        'lstm_hidden_dim':20,
        'loss':'bce_weighted_sum',
        'gamma':1,
        'alpha':1,
        'beta':1,
        'delta':1,
        'dropout_rate':0.5,
        'threshold_opt':'precision_recall',
        'ratio':0.5,
        'local_test':True,
        'explainability_mode':False,
        'embeddings_mode':False,
        'explainer_input':'',
        'device_specification':'na',
        'num_positive_samples':5000}

if params['device_specification'] != 'na':
    params['device'] = torch.device(params['device_specification'])
else:
    params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Using {} device".format(params['device']))

##  Data
### Source of the data

The experiments in the paper use a nationwide health registry dataset, which cannot be publicly
shared for data privacy reasons. Instead, the authors have provided code and instructions for
generating synthetic datasets that mimic the real dataset. I am using the synthetic dataset provided with the code at https://github.com/dsgelab/family-EHR-graphs/tree/main/test.

The data consist of 4 different file types:

**Maskfile**: Specifies which samples belong to the target cohort (patients to predict health outcomes for) and which samples belong to the graph cohort (relatives of the target patients). This file also specifies the train, validation and test split for the dataset.

**Statfile**: Contains the (static node) feature dataset for all samples in both the target and graph cohorts. This file also contains the data for the label being predicted for the binary classification task.

**Edgefile**: Contains the edge pairs for the family graphs, where each patient in the target cohort has a separate family graph. This file also contains the data for the edge features.

**Featfile**: Specifies which features to use for training the model, for 4 types of features: static, longitudinal, label and edge.

In [None]:
# dir and function to load raw data
raw_data_dir = 'family-EHR-graphs/test/'
maskfile_path = raw_data_dir + 'Gen3_50k_0.7_142857_maskfile.csv'
statfile_path = raw_data_dir + 'Gen3_50k_0.7_142857_statfile.csv'
edgefile_path = raw_data_dir + 'Gen3_50k_0.7_142857_edgefile.csv'

def load_raw_data(maskfile_path, statfile_path, edgefile_path):
  masks = pd.read_csv(maskfile_path)
  stats = pd.read_csv(statfile_path)
  edges = pd.read_csv(edgefile_path)
  return masks, stats, edges

masks, stats, edges = load_raw_data(maskfile_path, statfile_path, edgefile_path)

filepaths = {'maskfile':maskfile_path,
            'featfile':raw_data_dir + 'featfiles/featfile_G2.csv',
            'alt_featfile':raw_data_dir + 'featfiles/featfile_A5.csv',
            'statfile':statfile_path,
            'edgefile':edgefile_path}

### Statistics

**Size:**

* Maskfile: 150k rows, 5 columns
* Statfile: 150k rows, 33 columns
* Edgefile: 1.1M rows, 14 columns
* Featfile: Depends on the model, but only a few rows specifying the features.

**Label distribution:**

150k total patients

Target population (patients): 4357 positive, 34940 negative

Non-target population (relatives): 57763 negative, 52940 negative

**Cross validation split:**

Out of 39k target patients

28k train, 4k validation, 8k test

In [None]:
# calculate statistics
def calculate_stats(masks, stats, edges):
  n_rows, n_columns = masks.shape
  print("Maskfile size: {} rows, {} columns".format(n_rows, n_columns))
  n_rows, n_columns = stats.shape
  print("Statfile size: {} rows, {} columns".format(n_rows, n_columns))
  n_rows, n_columns = edges.shape
  print("Edgefile size: {} rows, {} columns".format(n_rows, n_columns))

  target_pos, target_neg = stats[masks['target']==1].EndPtStat.value_counts().sort_index().values
  print("Target population: {} positive, {} negative".format(target_pos, target_neg))
  nontarget_pos, nontarget_neg = stats[masks['target']==0].EndPtStat.value_counts().sort_index().values
  print("Non-target population: {} positive, {} negative".format(nontarget_pos, nontarget_neg))
  nontarget, train, valid, test = masks.train.value_counts().sort_index().values
  print("Cross validation split: {} train, {} validation, {} test".format(train, valid, test))
  return None

calculate_stats(masks, stats, edges)

### Data processing

Most of the data is processed already via data generation. We use the provided function DataFetch to split the data into train / validation / test sets, then use get_data_and_loader to prepare the datasets for modeling.

In [None]:
fetch_data = DataFetch(filepaths['maskfile'], filepaths['featfile'], filepaths['statfile'], filepaths['edgefile'], sqlpath, params, alt_featfile=filepaths['alt_featfile'], local=params['local_test'])

train_patient_list = fetch_data.train_patient_list
params['num_batches_train'] = int(np.ceil(len(train_patient_list)/params['batchsize']))
params['num_samples_train_dataset'] = len(fetch_data.train_patient_list)
params['num_samples_train_minority_class'] = fetch_data.num_samples_train_minority_class
params['num_samples_train_majority_class'] = fetch_data.num_samples_train_majority_class
validate_patient_list = fetch_data.validate_patient_list
params['num_batches_validate'] = int(np.ceil(len(validate_patient_list)/params['batchsize']))
params['num_samples_valid_dataset'] = len(fetch_data.validate_patient_list)
params['num_samples_valid_minority_class'] = fetch_data.num_samples_valid_minority_class
params['num_samples_valid_majority_class'] = fetch_data.num_samples_valid_majority_class
test_patient_list = fetch_data.test_patient_list
params['num_batches_test'] = int(np.ceil(len(test_patient_list)/params['batchsize']))

train_dataset, train_loader = get_data_and_loader(train_patient_list, fetch_data, params, shuffle=True)
validate_dataset, validate_loader = get_data_and_loader(validate_patient_list, fetch_data, params, shuffle=True)
test_dataset, test_loader = get_data_and_loader(test_patient_list, fetch_data, params, shuffle=False)
params['include_longitudinal'] = train_dataset.include_longitudinal
params['num_features_static'] = len(fetch_data.static_features)
if params['model_type'] in ['graph', 'graph_no_target', 'explainability']: params['num_features_alt_static'] = len(fetch_data.alt_static_features)
params['num_features_longitudinal'] = len(fetch_data.longitudinal_features)

##   Model
### Model architecture
The GNN Longitudinal model is a graph embedding model that consists of two separate paths for family-based data and patient-based data, which get combined into final classification.

The patient part consists of LSTM, Linear, and Dropout layers with ReLU activation.

The family part consists of LSTM and graph convolutional layers (with GCN), with ReLU activation.

The combined part concatenates the patient and family outputs and passes them through a linear layer and sigmoid activation function to get the final output.

### Training objectives
Loss function: WeightedBCELoss (BCELoss with weights adjusted for class imbalance)

Optimizer: Adam with learning rate = 0.001

### Others
The model was trained on the synthetic dataset, we are loading the trained version of the model for expediency (but running 1 iteration of training for demonstration).

### Model Training

In [None]:
model = get_model(params)
model_path = '{}/{}_model.pth'.format(params['outpath'], params['outname'])

optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'])
if params['loss']=='bce_weighted_single' or params['loss']=='bce_weighted_sum':
  train_criterion = WeightedBCELoss(params['num_samples_train_dataset'], params['num_samples_train_minority_class'], params['num_samples_train_majority_class'], params['device'])
  valid_criterion = WeightedBCELoss(params['num_samples_valid_dataset'], params['num_samples_valid_minority_class'], params['num_samples_valid_majority_class'], params['device'])

# normal training model
del fetch_data # free up memory no longer needed
del train_dataset
del validate_dataset
del test_dataset

# model training (we run 1 epoch just for demonstration)
params['max_epochs'] = 1
start_time_train = time.time()
# train the model for at most max_epochs
model, threshold = train_model(model, train_loader, validate_loader, params)
end_time_train = time.time()
# we will load the pretrained model
#torch.save(model.state_dict(), model_path)
threshold = 0.639   # threshold determined from training
params['threshold'] = threshold
params['training_time'] = end_time_train - start_time_train

### Model Testing

In [None]:
model_path = 'ehr_graph_replica/results_G2/G2_TestDisease_model.pth'
results_path = '{}/{}_results.csv'.format(params['outpath'], params['outname'])
stats_path = '{}/{}_stats.csv'.format(params['outpath'], params['outname'])
model.load_state_dict(torch.load(model_path))

# model testing
results, metric_results = test_model(model, test_loader, threshold, params)
results.to_csv(results_path, index=None)
params.update(metric_results)
stats = pd.DataFrame({'name':list(params.keys()), 'value':list(params.values())})
stats.to_csv(stats_path, index=None)

# Results

This section summarizes the results of the models we have tested. The main model from the paper is the GNN model with longitudinal data. In addition, we provide results of the other models, the results of which we load directly:

* Baseline model with age and sex data
* Age, sex, and family history MLP
* Age, sex, and graph connectivity MLP

### GNN model with longitudinal data

Test set size: 7860

Accuracy: 0.77

Recall: 0.64

Precision: 0.28

F1 Score: 0.388

ROC_AUC: 0.803

**Figure 1: Loss shrinkage for GNN model**

![sample_image.png](https://drive.google.com/uc?export=view&id=1D3T-qjw5mYLfL8h5X9NwEVT2w3IDtbSI)

In [None]:
# load the results directly if we did not run the model in the previous step
#results = pd.read_csv('ehr_graph_replica/results_G2/G2_TestDisease_results.csv')
#stats = pd.read_csv('ehr_graph_replica/results_A2/G2_TestDisease_stats.csv')

y_true = results['actual']
y_score = results['pred_raw']
y_pred = results['pred_binary']

# metrics to evaluate my model
print("Accuracy: ", round(metrics.accuracy_score(y_true, y_pred), 3))
print("Recall: ", round(metrics.recall_score(y_true, y_pred), 3))
print("Precision: ", round(metrics.precision_score(y_true, y_pred), 3))
print("F1: ", round(metrics.f1_score(y_true, y_pred), 3))
print("ROC AUC: ", round(metrics.roc_auc_score(y_true, y_score), 3))

## Baseline model with age and sex data

Test set size: 7860

Accuracy: 0.65

Recall: 0.51

Precision: 0.16

F1 Score: 0.243

ROC_AUC: 0.626

In [None]:
# load the results directly (we trained and ran the model separately)
results = pd.read_csv('ehr_graph_replica/results_A1/A1_TestDisease_results.csv')
stats = pd.read_csv('ehr_graph_replica/results_A1/A1_TestDisease_stats.csv')

y_true = results['actual']
y_score = results['pred_raw']
y_pred = results['pred_binary']

# metrics to evaluate my model
print("Accuracy: ", round(metrics.accuracy_score(y_true, y_pred), 3))
print("Recall: ", round(metrics.recall_score(y_true, y_pred), 3))
print("Precision: ", round(metrics.precision_score(y_true, y_pred), 3))
print("F1: ", round(metrics.f1_score(y_true, y_pred), 3))
print("ROC AUC: ", round(metrics.roc_auc_score(y_true, y_score), 3))

## Age, sex and family history MLP

Test set size: 7860

Accuracy: 0.73

Recall: 0.63

Precision: 0.24

F1 Score: 0.344

ROC_AUC: 0.768

In [None]:
# load the results directly (we trained and ran the model separately)
results = pd.read_csv('ehr_graph_replica/results_A2/A2_TestDisease_results.csv')
stats = pd.read_csv('ehr_graph_replica/results_A2/A2_TestDisease_stats.csv')

y_true = results['actual']
y_score = results['pred_raw']
y_pred = results['pred_binary']

# metrics to evaluate my model
print("Accuracy: ", round(metrics.accuracy_score(y_true, y_pred), 3))
print("Recall: ", round(metrics.recall_score(y_true, y_pred), 3))
print("Precision: ", round(metrics.precision_score(y_true, y_pred), 3))
print("F1: ", round(metrics.f1_score(y_true, y_pred), 3))
print("ROC AUC: ", round(metrics.roc_auc_score(y_true, y_score), 3))

## Age, sex and graph connectivity MLP

Test set size: 7860

Accuracy: 0.67

Recall: 0.49

Precision: 0.17

F1 Score: 0.249

ROC_AUC: 0.638

In [None]:
# load the results directly (we trained and ran the model separately)
results = pd.read_csv('ehr_graph_replica/results_A3/A3_TestDisease_results.csv')
stats = pd.read_csv('ehr_graph_replica/results_A3/A3_TestDisease_stats.csv')

y_true = results['actual']
y_score = results['pred_raw']
y_pred = results['pred_binary']

# metrics to evaluate my model
print("Accuracy: ", round(metrics.accuracy_score(y_true, y_pred), 3))
print("Recall: ", round(metrics.recall_score(y_true, y_pred), 3))
print("Precision: ", round(metrics.precision_score(y_true, y_pred), 3))
print("F1: ", round(metrics.f1_score(y_true, y_pred), 3))
print("ROC AUC: ", round(metrics.roc_auc_score(y_true, y_score), 3))

## Analysis and Plans

The results suggest that the Graph-based model with longitudinal data outperforms non graph-based models when making predictions on the synthetic dataset. They also suggset that family history plays a significant role in the prediction of the disease, since those models performed significantly better.

In further work, we could try different combinations of features or look into how the models perform on datasets with different characteristics like heritability to see if the pattern continues to hold.

## Model comparison

### Results of models in the paper for predicting coronary heart disease (AUC-ROC):

Baseline model with age and sex data: 0.696

Age, sex and family history MLP: 0.710

Age, sex and graph connectivity MLP: 0.696

GNN model with longitudinal data: 0.775

### Results of models in replicated experiment for predicting on synthetic dataset (AUC-ROC):

Baseline model with age and sex data: 0.626

Age, sex and family history MLP: 0.768

Age, sex and graph connectivity MLP: 0.638

GNN model with longitudinal data: 0.803

# Discussion

Based on the experiment, we were able to reproduce two of the major findings in the paper, which were 1) that graph based representation learning based on family history provides better preditive value over non-graph based models, and 2) family history is a significant factor in how well the model performs, since all models that included family history performed significantly better than those that did not.

Of course, we should note the caveat that this was using synthetic data instead of the real dataset, which may yield different results. This is probably the major contributing gap between the original results and the replicated results.

The easy part was that most of the code for producing the data and training the models was available in the Github, making it relatively easy for someone to run similar code in their own environment. The difficult part is that it's not easy to compare the results of the new dataset with those of the original dataset, due to their different sources.

The authors might consider going into detail about some methods that use less sensitive data that can be accessed more widely, even though it is not optimal, so that follow-up experiments can compared using the same datasets.

# References

1.   Sophie Wharrie, Zhiyu Yang, Andrea Ganna, Samuel Kaski. (2023). Characterizing personalized
effects of family information on disease risk using graph representation learning. Proceedings of
the 8th Machine Learning for Healthcare Conference, in Proceedings of Machine Learning
Research. 219:824-845. Available from https://arxiv.org/abs/2304.05010.

