In [None]:
from matplotlib import pyplot as plt
import matplotlib.pyplot as plt
from nilearn.signal import clean
from numpy import pi
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.cross_decomposition import PLSCanonical
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.utils import resample

In [None]:
# Initialize
PATH = '...'  # Path to the data

scaler = StandardScaler()
label_encoder = LabelEncoder()

Define functions

In [None]:
def map_divergent(X, cmap):
    colors_div = sns.color_palette(cmap, n_colors=100)
    X_min = np.min(X)
    X_max = np.max(X)
    X_scaled = ((X-X_min)/(X_max-X_min)*99).astype(int)
    clrs = []
    for c in X_scaled:
        clrs.append(colors_div[c])
    return clrs


In [45]:
class SimpleGroupedColorFunc(object):
    def __init__(self, color_to_words, default_color):
        self.word_to_color = {word: color
                              for (color, words) in color_to_words.items()
                              for word in words}

        self.default_color = default_color

    def __call__(self, word, **kwargs):
        return self.word_to_color.get(word, self.default_color)


In [None]:
def manhattan_plot(df, colors, mask=None, ylabel=None,
                   ylim=None, title=None, export_figures=False,
                   fig_name=None):
    # Find FDR Thres
    n_t = df.shape[0]
    clrs = []
    for c, cc in enumerate(df.phen_cat_num):
        clrs.append(colors[df.phen_cat_num[c]])
    sns.set_theme(style="ticks", palette="bright", font_scale=1.2,
                  font='Helvetica Neue', rc={"axes.spines.right": False,
                                             "axes.spines.top": False,
                                             "axes.spines.left": False})
    fig, ax = plt.subplots(1, figsize=(12, 6))
    plot = sns.scatterplot(data=df, x='index', y='beta',
                           hue='phen_cat_num', palette=colors,
                           linewidth=0.1, edgecolor='w', legend=False)
    # Find center of each category
    t_df = df.groupby('phen_cat_num')['index'].median()
    # Find maximum of each category
    t_dfm = np.array(df.groupby('phen_cat_num')['index'].max())  # uncomment

    # Axis labels
    if ylabel:
        plot.set_ylabel(ylabel)
    else:
        plot.set_ylabel('PLS coefficient')
    plot.set_xlabel('')
    # xticks to categories
    plot.set_xticks(t_df)
    plt.gca().set_xticklabels(df.phen_cat.drop_duplicates().values,
                              rotation=60, ha='right', fontsize=14)
    # Color sorted xticks
    for xtick, color in zip(sorted(plt.gca().get_xticklabels(), key=lambda x: float(x.get_position()[0])), colors):
        xtick.set_color(color)
    if ylim:
        plot.set_ylim(ylim)
    if title:
        plot.set_title(title)
    plt.tick_params(axis='x', bottom=False)
    [plt.axvline(x=xc, color='grey', linestyle='--', linewidth=0.5)
     for xc in t_dfm]
    plt.xlim(left=-20, right=len(df.index) + 20)
    plt.axhline(y=0, linewidth=2, color='k')
    plt.show()

In [None]:
def radar_barplot(data=None, title="", axis_range=None, label=None, color=None, width=0.25):
    # Create dataframe
    class_mean = pd.DataFrame(data=data.T, index=label, columns=['mean']).T
    # Number of variable
    categories = list(class_mean)
    N = len(categories)
    # We are going to plot the first line of the data frame.
    # But we need to repeat the first value to close the circular graph:
    values = class_mean.values.flatten().tolist()
    # Angle? (we divide the plot / number of variable)
    angles = [n / float(N) * 2 * pi for n in range(N)]
    # Initialise the spider plot
    fig = plt.figure(figsize=(16, 8))
    ax = plt.subplot(111, polar=True)
    ax.spines["polar"].set_visible(False)
    # Draw one axe per variable + add labels labels yet
    plt.xticks(angles, label, color="black", size=16)
    ax.xaxis.get_majorticklabels()[0].set_horizontalalignment("left")
    ax.xaxis.get_majorticklabels()[2].set_verticalalignment("bottom")
    ax.xaxis.get_majorticklabels()[3].set_verticalalignment("top")
    # Draw ylabels
    ax.set_rlabel_position(0)
    if axis_range is None:
        axis_range = (np.min(values), np.max(values))
    inc = (axis_range[1] - axis_range[0]) / 4
    newinc = [
        axis_range[0] + inc,
        axis_range[0] + (inc * 2),
        axis_range[0] + (inc * 3),
        0,
    ]
    plt.yticks(
        newinc, [str("{:.2f}".format(elem)) for elem in newinc],
        color="black", size=12
    )
    plt.ylim(axis_range)
    if title:
        plt.title(title)
    # Plot data
    ax.bar(angles, np.abs(values), alpha=1,
           width=width, linewidth=1, edgecolor='k',
           color=color,
           )
    ax.yaxis.zorder = 1
    plt.show()

