Skip to content

Commit

Permalink
Removed SenseMaps.
Browse files Browse the repository at this point in the history
  • Loading branch information
frankong committed Jul 21, 2018
1 parent a6de53c commit 1f8b45f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 150 deletions.
109 changes: 12 additions & 97 deletions sigpy/mri/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SenseRecon(sp.app.LinearLeastSquares):
Args:
ksp (array): k-space measurements.
mps (array or SenseMaps): sensitivity maps.
mps (array): sensitivity maps.
lamda (float): regularization parameter.
weights (float or array): weights for data consistency.
coord (None or array): coordinates.
Expand Down Expand Up @@ -87,7 +87,7 @@ class SenseConstrainedRecon(sp.app.L2ConstrainedMinimization):
Args:
ksp (array): k-space measurements.
mps (array or SenseMaps): sensitivity maps.
mps (array): sensitivity maps.
eps (float): constraint parameter.
weights (float or array): weights for data consistency.
coord (None or array): coordinates.
Expand Down Expand Up @@ -124,7 +124,7 @@ class L1WaveletRecon(sp.app.LinearLeastSquares):
Args:
ksp (array): k-space measurements.
mps (array or SenseMaps): sensitivity maps.
mps (array): sensitivity maps.
lamda (float): regularization parameter.
weights (float or array): weights for data consistency.
coord (None or array): coordinates.
Expand Down Expand Up @@ -176,7 +176,7 @@ class L1WaveletConstrainedRecon(sp.app.L2ConstrainedMinimization):
Args:
ksp (array): k-space measurements.
mps (array or SenseMaps): sensitivity maps.
mps (array): sensitivity maps.
eps (float): constraint parameter.
wave_name (str): wavelet name.
weights (float or array): weights for data consistency.
Expand Down Expand Up @@ -215,7 +215,7 @@ class TotalVariationRecon(sp.app.LinearLeastSquares):
Args:
ksp (array): k-space measurements.
mps (array or SenseMaps): sensitivity maps.
mps (array): sensitivity maps.
lamda (float): regularization parameter.
weights (float or array): weights for data consistency.
coord (None or array): coordinates.
Expand Down Expand Up @@ -264,7 +264,7 @@ class TotalVariationConstrainedRecon(sp.app.L2ConstrainedMinimization):
Args:
ksp (array): k-space measurements.
mps (array or SenseMaps): sensitivity maps.
mps (array): sensitivity maps.
eps (float): constraint parameter.
weights (float or array): weights for data consistency.
coord (None or array): coordinates.
Expand Down Expand Up @@ -424,99 +424,14 @@ def _output(self):
# Coil by coil to save memory
with self.device:
mps_rss = 0
mps = []
for mps_ker_c in self.mps_ker:
mps_c = sp.fft.ifft(sp.util.resize(mps_ker_c, self.img_shape))
mps.append(sp.util.move(mps_c))
mps_rss += xp.abs(mps_c)**2

mps_rss = mps_rss**0.5
mps_rss = sp.util.move(mps_rss**0.5)
mps = np.stack(mps)
mps /= mps_rss

img = xp.abs(sp.fft.ifft(
sp.util.resize(self.img_ker, self.img_shape)))
img *= mps_rss

img_weights = 1 / mps_rss
img_weights *= img > self.thresh * img.max()

return SenseMaps(self.mps_ker, img_weights)


class SenseMaps(object):
"""Sensitivity maps class.
Implicitly stored as sensitvity map kernels in k-space and an image mask.
Can be sliced like an array.
Args:
mps_ker (array): sensitivity map kernels.
img_mask (array): image mask.
device (Device): device to store mps_ker and img_mask.
"""

def __init__(self, mps_ker, img_mask, conj=False, device=sp.util.cpu_device):
self.num_coils = len(mps_ker)
self.shape = (self.num_coils, ) + img_mask.shape
self.ndim = len(self.shape)
self.mps_ker = mps_ker
self.img_mask = img_mask
self.use_device(device)
self.dtype = self.mps_ker.dtype
self.conj = conj

def use_device(self, device):
self.device = sp.util.Device(device)
self.mps_ker = sp.util.move(self.mps_ker, device)
self.img_mask = sp.util.move(self.img_mask, device)

def __getitem__(self, slc):

xp = self.device.xp
with self.device:
if isinstance(slc, int):
mps_c = sp.fft.ifft(self.mps_ker[slc], oshape=self.img_mask.shape)
mps_c *= self.img_mask
if self.conj:
return xp.conj(mps_c)
else:
return mps_c

elif isinstance(slc, slice):
return SenseMaps(self.mps_ker[slc], self.img_mask,
conj=self.conj, device=self.device)

elif isinstance(slc, tuple) or isinstance(slc, list):
if isinstance(slc[0], int):
mps = sp.fft.ifft(self.mps_ker[slc[0]], oshape=self.img_mask.shape)
mps *= self.img_mask
if self.conj:
return xp.conj(mps[slc[1:]])
else:
return mps[slc[1:]]

def asarray(self):
ndim = self.img_mask.ndim
xp = self.device.xp
with self.device:
mps = sp.fft.ifft(self.mps_ker, oshape=self.shape, axes=range(-ndim, 0))
mps *= self.img_mask

if self.conj:
return xp.conj(mps)
else:
return mps

def __mul__(self, input):
mps = self.asarray()
return mps * input

def __rmul__(self, input):
return self.__mul__(input)

def conjugate(self):
return SenseMaps(self.mps_ker, self.img_mask,
conj=not self.conj, device=self.device)

def save(self, filename):
self.use_device(sp.util.cpu_device)
with open(filename, "wb") as f:
pickle.dump(self, f)
return mps
56 changes: 3 additions & 53 deletions sigpy/mri/linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ def Sense(mps, coord=None):
"""Sense linear operator.
Args:
mps (array or SenseMaps): sensitivity maps of length = number of channels.
mps (array): sensitivity maps of length = number of channels.
coord (None or array): coordinates.
"""

ndim = mps.ndim - 1
img_shape = mps.shape[1:]

S = SenseMultiply(mps)
S = sp.linop.Multiply(img_shape, mps)

if coord is None:
F = sp.linop.FFT(S.oshape, axes=range(-ndim, 0))
Expand All @@ -28,57 +29,6 @@ def Sense(mps, coord=None):
return A


class SenseMultiply(sp.linop.Linop):
"""Sense multiply linear operator.
Args:
mps (array or SenseMaps): sensitivity maps of length = number of channels.
"""

def __init__(self, mps):
self.mps = mps
ishape = self.mps.shape[1:]
oshape = self.mps.shape

super().__init__(oshape, ishape)

def _apply(self, input):
device = sp.util.get_device(input)

with device:
return self.mps * input

def _adjoint_linop(self):

return SenseCombine(self.mps)


class SenseCombine(sp.linop.Linop):
"""Sense combine linear operator.
Args:
mps (array or SenseMaps): sensitivity maps of length = number of channels.
"""

def __init__(self, mps):
self.mps = mps
oshape = self.mps.shape[1:]
ishape = self.mps.shape

super().__init__(oshape, ishape)

def _apply(self, input):
device = sp.util.get_device(input)
xp = device.xp

with device:
return xp.sum(self.mps.conjugate() * input, axis=0)

def _adjoint_linop(self):

return SenseMultiply(self.mps)


def ConvSense(img_ker_shape, mps_ker, coord=None):
"""Convolution linear operator with sensitivity maps kernel in k-space.
Expand Down

0 comments on commit 1f8b45f

Please sign in to comment.