In [None]:
import itertools
import os.path as op
import os

from gradec.decode import LDADecoder
from gradec.utils import _rm_medial_wall, _decoding_filter
from gradec.plot import plot_surf_maps, plot_radar, plot_cloud
from gradec.fetcher import _fetch_features, _fetch_frequencies, _fetch_classification
import nibabel as nib
import numpy as np

### Define space, density and paths to data

In [None]:
SPACE, DENSITY = "fsaverage", "164k"
DSET, MODEL = "neuroquery", "lda"

data_dir = op.join(".", "data")
neuromaps_dir = op.join(data_dir, "neuromaps")
figures_dir = op.join(data_dir, "figures")

# List of possible combinations of tracts, regions and smoothing
tracts = ["Arc", "SLF1And2", "CST"]
regions = ["RAS", "LPI"]
smths = ["", ".smooth_1"]
thresholds = ["0", "0.15", "0.25"]

# Dictionaries for the title of the figures
TRACTS_DICT = {
    "Arc": "Arcuate",
    "CST": "Corticospinal",
    "SLF1And2": "SLF 1 & 2",
}
REGIONS_DICT = {
    "LPI": "Left-Posterior-Inferior",
    "RAS": "Right-Anterior-Superior",
}
SMTHS_DICT = {
    "": "Unsmoothed",
    ".smooth_1": "Smoothed",
}

### Train and LDA-based decoder on NeuroQuery detabase

In [None]:
decode = LDADecoder(space=SPACE, density=DENSITY, calc_pvals=False, data_dir=data_dir)
decode.fit(DSET)

# Load features for visualization
features = _fetch_features(DSET, MODEL, data_dir=data_dir)
frequencies = _fetch_frequencies(DSET, MODEL, data_dir=data_dir)
classification, class_lst = _fetch_classification(DSET, MODEL, data_dir=data_dir)

### Run decoder on each regions separate

In [None]:
sep_figures_dir = op.join(figures_dir, "separated")
os.makedirs(sep_figures_dir, exist_ok=True)

separated_results = {}
for fig_i, (threshold, tract, region, smth) in enumerate(itertools.product(thresholds, tracts, regions, smths)):
    # Path to the maps
    regions_dir = op.join(
        data_dir, 
        "white-matter-atlas_thresholds", 
        f"cortexmap_binarize_smooth-surf-1_threshold-{threshold}_dilate-0", 
        "cortexmap", 
        "func",
    )
    
    # Read maps
    map_lh = op.join(regions_dir, f"lh.left{tract}_box_1mm_{region}_FiberEndpoint{smth}.func.gii")
    map_rh = op.join(regions_dir, f"rh.right{tract}_box_1mm_{region}_FiberEndpoint{smth}.func.gii")
    map_arr_lh = nib.load(map_lh).agg_data()
    map_arr_rh = nib.load(map_rh).agg_data()

    # Remove medial wall
    map_arr = _rm_medial_wall(
        map_arr_lh,
        map_arr_rh,
        space=SPACE,
        density=DENSITY,
        neuromaps_dir=neuromaps_dir,
    )

    # Decode map
    corrs_df = decode.transform([map_arr], method="correlation")
    filtered_df, filtered_features, filtered_frequencies = _decoding_filter(
        corrs_df,
        features,
        classification,
        freq_by_topic=frequencies,
        class_by_topic=class_lst,
    )
    filtered_df.columns = ["r"]
    separated_results[f"{tract}_{region}{smth}_thr-{threshold}"] = filtered_df.sort_values(by="r", ascending=False)

    # Visualize maps to decode
    plot_surf_maps(
        map_arr_lh, 
        map_arr_rh, 
        space=SPACE, 
        density=DENSITY, 
        cmap="YlOrRd",
        color_range=(0, 1),
        title=f"{TRACTS_DICT[tract]} {REGIONS_DICT[region]}\n{SMTHS_DICT[smth]}. Threshold: {threshold}",
        data_dir=data_dir,
        out_fig=op.join(sep_figures_dir, f"{fig_i}-01_{tract}_{region}{smth}_thr-{threshold}_surf.png"),
    )

    # Visualize results
    corrs = filtered_df["r"].to_numpy()
    if not np.any(np.isnan(corrs)) and corrs.size > 0: # Skip one of the regions of CST
        # Radar plot
        plot_radar(
            corrs, 
            filtered_features, 
            MODEL,
            out_fig=op.join(sep_figures_dir, f"{fig_i}-02_{tract}_{region}{smth}_thr-{threshold}_radar.png"),
        )
        
        # Word cloud plot
        plot_cloud(
            corrs, 
            filtered_features,
            MODEL,
            frequencies=filtered_frequencies,
            out_fig=op.join(sep_figures_dir, f"{fig_i}-03_{tract}_{region}{smth}_thr-{threshold}_wordcloud.png"),
        )

