diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 8dd2f45..2b73b2d 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -400,22 +400,87 @@ def _remove_undefined_primal(x): return -def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator: +# def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator: +# operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) +# if isinstance(op, lx.MatrixLinearOperator): +# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) +# diag = ct_result * jnp.ones(in_size) +# return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) +# elif isinstance(op, lx.DiagonalLinearOperator): +# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) +# diag = ct_result * jnp.ones(in_size) +# return lx.DiagonalLinearOperator(diag) +# elif isinstance(op, lx.MulLinearOperator): +# inner_op = _build_diagonal(ct_result, op.operator) +# scalar = op.scalar +# return scalar * inner_op # type: ignore +# else: +# raise ValueError("Unsupported type!") + + +# replaces _build_diagonal +@ft.singledispatch +def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + raise ValueError("Unsupported type!") + + +@_make_identity.register +def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) - if isinstance(op, lx.MatrixLinearOperator): - in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = ct_result * jnp.ones(in_size) - return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) - elif isinstance(op, lx.DiagonalLinearOperator): - in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = ct_result * jnp.ones(in_size) - return lx.DiagonalLinearOperator(diag) - elif isinstance(op, lx.MulLinearOperator): - inner_op = _build_diagonal(ct_result, op.operator) - scalar = op.scalar - return scalar * inner_op # type: ignore - else: - raise ValueError("Unsupported type!") + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) + + +@_make_identity.register +def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + return lx.DiagonalLinearOperator(diag) + + +@_make_identity.register +def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = op.scalar + return lx.MulLinearOperator(inner_op, scalar) + + +@_make_identity.register +def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + off_diag = jnp.zeros(in_size - 1) + return lx.TridiagonalLinearOperator(diag, off_diag, off_diag) + + +@_make_identity.register +def _(op: lx.AddLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op1 = _make_identity(op.operator1, ct_result) + inner_op2 = _make_identity(op.operator2, ct_result) + return lx.AddLinearOperator(inner_op1, inner_op2) + + +@_make_identity.register +def _(op: lx.NegLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + return lx.NegLinearOperator(inner_op) + + +@_make_identity.register +def _(op: lx.DivLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = op.scalar + return lx.DivLinearOperator(inner_op, scalar) + + +@_make_identity.register +def _(op: lx.ComposedLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op1 = _make_identity(op.operator1, ct_result) + inner_op2 = _make_identity(op.operator2, ct_result) + return lx.ComposedLinearOperator(inner_op1, inner_op2) @eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore @@ -431,7 +496,8 @@ def _estimate_trace_transpose(inputs, cts_out): # the internals of the operator are UndefinedPrimal leaves so # we need to rely on abstract values to pull structure info - op_t = _build_diagonal(cts_result, operator) + # op_t = _build_diagonal(cts_result, operator) + op_t = _make_identity(operator, cts_result) key_none = jtu.tree_map(lambda _: None, key) # state_none = jtu.tree_map(lambda _: None, state)