# BRAIN-BRIDGE
## Let's explore individual structural differences among the Human Connectome Project dataset! How can we ~bridge~ one brain with the rest?
***

In [14]:
import pandas as pd
import numpy as np
from nilearn import datasets, plotting 
import ipywidgets as widgets
from IPython.display import display, Markdown
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score

struct_df = pd.read_csv('data/all_structural_metrics_HCP.csv')
#display(Markdown(f'## There are {struct_df.shape[0]} subjects in the HCP dataset.'))
display(Markdown(f'## We extracted all FreeSurfer parcellation metrics for the 74 regions of the Destrieux Atlas for {struct_df.shape[0]} subjects.'))
display(Markdown('### For each region, we will examine:'))
display(Markdown('####  1. Gray Matter Volume'))
display(Markdown('####  2. Average Cortical Thickness'))
display(Markdown('####  3. Surface Area'))
display(Markdown('####  4. Mean Curvature'))

## We extracted all FreeSurfer parcellation metrics for the 74 regions of the Destrieux Atlas for 1042 subjects.

### For each region, we will examine:

####  1. Gray Matter Volume

####  2. Average Cortical Thickness

####  3. Surface Area

####  4. Mean Curvature

In [15]:
# Create individual dataframes for each measure
thick_df = struct_df.filter(regex=('ThickAvg$'))
thick_df.insert(0, 'id_number', struct_df['id_number'])
vol_df = struct_df.filter(regex='GrayVol$')
vol_df.insert(0, 'id_number', struct_df['id_number'])
area_df = struct_df.filter(regex='SurfArea$')
area_df.insert(0, 'id_number', struct_df['id_number'])
numvert_df = struct_df.filter(regex='NumVert$')
numvert_df.insert(0, 'id_number', struct_df['id_number'])
meancurv_df = struct_df.filter(regex='MeanCurv$')
meancurv_df.insert(0, 'id_number', struct_df['id_number']) 
curvind_df = struct_df.filter(regex='CurvInd$')
curvind_df.insert(0, 'id_number', struct_df['id_number']) 
foldind_df = struct_df.filter(regex='FoldInd$')
foldind_df.insert(0, 'id_number', struct_df['id_number']) 
gauscurv_df = struct_df.filter(regex='GausCurv$')
gauscurv_df.insert(0, 'id_number', struct_df['id_number'])
thickstd_df = struct_df.filter(regex='ThickStd$')
thickstd_df.insert(0, 'id_number', struct_df['id_number'])

In [16]:
out_brain = widgets.Output()
out_table = widgets.Output()
out_bar = widgets.Output()
out_plot = widgets.Output()


def zscore(id_number, dtype):
    if dtype == 'Gray Matter Volume':
        df = vol_df
    elif dtype == 'Surface Area':
        df = area_df
    elif dtype == 'Average Cortical Thickness':
        df = thick_df
    elif dtype == 'Mean Curvature':
        df = meancurv_df
    elif dtype == 'Number Of Vertices':
        df = numvert_df
    elif dtype == 'Curvature Index':
        df = curvind_df
    elif dtype == 'Folding Index':
        df = foldind_df
    elif dtype == 'Gaussian Curvature':
        df = meancurv_df
    elif dtype == 'Thickness Std Dev':
        df = thickstd_df
        
    patient_df = df[df['id_number'] == id_number].drop(columns=['id_number'])
    control_df = df[df['id_number'] != id_number].drop(columns=['id_number'])
    control_means = np.mean(control_df, axis=0)
    control_stds = np.std(control_df, axis=0)
    return (patient_df - control_means) / control_stds

