In [2]:
import numpy as np
from pathlib import Path
import shutil
import warnings

from dipy.core.gradients import gradient_table

from eddymotion import dmri
from eddymotion.viz import plot_dwi

%matplotlib inline

In [3]:
base_dir = Path("/Users/michael/projects/datasets/ds000206")
bids_dir = base_dir / "bids"
derivatives_dir = base_dir / "dmriprep"

dwi_file = bids_dir / "sub-05" / "ses-JHU1" / "dwi" / "sub-05_ses-JHU1_acq-GD72_dwi.nii.gz"
bvec_file = bids_dir / "sub-05" / "ses-JHU1" / "dwi" / "sub-05_ses-JHU1_acq-GD72_dwi.bvec"
bval_file = bids_dir / "sub-05" / "ses-JHU1" / "dwi" / "sub-05_ses-JHU1_acq-GD72_dwi.bval"

In [4]:
gtab = gradient_table(str(bval_file), str(bvec_file))

In [5]:
#from dmriprep.interfaces.vectors import CheckGradientTable

#gen_rasb = CheckGradientTable(dwi_file=str(dwi_file),
#                              in_bvec=str(bvec_file),
#                              in_bval=str(bval_file)
#                             ).run()

rasb_file = bids_dir / "sub-05" / "ses-JHU1" / "dwi" / "sub-05_ses-JHU1_acq-GD72_dwi.tsv"
#shutil.copy("sub-05_ses-JHU1_acq-GD72_dwi.tsv", str(rasb_file))

In [6]:
b0_file = derivatives_dir / "sub-05" / "ses-JHU1" / "dwi" / "sub-05_ses-JHU1_acq-GD72_desc-b0_dwi.nii.gz"
brainmask_file = derivatives_dir / "sub-05" / "ses-JHU1" / "dwi" / "sub-05_ses-JHU1_acq-GD72_desc-brain_mask.nii.gz"

In [7]:
data = dmri.load(
    str(dwi_file),
    gradients_file=str(rasb_file),
    b0_file=str(b0_file),
    brainmask_file=str(brainmask_file)
    )

In [8]:
def _rasb2dipy(gradient):
    gradient = np.asanyarray(gradient)
    if gradient.ndim == 1:
        if gradient.size != 4:
            raise ValueError("Missing gradient information.")
        gradient = gradient[..., np.newaxis]

    if gradient.shape[0] != 4:
        gradient = gradient.T
    elif gradient.shape == (4, 4):
        print("Warning: make sure gradient information is not transposed!")

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        retval = gradient_table(gradient[3, :], gradient[:3, :].T)
    return retval

In [18]:
class DKIModel:
    """A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel."""

    __slots__ = ("_model", "_S0", "_mask")

    def __init__(self, gtab, S0=None, mask=None, **kwargs):
        """Instantiate the wrapped tensor model."""
        from dipy.reconst.dki import DiffusionKurtosisModel

        self._S0 = None
        if S0 is not None:
            self._S0 = np.clip(
                S0.astype("float32") / S0.max(),
                a_min=1e-5,
                a_max=1.0,
            )
        self._mask = mask
        if mask is None and S0 is not None:
            self._mask = self._S0 > np.percentile(self._S0, 35)

        if self._mask is not None:
            self._S0 = self._S0[self._mask.astype(bool)]

        kwargs = {
            k: v
            for k, v in kwargs.items()
            if k
            in (
                "min_signal",
                "return_S0_hat",
                "fit_method",
                "weighting",
                "sigma",
                "jac",
            )
        }
        self._model = DiffusionKurtosisModel(gtab, **kwargs)

    def fit(self, data, **kwargs):
        """Clean-up permitted args and kwargs, and call model's fit."""
        self._model = self._model.fit(data[self._mask, ...])

    def predict(self, gradient, **kwargs):
        """Propagate model parameters and call predict."""
        predicted = np.squeeze(
            self._model.predict(
                _rasb2dipy(gradient),
                S0=self._S0,
            )
        )
        if predicted.ndim == 3:
            return predicted

        retval = np.zeros_like(self._mask, dtype="float32")
        retval[self._mask, ...] = predicted
        return retval

