Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LevenbergMarquardt implementation does not accept PyTree parameters #579

Open
Joshuaalbert opened this issue Feb 6, 2024 · 0 comments
Open

Comments

@Joshuaalbert
Copy link

Description

The LevenbergMarquardt implementation does not accept PyTree parameters, giving TypeError: primal and tangent arguments to jax.jvp must have the same tree structure at levenberg_marquardt.py, line 534.

MVCE

from dataclasses import dataclass
from typing import Literal, NamedTuple, Tuple

import jaxopt
from jax import numpy as jnp


class CalibrationParams(NamedTuple):
    gains_real: jnp.ndarray  # [source, time, ant, chan, 2, 2]
    gains_imag: jnp.ndarray  # [source, time, ant, chan, 2, 2]


class CalibrationData(NamedTuple):
    gains_real: jnp.ndarray  # [source, time, ant, chan, 2, 2]
    gains_imag: jnp.ndarray  # [source, time, ant, chan, 2, 2]


@dataclass(eq=False)
class Calibration:
    convention: Literal['fourier', 'casa'] = 'casa'
    dtype: jnp.dtype = jnp.complex64
    chunksize: int = 1
    unroll: int = 1

    def _residual_fun(self, params: CalibrationParams, data: CalibrationData) -> jnp.ndarray:
        residuals = jnp.concatenate([
            (params.gains_real - data.gains_real).ravel(),
            (params.gains_imag - data.gains_imag).ravel()
        ])
        return residuals

    @property
    def float_dtype(self):
        # Given self.dtype is complex, find float dtype
        return jnp.real(jnp.zeros((), dtype=self.dtype)).dtype

    def get_init_params(self, shape) -> CalibrationParams:
        """
        Get initial parameters.

        Args:
            shape: shape of gains_real and gains_imag

        Returns:
            initial parameters
        """
        return CalibrationParams(
            gains_real=jnp.ones(shape, self.float_dtype),
            gains_imag=jnp.zeros(shape, self.float_dtype)
        )

    def solve(self, init_params: CalibrationParams, data: CalibrationData) -> Tuple[CalibrationParams, jaxopt.OptStep]:
        solver = jaxopt.LevenbergMarquardt(
            residual_fun=self._residual_fun,
            maxiter=100,
            jit=True,
            unroll=False,
            materialize_jac=False,
            geodesic=False,
            implicit_diff=False
        )
        opt_result = solver.run(init_params=init_params, data=data)
        params = opt_result.params
        return params, opt_result


if __name__ == '__main__':
    calibration = Calibration()
    shape = (10, 100, 100, 100, 2, 2)
    init_params = calibration.get_init_params(shape)
    data = CalibrationData(
        gains_real=jnp.ones(shape, calibration.float_dtype),
        gains_imag=jnp.zeros(shape, calibration.float_dtype)
    )
    params, opt_results = calibration.solve(init_params=init_params, data=data)
    print(params)
    print(opt_results)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 359, in run
    return run(init_params, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 301, in _run
    state = self.init_state(init_params, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 216, in init_state
    jtj_diag = self._jtj_diag_op(init_params, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 535, in _jtj_diag_op
    return jax.vmap(diag_op)(jnp.eye(len(params))).T
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 534, in <lambda>
    diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 528, in _jtj_op
    _, jvp_val = jax.jvp(fun_with_args, (params,), (vec,))
TypeError: primal and tangent arguments to jax.jvp must have the same tree structure; primals have tree structure PyTreeDef((CustomNode(namedtuple[CalibrationParams], [*, *]),)) whereas tangents have tree structure PyTreeDef((*,)).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant