Skip to content

Commit

Permalink
Fix device placement for PDHG in LinearLeastSquares
Browse files Browse the repository at this point in the history
  • Loading branch information
frankong committed Oct 30, 2019
1 parent 53624f6 commit f681464
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
3 changes: 1 addition & 2 deletions sigpy/alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,13 @@ def __init__(self, proxfc, proxg, A, AH, x, u,
super().__init__(max_iter)

def _update(self):
x_old = self.x.copy()

# Update dual.
util.axpy(self.u, self.sigma, self.A(self.x_ext))
backend.copyto(self.u, self.proxfc(self.sigma, self.u))

# Update primal.
with self.x_device:
x_old = self.x.copy()
util.axpy(self.x, -self.tau, self.AH(self.u))
backend.copyto(self.x, self.proxg(self.tau, self.x))

Expand Down
21 changes: 11 additions & 10 deletions sigpy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,17 @@ def _get_PrimalDualHybridGradient(self):
else:
proxg = self.proxg

if self.G is None:
proxfc = prox.L2Reg(self.y.shape, 1, y=-self.y)
gamma_dual = 1
else:
A = linop.Vstack([A, self.G])
proxf1c = prox.L2Reg(self.y.shape, 1, y=-self.y)
proxf2c = prox.Conj(proxg)
proxfc = prox.Stack([proxf1c, proxf2c])
proxg = prox.NoOp(self.x.shape)
gamma_dual = 0
with self.y_device:
if self.G is None:
proxfc = prox.L2Reg(self.y.shape, 1, y=-self.y)
gamma_dual = 1
else:
A = linop.Vstack([A, self.G])
proxf1c = prox.L2Reg(self.y.shape, 1, y=-self.y)
proxf2c = prox.Conj(proxg)
proxfc = prox.Stack([proxf1c, proxf2c])
proxg = prox.NoOp(self.x.shape)
gamma_dual = 0

if self.tau is None:
if self.sigma is None:
Expand Down

0 comments on commit f681464

Please sign in to comment.