Skip to content

Commit

Permalink
Updated submodules to use new convolution.
Browse files Browse the repository at this point in the history
  • Loading branch information
frankong committed Aug 23, 2018
1 parent bf34b8e commit f0c25e0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 45 deletions.
47 changes: 6 additions & 41 deletions sigpy/learn/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(self, y_j, l, lamda=1,

self._get_params()
self.r_j = sp.util.empty(self.r_j_shape, dtype=self.dtype, device=device)
self._get_A_r_j()
self.A_r_j = sp.linop.ConvolveInput(self.r_j.shape, self.l, mode=self.mode,
input_multi_channel=True,
output_multi_channel=self.multi_channel)
proxg_r_j = sp.prox.L1Reg(self.A_r_j.ishape, lamda)

super().__init__(self.A_r_j, self.y_j, self.r_j, proxg=proxg_r_j, **kwargs)
Expand Down Expand Up @@ -79,25 +81,6 @@ def _get_params(self):
[i + min(i, f) - 1 for i, f in zip(self.y_j.shape[-self.data_ndim:],
self.l.shape[-self.data_ndim:])])

def _get_A_r_j(self):
if self.device != sp.util.cpu_device and sp.config.cudnn_enabled:
if self.multi_channel:
l_cudnn_shape = self.l.shape
else:
l_cudnn_shape = (1, ) + self.l.shape

self.A_r_j = sp.linop.CudnnConvolveData(
self.r_j.shape, self.l.reshape(l_cudnn_shape), mode=self.mode)

if not self.multi_channel:
R_y = sp.linop.Reshape(self.y_j.shape, self.A_r_j.oshape)
self.A_r_j = R_y * self.A_r_j
else:
C_r_j = sp.linop.Convolve(
self.r_j.shape, self.l, axes=range(-self.data_ndim, 0), mode=self.mode)
S_r_j = sp.linop.Sum(C_r_j.oshape, axes=[-(self.data_ndim + 1)])
self.A_r_j = S_r_j * C_r_j


class ConvSparseCoding(sp.app.App):
r"""Convolutional sparse coding application.
Expand Down Expand Up @@ -212,26 +195,6 @@ def _get_batch_vars(self):
self.j_idx = sp.index.ShuffledIndex(self.num_batches)
self.y_j = sp.util.empty((self.batch_size, ) + self.y.shape[1:],
dtype=self.dtype, device=self.device)

def _get_A_l(self):
if self.device != sp.util.cpu_device and sp.config.cudnn_enabled:
if self.multi_channel:
l_cudnn_shape = self.l.shape
else:
l_cudnn_shape = (1, ) + self.l.shape

R_l = sp.linop.Reshape(l_cudnn_shape, self.l.shape)
C_l = sp.linop.CudnnConvolveFilter(l_cudnn_shape, self.r_j, mode=self.mode)
self.A_l = C_l * R_l

if not self.multi_channel:
R_y = sp.linop.Reshape(self.y_j.shape, self.A_l.oshape)
self.A_l = R_y * self.A_l
else:
C_l = sp.linop.Convolve(
self.l.shape, self.r_j, axes=range(-self.data_ndim, 0), mode=self.mode)
S_l = sp.linop.Sum(C_l.oshape, axes=[1])
self.A_l = S_l * C_l

def _get_alg(self):
min_r_j_app = ConvSparseDecom(self.y_j, self.l, lamda=self.lamda,
Expand All @@ -240,7 +203,9 @@ def _get_alg(self):
max_iter=self.max_r_j_iter, device=self.device)
self.r_j = min_r_j_app.r_j

self._get_A_l()
self.A_l = sp.linop.ConvolveFilter(self.l_shape, self.r_j, mode=self.mode,
input_multi_channel=True,
output_multi_channel=self.multi_channel)
if self.multi_channel:
proxg_l = sp.prox.L2Proj(self.l_shape, 1, axes=[0] + list(range(-self.data_ndim, 0)))
else:
Expand Down
6 changes: 2 additions & 4 deletions sigpy/mri/linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def ConvSense(img_ker_shape, mps_ker, coord=None):
"""

ndim = len(img_ker_shape)
A = sp.linop.Convolve(
img_ker_shape, mps_ker, axes=range(-ndim, 0), mode='valid')
A = sp.linop.ConvolveInput(img_ker_shape, mps_ker, mode='valid', output_multi_channel=True)

if coord is not None:
num_coils = mps_ker.shape[0]
Expand All @@ -63,8 +62,7 @@ def ConvImage(mps_ker_shape, img_ker, coord=None):
"""
ndim = img_ker.ndim

A = sp.linop.Convolve(
mps_ker_shape, img_ker, axes=range(-ndim, 0), mode='valid')
A = sp.linop.ConvolveFilter(mps_ker_shape, img_ker, mode='valid', output_multi_channel=True)

if coord is not None:
num_coils = mps_ker_shape[0]
Expand Down

0 comments on commit f0c25e0

Please sign in to comment.