def plot_brain(z_data, dtype):
    # The left and right hemisphere views will be stored here
    left_hemi_widget = None
    right_hemi_widget = None

    # 1. Determine the global min and max values across both hemispheres for consistent color scaling
    global_vmin = min(z_data.min())
    global_vmax = max(z_data.max())

    for hemi in ['left', 'right']:
        z_data_tmp = z_data.copy()

        # Fetching Destrieux atlas and fsaverage
        destrieux_atlas = datasets.fetch_atlas_surf_destrieux()
        fsaverage = datasets.fetch_surf_fsaverage()

        # Filter z_data to only include variables starting with lh_ or rh_
        if hemi == 'left':
            z_data_tmp = z_data_tmp.filter(regex='^lh_')
            z_data_tmp.columns = z_data_tmp.columns.str.replace('lh_', '')
        else:
            z_data_tmp = z_data_tmp.filter(regex='^rh_')
            z_data_tmp.columns = z_data_tmp.columns.str.replace('rh_', '')

        # remove the ending of columns after last underscore
        z_data_tmp.columns = z_data_tmp.columns.str.rsplit(pat='_', n=1).str[0]
        
        if 'Lat_Fis-post' not in z_data_tmp.columns:
            # Handle the case when 'Lat_Fis-post' doesn't exist. Raise an error.
            raise ValueError('Lat_Fis-post does not exist. I must place a Medial_wall column in the dataframe after this column to match up with the atlas labels.')

        loc = z_data_tmp.columns.get_loc('Lat_Fis-post')
        z_data_tmp.insert(loc+1, 'Medial_wall', 0)
        z_data_tmp.insert(0, 'Unknown', 0)

        # Zip the values of z_data to the labels of the Destrieux atlas
        region_values = dict(zip(destrieux_atlas['labels'], z_data_tmp.values[0]))

        # Initialize an array with zeros
        mapped_values = np.zeros_like(destrieux_atlas[f'map_{hemi}'], dtype=float)

        # Populate the mapped_values array using regi)on_values dictionary
        for label, value in region_values.items():
            region_idx = destrieux_atlas['labels'].index(label)
            mapped_values[destrieux_atlas[f'map_{hemi}'] == region_idx] = value

        fsavg = fsaverage.infl_left if hemi == 'left' else fsaverage.infl_right

        # Plotting
        
        # For right hemisphere, set the view to "Right"
        hemi_view = "right" if hemi == "right" else "left"

        view = plotting.view_surf(fsavg, mapped_values,
                                cmap='coolwarm', symmetric_cmap=True,
                                vmax=np.max([np.abs(global_vmin), np.abs(global_vmax)]))
                                #view=hemi_view)

        widget = widgets.HTML(view.get_iframe())#.decode())

        if hemi == 'left':
            left_hemi_widget = widget
        else:
            right_hemi_widget = widget

        # Display the widgets side by side  
    with out_brain:
        display(widgets.HBox([left_hemi_widget, right_hemi_widget]))

def plot_bar_for_thresholded_regions(z_data, dtype, thresh):

    prominent_regions = [col for col in z_data.columns if z_data[col].abs().mean() > thresh]

    # Sort in descending order
    prominent_regions = sorted(prominent_regions, key=lambda x: z_data[x].mean(), reverse=True)

    # set color to be red if z-score is positive, blue if negative
    # color_map = lambda z: 'red' if z > 0 else 'blue'
    fig = plt.figure(figsize=(10, 5))
    plt.bar(prominent_regions, z_data[prominent_regions].values[0], color=(0.2, 0.4, 0.6, 0.6))
    plt.ylabel('Z-Score')
    plt.title(f'Regions with Abs(Z-Score) > {thresh}')
    plt.xticks(rotation=90)
    # add a horizontal line at z=0
    plt.axhline(y=0, color='black', linestyle='--')
    plt.tight_layout()
    plt.close(fig)
    # make the table into a widget object
    # plot_widget = widgets.HTML(value=plt.gcf().get_iframe().decode())

    # with out_table:
    #     # plt.show()
    #     # display in a horizontal box with the table
    #     display(widgets.HBox([plot_widget]))
    # convert figure into a widget object

    return fig


