In [1]:
from pathlib import Path

import pandas as pd
import nibabel as nib
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
ROOT = Path("../")

In [3]:
def parse_metadata(path: Path):
    # Brown/sub-0026001/ses-1/func/sub-0026001_ses-1_task-rest_run-1_bold.nii.gz
    dataset = path.parts[0]
    site = dataset.split("_")[0]  # Peking_1 -> Peking
    stem, ext = path.name.split(".", 1)
    stem, suffix = stem.rsplit("_", 1)
    meta = dict(item.split("-") for item in stem.split("_") if "-" in item)
    # set sometimes missing keys
    for k in ["task", "run", "acq"]:
        meta[k] = meta.get(k)
    meta = {"site": site, "dataset": dataset, **meta, "suffix": suffix}
    return meta


def read_header(path: Path):
    img = nib.load(path, mmap=True)
    shape = list(img.shape)
    if len(shape) == 4:
        tr = round(float(img.header["pixdim"][4]), 2)
        num_trs = shape[-1]
        dur = num_trs * tr
    else:
        tr = num_trs = dur = None
    info = {"tr": tr, "num_trs": num_trs, "dur": dur}
    return info

In [4]:
# index all the completed fmriprep data
fmriprep_root = ROOT / "data/fmriprep"
mni_path_list = sorted(
    fmriprep_root.rglob("*_space-MNI152NLin6Asym_res-2_desc-preproc_bold.nii.gz")
)
print(f"num fmriprep mni paths: {len(mni_path_list)}")

bids_index_path = Path(ROOT / "metadata/ADHD200_BIDS_index.csv")

if not bids_index_path.exists():
    records = []
    for mni_path in tqdm(mni_path_list):
        original_path = mni_path.parent / mni_path.name.replace(
            "_space-MNI152NLin6Asym_res-2_desc-preproc_bold.nii.gz",
            "_bold.nii.gz",
        )
        original_path = original_path.relative_to(fmriprep_root)
        meta = parse_metadata(original_path)

        info = read_header(mni_path)

        cifti_path = mni_path.parent / mni_path.name.replace(
            "_space-MNI152NLin6Asym_res-2_desc-preproc_bold.nii.gz",
            "_space-fsLR_den-91k_bold.dtseries.nii",
        )

        record = {
            **meta,
            **info,
            "has_mni": mni_path.exists(),
            "has_cifti": cifti_path.exists(),
            "path": str(original_path),
        }
        records.append(record)

    bids_df = pd.DataFrame.from_records(records)
    bids_df.to_csv(bids_index_path, index=False)

bids_df = pd.read_csv(bids_index_path, dtype={"sub": str})

num fmriprep mni paths: 1390


100%|██████████| 1390/1390 [00:02<00:00, 513.62it/s]


In [5]:
print(bids_df.shape)
bids_df.head(2)

(1390, 14)


Unnamed: 0,site,dataset,sub,ses,task,run,acq,suffix,tr,num_trs,dur,has_mni,has_cifti,path
0,Brown,Brown,26001,1,rest,1,,bold,2.0,251,502.0,True,True,Brown/sub-0026001/ses-1/func/sub-0026001_ses-1...
1,Brown,Brown,26002,1,rest,1,,bold,2.0,251,502.0,True,True,Brown/sub-0026002/ses-1/func/sub-0026002_ses-1...


In [6]:
# look at trs
print(bids_df.groupby(["dataset", "tr"]).agg({"path": "count"}))

                 path
dataset    tr        
Brown      2.00    26
KKI        2.50    83
NYU        2.00   435
NeuroIMAGE 1.96    73
OHSU       2.50   268
Peking_1   2.00   136
Peking_2   2.00    67
Peking_3   2.00    41
Pittsburgh 1.50    88
           3.00     8
WashU      2.50   165


In [7]:
# look at durations
# most longer than 5 min
print(bids_df.groupby(["dataset", "dur"]).agg({"path": "count"}))

                   path
dataset    dur         
Brown      502.00    26
KKI        310.00    54
           380.00    29
NYU        350.00    41
           352.00   394
NeuroIMAGE 509.60     5
           511.56    68
OHSU       147.50     1
           190.00     3
           192.50    34
           195.00   230
Peking_1   470.00    51
           472.00    85
