<a href="https://colab.research.google.com/github/marlapinkert/final_project_compcognition/blob/main/download_and_correlation_df.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Random Forest Classifier for Autism Spectrum Disorder (ASD)
#### Final Project for TEWA I, supervisors: Dominik Pegler, Mengfan Zhang, Jozsef Arato
*Jannis Breßgott, Sara Binder, Marla Pinkert, 04.07.2024*

Autism spectrum disorder (ASD) is a neurological and developmental disorder which manifests in differences in social communication and interaction, learning, and sensation, with a wide range of associated subjective and objective impairments of quality of life (for review, see Lord et al., 2020). As demonstrated by Ilioska and colleagues in a large scale analysis of resting-state fMRI data including 1824 individuals (796 with ASD), ASD is characterized by both hypo- and hyperconnectivity (Ilioska et al., 2023). Previous studies attempted to train machine learning models to predict ASD diagnosis based on structural and functional changes in the brain, and reached prediction accuracies of ? to ? (SOURCES). 

Subsequently, we decided to train a random forest model to predict subjects' ASD diagnosis based on functional connectivity, utilizing preprocessed data from the Autism Brain Imaging Data Exchange I (ABIDE I, Craddock et al., 2013). Thus, we had access to 883 subjects (??? with ASD). We downloaded timecourses preprocessed using the CPAC pipeline (as outlined in http://preprocessed-connectomes-project.org/abide/Pipelines.html), as it previously performed best in a similar classification attempt (SOURCE). Timecourses were extracted using the Craddock 200 (CC200) atlas (Craddock et al., 2011).

## Import dependencies and download data

In [216]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from itertools import combinations
from glob import glob
from nilearn import connectome
from nilearn import plotting
import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA

In [None]:
# Below function can be used to download the ABIDE 
!python resources/download_abide_preproc.py -d rois_cc200 -p cpac -s filt_global -o output

Downloading data for ASD and TDC participants
No upper age threshold specified
No lower age threshold specified
No site specified, using all sites...
No sex specified, using all sexes...
b',Unnamed: 0,SUB_ID,X,subject,SITE_ID,FILE_ID,DX_GROUP,DSM_IV_TR,AGE_AT_SCAN,SEX,HANDEDNESS_CATEGORY,HANDEDNESS_SCORES,FIQ,VIQ,PIQ,FIQ_TEST_TYPE,VIQ_TEST_TYPE,PIQ_TEST_TYPE,ADI_R_SOCIAL_TOTAL_A,ADI_R_VERBAL_TOTAL_BV,ADI_RRB_TOTAL_C,ADI_R_ONSET_TOTAL_D,ADI_R_RSRCH_RELIABLE,ADOS_MODULE,ADOS_TOTAL,ADOS_COMM,ADOS_SOCIAL,ADOS_STEREO_BEHAV,ADOS_RSRCH_RELIABLE,ADOS_GOTHAM_SOCAFFECT,ADOS_GOTHAM_RRB,ADOS_GOTHAM_TOTAL,ADOS_GOTHAM_SEVERITY,SRS_VERSION,SRS_RAW_TOTAL,SRS_AWARENESS,SRS_COGNITION,SRS_COMMUNICATION,SRS_MOTIVATION,SRS_MANNERISMS,SCQ_TOTAL,AQ_TOTAL,COMORBIDITY,CURRENT_MED_STATUS,MEDICATION_NAME,OFF_STIMULANTS_AT_SCAN,VINELAND_RECEPTIVE_V_SCALED,VINELAND_EXPRESSIVE_V_SCALED,VINELAND_WRITTEN_V_SCALED,VINELAND_COMMUNICATION_STANDARD,VINELAND_PERSONAL_V_SCALED,VINELAND_DOMESTIC_V_SCALED,VINELAND_COMMUNITY_

