In [1]:
import nimare
from nimare.extract import download_abstracts, fetch_neuroquery, fetch_neurosynth
from nimare.dataset import Dataset
from nimare.decode import gclda_decode_roi
from nilearn.image import load_img, math_img
from nimare.io import convert_neurosynth_to_dataset
from nilearn.plotting import plot_roi
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import nibabel as nib
from nimare.decode import discrete
import glob
from wordcloud import WordCloud, get_single_color_func
import os
from pprint import pprint

from collections import defaultdict
import nltk
from nltk.stem import WordNetLemmatizer

In [2]:
img_vol = nib.load('./data/network_all_035.nii.gz')
img_data = img_vol.get_fdata()
print(np.unique(img_data))
print(img_data.shape)

n_networks = img_data.shape[-1]
print(f'Number of networks: {n_networks}')

[0. 1.]
(182, 218, 182, 23)
Number of networks: 23


In [3]:
out_dir = os.path.abspath("./data/supportive_data/")

files = fetch_neurosynth(
    data_dir=out_dir,
    version="7",
    source = "abstract",
    vocab="terms",
    overwrite=False
)
# Note that the files are saved to a new folder within "out_dir" named "neurosynth".
pprint(files)
neurosynth_db = files[0]

INFO:nimare.extract.utils:Dataset found in /data/wheelock/data1/people/Chenyan/NiMare_test/data/supportive_data/neurosynth

INFO:nimare.extract.extract:Searching for any feature files matching the following criteria: [('source-abstract', 'vocab-terms', 'data-neurosynth', 'version-7')]