def create_plot(id_number, dtype, region):
    if dtype == 'Gray Matter Volume':
        df = vol_df
    elif dtype == 'Surface Area':
        df = area_df
    elif dtype == 'Average Cortical Thickness':
        df = thick_df
    elif dtype == 'Mean Curvature':
        df = meancurv_df
    elif dtype == 'Number Of Vertices':
        df = numvert_df
    elif dtype == 'Curvature Index':
        df = curvind_df
    elif dtype == 'Folding Index':
        df = foldind_df
    elif dtype == 'Gaussian Curvature':
        df = meancurv_df
    elif dtype == 'Thickness Std Dev':
        df = thickstd_df

    region_data = df[region]
    
    subject_data = df[df['id_number'] == id_number][region].values[0]
    
    # initialize subplots but make the right plot wider than the left
    fig, ax = plt.subplots(1, 2, figsize=(14, 7), gridspec_kw={'width_ratios': [2, 3]})

    # Scatter plot with jitter
    sns.boxplot(y=region_data, ax=ax[0], color='lightgray', showfliers=False)
    sns.stripplot(y=region_data, jitter=0.3, size=3, ax=ax[0], alpha=0.6)
    ax[0].scatter(x=0, y=subject_data, color='red', s=50, label=f'Subject {id_number}: Val={subject_data:.2f}')
    ax[0].set_title(f'Distribution of {region}')
    ax[0].set_ylabel(dtype)
    ax[0].set_xticks([])  # Hide x-axis ticks as they are not necessary in this context.
    ax[0].set_xlabel('Subjects')
    ax[0].legend()

    # Distribution with z-score
    sns.kdeplot(region_data, ax=ax[1], shade=True)
    z_val = (subject_data - region_data.mean()) / region_data.std()
    ax[1].axvline(x=subject_data, color='r', linestyle='--', label=f'Subject {id_number}: Z={z_val:.2f}')
    ax[1].set_title(f'Z-Score Distribution for {region}')
    ax[1].set_xlabel(dtype)
    ax[1].legend()
    
    with out_plot:
        out_plot.clear_output(wait=True)
        plt.tight_layout()
        plt.show()

def create_interactive_table(id_number, z_data, dtype, thresh):
    prominent_regions = [col for col in z_data.columns if z_data[col].abs().mean() > thresh]
    # Sort in descending order
    # prominent_regions.sort(key=lambda x: z_data[x].abs().mean(), reverse=True)
    prominent_regions = sorted(prominent_regions, key=lambda x: z_data[x].mean(), reverse=True)

    region_selector = widgets.Select(options=prominent_regions, description='Region:', rows=25)
    # make region_selector wider
    region_selector.layout.width = '400px'
  
    def on_region_selected(change):
        region = change['new']
        create_plot(id_number, dtype, region)

    region_selector.observe(on_region_selected, names='value')

    barplot = plot_bar_for_thresholded_regions(z_data, dtype, thresh)

    with out_bar:
        out_bar.clear_output(wait=True)
        display(barplot)

    with out_table:
        out_table.clear_output(wait=True)
        display(widgets.HBox([region_selector, out_bar]))

    # Trigger an initial plot when the table is first created.
    create_plot(id_number, dtype, prominent_regions[0])

def submit_id(b):
    out_brain.clear_output(wait=True)
    # Fetch the current selected value from the radio buttons inside the function
    dtype = data_type.value

    # get the id number from the input cell
    id_number = id.value.strip()
    # if the id number is not a number, raise an error
    if not id_number.isdigit():
        raise ValueError('Please enter a number for the ID number.')
    
    id_number = np.int64(id_number)

    z_data = zscore(id_number, dtype)
    plot_brain(z_data, dtype)

def submit_thresh(b):
    out_table.clear_output(wait=True)
    out_bar.clear_output(wait=True)
    out_plot.clear_output(wait=True)

    # Fetch the current selected value from the radio buttons inside the function
    dtype = data_type.value

    # get the id number from the input cell
    id_number = id.value.strip()
    # if the id number is not a number, raise an error
    if not id_number.isdigit():
        raise ValueError('Please enter a number for the ID number.')
    
    id_number = np.int64(id_number)
    thresh_value = thresh.value
    thresh_value = np.float64(thresh_value)

    z_data = zscore(id_number, dtype)

    # plot_bar_for_thresholded_regions(z_data, dtype, thresh_value)
    #interactive_plot_output = create_interactive_table(id_number, z_data, dtype, thresh_value)
    #display(interactive_plot_output)
    create_interactive_table(id_number, z_data, dtype, thresh_value)

# Add an input cell to enter an id_number
id = widgets.Text(
    value='',
    placeholder='Enter ID number',
    description='ID Number:',
    disabled=False  
    # make it larger

)

# Add an input cell for the threshold
thresh = widgets.FloatText(
    value=2,
    description='Z-Score Threshold:',
    disabled=False
)

