From 2179576d6a70abe6754e2c07f5bb53dd0a5a9227 Mon Sep 17 00:00:00 2001 From: mathiasg Date: Wed, 18 Dec 2024 11:40:20 -0500 Subject: [PATCH 1/3] FIX: Use nitransforms for most xfm handling --- nibabies/interfaces/resampling.py | 96 +------------------------ nibabies/utils/transforms.py | 34 +++++++++ nibabies/workflows/bold/registration.py | 6 +- pyproject.toml | 2 +- requirements.txt | 36 +++------- 5 files changed, 53 insertions(+), 121 deletions(-) create mode 100644 nibabies/utils/transforms.py diff --git a/nibabies/interfaces/resampling.py b/nibabies/interfaces/resampling.py index 99e56ff2..9a7220ca 100644 --- a/nibabies/interfaces/resampling.py +++ b/nibabies/interfaces/resampling.py @@ -4,10 +4,8 @@ import os from collections.abc import Callable from functools import partial -from pathlib import Path from typing import TypeVar -import h5py import nibabel as nb import nitransforms as nt import numpy as np @@ -19,12 +17,13 @@ traits, ) from nipype.utils.filemanip import fname_presuffix -from nitransforms.io.itk import ITKCompositeH5 from scipy import ndimage as ndi from scipy.sparse import hstack as sparse_hstack from sdcflows.transform import grid_bspline_weights from sdcflows.utils.tools import ensure_positive_cosines +from nibabies.utils.transforms import load_transforms + R = TypeVar('R') @@ -34,95 +33,6 @@ async def worker(job: Callable[[], R], semaphore: asyncio.Semaphore) -> R: return await loop.run_in_executor(None, job) -def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase: - """Load a series of transforms as a nitransforms TransformChain - - An empty list will return an identity transform - """ - if len(inverse) == 1: - inverse *= len(xfm_paths) - elif len(inverse) != len(xfm_paths): - raise ValueError('Mismatched number of transforms and inverses') - - chain = None - for path, inv in zip(xfm_paths[::-1], inverse[::-1], strict=False): - path = Path(path) - if path.suffix == '.h5': - xfm = load_ants_h5(path) - else: - xfm = nt.linear.load(path) - if inv: - xfm = ~xfm - if chain is None: - chain = xfm - else: - chain += xfm - if chain is None: - chain = nt.base.TransformBase() - return chain - - -FIXED_PARAMS = np.array([ - 193.0, 229.0, 193.0, # Size - 96.0, 132.0, -78.0, # Origin - 1.0, 1.0, 1.0, # Spacing - -1.0, 0.0, 0.0, # Directions - 0.0, -1.0, 0.0, - 0.0, 0.0, 1.0, -]) # fmt:skip - - -def load_ants_h5(filename: Path) -> nt.base.TransformBase: - """Load ANTs H5 files as a nitransforms TransformChain""" - # Borrowed from https://github.com/feilong/process - # process.resample.parse_combined_hdf5() - # - # Changes: - # * Tolerate a missing displacement field - # * Return the original affine without a round-trip - # * Always return a nitransforms TransformChain - # - # This should be upstreamed into nitransforms - h = h5py.File(filename) - xform = ITKCompositeH5.from_h5obj(h) - - # nt.Affine - transforms = [nt.Affine(xform[0].to_ras())] - - if '2' not in h['TransformGroup']: - return transforms[0] - - transform2 = h['TransformGroup']['2'] - - # Confirm these transformations are applicable - if transform2['TransformType'][:][0] not in ( - b'DisplacementFieldTransform_float_3_3', - b'DisplacementFieldTransform_double_3_3', - ): - msg = 'Unknown transform type [2]\n' - for i in h['TransformGroup'].keys(): - msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n' - raise ValueError(msg) - - fixed_params = transform2['TransformFixedParameters'][:] - shape = tuple(fixed_params[:3].astype(int)) - # ITK stores warps in Fortran-order, where the vector components change fastest - # Nitransforms expects 3 volumes, not a volume of three-vectors, so transpose - warp = np.reshape( - transform2['TransformParameters'], - (3, *shape), - order='F', - ).transpose(1, 2, 3, 0) - - warp_affine = np.eye(4) - warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3)) - warp_affine[:3, 3] = fixed_params[3:6] - lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1]) - warp_affine = lps_to_ras @ warp_affine - transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine))) - return nt.TransformChain(transforms) - - class ResampleSeriesInputSpec(TraitedSpec): in_file = File(exists=True, mandatory=True, desc='3D or 4D image file to resample') ref_file = File(exists=True, mandatory=True, desc='File to resample in_file to') @@ -788,7 +698,7 @@ def reconstruct_fieldmap( ) if not direct: - fmap_img = transforms.apply(fmap_img, reference=target) + fmap_img = nt.apply(fmap_img, reference=target) fmap_img.header.set_intent('estimate', name='fieldmap Hz') fmap_img.header.set_data_dtype('float32') diff --git a/nibabies/utils/transforms.py b/nibabies/utils/transforms.py new file mode 100644 index 00000000..b1049ebd --- /dev/null +++ b/nibabies/utils/transforms.py @@ -0,0 +1,34 @@ +"""Utilities for loading transforms for resampling""" + +from pathlib import Path + +import nitransforms as nt + + +def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase: + """Load a series of transforms as a nitransforms TransformChain + + An empty list will return an identity transform + """ + if len(inverse) == 1: + inverse *= len(xfm_paths) + elif len(inverse) != len(xfm_paths): + raise ValueError('Mismatched number of transforms and inverses') + + chain = None + for path, inv in zip(xfm_paths[::-1], inverse[::-1], strict=False): + path = Path(path) + if path.suffix == '.h5': + # Load as a TransformChain + xfm = nt.manip.load(path) + else: + xfm = nt.linear.load(path) + if inv: + xfm = ~xfm + if chain is None: + chain = xfm + else: + chain += xfm + if chain is None: + chain = nt.Affine() # Identity + return chain diff --git a/nibabies/workflows/bold/registration.py b/nibabies/workflows/bold/registration.py index e66a6727..2c49be1a 100644 --- a/nibabies/workflows/bold/registration.py +++ b/nibabies/workflows/bold/registration.py @@ -741,14 +741,16 @@ def _conditional_downsampling(in_file, in_mask, zoom_th=4.0): offset = old_center - newrot.dot((newshape - 1) * 0.5) newaffine = nb.affines.from_matvec(newrot, offset) + identity = nt.Affine() + newref = nb.Nifti1Image(np.zeros(newshape, dtype=np.uint8), newaffine) - nt.Affine(reference=newref).apply(img).to_filename(out_file) + nt.apply(identity, img, reference=newref).to_filename(out_file) mask = nb.load(in_mask) mask.set_data_dtype(float) mdata = gaussian_filter(mask.get_fdata(dtype=float), scaling) floatmask = nb.Nifti1Image(mdata, mask.affine, mask.header) - newmask = nt.Affine(reference=newref).apply(floatmask) + newmask = nt.apply(identity, floatmask, reference=newref) hdr = newmask.header.copy() hdr.set_data_dtype(np.uint8) newmaskdata = (newmask.get_fdata(dtype=float) > 0.5).astype(np.uint8) diff --git a/pyproject.toml b/pyproject.toml index cc0dbd51..5c5966c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "nipype >= 1.8.5", "nireports >= 23.2.0", "nitime", - "nitransforms >= 23.0.1", + "nitransforms >= 24.1.1", "niworkflows >= 1.12.1", "numpy >= 1.21.0", "packaging", diff --git a/requirements.txt b/requirements.txt index 8515238c..def11f4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ annexremote==1.6.6 # datalad-osf astor==0.8.1 # via formulaic -attrs==24.2.0 +attrs==24.3.0 # via # jsonschema # niworkflows @@ -31,16 +31,14 @@ bidsschematools==1.0.0 # via bids-validator bokeh==3.5.2 # via tedana -boto3==1.35.80 +boto3==1.35.83 # via datalad -botocore==1.35.80 +botocore==1.35.83 # via # boto3 # s3transfer -certifi==2024.8.30 +certifi==2024.12.14 # via requests -cffi==1.17.1 - # via cryptography chardet==5.2.0 # via datalad charset-normalizer==3.4.0 @@ -58,11 +56,9 @@ contourpy==1.3.1 # via # bokeh # matplotlib -cryptography==44.0.0 - # via secretstorage cycler==0.12.1 # via matplotlib -datalad==1.1.4 +datalad==1.1.5 # via # datalad-next # datalad-osf @@ -87,8 +83,6 @@ formulaic==0.5.2 # via pybids fsspec==2024.10.0 # via universal-pathlib -greenlet==3.1.1 - # via sqlalchemy h5py==3.12.1 # via nitransforms humanize==4.11.0 @@ -123,10 +117,6 @@ jaraco-context==6.0.1 # keyrings-alt jaraco-functools==4.1.0 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.4 # via # bokeh @@ -171,7 +161,7 @@ mapca==0.0.5 # via tedana markupsafe==3.0.2 # via jinja2 -matplotlib==3.9.3 +matplotlib==3.10.0 # via # nireports # nitime @@ -214,7 +204,7 @@ nilearn==0.10.4 # nireports # niworkflows # tedana -nipype==1.9.1 +nipype==1.9.2 # via # nibabies (pyproject.toml) # nireports @@ -225,7 +215,7 @@ nireports==24.0.3 # via nibabies (pyproject.toml) nitime==0.11 # via nibabies (pyproject.toml) -nitransforms==24.1.0 +nitransforms==24.1.1 # via # nibabies (pyproject.toml) # niworkflows @@ -235,7 +225,7 @@ niworkflows==1.12.1 # nibabies (pyproject.toml) # sdcflows # smriprep -num2words==0.5.13 +num2words==0.5.14 # via pybids numpy==2.1.1 # via @@ -327,8 +317,6 @@ pybtex==0.24.0 # via tedana pybtex-apa-style==1.3 # via tedana -pycparser==2.22 - # via cffi pydot==3.0.3 # via nipype pyparsing==3.2.0 @@ -343,7 +331,7 @@ python-dateutil==2.9.0.post0 # nipype # pandas # prov -python-gitlab==5.1.0 +python-gitlab==5.2.0 # via datalad pytz==2024.2 # via pandas @@ -384,7 +372,7 @@ rpds-py==0.22.3 # referencing s3transfer==0.10.4 # via boto3 -scikit-image==0.24.0 +scikit-image==0.25.0 # via # niworkflows # sdcflows @@ -415,8 +403,6 @@ seaborn==0.13.2 # via # nireports # niworkflows -secretstorage==3.3.3 - # via keyring simplejson==3.19.3 # via nipype six==1.17.0 From cfe8d1795051731783ac46c028960246efaa98b4 Mon Sep 17 00:00:00 2001 From: mathiasg Date: Wed, 18 Dec 2024 11:45:17 -0500 Subject: [PATCH 2/3] FIX: Use nitransforms for `compare_xforms` --- nibabies/workflows/bold/registration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nibabies/workflows/bold/registration.py b/nibabies/workflows/bold/registration.py index 2c49be1a..14b9c6ba 100644 --- a/nibabies/workflows/bold/registration.py +++ b/nibabies/workflows/bold/registration.py @@ -704,11 +704,11 @@ def compare_xforms(lta_list, norm_threshold=15): second transform relative to the first (default: `15`) """ + import nitransforms as nt from nipype.algorithms.rapidart import _calc_norm_affine - from niworkflows.interfaces.surf import load_transform - bbr_affine = load_transform(lta_list[0]) - fallback_affine = load_transform(lta_list[1]) + bbr_affine = nt.linear.load(lta_list[0]).matrix + fallback_affine = nt.linear.load(lta_list[1]).matrix norm, _ = _calc_norm_affine([fallback_affine, bbr_affine], use_differences=True) From d338dd0b88dc23ec27f90d93106c31d03d471482 Mon Sep 17 00:00:00 2001 From: mathiasg Date: Wed, 18 Dec 2024 12:37:00 -0500 Subject: [PATCH 3/3] FIX: Missing parameter --- nibabies/interfaces/resampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nibabies/interfaces/resampling.py b/nibabies/interfaces/resampling.py index 9a7220ca..46781c55 100644 --- a/nibabies/interfaces/resampling.py +++ b/nibabies/interfaces/resampling.py @@ -698,7 +698,7 @@ def reconstruct_fieldmap( ) if not direct: - fmap_img = nt.apply(fmap_img, reference=target) + fmap_img = nt.apply(transforms, fmap_img, reference=target) fmap_img.header.set_intent('estimate', name='fieldmap Hz') fmap_img.header.set_data_dtype('float32')