In [None]:
# We read in the phenotype dataframe and create a new df with only diagnoses
# The Phenotypic_V1_0b_preprocessed1.csv can be downloaded from http://preprocessed-connectomes-project.org/abide/download.html
pheno_df = pd.read_csv("Phenotypic_V1_0b_preprocessed1.csv", index_col = 0)
diagnose_df = pheno_df.loc[:,["FILE_ID", "DX_GROUP", "DSM_IV_TR"]]

In [147]:
# We list the paths for each subject using glob
subj_paths = sorted(glob("cpac/filt_global/rois_cc200/*.1D"))

## Model 1: Functional Connectivity, all Predictors

For our first model, we followed an approach already attempted by Chen et al. (2015), who trained a random forest classifier on the ABIDE data and utlized individual correlations between all ROIs as predictors without employing any kinds of dimensionality reduction techniques. Our procedure was the following: 
1. Calculate connectivity matrices using 200 timcourses
2. Vectorize connectivity matrices while only keeping individual correlations
3. Repeat first two steps for all subjects to create a DataFrame with rows corresponding to subjects and columns to correlations

### Prepare Data

In [None]:
# Define Class which gets dict of unique correlations for one subject
class UniqueCorrelations:
    '''
        Parameters:
            - time_course_path: path to a 1D file containing timecourses of all relevant regions
              for one specific subject
            - correlation_kind: {"covariance", "correlation", "partial correlation", "tangent", "precision"}, 
              default="correlation"
              Kind of correlation to calcualte, takes same arguments as nilearn.connectome.ConnectivityMeasure()
            - subject_ID: whether "FILE_ID" should be added to dict, default = True
    '''
    
    def __init__(self, correlation_kind, time_course_path, subject_ID=True):

        self.correlation_kind = correlation_kind
        self.time_course_path = time_course_path
        self.subject_ID = subject_ID
    
    def get_corr_list(self):
    
        '''          
        Creates: self.unique_corrs, self.connectivity_matrix
        - unique corrs: dictionnary with unique correlations between timecourses,
          keys corresponding to correlation of two timeseries 
          e.g.: key '0x4' corresponds to correlation of regions 0 and 4.
        - connectivity matrix: connectivity matrix for all timeseries for one subjects.
          can be used to be plotted with nilearn.plotting.plot_matrix()
        '''

        # read the file into a DataFrame
        df = pd.read_csv(self.time_course_path, sep="\t")
        df = np.array(df)
        
        # calculate corr matrix for DataFrame (all rois)
        correlation_measure = connectome.ConnectivityMeasure(
            kind=self.correlation_kind
        )
        
        connectivity_matrix = correlation_measure.fit_transform([df])[0]
        corrs = pd.DataFrame(connectivity_matrix)
    
        # create index with only unique combinations of row x column
        # this ensures we get an index which only gets us the "lower half"
        # of the correlation matrix
        rois_numbers = np.linspace(0, (len(corrs)-1), len(corrs), dtype = "int")
        comb_ind = [comb for comb in combinations(rois_numbers, 2)]
    
        # use new index to select only the unique combinations from our DataFrame
        # we save this as a dictionnary with the combination of values as keys
        # I.e., correlation of region 1 and region 4 is called "1x4"
        unique_corrs = {}
        for ind in comb_ind:
           unique_corrs[f"{ind[0]}x{ind[1]}"] = corrs.iloc[ind]
    
        # We add the name of the file as "ID" to the dictionnary. This allows us to later identify 
        # the subjects when we create our DataFrame with all our subjects
        subject_id = self.time_course_path.split("\\")[-1].split("_rois")[0]
        unique_corrs["FILE_ID"] = subject_id   

        self.unique_corrs = unique_corrs
        self.connectivity_matrix = connectivity_matrix
        

In [None]:
# Create DataFrame with all subjects
list_unique_corrs = []
for path in subj_paths:
    get_unique_corr = UniqueCorrelations("correlation", path)
    list_unique_corrs.append(get_unique_corr.get_corr_list().unique_corrs))

corr_df = pd.DataFrame(list_unique_corrs)
corr_df.to_csv("corr_df2.csv")

