Notebook whose primary function is to create large multipanel tractography visualizations across different datasets. Easier than copying a function across to dataset specific notebooks.

Paths are (currently) specific to BIDS organization of Human Connectome Project (HCP)

In [7]:
import concurrent
import itertools
import os

import dill
import dipy.io.streamline
import matplotlib as mpl
import matplotlib.pyplot as plt
import nibabel as nib
import pandas as pd
import seaborn as sns
import svgutils
from dipy.viz import colormap as cmap
from fury import actor, window
from IPython.display import HTML, clear_output
from joblib import Parallel, delayed
from joblib.pool import has_shareable_memory
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

# Set number of processors for multiproccessing
n_proc = 12

# Define a function to handle importing database to avoid rerunning cells
def handle_session(save=False):
    if save:
        print("Saving notebook session")
        dill.dump_session(os.path.realpath("./0_figures/connectivity_viz.db"))
    else:
        print("Loading notebook session")
        dill.load_session(os.path.realpath("./0_figures/connectivity_viz.db"))


# Pandas settings
pd.set_option("display.max_rows", None)

try:
    handle_session()
except:
    raise FileNotFoundError(
        """No database found - please run the all of the following cells to
    set up custom functions and perform analysis"""
    )

Loading notebook session


In [2]:
# Plot settings
# Color friendly color cycle - gist.github.com/thriveth/8560036
cb_color_cycle = [
    "#377eb8",
    "#ff7f00",
    "#4daf4a",
    "#f781bf",
    "#a65628",
    "#984ea3",
    "#999999",
    "#e41a1c",
    "#dede00",
]

cmap = mpl.cm.get_cmap("viridis")
cmap.set_bad("black")
sns.set(
    style="ticks",
    context="poster",
    rc={
        "image.cmap": "viridis",
        "axes.prop_cycle": plt.cycler(color=cb_color_cycle),
        "font.sans-serif": "Liberation Sans",
        "font.monospace": "Liberation Sans",
        "axes.titlesize": 16,
        "axes.titleweight": "bold",
        "axes.labelsize": 14,
        "axes.labelweight": "bold",
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 12,
    },
)

boxprops = dict(edgecolor="white", alpha=0.5)
whiskerprops = dict(color="black", linestyle="--", alpha=0.5)
capprops = dict(color="black", alpha=0.5)
medianprops = dict(color="white", linewidth=2)

In [3]:
# Check intra vs inter hemispheric connectivity
def check_hemi(node: int):
    return "Even" if ((node % 2) == 0) else "Odd"


def get_tract(subjid, dataset, node1, node2, tract_type):
    tract_path = f"{dataset}/derivatives/mrtpipelines_0.1.6/mrtpipelines/{subjid}/tractography/{tract_type}/{subjid}_space-dwi_desc-from_{node1}-{node2}.tck"

    if not os.path.exists(tract_path):
        try:
            os.system(
                f"tar -xf {dataset}/derivatives/mrtpipelines_0.1.6/mrtpipelines/{subjid}/tractography/{subjid}_edge.tar -C {dataset}/derivatives/mrtpipelines_0.1.6/mrtpipelines/{subjid}/tractography --wildcards *{node1}-{node2}.tck"
            )
        except:
            raise FileNotFoundError("Tract cannot be found in dataset")

    return tract_path


