Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into thilo/svmbir_weight_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tbalke committed Nov 12, 2021
2 parents a4e259c + f1a138c commit 449d0b2
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
pip install -r dev_requirements.txt
pip install -e .
- name: Run tests with pytest
run: pytest --cov=./ --cov-report=xml
run: pytest --cov=scico --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
with:
Expand Down
59 changes: 58 additions & 1 deletion scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from bm3d import bm3d, bm3d_rgb

import scico.numpy as snp
from scico.blockarray import BlockArray
from scico.data import _flax_data_path
from scico.flax import DnCNNNet, load_weights
Expand Down Expand Up @@ -54,7 +55,15 @@ def __init__(self, is_rgb: Optional[bool] = False):
super().__init__()

def prox(self, x: JaxArray, lam: float = 1) -> JaxArray:
r"""Apply BM3D denoiser with noise level ``lam``"""
r"""Apply BM3D denoiser with noise level ``lam``.
Args:
x : input image.
lam : noise level.
Returns:
BM3D denoised output.
"""

# BM3D only works on (NxN) or (NxNxC) arrays
# In future, may want to apply prox along an "axis"
Expand Down Expand Up @@ -112,6 +121,10 @@ class DnCNN(FlaxMap):
def __init__(self, variant: Optional[str] = "6M"):
"""Initialize a :class:`DnCNN` object.
Note that all DnCNN models are trained for single-channel image
input. Multi-channel input is supported via independent denoising
of each channel.
Args:
variant : Identify the DnCNN model to be used. Options are
'6L', '6M' (default), '6H', '17L', '17M', and '17H',
Expand All @@ -130,3 +143,47 @@ def __init__(self, variant: Optional[str] = "6M"):
model = DnCNNNet(depth=nlayer, channels=1, num_filters=64, dtype=np.float32)
variables = load_weights(_flax_data_path("dncnn%s.npz" % variant))
super().__init__(model, variables)

def prox(self, x: JaxArray, lam: float = 1) -> JaxArray:
r"""Apply DnCNN denoiser.
*Warning*: The ``lam`` parameter is ignored, and has no effect on
the output.
Args:
x : input.
lam : noise estimate (ignored).
Returns:
DnCNN denoised output.
"""
if np.iscomplexobj(x):
raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}")

if x.ndim < 2:
raise ValueError(
f"DnCNN requires two dimensional (M, N) or three dimensional (M, N, C)"
" inputs; got ndim = {x.ndim}"
)

x_in_shape = x.shape
if x.ndim > 3:
if all(k == 1 for k in x.shape[3:]):
x = x.squeeze()
else:
raise ValueError(
"Arrays with more than three axes are only supported when "
" the additional axes are singletons"
)

if x.ndim == 3:
# swap channel axis to batch axis and add singleton axis at end
y = super().prox(snp.swapaxes(x, 0, -1)[..., np.newaxis], lam)
# drop singleton axis and swap axes back to original positions
y = snp.swapaxes(y[..., 0], 0, -1)
else:
y = super().prox(x, lam)

y = y.reshape(x_in_shape)

return y
24 changes: 15 additions & 9 deletions scico/functional/_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from typing import Any, Callable

import scico.numpy as snp
from flax import linen as nn
from scico.blockarray import BlockArray
from scico.typing import JaxArray
Expand All @@ -33,7 +32,7 @@ def __init__(self, model: Callable[..., nn.Module], variables: PyTree):
Args:
model : Flax model to apply.
variables : parameters and batch stats of trained model.
variables : Parameters and batch stats of trained model.
"""
self.model = model
self.variables = variables
Expand All @@ -42,21 +41,28 @@ def __init__(self, model: Callable[..., nn.Module], variables: PyTree):
def prox(self, x: JaxArray, lam: float = 1) -> JaxArray:
r"""Apply trained flax model.
*Warning*: The ``lam`` parameter is ignored, and has no effect on
the output.
Args:
x : input.
lam : noise estimate (not used).
lam : noise estimate (ignored).
Returns:
Output of flax model.
"""
if isinstance(x, BlockArray):
raise NotImplementedError
else:
# add input singleton
# scico works on (NxN) or (NxNxC) arrays
# flax works on (KxNxNxC) arrays
# (generally KxHxWxC arrays)
# K: input dim
# Add singleton to input as necessary:
# scico typically works with (HxW) or (HxWxC) arrays
# flax expects (KxHxWxC) arrays
# H: spatial height W: spatial width
# K: batch size C: channel size
x_shape = x.shape
if x.ndim == 2:
x = x.reshape((1,) + x.shape + (1,))
elif x.ndim == 3:
x = x.reshape((1,) + x.shape)
y = self.model.apply(self.variables, x, train=False, mutable=False)
return snp.squeeze(y)
return y.reshape(x_shape)
42 changes: 42 additions & 0 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,45 @@ def test_prox_bad_inputs(self):
z, key = randn((32, 32), key=None, dtype=np.complex64)
with pytest.raises(TypeError):
self.f.prox(z, 1.0)


class TestDnCNN:
def setup(self):
key = None
N = 32
self.x, key = randn((N, N), key=key, dtype=np.float32)
self.x_mltchn, key = randn((N, N, 5), key=key, dtype=np.float32)

self.f = functional.DnCNN()

def test_prox(self):
no_jit = self.f.prox(self.x, 1.0)
jitted = jax.jit(self.f.prox)(self.x, 1.0)
np.testing.assert_allclose(no_jit, jitted, rtol=1e-3)
assert no_jit.dtype == np.float32
assert jitted.dtype == np.float32

def test_prox_mltchn(self):
no_jit = self.f.prox(self.x_mltchn, 1.0)
jitted = jax.jit(self.f.prox)(self.x_mltchn, 1.0)
np.testing.assert_allclose(no_jit, jitted, rtol=1e-3)
assert no_jit.dtype == np.float32
assert jitted.dtype == np.float32

def test_prox_bad_inputs(self):

x, key = randn((32,), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.f.prox(x, 1.0)

x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.f.prox(x, 1.0)

x_b, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.f.prox(x, 1.0)

z, key = randn((32, 32), key=None, dtype=np.complex64)
with pytest.raises(TypeError):
self.f.prox(z, 1.0)

0 comments on commit 449d0b2

Please sign in to comment.