# Code for the functional neuroimaging schizophrenia classification case study

## Load necessary python and R packages

In [1]:
import pandas as pd
import numpy as np

from sklearn import svm
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import RepeatedStratifiedKFold, cross_validate
from sklearn.metrics import balanced_accuracy_score,make_scorer

%load_ext rpy2.ipython

In [2]:
%%R 

library(correctR)
library(cowplot)
library(ggpubr)
library(tidyverse)

# Set cowplot theme
theme_set(theme_cowplot())

── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ lubridate 1.9.3     ✔ tibble    3.2.1
✔ purrr     1.0.2     ✔ tidyr     1.3.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter()    masks stats::filter()
✖ dplyr::lag()       masks stats::lag()
✖ lubridate::stamp() masks cowplot::stamp()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors


This is version 0.2.1 of correctR. All functions now use two-tailed hypothesis tests by default instead of one-tailed.
One-tailed tests can be manually specified through the new 'tailed' and 'greater' arguments.
Please consult the help files and vignette for more information.
Loading required package: ggplot2

Attaching package: ‘ggpubr’

The following object is masked from ‘package:cowplot’:

    get_legend



## Load in the FTM and catch22 time-series feature for participants:

In [44]:
# Load in metadata
sample_metadata = pd.read_csv("../fMRI_analysis_data/fMRI_analysis_metadata.csv")

# Load in FTM+catch22 feature data
SCZ_FTM_plus_catch22_features = pd.read_csv("../fMRI_analysis_data/fMRI_time_series_features.csv").merge(sample_metadata)

# Subset catch22 data down to FTM
SCZ_FTM_features = SCZ_FTM_plus_catch22_features.query("names in ['DN_Mean', 'DN_Spread_Std']").assign(feature_set = "FTM")

## Define our classification pipeline and input datasets:

In [45]:
pipe = Pipeline([('scaler', StandardScaler()), # Z-score normalisation
                ('model', svm.SVC(kernel="linear", C=1, class_weight="balanced"))]) # Linear SVM with regularization parameter C=1 and balanced class weights

# Define the 10-repeat 10-fold cross-validation scheme
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=10, random_state=127)

# Define the scoring metric as balanced accuracy
scorers = [make_scorer(balanced_accuracy_score)]
scoring_names = ["Balanced_Accuracy"]

# Define input matrices for FTM and FTM+catch22
X_FTM = (SCZ_FTM_features
         .assign(Region_Feature = lambda x: x.names + "__" + x.Brain_Region)
         .pivot(index=["Sample_ID", "Diagnosis"], columns="Region_Feature", values="values"))
X_index = X_FTM.index.to_frame().reset_index(drop=True)
y = [int(i=="SCZ") for i in X_index.Diagnosis.tolist()]
X_FTM = X_FTM.reset_index(drop=True).to_numpy()

X_FTM_plus_catch22 = (SCZ_FTM_plus_catch22_features
                        .assign(Region_Feature = lambda x: x.names + "__" + x.Brain_Region)
                        .pivot(index=["Sample_ID", "Diagnosis"], columns="Region_Feature", values="values"))
X_FTM_plus_catch22 = X_FTM_plus_catch22.reset_index(drop=True).to_numpy()

## Run the linear classifier for just the FTM first:

In [54]:
FTM_classification_results = pd.DataFrame({"Balanced_Accuracy": cross_validate(pipe,
                                            X_FTM,
                                            y,
                                            cv=cv,
                                            scoring="balanced_accuracy",
                                            n_jobs=1)["test_score"]})
FTM_classification_results["Fold_Number"] = FTM_classification_results.index % 10
FTM_classification_results["Repeat_Number"] = FTM_classification_results.index // 10
FTM_classification_results["feature_set"] = "FTM"
FTM_classification_results.head()

Unnamed: 0,Balanced_Accuracy,Fold_Number,Repeat_Number,feature_set
0,0.816667,0,0,FTM
1,0.616667,1,0,FTM
2,0.675,2,0,FTM
3,0.675,3,0,FTM
4,0.791667,4,0,FTM


## Then run for the FTM plus catch22 feature set:

In [55]:
FTM_plus_catch22_classification_results = pd.DataFrame({"Balanced_Accuracy": cross_validate(pipe,
                                            X_FTM_plus_catch22,
                                            y,
                                            cv=cv,
                                            scoring="balanced_accuracy",
                                            n_jobs=1)["test_score"]})