### Model 1.a. Baseline Parameters
First, we used the baseline parameters of sklearn.

In [None]:
corr_merged_df = pd.merge(corr_df, diagnose_df, on = "FILE_ID", how = "left")
corr_merged_df = corr_merged_df.dropna()

In [None]:
X = graph_merged_df.drop(["FILE_ID", "DX_GROUP", "DSM_IV_TR"], axis=1)
y = graph_merged_df["DX_GROUP"]

# Split Data into test and train
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [None]:
# Initialize basic rf
rf = RandomForestClassifier(oob_score=True)
rf.fit(X_train, y_train)
rf_cross_val = cross_val_score(rf, X_train, y_train, cv=5)

In [None]:
print(f"This model has an OOB score of {rf.oob_score_}")
print(f"The mean cross validation score was {rf_cross_val.mean()}")
print(f"The standard deviation of our cross validation was {rf_cross_val.std()}")


### Model 1.b. Hyperparameter Tuning

## Model 2: Functional Connectivity, PCA

### Prepare Data

### Model 2.a. Baseline Parameters

### Model 2.b. Hyperparameter Tuning

## Model 3: Betweenness Centrality

To employ a dimensionality reduction innate to functional connectivity data, with Model 3, we decided to focus on the numbers of connections each ROI has to other ROIs, i.e., betweenness centrality. Betweenness centrality is a graph theoretical measure which describes the number of "edges" (connections) a "node" (ROI) has. Although other graph theoretical measures such as global efficiency and clustering coefficient might also be of interest for our classifcation attempt, we decided to restrain our analysis due to time constraints.

As threshholding the correlations between ROIs in order to create a binary graph can be accomplished in a somewhat arbitrary manner, we decided to get a more stable estimate by calculating the mean betweenness centrality based on proportional thresholds between percentile 70 and 90 in steps of 2.

### Prepare Data

In [230]:
# Create class "GraphTheoryBetweenness" to calculate betweenness centrality.
# The idea is to first calculate the connectivity matrix and then based on this,
# continue with the betweenness centrality.
class GraphTheoryBetweenness:
    '''
        Parameters:
            - time_course_path: path to a 1D file containing timecourses of all relevant regions
              for one specific subject
            - correlation_kind: {"covariance", "correlation", "partial correlation", "tangent", "precision"}, 
              default="correlation"
              Kind of correlation to calcualte, takes same arguments as nilearn.connectome.ConnectivityMeasure()
            - thresh_percentile: int or list of ints for percentile(s) at which matrix is thresholded, default = 80
            - subject_ID: whether "FILE_ID" should be added to dict, default = True

        Returns: 
            - bet_cent_dict: Dictionnary with betweenness centrality for all ROIs and subject ID as key "FILE_ID"
    '''
    
    def __init__(self, thresh_percentile, correlation_kind, time_course_path, subject_ID=True):

        self.thresh_percentile = thresh_percentile
        self.correlation_kind = correlation_kind
        self.time_course_path = time_course_path
        self.subject_ID = subject_ID
        
    def corr_matrix(self):
        '''
        This function calculates a correlation matrix based on a 1D timeseries file.     
        
            Returns:
                - correlation_matrix: A correlation matrix in form of a 2D array 
        '''

        # Read in time courses as a dataframe
        timecourse_df = pd.read_csv(self.time_course_path, sep="\t")

        # Instantiate nilearn connectivity measure object
        correlation_measure  = connectome.ConnectivityMeasure(kind=self.correlation_kind)

        # Calculate partical correlations
        correlation_matrix = correlation_measure.fit_transform(
            np.array([timecourse_df]))[0]
    
        return(correlation_matrix)


    # A function that calculates betweenness centrality for a single subject
    def calculate_bet_centrality(self, subject_ID=True):
        '''
        This function calculates betweenness centrality for each ROI contained in
        a correlation matrix.
        
            Returns:
                - bet_cent: A dictionnary containing the betweenness centrality value for each ROI
        '''

        if type(self.thresh_percentile) == list:

            thresh_list = []

            for thresh_perc in self.thresh_percentile:

                # Define proportional threshold
                thresh = np.nanpercentile(self.corr_matrix(), thresh_perc)
            
                # Set edges to 1 or 0 depending on threshold
                mean_matrix_thresh = self.corr_matrix().copy()
                mean_matrix_thresh[mean_matrix_thresh < thresh] = 0
                mean_matrix_thresh[mean_matrix_thresh >= thresh] = 1
            
                # Create Graph
                G = nx.from_numpy_array(mean_matrix_thresh)
            
                # Calculate betweenness centrality
                bet_cent_thresh = nx.betweenness_centrality(G)

                thresh_list.append(bet_cent_thresh)
            
            bet_cent = dict(np.mean(pd.DataFrame(thresh_list), axis = 0))
            
        
        else:    
        
            # Define proportional threshold
            thresh = np.nanpercentile(self.corr_matrix(), self.thresh_percentile)
        
            # Set edges to 1 or 0 depending on threshold
            mean_matrix_thresh = self.corr_matrix().copy()
            mean_matrix_thresh[mean_matrix_thresh < thresh] = 0
            mean_matrix_thresh[mean_matrix_thresh >= thresh] = 1
        
            # Create Graph
            G = nx.from_numpy_array(mean_matrix_thresh)
        
            # Calculaate betweenness centrality
            bet_cent = nx.betweenness_centrality(G)
    
        if self.subject_ID:
            # Get "FILE_ID" from path and add it to dict
            subject_id = self.time_course_path.split("\\")[-1].split("_rois")[0]
            bet_cent["FILE_ID"] = subject_id   
            return(bet_cent)
        
        else:
            return(bet_cent)