Load data

In [None]:
# Load genetic data
df_cnv = pd.read_csv(PATH + '....csv', low_memory=False)  # Load CNV data
df_cnv = df_cnv[df_cnv['TYPE'] == 'CTRL']  # Filter for CTRL type
idx_cnv = df_cnv.SampleID.values  # Get sample IDs
df_cnv.set_index('SampleID', inplace=True)  # Set SampleID as index
df_cnv = df_cnv.loc[:, ['TYPE', 'sum_loeuf_inv', 'n_genes', 'gene_id']]  # Select relevant columns

In [None]:
# Load clean data
df_brain = pd.read_csv(PATH + '....csv', index_col=0)  # Load structural data
df_phens = pd.read_csv(PATH + '....csv', index_col=0)  # Load phenotypic data
df_cov = pd.read_csv(PATH + '....csv', index_col=0)  # Load covariate data

Subset to common subjects

In [None]:
idx_str = df_brain.index
idx_phens = df_phens.index
idx_genetic = df_cnv.index
idx_cov = df_cov.index

idx_all = list(set(idx_str) & set(idx_phens) & set(idx_genetic) & set(idx_cov))

df_cnv = df_cnv.loc[idx_all, :]
df_brain = df_brain.loc[idx_all, :]
df_phens = df_phens.loc[idx_all, :]
df_cov = df_cov.loc[idx_all, :]

Clean data for PLS

In [None]:
# Regress out covariates from the brain data
df_brain[:] = clean(df_brain.values, confounds=df_cov.loc[:, ['interview_age', 'sex', 'volume', 'scanner']].values,
                    detrend=False, standardize=False, standardize_confounds=True, extrapolate=False)

In [None]:
# Regress out covariates from the phenotypic data
df_phens[:] = clean(df_phens.values, confounds=df_cov.loc[:, ['interview_age', 'sex']].values,
                    detrend=False, standardize=False, standardize_confounds=True, extrapolate=False)

In [None]:
# Z-score data
df_phens[:] = scaler.fit_transform(df_phens.values)
df_brain[:] = scaler.fit_transform(df_brain.values)
x_ctrl_ss = df_brain.values.astype('float')
y_ctrl_ss = df_phens.values.astype('float')

In [None]:
# Reduce data with PCA
pca_x = PCA(n_components=50)
pca_y = PCA(n_components=50)

x_ctrl_pca = pca_x.fit_transform(x_ctrl_ss)
y_ctrl_pca = pca_y.fit_transform(y_ctrl_ss)

PLS machinery

In [None]:
# Run PLS model
ndim = 3  # Number of components
pls = PLSCanonical(n_components=ndim, scale=False, max_iter=10000)  # Initialize PLS model
pls.fit(x_ctrl_pca, y_ctrl_pca)  # Fit PLS model

# PLS parameters
x_rotations = pls.x_rotations_  # Get x rotations
y_rotations = pls.y_rotations_  # Get y rotations

# Inverse transform to original space from PCA space
x_rotations = pca_x.inverse_transform(x_rotations.T).T  # Inverse transform x rotations
y_rotations = pca_y.inverse_transform(y_rotations.T).T  # Inverse transform y rotations 

# Calculate scores
x_score_ctrl = np.dot(x_ctrl_ss, x_rotations)  # Calculate x scores
y_score_ctrl = np.dot(y_ctrl_ss, y_rotations)  # Calculate y scores