def viz_tract(
    subjid,
    dataset,
    node1,
    node2,
    interactive=False,
    opacity=0.5,
    size=(600, 600),
):
    """
    Visualize tract using `fury`.
    By default, only show the tract from the test dataset

    Eg. Interactive visualization of ansa, passing through 'fl' roi, thresholded at 25th percentile of CCI
    viz_tract('sub-103818', 13, 69, roi='fl', perc=25, retest=True, interactive=True)

    INPUTS:
        subjid - ID of subject to visualize
        dataset - top-level directory of dataset to visualize
        node1 - terminal node for tract
        node2 - other terminal node for trat
        interactive - interactive visualization
        opacity - opacity of tractography
        size - figure window dimensions
    """

    # Load anatomical reference and ROIs
    if dataset.split("/")[-1] == "hcp_retest":
        dataset_base = dataset.replace("retest", "test")
    
    anat_ref = nib.load(
        f"{dataset_base}/{subjid}/anat/{subjid}_acq-procHCP_T1w.nii.gz"
    )
    anat_affine = anat_ref.affine
    
    subconn_roi = nib.load(
        f"{dataset_base}/derivatives/zona_bb_subcortex/{subjid}/anat/{subjid}_space-T1w_desc-ZonaBBSubCorSeg.nii.gz"
    )
    subconn_roi = subconn_roi.get_fdata()

    # Check type of connectivity
    tract_type = "edge"

    if check_hemi(node1) == check_hemi(node2):
        conn = "intra"
        if check_hemi(node1) == "Odd":
            tract_left = get_tract(subjid, dataset, node1, node2, tract_type)
            tract_right = get_tract(subjid, dataset, node1 + 1, node2 + 1, tract_type)
            tck_left = dipy.io.streamline.load_tractogram(tract_left, anat_ref)
            tck_right = dipy.io.streamline.load_tractogram(tract_right, anat_ref)

            roi_left1 = subconn_roi == node1
            roi_left2 = subconn_roi == node2
            roi_right1 = subconn_roi == node1 + 1
            roi_right2 = subconn_roi == node2 + 1
        else:
            tract_left = get_tract(subjid, dataset, node1 - 1, node2 - 1, tract_type)
            tract_right = get_tract(subjid, dataset, node1 - 1, node2 - 1, tract_type)
            tck_left = dipy.io.streamline.load_tractogram(tract_left, anat_ref)
            tck_right = dipy.io.streamline.load_tractogram(tract_right, anat_ref)

            roi_left1 = subconn_roi == node1 - 1
            roi_left2 = subconn_roi == node2 - 1
            roi_right1 = subconn_roi == node2
            roi_right2 = subconn_roi == node2
        
    else:
        conn = "inter"
        tract = get_tract(subjid, dataset, node1, node2, tract_type)
        tck = dipy.io.streamline.load_tractogram(tract, anat_ref)

        roi1 = subconn_roi == node1
        roi2 = subconn_roi == node2

    # Setup actors & scene
    scene = window.Scene()
    scene.background([1, 1, 1])
    tract_width = 2
    roi_color = [0.2, 0.2, 0.2]
    roi_opacity = 0.2
    
    if conn == "intra":
        left_actor = actor.line(
            tck_left.streamlines,
            (0, 0.6, 0),
            opacity=opacity,
            linewidth=tract_width,
            fake_tube=True,
        )
        right_actor = actor.line(
            tck_right.streamlines,
            (.8, 0.5, 0),
            opacity=opacity,
            linewidth=tract_width,
            fake_tube=True,
        )
        
        scene.add(left_actor)
        scene.add(right_actor)
        
        roi_left1_actor = actor.contour_from_roi(roi_left1, anat_affine, roi_color, roi_opacity)
        roi_left2_actor = actor.contour_from_roi(roi_left2, anat_affine, roi_color, roi_opacity)
        roi_right1_actor = actor.contour_from_roi(roi_right1, anat_affine, roi_color, roi_opacity)
        roi_right2_actor = actor.contour_from_roi(roi_right2, anat_affine, roi_color, roi_opacity)
        
        scene.add(roi_left1_actor)
        scene.add(roi_left2_actor)
        scene.add(roi_right1_actor)
        scene.add(roi_right2_actor)
        
        del tck_left, tck_right
        
    else:
        tck_actor = actor.line(
            tck.streamlines,
            (0, .7, .7),
            linewidth=tract_width,
            fake_tube=True,
            opacity=opacity,
        )
        
        scene.add(tck_actor)

        roi1_actor = actor.contour_from_roi(roi1, anat_affine, roi_color, roi_opacity)
        roi2_actor = actor.contour_from_roi(roi2, anat_affine, roi_color, roi_opacity)

        scene.add(roi1_actor)
        scene.add(roi2_actor)
        
        del tck

    # Set coronal anterior
    if interactive:
        scene.yaw(90)
        scene.pitch(-90)
        window.show(scene, size=size)
    else:
        # Axial Superior
        scene.roll(180)
        scene.reset_camera()
        axial_scene = window.snapshot(scene, size=size, offscreen=not(interactive))
        scene.roll(-180)
        
        # Coronal Anterior
        scene.yaw(90)
        scene.pitch(-90)
        scene.reset_camera()
        coronal_scene = window.snapshot(scene, size=size, offscreen=not(interactive))
        scene.pitch(90)
        scene.yaw(-90)

        # Sagittal Right
        scene.pitch(90)
        scene.reset_camera()
        sagittal_scene = window.snapshot(scene, size=size, offscreen=not(interactive))
        scene.pitch(-90)            
    
    if conn == "intra":
        os.remove(tract_left)
        os.remove(tract_right)
        del tract_left, tract_right
    else:
        os.remove(tract)
        del tract

    del scene
    if not(interactive):
        return (coronal_scene, sagittal_scene, axial_scene)

