The purpose of this notebook is to look at specific images for which the brain mask ended up looking wrong (by visual inspection) and to tweak parameters to get it right. The script `05.1_dti_fit.py` was run first to generate a family of brain masks, DTI fits, and FA images.

This notebook concludes that using bvalue=0 images only makes sense for computing brain mask.

This notebook also looks into the option of using my GPU implementation of brain masking. The memory limitation makes it hard to use a large filter, but the speed makes it cheap to run the filter many times. This approach seems to have the best results in the end.

It's also faster in a sense, but not practically. Computing one brain mask with GPU is much faster than CPU, but if you have a large dataset then you can typically process the images in parallel with CPU but GPU memory limitations will make that harder to do with GPU.

In [None]:
import os
import glob
import random
import json
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import dipy.io.image
import dipy.io
import dipy.core.gradients
import dipy.reconst.dti
import dipy.segment.mask

In [None]:
data_dir = 'DMRI_EXTRACTED_NONTEST/'
img_dirs = glob.glob(os.path.join(data_dir,'*ABCD-MPROC-DTI*/sub-*/ses-*/dwi/'))

In [None]:
sampled_fmriresults01_df = pd.read_csv('01.1_abcd_sample2/sampled_nontest_fmriresults01.csv')
sampled_fmriresults01_df['dirname'] = sampled_fmriresults01_df.derived_files.apply(lambda x : x.split('/')[-1].strip('.tgz'))
dirname_to_full_path = {img_dir.split('/')[-5]:img_dir for img_dir in img_dirs}

In [None]:
data = []
for (subjectkey,interview_age),df in sampled_fmriresults01_df.groupby(['subjectkey', 'interview_age']):
    paths = []
    for _,row in df.iterrows():
        if row.dirname not in dirname_to_full_path.keys():
            raise FileNotFoundError(f"Could not find a directory for fmriresults01 id {row.fmriresults01_id}")
        img_dir = dirname_to_full_path[row.dirname]
        dwi_path = glob.glob(os.path.join(img_dir, '*.nii'))[0]
        bval_path = glob.glob(os.path.join(img_dir, '*.bval'))[0]
        bvec_path = glob.glob(os.path.join(img_dir, '*.bvec'))[0]
        paths.append({
            'img_dir' : img_dir,
            'dwi_path' : dwi_path,
            'bval_path' : bval_path,
            'bvec_path' : bvec_path,
        })
    data.append({
        'paths' : paths,

        'subjectkey' : row.subjectkey,
        'interview_age' : row.interview_age,
    })

In [None]:
# Run this to look at the number of bvals for each image, alongside subject id and interview age

for d in data:
    for i,p in enumerate(d['paths']):
        bvals, bvecs = dipy.io.read_bvals_bvecs(p['bval_path'], p['bvec_path'])
        print(f"{len(bvals)} \t {d['subjectkey']} \t {d['interview_age']}",
              f"(part {i+1} of {len(d['paths'])})" if len(d['paths'])>1 else "")

In [None]:
# Function to load data from one of the dictionaries listed in the object "data" defined above
def load_data(d):
    img_data_list =[]
    bvals_list = []
    bvecs_list = []
    prev_affine_transform = None

    for p in d['paths']:
        img_data, affine = dipy.io.image.load_nifti(p['dwi_path'])
        assert((prev_affine_transform is None) or (affine==prev_affine_transform).all())
        prev_affine_transform = affine  
        bvals, bvecs = dipy.io.read_bvals_bvecs(p['bval_path'], p['bvec_path'])
        img_data_list.append(img_data)
        bvals_list.append(bvals)
        bvecs_list.append(bvecs)
        bvals = np.concatenate(bvals_list)
    img_data = np.concatenate(img_data_list, axis=-1)
    bvecs = np.concatenate(bvecs_list, axis=0)
    gtab = dipy.core.gradients.gradient_table(bvals, bvecs)
    return img_data, affine, gtab

In [None]:
data_indexed_by_subject = {d['subjectkey']:d for d in data}

d_random = random.choice(data) # Pick a random one

# pick ones for which the brain mask was observed problematic
d_badmask1 = data_indexed_by_subject['NDAR_INV87T95RHP']
d_badmask2 = data_indexed_by_subject['NDAR_INVE0KZKF5V']
d_badmask3 = data_indexed_by_subject['NDAR_INVGL5PNTK7']
d_badmask4 = data_indexed_by_subject['NDAR_INVWAC9RH98']

img_data, affine, gtab = load_data(d_badmask1) # Load one to demonstrate how we process it below

Generate brain mask and preview it for the loaded image:

In [None]:
img_data_masked, mask = dipy.segment.mask.median_otsu(img_data, vol_idx = range(img_data.shape[-1]))

