Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve #478 #481

Merged
merged 3 commits into from Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion scico/linop/_func.py
Expand Up @@ -49,6 +49,12 @@ def linop_from_function(f: Callable, classname: str, f_name: Optional[str] = Non
implements complex-valued operations, this must be a
complex dtype (typically :attr:`~numpy.complex64`) for
correct adjoint and gradient calculation.
output_shape: Shape of output array. Defaults to ``None``.
If ``None``, `output_shape` is determined by evaluating
`self.__call__` on an input array of zeros.
output_dtype: `dtype` for output argument. Defaults to
``None``. If ``None``, `output_dtype` is determined by
evaluating `self.__call__` on an input array of zeros.
jit: If ``True``, call :meth:`~.LinearOperator.jit` on this
:class:`LinearOperator` to jit the forward, adjoint, and
gram functions. Same as calling
Expand All @@ -62,12 +68,14 @@ def __init__(
input_shape: Union[Shape, BlockShape],
*args: Any,
input_dtype: DType = snp.float32,
output_shape: Optional[Union[Shape, BlockShape]] = None,
output_dtype: Optional[DType] = None,
jit: bool = True,
**kwargs: Any,
):
self._eval = lambda x: f(x, *args, **kwargs)
self.kwargs = kwargs
super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore
super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore

OpClass = type(classname, (LinearOperator,), {"__init__": __init__})
__class__ = OpClass # needed for super() to work
Expand Down
10 changes: 9 additions & 1 deletion scico/operator/_func.py
Expand Up @@ -53,6 +53,12 @@ def operator_from_function(f: Callable, classname: str, f_name: Optional[str] =
implements complex-valued operations, this must be a
complex dtype (typically :attr:`~numpy.complex64`) for
correct adjoint and gradient calculation.
output_shape: Shape of output array. Defaults to ``None``.
If ``None``, `output_shape` is determined by evaluating
`self.__call__` on an input array of zeros.
output_dtype: `dtype` for output argument. Defaults to
``None``. If ``None``, `output_dtype` is determined by
evaluating `self.__call__` on an input array of zeros.
jit: If ``True``, call :meth:`.Operator.jit` on this
`Operator` to jit the forward, adjoint, and gram
functions. Same as calling :meth:`.Operator.jit` after
Expand All @@ -65,11 +71,13 @@ def __init__(
input_shape: Union[Shape, BlockShape],
*args: Any,
input_dtype: DType = snp.float32,
output_shape: Optional[Union[Shape, BlockShape]] = None,
output_dtype: Optional[DType] = None,
jit: bool = True,
**kwargs: Any,
):
self._eval = lambda x: f(x, *args, **kwargs)
super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore
super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore

OpClass = type(classname, (Operator,), {"__init__": __init__})
__class__ = OpClass # needed for super() to work
Expand Down
2 changes: 1 addition & 1 deletion scico/operator/_operator.py
Expand Up @@ -66,7 +66,7 @@ def __repr__(self):
output_dtype : {self.output_dtype}
"""

# See https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__
# See https://numpy.org/doc/stable/user/c-info.beyond-basics.html#ndarray.__array_priority__
__array_priority__ = 1

def __init__(
Expand Down
10 changes: 10 additions & 0 deletions scico/test/linop/test_func.py
Expand Up @@ -19,6 +19,16 @@ def test_transpose():
np.testing.assert_array_equal(H.T @ H @ x, x)


def test_transpose_ext_init():
shape = (1, 2, 3, 4)
perm = (1, 0, 3, 2)
x, _ = randn(shape)
H = linop.Transpose(
shape, perm, input_dtype=snp.float32, output_shape=shape, output_dtype=snp.float32
)
np.testing.assert_array_equal(H @ x, x.transpose(perm))


def test_reshape():
shape = (1, 2, 3, 4)
newshape = (2, 12)
Expand Down
10 changes: 10 additions & 0 deletions scico/test/operator/test_operator.py
Expand Up @@ -233,6 +233,16 @@ def test_make_func_op():
np.testing.assert_array_equal(H(x), snp.abs(x))


def test_make_func_op_ext_init():
AbsVal = operator_from_function(snp.abs, "AbsVal")
shape = (2,)
x, _ = randn(shape, dtype=np.float32)
H = AbsVal(
input_shape=shape, output_shape=shape, input_dtype=np.float32, output_dtype=np.float32
)
np.testing.assert_array_equal(H(x), snp.abs(x))


class TestJacobianProdReal:
def setup_method(self):
N = 7
Expand Down