In [4]:
def create_figure(node, dataset, dataset_fname):
    # Grab nodes
    print(f"Node: {node}")
    node1, node2 = node.split("-")
    
    # Create multipanel figure
    fig = plt.figure(constrained_layout=True, figsize=(12, 7))
    subfigs = fig.subfigures(7, 6, wspace=0.25, hspace=0.05)
    
    # Create individual subject panels
    for ix, subj in enumerate(participants_df[0].tolist()):
        print(f"Creating scenes for {subj}")
        scenes = viz_tract(subj, f"{hcp_dir}/{dataset}", int(node1), int(node2), interactive=False, size=(2400, 2400))
        
        ax = subfigs[ix // 6][ix % 6].subplots(1, 3, sharey=True, gridspec_kw={"wspace": 0.})
        for i in range(3):
            ax[i].axis("off")
            ax[i].imshow(scenes[i], origin="lower")
        ax[0].set_title("Coronal-anterior", size=4)
        ax[1].set_title("Sagittal-left", size=4)
        ax[2].set_title("Axial-superior", size=4)
        subfigs[ix // 6][ix % 6].suptitle(f"{subj}", size=10)
        
        del scenes
    
    title = dataset.split("-")
    fig.suptitle(f"{title[0].upper()} {title[0].capitalize()}")
    
    # Save figure
    os.system(f"mkdir -p 0_figures/{dataset}")
    plt.savefig(f"0_figures/{dataset}/sub-{dataset_fname}_desc-{int(node1)}-{int(node2)}_tracts.svg", 
                dpi=960, bbox_inches="tight", facecolor="white")
    plt.close()
    
    del fig, subfigs
    
    clear_output(wait=True)

In [30]:
viz_tract("sub-103818", "/home/ROBARTS/tkai/graham/scratch/Zona/data/hcp1200_3T_2/hcp_retest", 13, 69, interactive=True)

In [5]:
# Thresholded nodes
nodes = pd.read_csv("/home/ROBARTS/tkai/graham/scratch/Zona/notebooks/hcp1200_3T/hcp1200_3T_TestvRetest_Thresholded.csv", sep=",")["Nodes"].tolist()

# Dataset directories
hcp_dir = "/home/ROBARTS/tkai/graham/scratch/Zona/data/hcp1200_3T_2"

## HCP Test vs Retest

In [6]:
# Subjects dropped from analysis due to missing data or non matching sequences
drop_subj = [601127, 192439, 137128, 135528, 169343, 151526, 660951]
drop_subj = [f"sub-{subj}" for subj in drop_subj]
drop_subj

participants_df = pd.read_csv(f"{hcp_dir}/participants.tsv", sep="\t", header=None)
participants_df = participants_df[~participants_df[0].isin(drop_subj)].reset_index().drop(columns="index")

handle_session(save=True)

Saving notebook session


In [12]:
# Spawn a child process for each iteration to deal with memory leak from fury
for node in tqdm(nodes): 
    with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
        executor.submit(create_figure, node, "hcp_test", "HCPTest")

clear_output(wait=False)
print("Finished creating multi-panel figures for HCP Test")

Finished creating multi-panel figures for HCP Test


In [9]:
# Spawn a child process for each iteration to deal with memory leak from fury
for node in tqdm(nodes): 
    with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
        executor.submit(create_figure, node, "hcp_retest", "HCPRetest")

clear_output(wait=False)
print("Finished creating multi-panel figures for HCP Retest")

  0%|          | 0/186 [00:00<?, ?it/s]

Node: 1-14
Creating scenes for sub-103818
Creating scenes for sub-105923


  1%|          | 1/186 [00:04<13:23,  4.34s/it]

Node: 1-2
Creating scenes for sub-103818
Creating scenes for sub-105923


  1%|          | 2/186 [00:08<12:45,  4.16s/it]

Node: 1-4
Creating scenes for sub-103818
Creating scenes for sub-105923


  2%|▏         | 3/186 [00:12<12:52,  4.22s/it]

Node: 11-12
Creating scenes for sub-103818


  2%|▏         | 3/186 [00:13<13:21,  4.38s/it]


KeyboardInterrupt: 