Skip to content

Commit

Permalink
feat: add proj_lineq
Browse files Browse the repository at this point in the history
  • Loading branch information
nperraud committed Jan 22, 2021
1 parent b097204 commit 030ec25
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
69 changes: 69 additions & 0 deletions pyunlocbox/functions.py
Expand Up @@ -36,6 +36,7 @@
proj_positive
proj_b2
proj_lineq
proj_spsd
**Miscellaneous**
Expand Down Expand Up @@ -1043,6 +1044,74 @@ def _prox(self, x, T):
return sol


class proj_lineq(proj):
r"""
Projection on the plane argmin_x || x - z||_2 s.t. Ax = b
This function is the indicator function :math:`i_S(z)` of the set S which
is zero if `z` is in the set and infinite otherwise. The set S is defined
by :math:`\left\{z \in \mathbb{R}^N \mid z \leq 0 \right\}`.
See generic attributes descriptions of the
:class:`pyunlocbox.functions.proj` base class. Note that the constructor
takes keyword-only parameters.
Parameters
----------
This projection requires A as a matrix or pinvA to be provided.
Notes
-----
* The evaluation of this function is zero.
Examples
--------
>>> from pyunlocbox import functions
>>> import numpy as np
>>> x = np.array([0,0])
>>> A = np.array([[1,1]])
>>> pinvA = np.linalg.pinv(A)
>>> y = np.array([1])
>>> f = functions.proj_lineq(A=A, pinvA=pinvA,y=y)
>>> sol = f.prox(x, 0)
>>> sol
array([0.5, 0.5])
>>> np.abs(A.dot(sol) - y)<1e-15
array([ True])
"""
def __init__(self, A=None, pinvA=None, **kwargs):
# Constructor takes keyword-only parameters to prevent user errors.
super(proj_lineq, self).__init__(A=A, **kwargs)
if pinvA is None:
if A is None:
print("Are you sure about the imput parameters?" +
"The projection will return y.")
self.pinvA = lambda x: x
else:
if callable(A):
raise ValueError(
"Please: provide A as a numpy array or provide pinv")
else:
# Transform matrix form to operator form.
self._pinvA = np.linalg.pinv(A)
self.pinvA = lambda x: self._pinvA.dot(x)

else:
if callable(pinvA):
self.pinvA = pinvA
else:
self.pinvA = lambda x: pinvA.dot(x)

def _prox(self, x, T):

# Applying the projection formula
# (for now, only the non scalable version)
residue = self.A(x) - self.y()
sol = x - self.pinvA(residue)
return sol


class structured_sparsity(func):
r"""
Structured sparsity (eval, prox).
Expand Down
59 changes: 58 additions & 1 deletion pyunlocbox/tests/test_functions.py
Expand Up @@ -51,8 +51,10 @@ def assert_equivalent(param1, param2):
assert_equivalent({'y': 3.2}, {'y': lambda: 3.2})
assert_equivalent({'A': None}, {'A': np.identity(3)})
A = np.array([[-4, 2, 5], [1, 3, -7], [2, -1, 0]])
pinvA = np.linalg.pinv(A)
assert_equivalent({'A': A}, {'A': A, 'At': A.T})
assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A})
assert_equivalent({'A': lambda x: A.dot(x), 'pinvA': pinvA},
{'A': A, 'At': A})

def test_dummy(self):
"""
Expand Down Expand Up @@ -417,6 +419,61 @@ def test_proj_b2(self):
f.method = 'NOT_A_VALID_METHOD'
self.assertRaises(ValueError, f.prox, x, 0)

def test_proj_lineq(self):
"""
Test the projection on Ax = y
"""
x = np.zeros([10])
A = np.ones([1, 10])
y = np.array([10])
f = functions.proj_lineq(A=A, y=y)
sol = f.prox(x, 0)
np.testing.assert_allclose(sol, np.ones([10]))
np.abs(A.dot(sol) - y) < 1e-15

f = functions.proj_lineq(A=A)
sol = f.prox(x, 0)
np.testing.assert_allclose(sol, np.zeros([10]))

for i in range(1, 11):
x = np.random.randn(10)
A = np.random.randn(i, 10)
y = np.random.randn(i)
pinvA = np.linalg.pinv(A)
f1 = functions.proj_lineq(A=A, y=y)
f2 = functions.proj_lineq(A=lambda x: A.dot(x), pinvA=pinvA, y=y)
f3 = functions.proj_lineq(A=A, pinvA=lambda x: pinvA.dot(x), y=y)
f4 = functions.proj_lineq(A=A, pinvA=pinvA, y=y)
sol1 = f1.prox(x, 0)
sol2 = f2.prox(x, 0)
sol3 = f3.prox(x, 0)
sol4 = f4.prox(x, 0)
np.testing.assert_allclose(sol1, sol2)
np.testing.assert_allclose(sol1, sol3)
np.testing.assert_allclose(sol1, sol4)
np.testing.assert_allclose(A.dot(sol1), y)

for i in range(11, 15):
x = np.random.randn(10)
A = np.random.randn(i, 10)
y = np.random.randn(i)
pinvA = np.linalg.pinv(A)
f1 = functions.proj_lineq(A=A, y=y)
f2 = functions.proj_lineq(A=lambda x: A.dot(x), pinvA=pinvA, y=y)
f3 = functions.proj_lineq(A=A, pinvA=lambda x: pinvA.dot(x), y=y)
f4 = functions.proj_lineq(A=A, pinvA=pinvA, y=y)
sol1 = f1.prox(x, 0)
sol2 = f2.prox(x, 0)
sol3 = f3.prox(x, 0)
sol4 = f4.prox(x, 0)
np.testing.assert_allclose(sol1, sol2)
np.testing.assert_allclose(sol1, sol3)
np.testing.assert_allclose(sol1, sol4)
np.testing.assert_allclose(sol1, pinvA.dot(y))

self.assertRaises(ValueError, functions.proj_lineq, A=lambda x: x)

def test_proj_positive(self):
"""
Test the projection on the positive octant.
Expand Down

0 comments on commit 030ec25

Please sign in to comment.