Peking_2   472.00    67
Peking_3   472.00    41
Pittsburgh 294.00    88
           369.00     7
           588.00     1
WashU      190.00    24
           330.00   112
           332.50    29


In [8]:
# include runs with complete preprocessed outputs and at least 5 min long
bids_mask = (bids_df["has_mni"]) & (bids_df["has_cifti"]) & (bids_df["dur"] > 5 * 60)
print(f"num valid bold runs: {bids_mask.sum()} / {len(bids_mask)}")
bids_df_clean = bids_df.loc[bids_mask]

num valid bold runs: 1010 / 1390


In [9]:
# get phenotype data
# downloaded from s3://fcp-indi/data/Projects/ADHD200/RawDataBIDS/*_phenotypic.csv
pheno_paths = sorted((ROOT / "data/phenotypic").glob("*_phenotypic.csv"))
print(f"num pheno csvs: {len(pheno_paths)}")
assert len(pheno_paths) == 9

num pheno csvs: 9


In [10]:
pheno_df = pd.concat([pd.read_csv(path) for path in pheno_paths], ignore_index=True)
print(pheno_df.shape)
print(pheno_df.iloc[:3, :6])

(717, 24)
   ScanDir ID  Site  Gender    Age Handedness       DX
0     26001.0     2     1.0  16.92          1  pending
1     26002.0     2     1.0  15.68          1  pending
2     26004.0     2     0.0  14.99          1  pending


In [11]:
print(pheno_df["Site"].value_counts())
print(pheno_df["Gender"].value_counts())
print(pheno_df["DX"].value_counts())

Site
5    222
1    136
6    113
7     89
3     83
4     48
2     26
Name: count, dtype: int64
Gender
1.0    401
0.0    315
Name: count, dtype: int64
DX
0          430
1          154
3           94
pending     26
2           13
Name: count, dtype: int64


In [12]:
# remap integer ids to text labels
# https://fcon_1000.projects.nitrc.org/indi/adhd200/general/ADHD-200_PhenotypicKey.pdf

site_map = {
    1: "Peking",
    2: "Brown",
    3: "KKI",
    4: "NeuroIMAGE",
    5: "NYU",
    6: "OHSU",
    7: "Pittsburgh",
    8: "WashU",
}

gender_map = {
    0: "F",
    1: "M",
}

# merge all adhd diagnosis categories
dx_map = {
    0: "Control",  # Typically Developing Children
    1: "ADHD",  # ADHD-Combined
    2: "ADHD",  # ADHD-Hyperactive/Impulsive
    3: "ADHD",  # ADHD-Inattentive
}

In [13]:
# clean up the phenotype table
pheno_df_clean = pheno_df.loc[:, ["ScanDir ID", "Site", "Gender", "Age", "DX"]].copy()
pheno_df_clean = pheno_df_clean.dropna()

# drop subs with "pending" dx (idk what that is)
pheno_df_clean = pheno_df_clean.loc[pheno_df_clean["DX"] != "pending"]

# format columns
pheno_df_clean["sub"] = [f"{int(subid):07d}" for subid in pheno_df_clean["ScanDir ID"]]
pheno_df_clean["site"] = pheno_df_clean["Site"].map(site_map)
pheno_df_clean["gender"] = pheno_df_clean["Gender"].map(gender_map)
pheno_df_clean["age"] = pheno_df_clean["Age"]
pheno_df_clean["dx"] = [dx_map[int(dx)] for dx in pheno_df_clean["DX"]]

pheno_df_clean = pheno_df_clean.loc[:, ["sub", "site", "gender", "dx", "age"]]

# remove subs that are missing image data
pheno_df_clean = pheno_df_clean.loc[pheno_df_clean["sub"].isin(bids_df_clean["sub"])]

# drop sites with too few subjects
site_counts = pheno_df_clean["site"].value_counts()
print(site_counts)
include_sites = [site for site, count in site_counts.items() if count >= 10]
pheno_df_clean = pheno_df_clean.loc[pheno_df_clean["site"].isin(include_sites)]

pheno_df_clean = pheno_df_clean.sort_values("sub")
pheno_df_clean = pheno_df_clean.reset_index(drop=True)

print(f"num valid subs: {len(pheno_df_clean)} / {len(pheno_df)}")
print(pheno_df_clean.shape)
print(pheno_df_clean.head())