In [None]:
# Here, we calculate betweenness centrality for all ROIs for all subjects

# Define thresholds we want to use
thresh_steps = np.arange(70, 91, 2).tolist()

# Create empty list
subj_dict_list = []

# Add dict for each subject to list
for path in subj_paths:
    graph_measure = GraphTheoryBetweenness(thresh_steps, "partial correlation", path) 
    subj_dict_list.append(graph_measure.calculate_bet_centrality())

# From list of dicts, create dataframe
graph_df = pd.DataFrame(subj_dict_list)

In [None]:
# This is our betweenness centrality dataframe
graph_df

In [203]:
graph_df.to_csv("graph_df.csv", index = False)

### Model 3.a. Baseline Parameters

In [None]:
On server

In [207]:
# We read in the phenotype dataframe and create a new df with only diagnoses
pheno_df = pd.read_csv("Phenotypic_V1_0b_preprocessed1.csv", index_col = 0)
diagnose_df = pheno_df.loc[:,["FILE_ID", "DX_GROUP", "DSM_IV_TR"]]

# We merge on file name
graph_merged_df = pd.merge(graph_df, diagnose_df, on = "FILE_ID", how = "left")
graph_merged_df = graph_merged_df.dropna()

In [208]:
X = graph_merged_df.drop(["FILE_ID", "DX_GROUP", "DSM_IV_TR"], axis=1)
y = graph_merged_df["DX_GROUP"]

# Split Data into test and train
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [213]:
# Initialize basic rf
rf = RandomForestClassifier(oob_score=True)

rf.fit(X_train, y_train)
rf_cross_val = cross_val_score(rf, X_train, y_train, cv=5)

In [214]:
print(f"This model has an OOB score of {rf.oob_score_}")
print(f"The mean cross validation score was {rf_cross_val.mean()}")
print(f"The standard deviation of our cross validation was {rf_cross_val.std()}")

This model has an OOB score of 0.5226537216828478
The mean cross validation score was 0.5371754523996853
The standard deviation of our cross validation was 0.03495660489716523


### Model 3.b. Hyperparameter Tuning

## Old functions I need to make object-oriented

