Skip to content

Commit

Permalink
Merge pull request yeatmanlab#484 from 36000/dice_coeff
Browse files Browse the repository at this point in the history
ENH: Dice coeff
  • Loading branch information
arokem committed Oct 2, 2020
2 parents 1f5d228 + 9a20e60 commit d778f8a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 3 deletions.
48 changes: 45 additions & 3 deletions AFQ/utils/tests/test_volume.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,55 @@
import numpy as np
import numpy.testing as npt
import pytest
import os.path as op

import AFQ.utils.volume as AFV
import nibabel as nib

from dipy.io.stateful_tractogram import Space
from dipy.io.streamline import StatefulTractogram

import AFQ.utils.volume as afv
import AFQ.data as afd


def test_patch_up_roi():
roi_bad = np.zeros((10, 10, 10))
roi_good = np.ones((10, 10, 10))

AFV.patch_up_roi(roi_good)
afv.patch_up_roi(roi_good)
with pytest.raises(ValueError):
AFV.patch_up_roi(roi_bad)
afv.patch_up_roi(roi_bad)


def test_density_map():
file_dict = afd.read_stanford_hardi_tractography()

# subsample even more
subsampled_tractography = file_dict["tractography_subsampled.trk"][441:444]
sft = StatefulTractogram(
subsampled_tractography,
file_dict["mapping.nii.gz"],
Space.VOX)
density_map = afv.density_map(sft)
npt.assert_equal(int(np.sum(density_map.get_fdata())), 69)


def test_dice_coeff():
affine = np.asarray([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0]])
img1 = nib.Nifti1Image(
np.asarray([
[0.8, 0.9, 0],
[0, 0, 0],
[0, 0, 0]]),
affine)
img2 = nib.Nifti1Image(
np.asarray([
[0.5, 0, 0],
[0.6, 0, 0],
[0, 0, 0]]),
affine)
npt.assert_equal(afv.dice_coeff(img1, img2), 0.5)
78 changes: 78 additions & 0 deletions AFQ/utils/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from skimage.filters import gaussian
from skimage.morphology import convex_hull_image
from scipy.spatial.qhull import QhullError
from scipy.spatial.distance import dice

import nibabel as nib

from dipy.io.utils import (create_nifti_header, get_reference_info)
from dipy.tracking.streamline import select_random_set_of_streamlines
import dipy.tracking.utils as dtu


def patch_up_roi(roi, bundle_name="ROI"):
Expand Down Expand Up @@ -42,3 +49,74 @@ def patch_up_roi(roi, bundle_name="ROI"):
return convex_hull_image(hole_filled)
except QhullError:
return hole_filled


def density_map(tractogram, n_sls=None, to_vox=False):
"""
Create a streamline density map.
based on:
https://dipy.org/documentation/1.1.1./examples_built/streamline_formats/
Parameters
----------
tractogram : StatefulTractogram
Stateful tractogram whose streamlines are used to make
the density map.
n_sls : int or None
n_sls to randomly select to make the density map.
If None, all streamlines are used.
Default: None
to_vox : bool
Whether to put the stateful tractogram in VOX space before making
the density map.
Returns
-------
Nifti1Image containing the density map.
"""
if to_vox:
tractogram.to_vox()

sls = tractogram.streamlines
if n_sls is not None:
sls = select_random_set_of_streamlines(sls, n_sls)

affine, vol_dims, voxel_sizes, voxel_order = get_reference_info(tractogram)
tractogram_density = dtu.density_map(sls, np.eye(4), vol_dims)
nifti_header = create_nifti_header(affine, vol_dims, voxel_sizes)
density_map_img = nib.Nifti1Image(tractogram_density, affine, nifti_header)

return density_map_img


def dice_coeff(arr1, arr2):
"""
Compute Dice's coefficient between two images.
Parameters
----------
arr1 : Nifti1Image, str, ndarray
One ndarray to compare. Can be a path or image, which will be
converted to an ndarray.
arr2 : Nifti1Image, str, ndarray
The other ndarray to compare. Can be a path or image, which will be
converted to an ndarray.
Returns
-------
The dice similarity between the images.
"""
if isinstance(arr1, str):
arr1 = nib.load(arr1)
if isinstance(arr2, str):
arr2 = nib.load(arr2)

if isinstance(arr1, nib.Nifti1Image):
arr1 = arr1.get_fdata()
if isinstance(arr2, nib.Nifti1Image):
arr2 = arr2.get_fdata()

# scipy's dice function returns the dice *dissimilarity*
return 1 - dice(
arr1.flatten().astype(bool),
arr2.flatten().astype(bool))

0 comments on commit d778f8a

Please sign in to comment.