site
NYU           214
Peking         85
KKI            83
NeuroIMAGE     48
Pittsburgh      1
Name: count, dtype: int64
num valid subs: 430 / 717
(430, 5)
       sub site gender       dx    age
0  0010001  NYU      F     ADHD  11.17
1  0010002  NYU      F     ADHD  13.24
2  0010003  NYU      F  Control   9.29
3  0010004  NYU      F  Control  13.75
4  0010005  NYU      M     ADHD  11.92


In [14]:
# splits

# Create a Combined Stratification Column
strat_key = (pheno_df_clean["gender"].astype(str) + "_" + pheno_df_clean["dx"].astype(str)).values

# Perform the Split (70% Train, 15% Val, 15% Test)
# First split: 70% Train, 30% Temp (which will be Val + Test)
train_ids, temp_ids = train_test_split(
    np.arange(len(pheno_df_clean)), test_size=0.30, random_state=42, stratify=strat_key
)

# Second split: Divide the 30% Temp into 50/50 Val and Test
val_ids, test_ids = train_test_split(
    temp_ids, test_size=0.50, random_state=42, stratify=strat_key[temp_ids]
)

print(f"Train: {len(train_ids)} | Val: {len(val_ids)} | Test: {len(test_ids)}")

splits = np.full(len(pheno_df_clean), None, dtype=object)
splits[train_ids] = "train"
splits[val_ids] = "validation"
splits[test_ids] = "test"
pheno_df_clean["split"] = pd.Categorical(
    splits, categories=["train", "validation", "test"], ordered=True
)

pheno_df_clean = pheno_df_clean.sort_values(["split", "sub"])

Train: 301 | Val: 64 | Test: 65


In [15]:
combined_counts = pheno_df_clean.groupby(["split"], observed=False).agg(
    {
        "sub": [("Total", "count")],
        "dx": [
            ("ADHD", lambda s: (s == "ADHD").sum()),
            ("Control", lambda s: (s == "Control").sum()),
        ],
        "gender": [
            ("Female", lambda s: (s == "F").sum()),
            ("Male", lambda s: (s == "M").sum()),
        ],
    }
)
combined_counts.columns = combined_counts.columns.get_level_values(1)
combined_counts = combined_counts.reset_index()
print(combined_counts.to_markdown(index=False))

| split      |   Total |   ADHD |   Control |   Female |   Male |
|:-----------|--------:|-------:|----------:|---------:|-------:|
| train      |     301 |    131 |       170 |      128 |    173 |
| validation |      64 |     28 |        36 |       27 |     37 |
| test       |      65 |     28 |        37 |       28 |     37 |


In [16]:
merged_df = pheno_df_clean.merge(bids_df_clean, on=["sub", "site"], how="inner")

# only keep one run per sub
merged_df = merged_df.drop_duplicates("sub")
assert len(merged_df) == len(pheno_df_clean)

print(merged_df.shape)
print(merged_df.head())

(430, 18)
       sub site gender       dx    age  split dataset  ses  task  run  acq  \
0  0010001  NYU      F     ADHD  11.17  train     NYU    1  rest    1  NaN   
2  0010002  NYU      F     ADHD  13.24  train     NYU    1  rest    1  NaN   
4  0010003  NYU      F  Control   9.29  train     NYU    1  rest    1  NaN   
5  0010006  NYU      F  Control  11.18  train     NYU    1  rest    1  NaN   
6  0010007  NYU      F     ADHD  11.41  train     NYU    1  rest    1  NaN   

  suffix   tr  num_trs    dur  has_mni  has_cifti  \
0   bold  2.0      176  352.0     True       True   
2   bold  2.0      176  352.0     True       True   
4   bold  2.0      176  352.0     True       True   
5   bold  2.0      176  352.0     True       True   
6   bold  2.0      176  352.0     True       True   

                                                path  
0  NYU/sub-0010001/ses-1/func/sub-0010001_ses-1_t...  
2  NYU/sub-0010002/ses-1/func/sub-0010002_ses-1_t...  
4  NYU/sub-0010003/ses-1/func/sub-001

In [17]:
merged_df.to_csv(ROOT / "metadata/ADHD200_curated.csv", index=False)