Downloading data-neurosynth_version-7_coordinates.tsv.gz
File exists and overwrite is False. Skipping.
Downloading data-neurosynth_version-7_metadata.tsv.gz
File exists and overwrite is False. Skipping.
Downloading data-neurosynth_version-7_vocab-terms_source-abstract_type-tfidf_features.npz
File exists and overwrite is False. Skipping.
Downloading data-neurosynth_version-7_vocab-terms_vocabulary.txt
File exists and overwrite is False. Skipping.
[{'coordinates': '/data/wheelock/data1/people/Chenyan/NiMare_test/data/supportive_data/neurosynth/data-neurosynth_version-7_coordinates.tsv.gz',
  'features': [{'features': '/data/wheelock/data1/people/Chenyan/NiMare_test/data/supportive_data/neurosynth/data-neurosynth_version-7_vocab-terms_source-abstract_type-tfidf_features.npz',
                'vocabulary': '/data/wheelock/data1/people/Chenyan/NiMare_test/data/supportive_data/neurosynth/data-neurosynth_version-7_vocab-terms_vocabulary.txt'}],
  'metadata': '/data/wheelock/data1/people/Cheny

In [4]:
neurosynth_dset = convert_neurosynth_to_dataset(
    coordinates_file=neurosynth_db["coordinates"],
    metadata_file=neurosynth_db["metadata"],
    annotations_files=neurosynth_db["features"],
)



In [5]:
def brain_network_decoding(img, neurosynth_dset, rotation_frame, network_value):
    """
    This function takes an image file, loads it, checks if it's 4D and extracts the appropriate frame,
    then applies a transformation based on a scan mask using the Neurosynth dataset.
    
    Parameters:
    img_path (str): Path to the image file (NIfTI format).
    neurosynth_dset: The Neurosynth dataset to use for decoding.
    rotation_frame (int): Frame index to use if the image is 4D (default is 0).
    scan_value (float or int): The value to use for creating the mask from the image (default is 1.5).
    
    Returns:
    pd.DataFrame: DataFrame sorted by 'r' values (association strength).
    """
    
    # Load the image using nibabel
    
    # Check the shape of the image to determine if it is 4D
    img_data = img.get_fdata()
    
    if img_data.ndim == 4:
        # If 4D, extract the frame specified by rotation_frame
        img_3d_data = img_data[:, :, :, rotation_frame]
    else:
        # If not 4D, use the original 3D data
        img_3d_data = img_data
    
    # Create a new Nifti1Image with the extracted data and original affine
    img_3d = nib.Nifti1Image(img_3d_data, img.affine)
    
    # Create a scan mask based on a dynamic condition (e.g., values equal to scan_value)
    network_mask = math_img(f"img == {network_value}", img=img_3d)
    
    # Initialize the decoder
    decoder = discrete.ROIAssociationDecoder(network_mask)
    
    # Fit the decoder with the Neurosynth dataset
    decoder.fit(neurosynth_dset)
    
    # Transform the dataset (decoder.transform does not take any parameters)
    decoder_df = decoder.transform()
    
    # Sort the resulting DataFrame by 'r' column in descending order and return the result
    return decoder_df

In [6]:
# decoder_network_df = brain_network_decoding(
#     img=img_vol,
#     neurosynth_dset=neurosynth_dset,
#     rotation_frame=0,  # Assuming we want the first frame if it's 4D
#     network_value=1.0  # The value to use for the mask
# )

# decoder_network_df.to_csv('./results/network_decoded.csv')

In [7]:
for i in range(n_networks):
    # img_network = np.squeeze(img_vol[:, :, :, i])
    decoded_network_df = brain_network_decoding(
        img=img_vol,
        neurosynth_dset=neurosynth_dset,
        rotation_frame=i,  # extract the i-th network
        network_value=1.0  # The value to use for the mask
    )
    decoded_network_df.to_csv(f'./results/23networks_035/network{i+1}_decoded.csv')

In [3]:
# def plot_top_k_wordcloud(weights, k=10):
#     """
#     Plot a word cloud of the top-k terms based on their weights.

#     Parameters:
#     - weights (dict): Dictionary where keys are terms and values are r values (float).
#     - k (int): Number of top terms to include.
#     - title (str): Title for the plot.
#     """
#     if not weights or k <= 0:
#         print("Empty input or invalid value of k.")
#         return

#     # Sort weights and select top-k
#     top_weights = dict(sorted(weights.items(), key=lambda item: item[1], reverse=True)[:k])

#     # Generate word cloud
#     wordcloud = WordCloud(background_color='white', width=800, height=400).generate_from_frequencies(top_weights)

#     # Plot
#     plt.figure(figsize=(10, 5))
#     plt.imshow(wordcloud, interpolation='bilinear')
#     plt.axis('off')
#     plt.show()


def plot_top_k_wordcloud(weights, network_idx, p_threshold, k=10, prefix_to_strip="terms_abstract_tfidf__"):
    """
    Plot a word cloud of the top-k terms based on their weights, stripping a common prefix.

    Parameters:
    - weights (dict): Dictionary where keys are terms and values are r values (float).
    - k (int): Number of top terms to include.
    - title (str): Title for the plot.
    - prefix_to_strip (str): Prefix to remove from each term.
    """
    if not weights or k <= 0:
        print("Empty input or invalid value of k.")
        return

    # # Remove prefix and sort by r
    # cleaned_weights = {
    #     key.replace(prefix_to_strip, ''): val
    #     for key, val in weights.items()
    #     if key.startswith(prefix_to_strip)
    # }

    # Select top-k
    # top_weights = dict(sorted(cleaned_weights.items(), key=lambda item: item[1], reverse=True)[:k])
    top_weights = dict(sorted(weights.items(), key=lambda item: item[1], reverse=True)[:k])

    color_func = get_single_color_func('darkblue')

    # Generate word cloud
    wordcloud = WordCloud(
        background_color='white', 
        width=800, height=400, 
        color_func=color_func
        ).generate_from_frequencies(top_weights)
    
    # Generate and create output directory if needed
    output_dir = f'./results/figures_{int(p_threshold * 100)}_k{k}'
    os.makedirs(output_dir, exist_ok=True)

    # Plot
    plt.figure(figsize=(10, 5))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis('off')
    plt.savefig(f'./results/figures_{int(p_threshold*100)}_k{k}/wordcloud_network{network_idx}.png', format='png', dpi=300, bbox_inches='tight')
    # plt.show()
    plt.close()



In [4]:
# network1_df = pd.read_csv('./results/23networks_035/network1_decoded.csv')
# weights1 = dict(zip(network1_df['feature'], network1_df['r']))

# plot_top_k_wordcloud(weights1, k=25, prefix_to_strip="terms_abstract_tfidf__")

In [5]:
# network2_df = pd.read_csv('./results/23networks_035/network2_decoded.csv')
# weights2 = dict(zip(network2_df['feature'], network2_df['r']))

# plot_top_k_wordcloud(weights2, k=25, prefix_to_strip="terms_abstract_tfidf__")

In [74]:
redundant_suffixes = {
    'network', 'networks', 'cortex', 'area', 'areas', 'region', 'regions',
    'system', 'systems', 'lobe', 'lobes', 'activation', 'activations',
    'brain', 'functional', 'related', 'associated', 'associated_with',
    'healthy', 'task', 'tasks', 'connectivity', 'mode'
}

stop_terms = {
    'network', 'networks', 'cortex', 'region', 'system', 'brain',
    'activation', 'task', 'mode', 'connectivity'
}

nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

True

In [75]:
def clean_and_merge_terms(weights, prefix_to_strip, redundant_suffixes, stop_terms):
    lemmatizer = WordNetLemmatizer()
    cleaned = defaultdict(float)

    for term, val in weights.items():
        # Strip prefix and normalize spaces
        if not term.startswith(prefix_to_strip):
            continue
        cleaned_term = term.replace(prefix_to_strip, '')

        # Split into words
        words = cleaned_term.split(' ')
        # Lemmatize each word (singularize, normalize tense)
        lemmatized = [lemmatizer.lemmatize(w.lower()) for w in words]

        # Remove redundant suffixes like 'network', 'networks', 'cortex'
        lemmatized = [w for w in lemmatized if w not in redundant_suffixes]

        if not lemmatized:
            continue

        merged_term = ' '.join(lemmatized)

        # Skip generic one-word terms like 'network'
        if merged_term in stop_terms:
            continue

        cleaned[merged_term] = max(val, cleaned[merged_term])

    return dict(cleaned)


# def suppress_subcomponents(merged_weights):
#     final_weights = {}
#     multiword_phrases = {k for k in merged_weights if ' ' in k}

#     all_suppressed = set()
#     for phrase in multiword_phrases:
#         components = phrase.split(' ')
#         all_suppressed.update(components)

#     for term, val in merged_weights.items():
#         # If it's a single word that also appears in a multiword phrase, drop it
#         if term in all_suppressed and term not in multiword_phrases:
#             continue
#         final_weights[term] = val

#     return final_weights

def suppress_lower_components(term_scores):
    """
    Suppress single-word terms that are part of a multi-word phrase
    if the multi-word phrase has a higher score.

    Parameters:
    - term_scores (dict): Dictionary of terms and their scores.

    Returns:
    - dict: Filtered term_scores with redundant components removed.
    """
    phrases = {term for term in term_scores if ' ' in term}
    suppressed = set()

    for phrase in phrases:
        phrase_score = term_scores[phrase]
        parts = phrase.split(' ')

        for part in parts:
            # Only suppress if the part exists and has a strictly lower score
            if part in term_scores and term_scores[part] < phrase_score:
                suppressed.add(part)

    # Filter out suppressed terms
    filtered_scores = {
        term: score
        for term, score in term_scores.items()
        if term not in suppressed
    }

    return filtered_scores


def suppress_bidirectional_overlaps(term_scores, top_k=None):
    """
    Suppress overlapping terms based on relative scores, keeping top-k protected.

    - If a phrase is stronger: suppress its parts (unless protected)
    - If parts are stronger: suppress the phrase (unless protected)

    Parameters:
    - term_scores (dict): term -> score
    - top_k (int or None): number of top scoring terms to protect

    Returns:
    - dict: filtered term scores
    """
    safeguard_terms = set()
    if top_k is not None:
        safeguard_terms = {
            term for term, _ in sorted(term_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
        }

    suppress_set = set()
    phrases = {t for t in term_scores if ' ' in t}

    for phrase in phrases:
        phrase_score = term_scores[phrase]
        parts = phrase.split(' ')

        # If phrase is stronger than all parts → suppress parts (unless protected)
        if all(
            part in term_scores and
            term_scores[part] < phrase_score and
            part not in safeguard_terms
            for part in parts
        ):
            suppress_set.update(parts)

        # If any parts exist and are each stronger than the phrase → suppress phrase
        elif any(
            part in term_scores and
            term_scores[part] > phrase_score
            for part in parts
        ):
            if phrase not in safeguard_terms:
                suppress_set.add(phrase)

    return {
        term: score
        for term, score in term_scores.items()
        if term not in suppress_set
    }

In [77]:
for kvals in [10, 15, 25]:
    for i in range(n_networks):

        print(f'Processing network {i+1}')
        network_p35_df = pd.read_csv(f'./results/23networks_035/network{i+1}_decoded.csv')
        network_p50_df = pd.read_csv(f'./results/23networks_05/network{i+1}_decoded.csv')

        weights35 = dict(zip(network_p35_df['feature'], network_p35_df['r']))
        weights50 = dict(zip(network_p50_df['feature'], network_p50_df['r']))

        merged_weights35 = clean_and_merge_terms(weights35, "terms_abstract_tfidf__", redundant_suffixes, stop_terms)
        merged_weights50 = clean_and_merge_terms(weights50, "terms_abstract_tfidf__", redundant_suffixes, stop_terms)

        # merged_weights35 = suppress_bidirectional_overlaps(merged_weights35)
        # merged_weights50 = suppress_bidirectional_overlaps(merged_weights50)

        plot_top_k_wordcloud(merged_weights35, i+1, 0.35, k=kvals)
        plot_top_k_wordcloud(merged_weights50, i+1, 0.5, k=kvals)

Processing network 1
Processing network 2
Processing network 3
Processing network 4
Processing network 5
Processing network 6
Processing network 7
Processing network 8
Processing network 9
Processing network 10
Processing network 11
Processing network 12
Processing network 13
Processing network 14
Processing network 15
Processing network 16
Processing network 17
Processing network 18
Processing network 19
Processing network 20
Processing network 21
Processing network 22
Processing network 23
Processing network 1
Processing network 2
Processing network 3
Processing network 4
Processing network 5
Processing network 6
Processing network 7
Processing network 8
Processing network 9
Processing network 10
Processing network 11
Processing network 12
Processing network 13
Processing network 14
Processing network 15
Processing network 16
Processing network 17
Processing network 18
Processing network 19
Processing network 20
Processing network 21
Processing network 22
Processing network 23
Proces

In [None]:
for i in range(n_networks):

    print(f'Processing network {i+1}')
    network_p35_df = pd.read_csv(f'./results/23networks_035/network{i+1}_decoded.csv')
    network_p50_df = pd.read_csv(f'./results/23networks_05/network{i+1}_decoded.csv')

    weights35 = dict(zip(network_p35_df['feature'], network_p35_df['r']))
    weights50 = dict(zip(network_p50_df['feature'], network_p50_df['r']))

    merged_weights35 = clean_and_merge_terms(weights35, "terms_abstract_tfidf__", redundant_suffixes, stop_terms)
    merged_weights50 = clean_and_merge_terms(weights50, "terms_abstract_tfidf__", redundant_suffixes, stop_terms)

    # cleaned_weights35 = suppress_subcomponents(merged_weights35)
    # cleaned_weights50 = suppress_subcomponents(merged_weights50)

    plot_top_k_wordcloud(merged_weights35, i+1, 0.35, k=25)
    plot_top_k_wordcloud(merged_weights50, i+1, 0.5, k=25)

In [41]:
for i in range(n_networks):

    print(f'Processing network {i+1}')
    network_p35_df = pd.read_csv(f'./results/23networks_035/network{i+1}_decoded.csv')
    network_p50_df = pd.read_csv(f'./results/23networks_05/network{i+1}_decoded.csv')

    weights35 = dict(zip(network_p35_df['feature'], network_p35_df['r']))
    weights50 = dict(zip(network_p50_df['feature'], network_p50_df['r']))

    plot_top_k_wordcloud(weights35, i+1, 0.35, k=15)
    plot_top_k_wordcloud(weights50, i+1, 0.5, k=15)

Processing network 1
Processing network 2
Processing network 3
Processing network 4
Processing network 5
Processing network 6
Processing network 7
Processing network 8
Processing network 9
Processing network 10
Processing network 11
Processing network 12
Processing network 13
Processing network 14
Processing network 15
Processing network 16
Processing network 17
Processing network 18
Processing network 19
Processing network 20
Processing network 21
Processing network 22
Processing network 23
