Skip to content

Commit

Permalink
Add tests for BlockArray variables
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Jan 6, 2022
1 parent b2da7c8 commit 28eeb7e
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 44 deletions.
87 changes: 65 additions & 22 deletions scico/test/optimize/test_ladmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, random
from scico.blockarray import BlockArray
from scico.optimize import LinearizedADMM


Expand All @@ -12,42 +13,51 @@ def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.λ = 1e0

def test_ladmm(self):
maxiter = 2
μ = 1e-1
ν = 1e-1
A = linop.Identity(self.y.shape)
f = loss.SquaredL2Loss(y=self.y, A=A)
g = (self.λ / 2) * functional.BM3D()
C = linop.Identity(self.y.shape)

self.maxiter = 2
self.μ = 1e-1
self.ν = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.BM3D()
self.C = linop.Identity(self.y.shape)

def test_itstat(self):
itstat_fields = {"Iter": "%d", "Time": "%8.2e"}

def itstat_func(obj):
return (obj.itnum, obj.timer.elapsed())

ladmm_ = LinearizedADMM(
f=f,
g=g,
C=C,
mu=μ,
nu=ν,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
)
assert len(ladmm_.itstat_object.fieldname) == 4
assert snp.sum(ladmm_.x) == 0.0

ladmm_ = LinearizedADMM(
f=f,
g=g,
C=C,
mu=μ,
nu=ν,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False},
)
assert len(ladmm_.itstat_object.fieldname) == 2

def test_callback(self):
ladmm_ = LinearizedADMM(
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
)
ladmm_.test_flag = False

def callback(obj):
Expand All @@ -57,6 +67,39 @@ def callback(obj):
assert ladmm_.test_flag


class TestBlockArray:
def setup_method(self, method):
np.random.seed(12345)
self.y = BlockArray.array(
(
np.random.randn(32, 33).astype(np.float32),
np.random.randn(
17,
).astype(np.float32),
)
)
self.λ = 1e0
self.maxiter = 1
self.μ = 1e-1
self.ν = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.L2Norm()
self.C = linop.Identity(self.y.shape)

def test_blockarray(self):
ladmm_ = LinearizedADMM(
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
)
x = ladmm_.solve()
assert isinstance(x, BlockArray)


class TestReal:
def setup_method(self, method):
np.random.seed(12345)
Expand Down
87 changes: 65 additions & 22 deletions scico/test/optimize/test_pdhg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, random
from scico.blockarray import BlockArray
from scico.optimize import PDHG


Expand All @@ -12,42 +13,51 @@ def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.λ = 1e0

def test_pdhg(self):
maxiter = 2
τ = 1e-1
σ = 1e-1
A = linop.Identity(self.y.shape)
f = loss.SquaredL2Loss(y=self.y, A=A)
g = (self.λ / 2) * functional.BM3D()
C = linop.Identity(self.y.shape)

self.maxiter = 2
self.τ = 1e-1
self.σ = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.BM3D()
self.C = linop.Identity(self.y.shape)

def test_itstat(self):
itstat_fields = {"Iter": "%d", "Time": "%8.2e"}

def itstat_func(obj):
return (obj.itnum, obj.timer.elapsed())

pdhg_ = PDHG(
f=f,
g=g,
C=C,
tau=τ,
sigma=σ,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
)
assert len(pdhg_.itstat_object.fieldname) == 4
assert snp.sum(pdhg_.x) == 0.0

pdhg_ = PDHG(
f=f,
g=g,
C=C,
tau=τ,
sigma=σ,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False},
)
assert len(pdhg_.itstat_object.fieldname) == 2

def test_callback(self):
pdhg_ = PDHG(
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
)
pdhg_.test_flag = False

def callback(obj):
Expand All @@ -57,6 +67,39 @@ def callback(obj):
assert pdhg_.test_flag


class TestBlockArray:
def setup_method(self, method):
np.random.seed(12345)
self.y = BlockArray.array(
(
np.random.randn(32, 33).astype(np.float32),
np.random.randn(
17,
).astype(np.float32),
)
)
self.λ = 1e0
self.maxiter = 1
self.τ = 1e-1
self.σ = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.L2Norm()
self.C = linop.Identity(self.y.shape)

def test_blockarray(self):
pdhg_ = PDHG(
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
)
x = pdhg_.solve()
assert isinstance(x, BlockArray)


class TestReal:
def setup_method(self, method):
np.random.seed(12345)
Expand Down

0 comments on commit 28eeb7e

Please sign in to comment.