In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
from pathlib import Path

In [None]:
from nilearn.connectome import ConnectivityMeasure
from nilearn import plotting

In [None]:
import gudhi as gd
import gudhi.representations

# Loading data

In [None]:
df = pd.read_csv('ADNI_Gene_Expression_Profile/ADNI_Gene_Expression_Profile.csv', low_memory=False)

In [None]:
subj_table_1 = pd.read_csv('ADNI_Gene_Expression_Profile/ADNI_Training_Q1_APOE_July22.2014.csv', low_memory=False)
subj_list_1 = subj_table_1[['PTID','DX.bl']].values
subj_table_2 = pd.read_csv('ADNI_Gene_Expression_Profile/ADNI_Training_Q3_APOE_CollectionADNI1Complete_July22.2014.csv', low_memory=False)
subj_list_2 = subj_table_2[['Subject','DX.bl']].values
subj_list = np.concatenate([subj_list_1,subj_list_2])
subj_list
print(len(subj_list_1), len(subj_list_2), len(subj_list))

In [None]:
dataset_name = 'Genes' # Genes, GenesExtended

In [None]:
if dataset_name == 'GenesExtended':
    subjects_AD = [item[0] for item in subj_list if ((item[1]=='AD') or (item[1]=='LMCI')) ]
elif dataset_name == 'Genes':
    subjects_AD = [item[0] for item in subj_list if (item[1]=='AD')]
subjects_control = [item[0] for item in subj_list if item[1]=='CN' ]

In [None]:
subjects = subjects_AD + subjects_control

### Choose markers of AD

In [None]:
genes_df = pd.read_csv('ADNI_Gene_Expression_Profile/genes/gene.csv', sep=';', header=None)
genes_df

In [None]:
genes_array=genes_df[1][1:200]
genes_array

In [None]:
genes_array_existed = [gene_symb for gene_symb in genes_array if gene_symb in set(df['Unnamed: 2'][8:])]

In [None]:
len(genes_array_existed)

In [None]:
df_top_genes = df.iloc[[item in genes_array_existed for item in df['Unnamed: 2']]]
df_top_genes

In [None]:
#df_matrix = df.loc[range(8, df.shape[0])][df.columns[3:-1]]
df_matrix = df_top_genes[df.columns[3:-1]]
df_matrix

In [None]:
gene_subjects = df[df.columns[3:-1]].loc[1]
gene_subjects

In [None]:
labels = np.array(gene_subjects)[np.array([ item in subjects for item in list(gene_subjects)])]
labels

In [None]:
df_matrix_AD = df_matrix[df_matrix.columns[np.array([ item in subjects_AD for item in list(gene_subjects)])]]
df_matrix_AD

In [None]:
df_matrix_control = df_matrix[df_matrix.columns[np.array([ item in subjects_control for item in list(gene_subjects)])]]
df_matrix_control

In [None]:
indexes = df_top_genes[df.columns[2]] + '_' + list(map(str, df_top_genes.index))
print(indexes)

In [None]:
df_top_genes

In [None]:
df_matrix_control = df_matrix_control.astype(float)
df_matrix_AD = df_matrix_AD.astype(float)

In [None]:
df_matrix_control = df_matrix_control.set_index(indexes)
df_matrix_AD = df_matrix_AD.set_index(indexes)

In [None]:
df_matrix_AD_normed = df_matrix_AD / df_matrix_AD.sum(axis=0)
df_matrix_control_normed = df_matrix_control / df_matrix_control.sum(axis=0)

### Data view