# Add radio buttons to select the data type
data_type = widgets.RadioButtons(
    options=['Gray Matter Volume', 'Surface Area', 'Average Cortical Thickness', 'Mean Curvature'],
    description='Structural Metric:',
    disabled=False,
    value='Gray Matter Volume'
)

# create a submit button for the threshold that will update the table and plot
thresh_button = widgets.Button(description="Submit")
thresh_button.on_click(submit_thresh)  # Bind the button click to your function

# Create the submit button
submit_button = widgets.Button(description="Submit")
submit_button.on_click(submit_id)  # Bind the button click to your function


id.layout.width = '300px'
thresh.layout.width = '300x'
# make more space for the text in thresh
id.style.description_width = '70px'
thresh.style.description_width = '120px'

box_layout = widgets.Layout(display='flex',
                flex_flow='column',
                align_items='flex-start',
                width='100%')

# group_id = widgets.HBox([id, submit_button], layout=box_layout)
group_thresh = widgets.HBox([thresh, thresh_button], layout=box_layout)

display(Markdown('## Enter a subject ID number and select a structural metric to plot the z-scores on the brain.'))
# Display widgets for Z-score brains
display(widgets.VBox([id, data_type, submit_button, out_brain]))

display(Markdown('## Enter a z-score threshold to plot the regions with absolute value of z-scores above the threshold.'))
# Display widgets for threshold/plots
display(widgets.VBox([group_thresh, out_table, out_plot]))

# # Display widgets for Z-score brains
# display(id)
# display(data_type)
# display(submit_button)
# display(out_brain)

# #Display widgets for threshold/plots
# display(thresh)
# display(thresh_button)
# display(out_table)
# display(out_plot)

## Enter a subject ID number and select a structural metric to plot the z-scores on the brain.

