-
Notifications
You must be signed in to change notification settings - Fork 43
/
dcomp.py
119 lines (101 loc) · 3.6 KB
/
dcomp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Optional, Sequence, Union
import torch
import torchkbnufft.functional as tkbnF
from torch import Tensor
from .utils import init_fn
def calc_density_compensation_function(
ktraj: Tensor,
im_size: Sequence[int],
num_iterations: int = 10,
grid_size: Optional[Sequence[int]] = None,
numpoints: Union[int, Sequence[int]] = 6,
n_shift: Optional[Sequence[int]] = None,
table_oversamp: Union[int, Sequence[int]] = 2 ** 10,
kbwidth: float = 2.34,
order: Union[float, Sequence[float]] = 0.0,
) -> Tensor:
"""Numerical density compensation estimation.
This function has optional parameters for initializing a NUFFT object. See
:py:class:`~torchkbnufft.KbInterp` for details.
* :attr:`ktraj` should be of size ``(len(grid_size), klength)`` or
``(N, len(grid_size), klength)``, where ``klength`` is the length of the
k-space trajectory.
Based on the `method of Pipe
<https://doi.org/10.1002/(SICI)1522-2594(199901)41:1%3C179::AID-MRM25%3E3.0.CO;2-V>`_.
This code was contributed by Chaithya G.R.
Args:
ktraj: k-space trajectory (in radians/voxel).
im_size: Size of image with length being the number of dimensions.
num_iterations: Number of iterations.
grid_size: Size of grid to use for interpolation, typically 1.25 to 2
times ``im_size``. Default: ``2 * im_size``
numpoints: Number of neighbors to use for interpolation in each
dimension. Default: ``6``
n_shift: Size for fftshift. Default: ``im_size // 2``.
table_oversamp: Table oversampling factor.
kbwidth: Size of Kaiser-Bessel kernel.
order: Order of Kaiser-Bessel kernel.
Returns:
The density compensation coefficients for ``ktraj``.
Examples:
>>> data = torch.randn(1, 1, 12) + 1j * torch.randn(1, 1, 12)
>>> omega = torch.rand(2, 12) * 2 * np.pi - np.pi
>>> dcomp = tkbn.calc_density_compensation_function(omega, (8, 8))
>>> adjkb_ob = tkbn.KbNufftAdjoint(im_size=(8, 8))
>>> image = adjkb_ob(data * dcomp, omega)
"""
device = ktraj.device
batch_size = 1
if ktraj.ndim not in (2, 3):
raise ValueError("ktraj must have 2 or 3 dimensions")
if ktraj.ndim == 3:
if ktraj.shape[0] == 1:
ktraj = ktraj[0]
else:
batch_size = ktraj.shape[0]
# init nufft variables
(
tables,
_,
grid_size_t,
n_shift_t,
numpoints_t,
offsets_t,
table_oversamp_t,
_,
_,
) = init_fn(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints,
n_shift=n_shift,
table_oversamp=table_oversamp,
kbwidth=kbwidth,
order=order,
dtype=ktraj.dtype,
device=device,
)
test_sig = torch.ones(
[batch_size, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device
)
for _ in range(num_iterations):
new_sig = tkbnF.kb_table_interp(
image=tkbnF.kb_table_interp_adjoint(
data=test_sig,
omega=ktraj,
tables=tables,
n_shift=n_shift_t,
numpoints=numpoints_t,
table_oversamp=table_oversamp_t,
offsets=offsets_t,
grid_size=grid_size_t,
),
omega=ktraj,
tables=tables,
n_shift=n_shift_t,
numpoints=numpoints_t,
table_oversamp=table_oversamp_t,
offsets=offsets_t,
)
test_sig = test_sig / torch.abs(new_sig)
return test_sig