### Run decoder on combined regions for each tract

In [None]:
com_figures_dir = op.join(figures_dir, "combined")
os.makedirs(com_figures_dir, exist_ok=True)

combined_results = {}
for fig_i, (threshold, tract, smth) in enumerate(itertools.product(thresholds, tracts, smths)):
    # Path to the maps
    regions_dir = op.join(
        data_dir, 
        "white-matter-atlas_thresholds", 
        f"cortexmap_binarize_smooth-surf-1_threshold-{threshold}_dilate-0", 
        "cortexmap", 
        "func",
    )
    
    # Read maps
    map_lpi_lh = op.join(regions_dir, f"lh.left{tract}_box_1mm_LPI_FiberEndpoint{smth}.func.gii")
    map_lpi_rh = op.join(regions_dir, f"rh.right{tract}_box_1mm_LPI_FiberEndpoint{smth}.func.gii")
    map_ras_lh = op.join(regions_dir, f"lh.left{tract}_box_1mm_RAS_FiberEndpoint{smth}.func.gii")
    map_ras_rh = op.join(regions_dir, f"rh.right{tract}_box_1mm_RAS_FiberEndpoint{smth}.func.gii")
    
    map_lpi_arr_lh = nib.load(map_lpi_lh).agg_data()
    map_lpi_arr_rh = nib.load(map_lpi_rh).agg_data()
    map_ras_arr_lh = nib.load(map_ras_lh).agg_data()
    map_ras_arr_rh = nib.load(map_ras_rh).agg_data()

    # Combined regions for each tract
    map_arr_lh = np.maximum(map_lpi_arr_lh, map_ras_arr_lh) # Take the maximum to address overlap
    map_arr_rh = np.maximum(map_lpi_arr_rh, map_ras_arr_rh) # Take the maximum to address overlap
    
    # Remove medial wall
    map_arr = _rm_medial_wall(
        map_arr_lh,
        map_arr_rh,
        space=SPACE,
        density=DENSITY,
        neuromaps_dir=neuromaps_dir,
    )

    # Decode map
    corrs_df = decode.transform([map_arr], method="correlation")
    filtered_df, filtered_features, filtered_frequencies = _decoding_filter(
        corrs_df,
        features,
        classification,
        freq_by_topic=frequencies,
        class_by_topic=class_lst,
    )

    filtered_df.columns = ["r"]
    combined_results[f"{tract}{smth}_thr-{threshold}"] = filtered_df.sort_values(by="r", ascending=False)

    # Visualize maps to decode
    surf_fig = plot_surf_maps(
        map_arr_lh, 
        map_arr_rh, 
        space=SPACE, 
        density=DENSITY, 
        cmap="YlOrRd",
        color_range=(0, 1),
        title=f"{TRACTS_DICT[tract]} LPI+RAS\n{SMTHS_DICT[smth]}. Threshold: {threshold}",
        data_dir=data_dir,
        out_fig=op.join(com_figures_dir, f"{fig_i}-01_{tract}_LPI+RAS{smth}_thr-{threshold}_surf.png"),
    )

    # Visualize results
    corrs = filtered_df["r"].to_numpy()
    if not np.any(np.isnan(corrs)) and corrs.size > 0: # Skip one of the regions of CST
        # Radar plot
        plot_radar(
            corrs, 
            filtered_features, 
            MODEL,
            out_fig=op.join(com_figures_dir, f"{fig_i}-02_{tract}_LPI+RAS{smth}_thr-{threshold}_radar.png"),
        )

        # Word cloud plot
        plot_cloud(
            corrs, 
            filtered_features,
            MODEL,
            frequencies=filtered_frequencies,
            out_fig=op.join(com_figures_dir, f"{fig_i}-03_{tract}_LPI+RAS{smth}_thr-{threshold}_wordcloud.png"),
        )
    