In [None]:
class DataGeneExpressions:
    
    #input parameters
    expressions = None
    labels = None
    phenotypes_array = []
    connectivity_measure_kind = None
    rips_complex_max_dimension = None
    
    #derived parameters
    matrix = None
    diagram = None
    simplex_tree = None
    
    def __init__(self, expressions, phenotype, labels,
                 connectivity_measure_kind='correlation', 
                 rips_complex_max_dimension=2):
        self.expressions = expressions.loc[labels]
        self.labels = labels
        self.phenotype = phenotype
        self.connectivity_measure_kind = connectivity_measure_kind
        self.rips_complex_max_dimension = rips_complex_max_dimension
        
        # create matrix for each time_series
        self.create_matrix()  
        # Rips complex and persistent diagrams
        self.create_persistence_view()
        
    def create_matrix(self):
        measure = ConnectivityMeasure(kind=self.connectivity_measure_kind)
        self.matrix = measure.fit_transform([self.expressions.values.T])[0]
            
        
    def create_persistence_view(self):
        rips_complex = gudhi.RipsComplex(distance_matrix=1-self.matrix, max_edge_length=2)
        simplex_tree = rips_complex.create_simplex_tree(max_dimension=self.rips_complex_max_dimension)
        diag=simplex_tree.persistence()
        self.diagram = diag
        self.simplex_tree = simplex_tree
    
    
    def get_persistence_intervals(self, dim):
        return self.simplex_tree.persistence_intervals_in_dimension(dim)
        
        
    # visualize
    
    def plot_matrix(self, reorder=True):
        matrix=self.matrix.copy()
        np.fill_diagonal(matrix, 0)
        print(len(self.labels), matrix.shape)
        ax = plotting.plot_matrix(matrix, figure=(10, 8), labels=range(len(self.labels)), 
                             vmax=1, vmin=-1, reorder=reorder)
        return self.labels[[int(lbl.get_text()) for lbl in ax.axes.get_xticklabels()]]
            
        
    def plot_persistence_diagram(self):
        gudhi.plot_persistence_diagram(self.diagram, legend=True, max_intervals=0)
        
    def plot_persistence_barcode(self):
        gudhi.plot_persistence_barcode(self.diagram, legend=True, max_intervals=0)
        
    def plot_persistence_density(self):
        gudhi.plot_persistence_density(self.diagram, dimension=1)
        

In [None]:
data_genes_AD = DataGeneExpressions(df_matrix_AD, phenotype = 'AD',
                                    labels = df_matrix_AD.index, 
                                   rips_complex_max_dimension=2)


In [None]:
reordered_labels = data_genes_AD.plot_matrix(reorder=True)

In [None]:
data_genes_control = DataGeneExpressions(df_matrix_control, phenotype = 'CN',
                                         labels = reordered_labels,
                                        rips_complex_max_dimension=2)

In [None]:
data_genes_control.plot_matrix(reorder=False)

In [None]:
data_genes_control.plot_persistence_diagram()

In [None]:
data_genes_AD.plot_persistence_diagram()

In [None]:
data_genes_control.plot_persistence_barcode()

In [None]:
data_genes_AD.plot_persistence_barcode()

# Results

In [None]:
import matplotlib.patches as mpatches

### Dim 0

In [None]:
intervals_array_control_0 = data_genes_control.get_persistence_intervals(0)
intervals_array_patient_0 = data_genes_AD.get_persistence_intervals(0)

In [None]:
data_dim0 = []
for item in np.vstack(intervals_array_control_0):
    data_dim0.append((0, (item[0], item[1])))
for item in np.vstack(intervals_array_patient_0):
    data_dim0.append((1, (item[0], item[1])))
    
axis = gudhi.plot_persistence_diagram(data_dim0, legend=True, max_intervals=0, fontsize=10)
axis.set_title('H0. Both controls and AD patients')
patch1 = mpatches.Patch(color='red', label='Control')
patch2 = mpatches.Patch(color='blue', label='Patient')
axis.legend(handles=[patch1, patch2])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,4))
patch1 = mpatches.Patch(color='red', label='Control')
patch2 = mpatches.Patch(color='blue', label='Patient')

gudhi.plot_persistence_diagram(np.vstack(intervals_array_control_0), legend=True, max_intervals=0, colormap='red',
                                      axes=ax[0], fontsize=10)
ax[0].set_title('Control dim 0')
ax[0].legend(handles=[patch1], loc='upper right')


gudhi.plot_persistence_diagram(np.vstack(intervals_array_patient_0), legend=True, max_intervals=0, colormap='blue',
                                     axes=ax[1], fontsize=10)
ax[1].set_title('Patient dim 0')
ax[1].legend(handles=[patch2], loc='upper right')


### Dim 1

In [None]:
intervals_array_control_1 = data_genes_control.get_persistence_intervals(1)
intervals_array_patient_1 = data_genes_AD.get_persistence_intervals(1)

In [None]:
data_dim1 = []
for item in np.vstack(intervals_array_control_1):
    data_dim1.append((0, (item[0], item[1])))
