/
io.py
90 lines (69 loc) · 2.16 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
"""
pyrobex.io
load an image with either antspy or nibabel
Author: Jacob Reinhold (jcreinhold@gmail.com)
Created on: May 6, 2021
"""
__all__ = [
"NiftiImage",
"NiftiImagePair",
]
import logging
import os
from typing import Tuple, Type, TypeVar, Union
import numpy as np
logger = logging.getLogger(__name__)
ants_flag = os.environ.get("USE_ANTSPY")
if ants_flag is not None:
use_ants = ants_flag.lower() == "true"
else:
use_ants = False
if use_ants:
try:
import ants
logger.info("Using antspy as the backend.")
except (ImportError, ModuleNotFoundError):
import nibabel as nib
msg = "USE_ANTSPY set to true, but could not import antspy.\n"
msg += "Using nibabel as fallback."
logger.warning(msg)
use_ants = False
else:
import nibabel as nib # noqa
logger.info("Using nibabel as the backend.")
NI = TypeVar("NI", bound="NiftiImage")
class NiftiImage:
"""Helper class to work with nibabel and antspy images"""
def __init__(
self,
data: np.ndarray,
header=None, # type: nib.Nifti1Header
affine: np.ndarray = None,
extra: dict = None,
):
self.data = data
self.header = header
self.affine = affine
self.extra = extra
@classmethod
def load(cls: Type[NI], filename: str) -> Union[NI, "ants.ANTsImage"]:
if use_ants:
# if ants, don't need to bother with this class
image = ants.image_read(str(filename))
return image #
else:
image = nib.load(filename)
data = np.asarray(image.get_fdata()) # convert memmap to ndarray
header = image.header
affine = image.affine
extra = image.extra
return cls(data, header, affine, extra)
def to_filename(self, filename: str) -> None:
img = self.to_nibabel()
img.to_filename(filename)
def to_nibabel(self) -> nib.Nifti1Image:
return nib.Nifti1Image(self.data, self.affine, self.header, self.extra)
def get_fdata(self) -> np.ndarray:
return self.data
NiftiImagePair = Tuple[NiftiImage, NiftiImage]