# Calculate loadings as correlation coefficients between original data and scores
x_loadings = np.zeros((np.shape(x_ctrl_ss)[1], ndim))
y_loadings = np.zeros((np.shape(y_ctrl_ss)[1], ndim))
for i in range(ndim):
    x_loadings[:, i] = np.array([np.corrcoef(x_ctrl_ss[:, j], x_score_ctrl[:, i])[
                                0, 1] for j in range(np.shape(x_ctrl_ss)[1])])
    y_loadings[:, i] = np.array([np.corrcoef(y_ctrl_ss[:, j], y_score_ctrl[:, i])[
                                0, 1] for j in range(np.shape(y_ctrl_ss)[1])])

Boostrap test

In [None]:
# Bootrap loadings
nboot = 1000  # Number of bootstrap iterations
pls_boot = PLSCanonical(n_components=ndim+3, scale=False, max_iter=10000)  # Initialize PLS model for bootstrapping
x_loadings_boot = []
y_loadings_boot = []
i = 0
cct = 0
ccq = 0
while len(x_loadings_boot) < nboot:  # Continue until we have enough bootstraps
    i = i + 1
    # Draw bootstrap samples with replacement
    x_boot_pca = resample(x_ctrl_pca, replace=True, n_samples=len(x_ctrl_ss),
                          random_state=i)
    y_boot_pca = resample(y_ctrl_pca, replace=True, n_samples=len(y_ctrl_pca),
                          random_state=i)
    pls_boot.fit(x_boot_pca, y_boot_pca)

    x_rotations_boot = pca_x.inverse_transform(pls_boot.x_rotations_.T).T
    y_rotations_boot = pca_y.inverse_transform(pls_boot.y_rotations_.T).T

    # Calculate scores
    x_score_boot = np.dot(x_ctrl_ss, x_rotations_boot)
    y_score_boot = np.dot(y_ctrl_ss, y_rotations_boot)

    # Loadings
    tmp_x_loadings_boot = np.zeros((np.shape(x_ctrl_ss)[1], ndim+2))
    tmp_y_loadings_boot = np.zeros((np.shape(y_ctrl_ss)[1], ndim+2))
    for j in range(ndim+2):
        tmp_x_loadings_boot[:, j] = np.array([np.corrcoef(x_ctrl_ss[:, k],
                                                          x_score_boot[:, j])[0, 1] for k in range(np.shape(x_ctrl_ss)[1])])
        tmp_y_loadings_boot[:, j] = np.array([np.corrcoef(y_ctrl_ss[:, k],
                                                          y_score_boot[:, j])[0, 1] for k in range(np.shape(y_ctrl_ss)[1])])

    # Check correspondence between PLS dimensions
    # Find the best correspondence between the original and bootstrapped loadings
    # by maximizing the correlation between them
    order = []
    flag = []
    max_corr = []
    for j in range(ndim):
        tmp_corr = []
        for k in range(ndim+2):
             tmp_corr.append(np.corrcoef(x_loadings[:, j], tmp_x_loadings_boot[:, k])[0, 1])
        idx = np.argmax(np.abs(tmp_corr))
        order.append(idx)
        flag.append(np.sign(tmp_corr[idx]))
        max_corr.append(np.max(np.abs(tmp_corr)))
    if np.unique(order).shape[0] != len(order):  # Check if all dimensions are unique
        ccq = ccq + 1
        continue
    if np.min(max_corr) < 0.5:  # Check if the maximum correlation is above a threshold
        cct = cct + 1
        continue  
    tmp_x_loadings_boot = tmp_x_loadings_boot[:, order]
    tmp_y_loadings_boot = tmp_y_loadings_boot[:, order]
    # Flip loadings
    for j in range(ndim):
        tmp_x_loadings_boot[:, j] = flag[j]*tmp_x_loadings_boot[:, j]
        tmp_y_loadings_boot[:, j] = flag[j]*tmp_y_loadings_boot[:, j]

    x_loadings_boot.append(tmp_x_loadings_boot)
    y_loadings_boot.append(tmp_y_loadings_boot)
x_loadings_boot = np.array(x_loadings_boot)
y_loadings_boot = np.array(y_loadings_boot)