VBox(children=(Text(value='', description='ID Number:', layout=Layout(width='300px'), placeholder='Enter ID nu…

## Enter a z-score threshold to plot the regions with absolute value of z-scores above the threshold.

VBox(children=(HBox(children=(FloatText(value=2.0, description='Z-Score Threshold:', layout=Layout(width='300x…

In [17]:
# Get the data for the regression model

#import cognitive data metrics
cogdata = pd.read_csv('data/cogtotals.csv')

#rename column to match imaging data file
cogdata.rename(columns={"Subject": "id"}, inplace = True)
cogdata.head()

# delete ids without cog score - assuming missing data is represented by NaN values
missing_data_rows = cogdata[cogdata.isnull().any(axis=1)]
clean_cog_data = cogdata.dropna()
clean_cog_data.to_csv('data/clean_cog_data.csv', index=False)
clean_cog_data = pd.read_csv('data/clean_cog_data.csv')
clean_cog_data.head()

#import ids for subjects with imaging metrics
imagingdata = pd.read_csv('data/freesurfer_diffusion_HCP.csv')

#find overlap in subjects
set1 = (clean_cog_data['id']).to_list()
set2 = (imagingdata['id']).to_list()
new_list = [value for value in set1 if value in set2]
df = pd.DataFrame(new_list)

#pull cog data for those that have overlapping structural data
selected_cog_data = clean_cog_data[clean_cog_data['id'].isin(new_list)]
#pull structural data for those that have overlapping cog data
structuraldata = pd.read_csv('data/all_structural_metrics_HCP.csv')
structuraldata.rename(columns={'id_number': 'id'}, inplace=True)
selected_struc_data = structuraldata[structuraldata['id'].isin(new_list)]
selected_struc_data = selected_struc_data.dropna()

#GRAY MATTER VOLUME
grayvol_columns = selected_struc_data.filter(like='GrayVol', axis=1)
grayvol_columns.insert(0, 'id', structuraldata['id'])
grayvol_columns['id'] = grayvol_columns['id'].astype(int)
grayvol_df = grayvol_columns.sort_values(by='id')
cognitive_gray_vol_data = pd.merge(selected_cog_data, grayvol_df)

#SURFACE AREA
surf_columns = selected_struc_data.filter(like='SurfArea', axis=1)
surf_columns.insert(0, 'id', structuraldata['id'])
surf_columns['id'] = surf_columns['id'].astype(int)
surf_df = surf_columns.sort_values(by='id')
cognitive_surf_data = pd.merge(selected_cog_data, surf_df)

#THICKNESS
thick_columns = selected_struc_data.filter(like='ThickAvg', axis=1)
thick_columns.insert(0, 'id', structuraldata['id'])
thick_columns['id'] = thick_columns['id'].astype(int)
thick_df = thick_columns.sort_values(by='id')

#merge cognitive data and surface area data
cognitive_thick_data = pd.merge(selected_cog_data, thick_df)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  grayvol_columns['id'] = grayvol_columns['id'].astype(int)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  surf_columns['id'] = surf_columns['id'].astype(int)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  thick_columns['id'] = thick_columns['id'].astype(int)


## Let's use this brain data to predict cognitive performance. We will try to predict the NIH Toolbox Cognitive Functioning Battery composite score adjusted for age.

In [18]:
out_plot_regression = widgets.Output()

def run_regression(dtype):
    if dtype == 'Volume':
        metrics_df = grayvol_df
        cog_df = cognitive_gray_vol_data
    elif dtype == 'Area':
        metrics_df = surf_df
        cog_df = cognitive_surf_data
    elif dtype == 'Thickness':
        metrics_df = thick_df
        cog_df = cognitive_thick_data

    X = metrics_df.drop(columns=['id'])
    y = cog_df['CogTotalComp_AgeAdj']

    (X_train, X_test, y_train, y_test)= train_test_split(X, y, test_size = 0.3, random_state = 42)

    ## Standardize the features using StandardScaler
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Initialize a regression model (e.g., Linear Regression)
    model = LinearRegression()

    # Train the model on the training data
    model.fit(X_train_scaled, y_train)

    # Predict cognitive scores on the test data
    y_pred = model.predict(X_test_scaled)
    #plot_regression(y_test, y_pred)
    return y_test, y_pred

def plot_regression(y_test, y_pred):
    # initialize subplots but make the right plot wider than the left
    fig, ax = plt.subplots(1, 2, figsize=(14, 7), gridspec_kw={'width_ratios': [1, 1]})

    # Plot the actual vs. predicted cognitive scores
    # plt.figure(figsize=(10,6))
    ax[0].scatter(y_test, y_pred, alpha=0.6)
    ax[0].set_title('Actual vs. Predicted Cognitive Scores')
    ax[0].set_xlabel('Actual Scores')
    ax[0].set_ylabel('Predicted Scores')

    # Plot the line of best fit
    slope, intercept = np.polyfit(y_test, y_pred, 1)
    ax[0].plot(y_test, slope*y_test + intercept, color='red', linewidth=2)

    # Annotate with MSE and R2 values
    mse = mean_squared_error(y_test, y_pred)
    r2 = r2_score(y_test, y_pred)
    ax[0].annotate(f'MSE: {mse:.2f}', xy=(0.05, 0.95), xycoords='axes fraction')
    ax[0].annotate(f'R^2: {r2:.2f}', xy=(0.05, 0.90), xycoords='axes fraction')

    # Plot residuals vs. predicted values
    residuals = y_test - y_pred
    # plt.figure(figsize=(10,6))
    ax[1].scatter(y_pred, residuals, alpha=0.6)
    ax[1].axhline(y=0, color='red', linestyle='--')
    ax[1].set_title('Residuals vs. Predicted Cognitive Scores')
    ax[1].set_xlabel('Predicted Scores')
    ax[1].set_ylabel('Residuals')
    
    with out_plot_regression:
        # out_plot_regression.clear_output(wait=True)
        # plt.tight_layout()
        plt.show()
        
button = widgets.RadioButtons(
    options=['Volume', 'Area', 'Thickness'],
    disabled=False,
    value='Volume',
    description='Structural Metric:',
)

def submit_regression(b):
    out_plot_regression.clear_output(wait=True)
    dtype = button.value
    y_test, y_pred = run_regression(dtype)
    plot_regression(y_test, y_pred)

# Create the submit button
submit_button = widgets.Button(description="Submit")
# run submit_regression when radio button is clicked
submit_button.on_click(submit_regression)

display(button)
display(submit_button)
display(out_plot_regression)

RadioButtons(description='Structural Metric:', options=('Volume', 'Area', 'Thickness'), value='Volume')

Button(description='Submit', style=ButtonStyle())

Output()