In [None]:
def preview(img):
    fig,axs = plt.subplots(1,3,figsize=(10,5))
    axs[0].imshow(img[62,:,:].T, origin='lower', cmap='gray')
    axs[1].imshow(img[:,:,80].T, origin='lower', cmap='gray')
    axs[2].imshow(img[:,75,:].T, origin='lower', cmap='gray')
    plt.show()

num_bvals = img_data.shape[3]
i = random.randint(0,num_bvals-1)
preview(img_data[:,:,:,i])
preview(mask)
preview(img_data_masked[:,:,:,i])

Try again with different parameters on the otsu thresholding:

In [None]:
img_data_masked, mask = dipy.segment.mask.median_otsu(img_data, vol_idx = [0], median_radius=4, numpass=4)

In [None]:
num_bvals = img_data.shape[3]
i = random.randint(0,num_bvals-1)
preview(img_data[:,:,:,i])
preview(mask)
preview(img_data_masked[:,:,:,i])

In the end, instead of tweaking the median filtering, it seems that focusing on the image for a specific bvalue, rather than all bvalues, helped the most. Picking the bvalue that has the least noisy image for mask generation seems to be the way to go. Let's inspect if there's a consistent best b-value for this purpose:

In [None]:
img_data, affine, gtab = load_data(random.choice(data)) # Load random subject
num_bvals = img_data.shape[3]

In [None]:
for b in np.unique(gtab.bvals):
    i = random.choice(np.where(gtab.bvals==b)[0])
    print(f"Image with bvalue {gtab.bvals[i]}:")
    preview(img_data[:,:,:,i])

Running this cell a few times, the b-value 0 images (i.e. the ones that aren't diffusion weighted are clearly best to use for masking. This makes sense, because

> image contrast increases at higher b-values, albeit at the cost of reduced SNR

(from https://doi.org/10.1016/B978-0-12-817057-1.00022-6)

and while SNR and contrast both matter for accuracy of otsu thresholding-- here it's SNR that is our limiting factor, rather than contrast.

Let's now try using bval 0 only in the mask generation:

In [None]:
img_data, affine, gtab = load_data(d_badmask1) # Load one to demonstrate how we process it below
img_data_masked, mask = dipy.segment.mask.median_otsu(img_data, vol_idx = np.where(gtab.bvals==0)[0])

In [None]:
num_bvals = img_data.shape[3]
i = random.randint(0,num_bvals-1)
preview(img_data[:,:,:,i])
preview(mask)
preview(img_data_masked[:,:,:,i])

Ah but here's one that the b=0 based masking seems to perform worse on:

In [None]:
img_data, affine, gtab = load_data(data_indexed_by_subject['NDAR_INV761E1JVD']) # Load one to demonstrate how we process it below
img_data_masked, mask = dipy.segment.mask.median_otsu(img_data, vol_idx = np.where(gtab.bvals==0)[0])

In [None]:
num_bvals = img_data.shape[3]
i = random.randint(0,num_bvals-1)
preview(img_data[:,:,:,i])
preview(mask)
preview(img_data_masked[:,:,:,i])

In [None]:
for b in np.unique(gtab.bvals):
    i = random.choice(np.where(gtab.bvals==b)[0])
    print(f"Image with bvalue {gtab.bvals[i]}:")
    preview(img_data[:,:,:,i])

In this case the higher SNR of the b=0 images is harmful for otsu thresholding, because it emphasizes a non-brain structure.

Perhaps a larger median filter can ignore such structures?

In [None]:
img_data, affine, gtab = load_data(data_indexed_by_subject['NDAR_INV761E1JVD']) # Load one to demonstrate how we process it below
img_data_masked, mask = dipy.segment.mask.median_otsu(img_data, vol_idx = np.where(gtab.bvals==0)[0], median_radius=7)

In [None]:
num_bvals = img_data.shape[3]
i = random.randint(0,num_bvals-1)
preview(img_data[:,:,:,i])
preview(mask)
preview(img_data_masked[:,:,:,i])

But a radius 7 median filter just takes so long.

Below is a different approach where I set up a GPU version of median filtering.
This runs a LOT faster, but is limited by GPU memory. With my 8GB card the highest `mean_radius` I can do is 3. However we can increase `numpass` very cheaply, because each run of the filter is extremely fast. Check it out:

In [None]:
from brainmask_with_gpu import median_otsu_gpu
img_data, affine, gtab = load_data(data_indexed_by_subject['NDAR_INV761E1JVD']) # Load one to demonstrate how we process it below
img_data_masked, mask = median_otsu_gpu(img_data, vol_idx = np.where(gtab.bvals==0)[0], median_radius=2, numpass=30)

In [None]:
num_bvals = img_data.shape[3]
i = random.randint(0,num_bvals-1)
i=0
preview(img_data[:,:,:,i])
preview(mask)
preview(img_data_masked[:,:,:,i])