### Make figures

In [None]:
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.image as mpimg

In [None]:
comb_width, comb_hight = 25, 13
sep_width, sep_hight = 25, 25

n_comb_rows, n_comb_cols = 3, 6
n_sep_rows, n_sep_cols = 6, 6

In [None]:
fig_i = 0
for thr_i, threshold in enumerate(thresholds):
    fig = plt.figure(figsize=(sep_width, sep_hight))
    fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.1, hspace=0.1)
    gs = GridSpec(nrows=n_sep_rows, ncols=n_sep_cols, figure=fig)

    for trc_i, tract in enumerate(tracts):
        for reg_i, region in enumerate(regions):
            for smth_i, smth in enumerate(smths):
                surf_plt = op.join(sep_figures_dir, f"{fig_i}-01_{tract}_{region}{smth}_thr-{threshold}_surf.png")
                radar_plt = op.join(sep_figures_dir, f"{fig_i}-02_{tract}_{region}{smth}_thr-{threshold}_radar.png")
                wordcloud_plt = op.join(sep_figures_dir, f"{fig_i}-03_{tract}_{region}{smth}_thr-{threshold}_wordcloud.png")

                for img_i, img_file in enumerate([surf_plt, radar_plt, wordcloud_plt]):
                    ax = fig.add_subplot(gs[trc_i*2 + reg_i, smth_i*3 + img_i], aspect="equal")
                    if op.exists(img_file):
                        img = mpimg.imread(img_file)    
                        ax.imshow(img)

                    ax.set_axis_off()

                fig_i += 1

        out_file = op.join(figures_dir, f"results-separated_thr-{float(threshold):.2f}.png")
        fig.savefig(out_file, bbox_inches="tight", dpi=300)
        plt.close()

In [None]:
fig_i = 0
for thr_i, threshold in enumerate(thresholds):
    fig = plt.figure(figsize=(comb_width, comb_hight))
    fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.1, hspace=0.1)
    gs = GridSpec(nrows=n_comb_rows, ncols=n_comb_cols, figure=fig)

    for trc_i, tract in enumerate(tracts):
        for smth_i, smth in enumerate(smths):

            surf_plt = op.join(com_figures_dir, f"{fig_i}-01_{tract}_LPI+RAS{smth}_thr-{threshold}_surf.png")
            radar_plt = op.join(com_figures_dir, f"{fig_i}-02_{tract}_LPI+RAS{smth}_thr-{threshold}_radar.png")
            wordcloud_plt = op.join(com_figures_dir, f"{fig_i}-03_{tract}_LPI+RAS{smth}_thr-{threshold}_wordcloud.png")
            
            for img_i, img_file in enumerate([surf_plt, radar_plt, wordcloud_plt]):
                ax = fig.add_subplot(gs[trc_i, smth_i*3 + img_i], aspect="equal")
                if op.exists(img_file):
                    img = mpimg.imread(img_file)    
                    ax.imshow(img)
                
                ax.set_axis_off()

            fig_i += 1

    out_file = op.join(figures_dir, f"results-combined_thr-{float(threshold):.2f}.png")
    fig.savefig(out_file, bbox_inches="tight", dpi=300)
    plt.close()