In [11]:
class DTIModel:
    """A wrapper of :obj:`dipy.reconst.dti.TensorModel."""

    __slots__ = ("_model", "_S0", "_mask")

    def __init__(self, gtab, S0=None, mask=None, **kwargs):
        """Instantiate the wrapped tensor model."""
        from dipy.reconst.dti import TensorModel as DipyTensorModel

        self._S0 = None
        if S0 is not None:
            self._S0 = np.clip(
                S0.astype("float32") / S0.max(),
                a_min=1e-5,
                a_max=1.0,
            )

        self._mask = mask
        if mask is None and S0 is not None:
            self._mask = self._S0 > np.percentile(self._S0, 35)

        if self._mask is not None:
            self._S0 = self._S0[self._mask.astype(bool)]

        kwargs = {
            k: v
            for k, v in kwargs.items()
            if k
            in (
                "min_signal",
                "return_S0_hat",
                "fit_method",
                "weighting",
                "sigma",
                "jac",
            )
        }
        self._model = DipyTensorModel(_rasb2dipy(gtab), **kwargs)

    def fit(self, data, **kwargs):
        """Fit the model chunk-by-chunk asynchronously."""
        self._model = self._model.fit(data[self._mask, ...])

    def predict(self, gradient, **kwargs):
        """Propagate model parameters and call predict."""
        predicted = np.squeeze(
            self._model.predict(
                _rasb2dipy(gradient),
                S0=self._S0,
            )
        )
        if predicted.ndim == 3:
            return predicted

        retval = np.zeros_like(self._mask, dtype="float32")
        retval[self._mask, ...] = predicted
        return retval

In [6]:
class SparseFascicleModel:
    """
    A wrapper of :obj:`dipy.reconst.sfm.SparseFascicleModel.
    """

    __slots__ = ("_model", "_S0", "_mask", "_solver")

    def __init__(self, gtab, S0=None, mask=None, solver=None, **kwargs):
        """Instantiate the wrapped model."""
        from dipy.reconst.sfm import SparseFascicleModel
        from sklearn.gaussian_process import GaussianProcessRegressor

        self._S0 = None
        if S0 is not None:
            self._S0 = np.clip(
                S0.astype("float32") / S0.max(),
                a_min=1e-5,
                a_max=1.0,
            )

        self._mask = mask
        if mask is None and S0 is not None:
            self._mask = self._S0 > np.percentile(self._S0, 35)

        if self._mask is not None:
            self._S0 = self._S0[self._mask.astype(bool)]

        self._solver = solver
        if solver is None:
            self._solver = "ElasticNet"

        kwargs = {k: v for k, v in kwargs.items() if k in ("solver",)}
        self._model = SparseFascicleModel(gtab, **kwargs)

    def fit(self, data, **kwargs):
        """Clean-up permitted args and kwargs, and call model's fit."""
        self._model = self._model.fit(data[self._mask, ...])

    def predict(self, gradient, **kwargs):
        """Propagate model parameters and call predict."""
        predicted = np.squeeze(
            self._model.predict(
                _rasb2dipy(gradient),
                S0=self._S0,
            )
        )
        if predicted.ndim == 3:
            return predicted

        retval = np.zeros_like(self._mask, dtype="float32")
        retval[self._mask, ...] = predicted
        return retval

In [12]:
model = DTIModel(
    gtab=data.gradients,
    S0=data.bzero,
    mask=data.brainmask
)

In [None]:
data_train, data_test = data.logo_split(10)
model.fit(data_train[0])

In [None]:
predicted = model.predict(data_test[1])

In [None]:
plot_dwi(predicted, data.affine, gradient=data_test[1]);