In [60]:
# Define function which get list of unique correlations for one subject
def get_corr_list(path, kind):
    
    '''
    Inputs: path, 
    - path: path to the extracted timecourses for one subject.
    - kind: {“covariance”, “correlation”, “partial correlation”, “tangent”, “precision”}, 
      default=’covariance’
      kind of correlation, takes all arguments nilearn.connectome.ConnectivityMeasure()
      
    Outputs: unique_corrs, connectivity_matrix
    - unique corrs: dictionnary with unique correlations between timecourses,
      keys corresponding to correlation of two timeseries 
      e.g.: key '0x4' corresponds to correlation of regions 0 and 4.
    - connectivity matrix: connectivity matrix for all timeseries for one subjects.
      can be used plotted with nilearn.plotting.plot_matrix()
    
    '''

    # read the file into a DataFrame
    df = pd.read_csv(path, sep="\t")
    df = np.array(df)
    
    # calculate corr matrix for DataFrame (all rois)
    correlation_measure = connectome.ConnectivityMeasure(
        kind=kind
    )
    
    connectivity_matrix = correlation_measure.fit_transform([df])[0]
    corrs = pd.DataFrame(connectivity_matrix)

    # create index with only unique combinations of row x column
    # this ensures we get an index which only gets us the "lower half"
    # of the correlation matrix
    rois_numbers = np.linspace(0, (len(corrs)-1), len(corrs), dtype = "int")
    comb_ind = [comb for comb in combinations(rois_numbers, 2)]

    # use new index to select only the unique combinations from our DataFrame
    # we save this as a dictionnary with the combination of values as keys
    # I.e., correlation of region 1 and region 4 is called "1x4"
    unique_corrs = {}
    for ind in comb_ind:
       unique_corrs[f"{ind[0]}x{ind[1]}"] = corrs.iloc[ind]

    # We add the name of the file as "ID" to the dictionnary.
    # This allows us to later identify the subjects when we create our
    # DataFrame with all our subjects
    unique_corrs["ID"] = path.split("/")[-1]

    return(unique_corrs, connectivity_matrix)

The below cell takes approximately ?? minutes

In [64]:
# Create DataFrame with all subjects
list_unique_corrs = []
for path in subj_paths:
    list_unique_corrs.append(get_corr_list(path, kind = "partial correlation")[0])

complete_df = pd.DataFrame(list_unique_corrs)
complete_df.to_csv("partial_corr_df.csv")

In [65]:
complete_df.to_csv("partial_corr_df.csv")

In [66]:
complete_df