In [None]:
# Boostrap significance test
x_loadings_sig = []
y_loadings_sig = []
x_loadings_low = np.zeros((np.shape(x_ctrl_ss)[1], ndim))
x_loadings_high = np.zeros((np.shape(x_ctrl_ss)[1], ndim))
y_loadings_low = np.zeros((np.shape(y_ctrl_ss)[1], ndim))
y_loadings_high = np.zeros((np.shape(y_ctrl_ss)[1], ndim))
# Calculate the 95% confidence intervals for the loadings
# and determine if the loadings are significantly different from zero
for i in range(ndim):  
    low, high = np.percentile(x_loadings_boot[:, :, i], [2.5, 97.5], axis=0)
    x_loadings_sig.append(((np.sign(high) * np.sign(low)) > 0)*1)
    x_loadings_low[:, i] = low
    x_loadings_high[:, i] = high

    low, high = np.percentile(y_loadings_boot[:, :, i], [2.5, 97.5], axis=0)
    y_loadings_sig.append(((np.sign(high) * np.sign(low)) > 0)*1)
    y_loadings_low[:, i] = low
    y_loadings_high[:, i] = high

x_loadings_sig = np.array(x_loadings_sig).T
y_loadings_sig = np.array(y_loadings_sig).T

PheWAS loadings

In [None]:
for i in range(ndim):
    df_ph = pd.DataFrame(data=y_loadings[:, i], columns=['beta'], index=range(np.shape(y_loadings)[0]))
    df_ph['phen_name'] = df_phens.phen_name
    df_ph['phen_code'] = df_phens.instrument_x
    df_ph['phen_cat'] = df_phens.phen_cat
    df_ph['phen_cat_num'] = label_encoder.fit_transform(df_ph['phen_cat'])
    df_ph['abs beta'] = np.abs(df_ph['beta'])

    df_ph.reset_index(inplace=True, drop=False)

    # Plot PheWAS
    manhattan_plot(df_ph,...)
    # Average phewas by category
    df = df_ph.loc[:, ['phen_cat', 'abs beta']].groupby(['phen_cat']).mean()

    radar_barplot(df['abs beta'].values, label=df.index,
                 axis_range=(np.min(df['abs beta'].values)-0.01,
                             np.max(df['abs beta'].values)+0.01))


Plot the highest brain loadings

In [None]:
c2 = '#73818C'
c1 = '#9F5358'
for i in range(ndim):
    mean_reg = []
    # Average loadings for left and right hemisphere
    for j in range(int(len(region_names)/2)):
        mean_reg.append(np.mean([x_loadings[j, i], x_loadings[j+74, i]]))
    order = np.tile(np.argsort(-np.abs(mean_reg)), 2)

    # Create dataframe for plotting
    df = pd.DataFrame({'region': np.array(my_regions_hem)[order],
                       'data': x_loadings[:, i],
                       'data_abs': np.abs(x_loadings[:, i]),
                       'hem': np.concatenate((np.repeat('L', 74),
                                              np.repeat('R', 74)))})
    
    # Plot horizontal barplot with bootstrapped confidence intervals
    fig, ax = plt.subplots(1, figsize=(4, 8))
    for j in range(int(len(region_names)/2)):
        plt.barh(0+j, height=0.3, width=x_loadings[order[j], i], color=c1)
        plt.barh(0.4+j, height=0.3, width=x_loadings[order[j]+74, i], color=c2)
        # Plot error
        plt.plot([x_loadings_low[order[j], i],
                  x_loadings_high[order[j], i]],
                 [j, j], color=c1,)
        plt.plot([x_loadings_low[order[j]+74, i],
                  x_loadings_high[order[j]+74, i]],
                 [j+0.4, j+0.4], color=c2)

    ax.set_title('', fontsize=22)
    ax.set_xlabel('Brain loading (95% CI)')
    ax.set_ylabel('')
    ax.set_yticks(np.arange(0.2, 74.2, 1))
    ax.set_yticklabels(np.array(my_regions)[order[:74]])
    plt.legend(['Left', 'Right'], loc='upper right')
    plt.axvline(x=0, color='black', linestyle='--', linewidth=1)
    plt.ylim((20.8, -0.2))
    plt.show()