Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 54 additions & 66 deletions src/traceax/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
_assert_false,
_check_operator,
_clip_k,
_is_none,
_is_undefined,
_remove_undefined_primal,
_to_shapedarray,
_to_struct,
_vmap_mv,
Expand Down Expand Up @@ -216,7 +219,7 @@ def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict
term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0)

if self.improved:
scale = _get_scale(W, SW_d, n, k)
scale = _get_scale(W, SW_d, n, m)
else:
scale = 1

Expand Down Expand Up @@ -355,60 +358,27 @@ def _estimate_trace_abstract_eval(key, operator, state, k, estimator):
return out


def _is_none(x):
return x is None


@eqxi.filter_primitive_jvp
def _estimate_trace_jvp(primals, tangents):
key, operator, state, k, estimator = primals
# t_operator := V
t_key, t_operator, t_state, t_k, t_estimator = tangents
jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator))
del t_key, t_state, t_k, t_estimator

# primal problem of t = tr(A)
result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator)
out = result, stats

# inner prodct in linear operator space => <A, B> = tr(A @ B)
# d tr(A) / dA = I
# t' = <tr'(A), V> = tr(I @ V) = tr(V)
# tangent problem => tr(V)
# TODO: should we reuse key or split? both seem confusing options
key, t_key = rdm.split(key)
if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)):
t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none)
t_operator = lx.TangentLinearOperator(operator, t_operator)

t_state = estimator.init(t_key, t_operator)
t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator)
t_out = (
t_result,
jtu.tree_map(lambda _: None, stats),
)

return out, t_out


def _is_undefined(x):
return isinstance(x, ad.UndefinedPrimal)


def _assert_defined(x):
assert not _is_undefined(x)
@ft.singledispatch
def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
raise ValueError("Unsupported type!")


def _remove_undefined_primal(x):
if _is_undefined(x):
return x.aval
else:
return x
@_make_identity.register
def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
in_size, out_size = eqx.filter_eval_shape(lambda o: (o.in_size(), o.out_size()), operator_struct)
if in_size != out_size:
raise ValueError("`_make_identity` only supports square matrices.")
diag = jnp.full(in_size, ct_result)
out = lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags)
return out


@ft.singledispatch
def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
raise ValueError("Unsupported type!")
@_make_identity.register
def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
inner_op = _make_identity(op.operator, ct_result)
scalar = jnp.array(1.0)
return lx.MulLinearOperator(inner_op, scalar*ct_result)


@_make_identity.register
Expand All @@ -418,14 +388,6 @@ def _(op: lx.TangentLinearOperator, ct_result: float) -> lx.AbstractLinearOperat
return lx.TangentLinearOperator(p_op, _make_identity(t_op, ct_result))


@_make_identity.register
def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct)
diag = jnp.full(in_size, ct_result)
return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags)


@_make_identity.register
def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
Expand All @@ -434,13 +396,6 @@ def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOpera
return lx.DiagonalLinearOperator(diag)


@_make_identity.register
def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
inner_op = _make_identity(op.operator, ct_result)
scalar = jnp.array(1.0)
return lx.MulLinearOperator(inner_op, scalar)


@_make_identity.register
def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator:
operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined)
Expand Down Expand Up @@ -477,6 +432,39 @@ def _(op: lx.ComposedLinearOperator, ct_result: float) -> lx.AbstractLinearOpera
return lx.ComposedLinearOperator(inner_op1, inner_op2)


@eqxi.filter_primitive_jvp
def _estimate_trace_jvp(primals, tangents):
key, operator, state, k, estimator = primals
# t_operator := V
t_key, t_operator, t_state, t_k, t_estimator = tangents
jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator))
del t_key, t_state, t_k, t_estimator

# primal problem of t = tr(A)
result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator)
out = result, stats

# inner prodct in linear operator space => <A, B> = tr(A @ B)
# d tr(A) / dA = I
# t' = <tr'(A), V> = tr(I @ V) = tr(V)
# tangent problem => tr(V)
# TODO: should we reuse key or split? both seem confusing options
key, t_key = rdm.split(key)

t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none)
t_operator = lx.TangentLinearOperator(operator, t_operator)

t_state = estimator.init(t_key, t_operator)
t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator)
# t_result = jnp.trace(t_operator.as_matrix())
t_out = (
t_result,
jtu.tree_map(lambda _: None, stats),
)

return out, t_out


@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore
def _estimate_trace_transpose(inputs, cts_out):
# the jacobian, for the trace is just the identity matrix, i.e. J = I
Expand Down
7 changes: 7 additions & 0 deletions src/traceax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,10 @@ def _keep_undefined(v, ct):
return ct
else:
return None


def _remove_undefined_primal(x):
if _is_undefined(x):
return x.aval
else:
return x