Skip to content

Commit

Permalink
port sparse_coding module to new API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 398689820
  • Loading branch information
marcocuturi authored and JAXopt authors committed Sep 24, 2021
1 parent 523edc3 commit ef169f9
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 120 deletions.
224 changes: 104 additions & 120 deletions examples/sparse_coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Implementation of sparse coding using jaxopt.
=============================================
"""
"""Implementation of sparse coding using jaxopt."""

import functools
from typing import Optional
from typing import Type
from typing import Mapping
from typing import Any
from typing import Callable
from typing import Tuple
from typing import Mapping
from typing import Optional

from flax import optim
import jax
import jax.numpy as jnp
from jaxopt import OptaxSolver
from jaxopt import projection
from jaxopt import prox
from jaxopt import proximal_gradient
from jaxopt import ProximalGradient


def dictionary_loss(
codes: jnp.ndarray,
params: Tuple[jnp.ndarray, jnp.ndarray],
reconstruction_loss_fun: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] = None
):
dictionary: jnp.ndarray,
data: jnp.ndarray,
reconstruction_loss_fun: Callable[[jnp.ndarray, jnp.ndarray],
jnp.ndarray] = None):
"""Computes reconstruction loss between data and dict/codes using loss fun.
Args:
codes: a samples x components jnp.ndarray of codes.
params: Tuple containing dictionary and data matrix.
codes: a n_samples x components array of codes.
dictionary: a components x dimension array
data: a n_samples x dimension array
reconstruction_loss_fun: a callable loss(x, y) -> a real number, where
x and y are either entries, slices or the matrices themselves.
Set to 1/2 squared L2 norm of difference by default.
Expand All @@ -52,175 +49,163 @@ def dictionary_loss(
"""
if reconstruction_loss_fun is None:
reconstruction_loss_fun = lambda x, y: 0.5 * jnp.sum((x - y)**2)

dic, X = params
X_pred = codes @ dic
return reconstruction_loss_fun(X, X_pred)
pred = codes @ dictionary
return reconstruction_loss_fun(data, pred)


def make_task_driven_dictionary_learner(
task_loss_fun: Optional[Callable[[Any, Any, Any, Any], float]] = None,
reconstruction_loss_fun: Optional[Callable[[jnp.ndarray, jnp.ndarray],
jnp.ndarray]] = None,
optimizer_cls: Optional[Type[optim.Optimizer]] = None,
optimizer_kw: Mapping[str, Any] = None,
sparse_coding_kw: Mapping[str, Any] = None):
"""Makes a task driven sparse dictionary learning solver.
optimizer = None,
sparse_coding_kw: Mapping[str, Any] = None,
**kwargs):
"""Makes a task-driven sparse dictionary learning solver.
Returns a jaxopt solver, using either an optax optimizer or jaxopt prox
gradient optimizer, to compute, given data, a dictionary whose corresponding
codes minimizes a given task loss. The solver is defined through the task loss
function, a reconstruction loss function, and an optimizer. Additional
parameters can be passed on to lower level functions, notably the computation
of sparse codes and optimizer parameters.
Args:
task_loss_fun: loss as specified on (codes, dict, task_vars, params) that
supplements the usual reconstruction loss formulation. If None, only
dictionary learning is carried out, i.e. that term is assumed to be 0.
reconstruction_loss_fun: entry (or slice-) wise loss function, set to be
the Frobenius norm, || . - . ||^2 by default.
optimizer_cls: Optimizer to solve for dictionary and task_vars (if auxiliary
task is given). Either None, in which case Jaxopt proximal gradient
(with sphere projection on dictionary) is used, or a flax
optimizer class specifying projection on the sphere explicitly for dic.
optimizer_kw: Arguments to be passed to the optimizer class above, or to
jaxopt proximal gradient descent.
the Frobenius norm between matrices, || . - . ||^2 by default.
optimizer: optax optimizer. fall back on jaxopt proxgrad if None.
sparse_coding_kw: Jaxopt arguments to be passed to the proximal descent
algorithm computing codes, sparse_coding.
**kwargs: passed onto _task_sparse_dictionary_learning function.
Returns:
Function to learn dictionary from data, number of components and
elastic net regularization, using initialization for dictionary,
parameters for task and task variables initialization.
"""
def learner(X: jnp.ndarray,
def learner(data: jnp.ndarray,
n_components: int,
regularization: float,
elastic_penalty: float,
dict_init: Optional[jnp.ndarray] = None,
task_vars_init: jnp.ndarray = None,
task_params: jnp.ndarray = None,
task_vars_init: jnp.ndarray = None):
dic_init: Optional[jnp.ndarray] = None):