for item in np.vstack(intervals_array_patient_1):
    data_dim1.append((1, (item[0], item[1])))
    
axis = gudhi.plot_persistence_diagram(data_dim1, legend=True, max_intervals=0, fontsize=10)
axis.set_title('H1. Both patients (AD + LMCI) and controls')
patch1 = mpatches.Patch(color='red', label='Control')
patch2 = mpatches.Patch(color='blue', label='Patient')
axis.legend(handles=[patch1, patch2])

In [None]:
data_genes_control.diagram

In [None]:
axis = gudhi.plot_persistence_barcode(data_genes_AD.diagram, legend=True, max_intervals=0, fontsize=10, alpha=0.8)
axis.set_title('AD patients')
patch1 = mpatches.Patch(color='red', label='dim 0')
patch2 = mpatches.Patch(color='blue', label='dim 1')
axis.legend(handles=[patch1, patch2])

In [None]:
axis = gudhi.plot_persistence_barcode(data_genes_control.diagram, legend=True, max_intervals=0, fontsize=10, alpha=0.8)
axis.set_title('Controls')
patch1 = mpatches.Patch(color='red', label='dim 0')
patch2 = mpatches.Patch(color='blue', label='dim 1')
axis.legend(handles=[patch1, patch2])

In [None]:
colors1 = [matplotlib.cm.Set2.colors[0], 'red']
colors2 = ['red', matplotlib.cm.Set2.colors[0]]

In [None]:
axis = gudhi.plot_persistence_barcode(data_genes_control.diagram,
                                      legend=True,
                                      fontsize=10,colormap=colors2,
                                      alpha=0.5)
gudhi.plot_persistence_barcode(data_genes_AD.diagram,
                               legend=True, fontsize=10, alpha=0.5,
                               colormap=colors1, axes=axis)
patch1 = mpatches.Patch(color='red', label='Control')
patch2 = mpatches.Patch(color=matplotlib.cm.Set2.colors[0], label='Patient')
axis.legend(handles=[patch1, patch2])
axis.set_title('H0 and H1. Both patients (AD + LMCI) and controls')

In [None]:
print(data_genes_control.matrix.mean(), data_genes_AD.matrix.mean())

In [None]:
data_genes_control.matrix.shape

# Distance

In [None]:
from gudhi import hera

In [None]:
gudhi.hera.wasserstein_distance(intervals_array_control_1, intervals_array_patient_1,
                                order = 1, internal_p = np.inf, delta = 0.01)

In [None]:
gudhi.hera.wasserstein_distance(intervals_array_control_0, intervals_array_patient_0,
                                order = 1, internal_p = np.inf, delta = 0.01)

In [None]:
gudhi.bottleneck_distance(intervals_array_control_1, intervals_array_patient_1, e=0)

In [None]:
gudhi.bottleneck_distance(intervals_array_control_0, intervals_array_patient_0, e=0)

# Saving

### Weights

In [None]:
matrix_control = data_genes_control.matrix.copy()
np.fill_diagonal(matrix_control, np.nan)
matrix_control

In [None]:
matrix_AD = data_genes_AD.matrix.copy()
np.fill_diagonal(matrix_AD, np.nan)
matrix_AD

Save distributions

In [None]:
hist_AD = plt.hist(matrix_AD.flatten(),bins=50)

In [None]:
hist_control = plt.hist(matrix_control.flatten(),bins=50)

In [None]:
np.save(f'{dataset_name}/diagrams/hist_control_0.npy', hist_control[0])
np.save(f'{dataset_name}/diagrams/hist_AD_0.npy', hist_AD[0])
np.save(f'{dataset_name}/diagrams/hist_control_1.npy', hist_control[1])
np.save(f'{dataset_name}/diagrams/hist_AD_1.npy', hist_AD[1])

Save diagrams

In [None]:
np.save(f'{dataset_name}/diagrams/control_diagram_0.npy', data_genes_control.get_persistence_intervals(0))
np.save(f'{dataset_name}/diagrams/control_diagram_1.npy', data_genes_control.get_persistence_intervals(1))
np.save(f'{dataset_name}/diagrams/AD_diagram_0.npy', data_genes_AD.get_persistence_intervals(0))
np.save(f'{dataset_name}/diagrams/AD_diagram_1.npy', data_genes_AD.get_persistence_intervals(1))