Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 82 additions & 16 deletions src/traceax/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,22 +400,87 @@ def _remove_undefined_primal(x):
return


def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator:
# def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator:
# operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
# if isinstance(op, lx.MatrixLinearOperator):
# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
# diag = ct_result * jnp.ones(in_size)
# return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags)
# elif isinstance(op, lx.DiagonalLinearOperator):
# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
# diag = ct_result * jnp.ones(in_size)
# return lx.DiagonalLinearOperator(diag)
# elif isinstance(op, lx.MulLinearOperator):
# inner_op = _build_diagonal(ct_result, op.operator)
# scalar = op.scalar
# return scalar * inner_op # type: ignore
# else:
# raise ValueError("Unsupported type!")


# replaces _build_diagonal
@ft.singledispatch
def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
raise ValueError("Unsupported type!")


@_make_identity.register
def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
if isinstance(op, lx.MatrixLinearOperator):
in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
diag = ct_result * jnp.ones(in_size)
return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags)
elif isinstance(op, lx.DiagonalLinearOperator):
in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
diag = ct_result * jnp.ones(in_size)
return lx.DiagonalLinearOperator(diag)
elif isinstance(op, lx.MulLinearOperator):
inner_op = _build_diagonal(ct_result, op.operator)
scalar = op.scalar
return scalar * inner_op # type: ignore
else:
raise ValueError("Unsupported type!")
in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
diag = ct_result * jnp.ones(in_size)
return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags)


@_make_identity.register
def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
diag = ct_result * jnp.ones(in_size)
return lx.DiagonalLinearOperator(diag)


@_make_identity.register
def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
inner_op = _make_identity(op.operator, ct_result)
scalar = op.scalar
return lx.MulLinearOperator(inner_op, scalar)


@_make_identity.register
def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
diag = ct_result * jnp.ones(in_size)
off_diag = jnp.zeros(in_size - 1)
return lx.TridiagonalLinearOperator(diag, off_diag, off_diag)


@_make_identity.register
def _(op: lx.AddLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
inner_op1 = _make_identity(op.operator1, ct_result)
inner_op2 = _make_identity(op.operator2, ct_result)
return lx.AddLinearOperator(inner_op1, inner_op2)


@_make_identity.register
def _(op: lx.NegLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
inner_op = _make_identity(op.operator, ct_result)
return lx.NegLinearOperator(inner_op)


@_make_identity.register
def _(op: lx.DivLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
inner_op = _make_identity(op.operator, ct_result)
scalar = op.scalar
return lx.DivLinearOperator(inner_op, scalar)


@_make_identity.register
def _(op: lx.ComposedLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
inner_op1 = _make_identity(op.operator1, ct_result)
inner_op2 = _make_identity(op.operator2, ct_result)
return lx.ComposedLinearOperator(inner_op1, inner_op2)


@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore
Expand All @@ -431,7 +496,8 @@ def _estimate_trace_transpose(inputs, cts_out):

# the internals of the operator are UndefinedPrimal leaves so
# we need to rely on abstract values to pull structure info
op_t = _build_diagonal(cts_result, operator)
# op_t = _build_diagonal(cts_result, operator)
op_t = _make_identity(operator, cts_result)

key_none = jtu.tree_map(lambda _: None, key)
# state_none = jtu.tree_map(lambda _: None, state)
Expand Down