# Assign fold and repeat numbers
FTM_plus_catch22_classification_results["Fold_Number"] = FTM_plus_catch22_classification_results.index % 10
FTM_plus_catch22_classification_results["Repeat_Number"] = FTM_plus_catch22_classification_results.index // 10
FTM_plus_catch22_classification_results["feature_set"] = "FTM+catch22"
FTM_plus_catch22_classification_results.head()


Unnamed: 0,Balanced_Accuracy,Fold_Number,Repeat_Number,feature_set
0,0.6,0,0,FTM+catch22
1,0.658333,1,0,FTM+catch22
2,0.658333,2,0,FTM+catch22
3,0.675,3,0,FTM+catch22
4,0.666667,4,0,FTM+catch22


In [56]:
# Join the two datasets
all_classification_results = pd.concat([FTM_classification_results, FTM_plus_catch22_classification_results])

# Take average across repeats to match other time-series case study classification problems
all_classification_results = (all_classification_results
                              .groupby(["feature_set", "Repeat_Number"])
                              .mean()
                              .drop(columns=["Fold_Number"])
                              .reset_index())

# Save to a CSV
all_classification_results.to_csv("../fMRI_analysis_data/fMRI_classification_results.csv", index=False)

In [57]:
# Group by feature_set and find the mean and SD for Balanced_Accuracy

all_classification_results.groupby("feature_set").agg({"Balanced_Accuracy": ["mean", "std"]}).assign(Balanced_Accuracy = lambda x: x.Balanced_Accuracy.apply(lambda x: round(100*x, 1)))

Unnamed: 0_level_0,Balanced_Accuracy,Balanced_Accuracy
Unnamed: 0_level_1,mean,std
feature_set,Unnamed: 1_level_2,Unnamed: 2_level_2
FTM,70.0,4.2
FTM+catch22,65.2,1.7


Now, we can compare the performance of FTM versus FTM+catch22 with a two-tailed corrected resampled T-test:

In [58]:
%%R -i all_classification_results,SCZ_metadata -o data_for_correctR

num_samples <- length(unique(SCZ_metadata$Sample_ID))
training_size <- ceiling(0.9*num_samples)
test_size <- floor(0.1*num_samples)

data_for_correctR <- all_classification_results %>%
  pivot_wider(id_cols = c(Repeat_Number), 
              names_from = feature_set,
              values_from = Balanced_Accuracy) %>%
  dplyr::rename("x" = "FTM", "y" = "FTM+catch22")

resampled_ttest(x=data_for_correctR$x, y=data_for_correctR$y, n=10, n1=training_size, n2=test_size)


  statistic    p.value
1  2.656313 0.02620366


## Visualize the repeat-wise balanced accuracy distributions for the two feature sets:

In [13]:
%%R -i all_classification_results,SCZ_metadata

# Function to specify number of decimal points in axis text
scaleFUN <- function(x) sprintf("%.0f", x)

all_classification_results %>%
mutate(Balanced_Accuracy = 100*Balanced_Accuracy) %>%
  ggplot(data=.) +
  geom_boxplot(mapping=aes(x = feature_set, 
                           y = Balanced_Accuracy,
                           color = feature_set),
               fill=NA, outlier.shape=NA) +
  geom_point(aes(x = feature_set, 
                 y = Balanced_Accuracy,
                 color = feature_set), alpha=0.7, size=1.25) +
  geom_line(aes(x = feature_set, 
                y = Balanced_Accuracy,
                group = Repeat_Number), alpha=0.3) +
  geom_bracket(xmin = "FTM", xmax = "FTM+catch22", y.position = 75.5, 
               label = "**", label.size = 9) +
  scale_color_brewer(palette = "Dark2") +
  
  ylab("Resample Balanced Accuracy (%)") +
  scale_y_continuous(expand=c(0.05,0,0.1,0),
                     labels = scaleFUN) +
  xlab("Feature Set") +
  theme(legend.position = "none",
        axis.text = element_text(size=14),
        axis.title = element_text(size=16),
        plot.title = element_text(hjust=0.5, size=13))

ggsave("../output/Schizophrenia_BalAcc_Combo_Feature_Set.png", width=4, height=4.25, units="in", dpi=300)