Unnamed: 0,0x1,0x2,0x3,0x4,0x5,0x6,0x7,0x8,0x9,0x10,...,195x197,195x198,195x199,196x197,196x198,196x199,197x198,197x199,198x199,ID
0,0.007605,-0.093767,0.025952,-0.004101,0.029709,-0.024893,-0.056791,-0.067321,-0.047638,0.032802,...,-0.010210,-0.002754,0.012114,0.029563,0.084258,0.021982,0.031459,0.046450,0.032715,rois_cc200\CMU_a_0050649_rois_cc200.1D
1,0.020877,-0.055024,0.005723,-0.026987,-0.019686,0.032149,-0.027109,-0.013659,0.026269,-0.023326,...,-0.000000,0.012072,-0.023838,-0.000000,-0.016615,0.012068,-0.000000,0.000000,0.112876,rois_cc200\CMU_a_0050653_rois_cc200.1D
2,-0.021543,-0.002342,-0.077224,0.010351,-0.008699,-0.061259,-0.036689,-0.016293,-0.000000,0.042201,...,-0.000000,-0.168892,0.026266,-0.000000,0.054605,0.034935,0.000000,0.000000,-0.085090,rois_cc200\CMU_b_0050651_rois_cc200.1D
3,-0.087209,-0.013003,0.019987,0.007853,0.070366,-0.021554,0.001785,-0.080408,0.013837,-0.025278,...,-0.008020,0.076996,0.090542,-0.003828,0.040650,-0.056931,0.000223,-0.013338,0.026678,rois_cc200\CMU_b_0050657_rois_cc200.1D
4,0.011371,0.001493,-0.025405,-0.004347,-0.024392,-0.037107,-0.011480,-0.044027,-0.004783,-0.009429,...,0.000737,0.022216,0.044547,-0.007328,-0.068902,-0.062551,-0.013311,-0.007835,-0.044507,rois_cc200\CMU_b_0050669_rois_cc200.1D
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
879,0.013888,-0.038484,0.008113,0.008713,-0.028483,-0.034524,0.006456,-0.012639,-0.022639,0.012105,...,0.008112,0.042715,-0.021487,0.012022,0.006185,0.008967,-0.018410,0.055779,-0.024578,rois_cc200\Yale_0050624_rois_cc200.1D
880,0.015597,0.047033,-0.028770,-0.025831,0.009842,-0.099223,0.023746,0.005844,0.016276,0.007632,...,0.000456,-0.036027,-0.006571,0.007715,-0.017228,0.054851,0.004269,0.017597,0.067953,rois_cc200\Yale_0050625_rois_cc200.1D
881,0.057704,-0.040731,0.036248,-0.028012,-0.017422,0.040367,-0.013773,-0.008708,-0.029883,0.026940,...,-0.040245,-0.028593,-0.023434,-0.007964,-0.009148,0.003982,-0.004194,0.004152,0.038866,rois_cc200\Yale_0050626_rois_cc200.1D
882,-0.022718,0.013849,0.031890,0.062709,0.005144,-0.019952,-0.065440,-0.000002,-0.011738,-0.001780,...,0.000452,0.056748,-0.067353,-0.001106,0.042089,-0.001142,0.000731,-0.000946,0.005785,rois_cc200\Yale_0050627_rois_cc200.1D


## Discussion
We were able to predict ASD diagnosis with ??% accuracy in the test data. While this points to general differences in functional connectivity between people with ASD and neurotypicals, our results cannot offer evidence for a potential diagnostic tool based on functional connectivity. However, we left several avenues unexplored, which may constitute starting points for more comprehensive analyses and comparisons to obtain both higher classification accuracies but also inferences about the neural correlates of ASD. Utilizing different graph theoretical measures might ...

## References: 

Craddock, R. C., Benhajali, B., Chu, C., Chouinard, C., Evans, E., Jakab, J., Khundrakpam, K., Lewis, L., Li, L., Milham, M., Yan, Y. & Bellec, B. (2013). The Neuro Bureau Preprocessing Initiative: open sharing of preprocessed neuroimaging data and derivatives. *Frontiers in Neuroinformatics, 7*. https://doi.org/10.3389/conf.fninf.2013.09.00041

Craddock, R. C., James, G., Holtzheimer, P. E., Hu, X. P. & Mayberg, H. S. (2011). A whole brain fMRI atlas generated via spatially constrained spectral clustering. *Human Brain Mapping, 33*(8), 1914–1928. https://doi.org/10.1002/hbm.21333

Ilioska, I., Oldehinkel, M., Llera, A., Chopra, S., Looden, T., Chauvin, R., Van Rooij, D., Floris, D. L., Tillmann, J., Moessnang, C., Banaschewski, T., Holt, R. J., Loth, E., Charman, T., Murphy, D. G., Ecker, C., Mennes, M., Beckmann, C. F., Fornito, A. & Buitelaar, J. K. (2023). Connectome-wide Mega-analysis Reveals Robust Patterns of Atypical Functional Connectivity in Autism. *Biological Psychiatry, 94*(1), 29–39. https://doi.org/10.1016/j.biopsych.2022.12.018

Lord, C., Brugha, T. S., Charman, T., Cusack, J., Dumas, G., Frazier, T., Jones, E. J. H., Jones, R. M., Pickles, A., State, M. W., Taylor, J. L. & Veenstra-VanderWeele, J. (2020). Autism spectrum disorder. *Nature Reviews. Disease Primers, 6*(1). https://doi.org/10.1038/s41572-019-0138-4 