In [1]:
# module load Bear-Python-DataScience/2019b-fosscuda-2019b-Python3.7.4
# module load sklearn-crfsuite/0.3.6-foss-2019b-Python3.7.47
# module load openpyxl...


# Load paths
data_tag = "OPTIMAL_conditions_for_clustering"
path_to_features = f"/rds/projects/g/gokhalkm-optimal/DataforCharles/{data_tag}.csv"  
# Save paths
path_to_model = f"/rds/projects/g/gokhalkm-optimal/DataforCharles/mmVAE_output/" 
path_to_figs = path_to_model + 'plots/'
path_to_outputs = path_to_model + 'output/'

import os
os.chdir('/rds/homes/g/gaddcz/Projects/mum-predict-repos/mmVAE/src/')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from experiment import fit_restarts
from model.plotting import *
import helpers

torch.manual_seed(0)
np.random.seed(0)
%matplotlib inline

# import sys
# !{sys.executable} -m pip install --quiet --user scikit-learn==1.0.1


# Load CPRD data
Here we load the CPRD data from outside the repository. 
We then re-order the the columns (which won't affect modelling) to  aid visualisation later. 
Finally we remove non-multimorbidity cases (with fewer than 2 conditions)

In [2]:
# Load all data
df_id = pd.read_csv(path_to_features)
print(f"Loaded {len(df_id.index)} patients for study")

Loaded 8220386 patients for study


In [3]:
display(df_id.iloc[-1].values)     # Last patient has nan values - exclude from study

# Drop NaN rows
df_id = df_id.dropna()
print(f"Retained {len(df_id.index)} patients for study")

array([1.11756422e+12, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00,
       0.00000000e+00, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
      

Retained 8220385 patients for study


In [4]:
# Remove patients with only one condition
df_id = df_id[df_id.sum(axis = 1) >= 2]
print(f"Retained {len(df_id.index)} patients for study")    # All retained

Retained 8220385 patients for study


In [5]:
# Remove patient identifiers
df = df_id.drop(df_id.columns[0], axis=1)

In [6]:
# Reset dataframe index
diag_frame = df.reset_index(level=None, drop=True, col_level=0, col_fill='')
display(diag_frame)

Unnamed: 0,Heart_failure,Atrial_fibrillation,Stroke,Hypertension,Ischaemic_heart_disease,Peripheral_vascular_disease,Heart_valve_disorders,Aortic_aneurysm,Type_1_diabetes,Type_2_diabetes,...,Menieres_disease,Peripheral_neuropathy,Intellectual_disabilites,Down_syndrome,Pernicious_anaemia,Sickle_cell_anaemia,Psoriasis,Psoriatic_arthritis,Interstitial_lung_disease,Haemochromatosis
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8220380,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8220381,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8220382,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8220383,0,0,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [7]:
N = diag_frame.shape[0]
D = diag_frame.shape[1]

# Check prevalence of diseases
print(f"{f'{D} conditions'.ljust(50)[:20]} | Count \n -------------------------------")
for i, j in zip(diag_frame.columns, diag_frame.to_numpy().sum(axis=0)):
    print(f"{i.ljust(50)[:20]} | {int(j)}")
print(f"-------------------------------\n" + "Total".ljust(50)[:20] + f" | {N}")

63 conditions        | Count 
 -------------------------------
Heart_failure        | 117947
Atrial_fibrillation  | 189709
Stroke               | 149000
Hypertension         | 1422280
Ischaemic_heart_dise | 355189
Peripheral_vascular_ | 61606
Heart_valve_disorder | 118748
Aortic_aneurysm      | 24428
Type_1_diabetes      | 44329
Type_2_diabetes      | 508626
Chronic_kidney_disea | 326432
Depression           | 1491649
Anxiety              | 1292656
Bipolar_disorder     | 37671
Eating_disorder      | 60214
Schizophrenia        | 44565
Post_traumatic_stres | 46169
Autism               | 34492
Drug_or_alcohol_misu | 547103
Alcoholic_liver_dise | 17404
Non_alc_fatty_liver_ | 62650
Other_chronic_liver_ | 101932
Inflammatory_bowel_d | 72714
Irritable_bowel_synd | 491337
Dementia             | 93585
Parkinsons_disease   | 18969
Epilepsy             | 116836
Cancer_excluding_BCC | 416408
Haematological_cance | 40552
Asthma               | 1371635
Chronic_obstr_pulmon | 230304
Obstructive_sleep

# Run experiment

Plot order
* the prevalence of each condition in each quantised cluster
* the log odds ratio of each quantised cluster
* the cluster factor association matrix
* the prevalence within factors
* the log odds ratio within factors


In [None]:
latent_factors = 15
batch_size = 2048

architecture = {'enc_h': [D, D, D, latent_factors], 
                'dec_h': [latent_factors, D, D, D], 
                'constrain': ['L0', 'L0'],
               }
model_params = {'tmp_schedule': [4, 0.4, 0.4],
                 'epochs': 50,
                 'batch_size': batch_size,
                 'lr': 1e-3,
                 'verbose': 1,
                 'anneal': True,
                 'norm_beta': -0.4
                }


method_name = f'{data_tag}_Beta{model_params["norm_beta"]:.2f}_L{latent_factors}_batch{batch_size}'.replace(".", "_").replace("-", "Neg")

print(f"{method_name}\n===============")
restart_dicts, losses, labels, avg_ami = fit_restarts(diag_frame, architecture, model_params,
                                                      n_restarts=3,
                                                      save_path=f"{path_to_model}{method_name}",
                                                      force_retrain=False
                                                     )

# plot_restarts(diag_frame, best_dict, labels, plot_path=)

OPTIMAL_conditions_for_clustering_BetaNeg0_40_L15_batch2048
Failed to load /rds/projects/g/gokhalkm-optimal/DataforCharles/mmVAE_output/OPTIMAL_conditions_for_clustering_BetaNeg0_40_L15_batch2048_0.pickle, training...


## Check Adjusted Mutual Information and choose a seed. 

We perform multiple restarts to check the consistency of the clustering obtained under the chosen parameters (see ML4H paper).

Whilst we choose the lowest reconstruction loss, a smaller value doesn't necessarily imply better clustering.


In [None]:
best_dict = restart_dicts[np.argmin(losses)]

## Process the model output to find statistics of interest 
### - this also saves to .xls at the given path

In [None]:
output_dict = helpers.post_process(diag_frame, best_dict, save_path=path_to_outputs)

for k in output_dict.keys():
    print(k)

# Plot some of the post-processed outputs

## Clusters 

Not plotting any clusters containing less than 1% of the total samples

In [None]:
cluster_labels = output_dict['cluster_labels']     
cluster_prevalence = output_dict['prevalence_clusters']
cluster_RR = output_dict['RR_clusters']
cluster_OR = output_dict['OR_clusters']
cluster_counts = output_dict['count_clusters']


# Prevalence 
_increments = 2500
plot_grid(cluster_prevalence, cluster_names=cluster_labels, condition_names=diag_frame.columns.tolist(), counts=cluster_counts, perc_threshold=1,
          save_path=f'{path_to_figs}ClusterPrevalence',
          ylabel='Cluster',  cbar_label="Condition prevalence",
          bins = [int(_increments * i) for i in range(1, np.int(np.ceil(np.max(cluster_prevalence) / _increments)))]
          )


# Relative risk 
plot_grid(cluster_RR, cluster_names=cluster_labels, condition_names=diag_frame.columns.tolist(), counts=cluster_counts, perc_threshold=1,
          save_path=f'{path_to_figs}ClusterRelativeRisk',
          ylabel='Cluster',  cbar_label="Relative risk",
          bins = [1, 2, 4, 6]
          )  

# Odds ratio
plot_grid(cluster_OR, cluster_names=cluster_labels, condition_names=diag_frame.columns.tolist(), counts=cluster_counts, perc_threshold=1,
          save_path=f'{path_to_figs}ClusterOddsRatio',
          ylabel='Cluster',  cbar_label="Odds ratio",
          bins = [1, 2, 4, 6]
          )  


## Latent factors

In [None]:
factor_labels = output_dict['topic_labels']     
factor_prevalence = output_dict['prevalence_topics']
factor_RR = output_dict['RR_topics']
factor_OR = output_dict['OR_topics']
factor_counts = output_dict['count_topics']

# Prevalence 
_increments = 20000
plot_grid(factor_prevalence, cluster_names=factor_labels, condition_names=diag_frame.columns.tolist(), counts=factor_counts, perc_threshold=None,
          save_path=f'{path_to_figs}FactorPrevalence',
          ylabel='Factor',  cbar_label="Condition prevalence",
          bins = [int(_increments * i) for i in range(1, np.int(np.ceil(np.max(factor_prevalence) / _increments)))]
          )


# Relative risk 
plot_grid(factor_RR, cluster_names=factor_labels, condition_names=diag_frame.columns.tolist(), counts=factor_counts, perc_threshold=None,
          save_path=f'{path_to_figs}FactorRelativeRisk',
          ylabel='Factor',  cbar_label="Relative risk",
          bins = [1, 2, 4, 6]
          )  

# Odds ratio
plot_grid(factor_OR, cluster_names=factor_labels, condition_names=diag_frame.columns.tolist(), counts=factor_counts, perc_threshold=None,
          save_path=f'{path_to_figs}FactorOddsRatio',
          ylabel='Factor',  cbar_label="Odds ratio",
          bins = [1, 2, 4, 6]
          )  

## And look at the relationship between clusters and factors

In [None]:
association_matrix = output_dict['cluster_factors'].T

#  Cluster-factor association matrix
cluster_factor_association(association_matrix,
                           xlabel='Cluster', ylabel='Latent factors',
                           x_ticks=cluster_labels, y_ticks=factor_labels, 
                           figsize=[3, 1.5], save_path=f'{path_to_figs}Cluster_factor_association')