return _task_sparse_dictionary_learning(X, n_components, regularization,
elastic_penalty, dict_init,
task_params, task_vars_init,
return _task_sparse_dictionary_learning(data, n_components, regularization,
elastic_penalty, task_vars_init,
optimizer,
dic_init, task_params,
reconstruction_loss_fun,
task_loss_fun,
optimizer_cls, optimizer_kw,
sparse_coding_kw)
sparse_coding_kw, **kwargs)

return learner


def _task_sparse_dictionary_learning(
X: jnp.ndarray,
data: jnp.ndarray,
n_components: int,
regularization: float,
elastic_penalty: float,
dict_init: Optional[jnp.ndarray] = None,
task_vars_init: jnp.ndarray,
optimizer=None,
dic_init: Optional[jnp.ndarray] = None,
task_params: jnp.ndarray = None,
task_vars_init: jnp.ndarray = None,
reconstruction_loss_fun: Callable[[jnp.ndarray, jnp.ndarray],
jnp.ndarray] = None,
task_loss_fun: Callable[[Any, Any, Any, Any], float] = None,
optimizer_cls: Optional[Type[optim.Optimizer]] = None,
optimizer_kw: Mapping[str, Any] = None,
sparse_coding_kw: Mapping[str, Any] = None):
"""Computes task driven dictionary, w. implicitly defined sparse codes.
Given a N x d data matrix X, solves a bilevel optimization problem by seeking
a dictionary dic of size n_components x d such that, defining implicitly
codes = sparse_coding(dic, (X, regularization, elastic_penalty))
one has that dic minimizes
task_loss(codes, dic, task_var, task_params)
if such as task_loss was passed on. If None, then task_loss is replaced by
dictionary_loss(codes, (dic, X)).
sparse_coding_kw: Mapping[str, Any] = None,
maxiter: int = 100):
r"""Computes task driven dictionary, w. implicitly defined sparse codes.
Given a N x d ``data`` matrix, solves a bilevel optimization problem by
seeking a dictionary ``dic`` of size ``n_components`` x ``d`` such that,
defining implicitly
``codes = sparse_coding(dic, (data, regularization, elastic_penalty))``
one has that ``dic`` minimizes
``task_loss(codes, dic, task_var, task_params)``,
if such as ``task_loss`` was passed on. If ``task_loss`` is ``None``, then
``task_loss`` is replaced by default by
``dictionary_loss(codes, (dic, data))``.
Args:
X: N x d jnp.ndarray, data matrix with N samples of d features.
data: N x d jnp.ndarray, data matrix with N samples of d features.
n_components: int, number of atoms in dictionary.
regularization: regularization strength of elastic penalty.
elastic_penalty: strength of L2 penalty relative to L1.
task_params: auxiliary parameters to define task loss, typically data.
dict_init: initialization for dictionary; that returned by SVD by default.
task_vars_init: initializer for task related optimization variables.
optimizer: If None, falls back on jaxopt proximal gradient (with sphere
projection for ``dic``). If not ``None``, use that algorithm's method with
a normalized dictionary.
dic_init: initialization for dictionary; that returned by SVD by default.
reconstruction_loss_fun: loss to be applied to compute reconstruction error.
task_params: auxiliary parameters to define task loss, typically data.
task_loss_fun: task driven loss for codes and dictionary using task_vars and
task_params.
optimizer_cls: flax optimizer class. If None, falls back on jaxopt projected
gradient (with sphere normalization constraints). If not None, instantiate
that optimizer.
optimizer_kw: parameters passed on to optimizer
sparse_coding_kw: parameters passed on to jaxopt prox gradient solver.
sparse_coding_kw: parameters passed on to jaxopt prox gradient solver to
compute codes.
maxiter: maximal number of iterations of the outer loop.
Returns:
the n_components x d dictionary solution found by the algorithm, as well as
codes.
A``n_components x d`` matrix, the ``dic`` solution found by the algorithm,
as well as task variables if task was provided.
"""

if dict_init is None:
_, _, dict_init = jax.scipy.linalg.svd(X, False)
dict_init = dict_init[:n_components, :]
if dic_init is None:
_, _, dic_init = jax.scipy.linalg.svd(data, False)
dic_init = dic_init[:n_components, :]

has_task = task_loss_fun is not None

# Loss function, dictionary learning in addition to task driven loss
def loss_fun(variables, params):
dic, task_vars = variables
coding_params, task_params = params
def loss_fun(params, hyper_params):
dic, task_vars = params
coding_params, task_params = hyper_params
codes = sparse_coding(
dic,
coding_params,
reconstruction_loss_fun=reconstruction_loss_fun,
sparse_coding_kw=sparse_coding_kw)
if optimizer is not None:
dic = projection.projection_l2_sphere(dic)

if has_task: # if there is a task, drop loss, replace it with proper value
if has_task:
loss = task_loss_fun(codes, dic, task_vars, task_params)
else:
loss = dictionary_loss(codes, (dic, X), reconstruction_loss_fun)
loss = dictionary_loss(codes, dic, data, reconstruction_loss_fun)
return loss, codes

init = (dict_init, task_vars_init)

optimizer_kw = {} if optimizer_kw is None else optimizer_kw
def prox_dic(params, hyper, step):
# Here projection/prox is only applied on the dictionary.
del hyper, step
dic, task_vars = params
return projection.projection_l2_sphere(dic), task_vars

proj_sphere = lambda x: jax.vmap(projection.projection_l2_sphere)(x)
if optimizer_cls is None:
# If no optimizer, use jaxopt projected gradient descent.
if optimizer is None:
solver = ProximalGradient(fun=loss_fun, prox=prox_dic, has_aux=True)
params, state = solver.init((dic_init, task_vars_init), None)

# Define projection-prox, here normalize each dict atom by its norm.
for _ in range(maxiter):
params, state = solver.update(
params, state, None,
((data, regularization, elastic_penalty), task_params))

prox_vars = lambda dic_vars, par, s : (
proj_sphere(dic_vars[0]), dic_vars[1])

solver = proximal_gradient.make_solver_fun(
fun=loss_fun, prox=prox_vars, has_aux=True,
init=init, **optimizer_kw)
dic, task_vars = solver(((X, regularization, elastic_penalty), task_params))
# Normalize dictionary before returning it.
dic, task_vars = prox_dic(params, None, None)

else:
maxiter = optimizer_kw.pop('maxiter', 500) # Pop'd to set loop size.
optimizer = optimizer_cls(**optimizer_kw)
optimizer = optimizer.create(init)

# Use implicit jaxopt gradients to inform optimizer's steps.
loss_normalized = lambda dic_vars, params: loss_fun(
(proj_sphere(dic_vars[0]), dic_vars[1]), params)
grad_fn = jax.value_and_grad(loss_normalized, has_aux=True)

def train_step(optimizer, params):
(loss, codes), grad = grad_fn(optimizer.target, params)
new_optimizer = optimizer.apply_gradient(grad)
return new_optimizer, loss

# Training body fun.
def body_fun(iteration, in_vars):
del iteration
optimizer, pars = in_vars
optimizer, _ = train_step(optimizer, pars)
return (optimizer, pars)

init_val = (optimizer, ((X, regularization, elastic_penalty), task_params))
solver = OptaxSolver(opt=optimizer, fun=loss_fun, has_aux=True)
params, state = solver.init((dic_init, task_vars_init))

# Run fori_loop, this will be converted to a scan.
optimizer, _ = jax.lax.fori_loop(0, maxiter, body_fun, init_val)
for _ in range(maxiter):
params, state = solver.update(
params, state,
((data, regularization, elastic_penalty), task_params))

dic, task_vars = optimizer.target
# Normalize dictionary before returning it.
dic = proj_sphere(dic)
# Normalize dictionary before returning it.
dic, task_vars = prox_dic(params, None, None)

if has_task:
return dic, task_vars
Expand All @@ -229,23 +214,22 @@ def body_fun(iteration, in_vars):

def sparse_coding(dic, params, reconstruction_loss_fun=None,
sparse_coding_kw=None, codes_init=None):
"""Computes optimal codes for data X given a dictionary dic."""
"""Computes optimal codes for data given a dictionary dic using params."""
sparse_coding_kw = {} if sparse_coding_kw is None else sparse_coding_kw
loss_fun = functools.partial(dictionary_loss,
reconstruction_loss_fun=reconstruction_loss_fun)
X, regularization, elastic_penalty = params
data, regularization, elastic_penalty = params
n_components, _ = dic.shape
N, _ = X.shape
n_points, _ = data.shape

if codes_init is None:
codes_init = jnp.zeros((N, n_components))
codes_init = jnp.zeros((n_points, n_components))

solver = proximal_gradient.make_solver_fun(
solver = ProximalGradient(
fun=loss_fun,
prox=prox.prox_elastic_net,
init=codes_init,
**sparse_coding_kw)

codes = solver(params_fun=(dic, X),
params_prox=[regularization, elastic_penalty])
codes = solver.run(codes_init, [regularization, elastic_penalty],
dic, data).params
return codes

1 comment on commit ef169f9

@marcocuturi
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Porting the existing sparse_coding module to the most recent API.

Please sign in to comment.