From dc5890b01a29528c8bb6ff52b60772c5459a2d83 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Mon, 14 Oct 2024 15:05:47 -0700 Subject: [PATCH 1/3] chore: organize --- src/traceax/_trace.py | 86 ++++++++++++++++++------------------------- src/traceax/_utils.py | 7 ++++ 2 files changed, 42 insertions(+), 51 deletions(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index f6296aa..ca6fe9a 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -36,6 +36,9 @@ _assert_false, _check_operator, _clip_k, + _is_none, + _is_undefined, + _remove_undefined_primal, _to_shapedarray, _to_struct, _vmap_mv, @@ -355,57 +358,6 @@ def _estimate_trace_abstract_eval(key, operator, state, k, estimator): return out -def _is_none(x): - return x is None - - -@eqxi.filter_primitive_jvp -def _estimate_trace_jvp(primals, tangents): - key, operator, state, k, estimator = primals - # t_operator := V - t_key, t_operator, t_state, t_k, t_estimator = tangents - jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator)) - del t_key, t_state, t_k, t_estimator - - # primal problem of t = tr(A) - result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) - out = result, stats - - # inner prodct in linear operator space => = tr(A @ B) - # d tr(A) / dA = I - # t' = = tr(I @ V) = tr(V) - # tangent problem => tr(V) - # TODO: should we reuse key or split? both seem confusing options - key, t_key = rdm.split(key) - if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): - t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) - t_operator = lx.TangentLinearOperator(operator, t_operator) - - t_state = estimator.init(t_key, t_operator) - t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) - t_out = ( - t_result, - jtu.tree_map(lambda _: None, stats), - ) - - return out, t_out - - -def _is_undefined(x): - return isinstance(x, ad.UndefinedPrimal) - - -def _assert_defined(x): - assert not _is_undefined(x) - - -def _remove_undefined_primal(x): - if _is_undefined(x): - return x.aval - else: - return x - - @ft.singledispatch def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: raise ValueError("Unsupported type!") @@ -477,6 +429,38 @@ def _(op: lx.ComposedLinearOperator, ct_result: float) -> lx.AbstractLinearOpera return lx.ComposedLinearOperator(inner_op1, inner_op2) +@eqxi.filter_primitive_jvp +def _estimate_trace_jvp(primals, tangents): + key, operator, state, k, estimator = primals + # t_operator := V + t_key, t_operator, t_state, t_k, t_estimator = tangents + jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator)) + del t_key, t_state, t_k, t_estimator + + # primal problem of t = tr(A) + result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) + out = result, stats + + # inner prodct in linear operator space => = tr(A @ B) + # d tr(A) / dA = I + # t' = = tr(I @ V) = tr(V) + # tangent problem => tr(V) + # TODO: should we reuse key or split? both seem confusing options + key, t_key = rdm.split(key) + if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): + t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) + t_operator = lx.TangentLinearOperator(operator, t_operator) + + t_state = estimator.init(t_key, t_operator) + t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) + t_out = ( + t_result, + jtu.tree_map(lambda _: None, stats), + ) + + return out, t_out + + @eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore def _estimate_trace_transpose(inputs, cts_out): # the jacobian, for the trace is just the identity matrix, i.e. J = I diff --git a/src/traceax/_utils.py b/src/traceax/_utils.py index ffae138..b5e50b3 100644 --- a/src/traceax/_utils.py +++ b/src/traceax/_utils.py @@ -81,3 +81,10 @@ def _keep_undefined(v, ct): return ct else: return None + + +def _remove_undefined_primal(x): + if _is_undefined(x): + return x.aval + else: + return x \ No newline at end of file From 7f32881997925f8814cccfa5a95a2d5e662aac2e Mon Sep 17 00:00:00 2001 From: nahid18 Date: Thu, 14 Nov 2024 13:22:56 -0800 Subject: [PATCH 2/3] chore: rearrange code --- src/traceax/_trace.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index ca6fe9a..a396500 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -363,6 +363,24 @@ def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.Abstra raise ValueError("Unsupported type!") +@_make_identity.register +def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size, out_size = eqx.filter_eval_shape(lambda o: (o.in_size(), o.out_size()), operator_struct) + if in_size != out_size: + raise ValueError("`_make_identity` only supports square matrices.") + diag = jnp.full(in_size, ct_result) + out = lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) + return out + + +@_make_identity.register +def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = jnp.array(1.0) + return lx.MulLinearOperator(inner_op, scalar*ct_result) + + @_make_identity.register def _(op: lx.TangentLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: p_op = op.primal @@ -370,14 +388,6 @@ def _(op: lx.TangentLinearOperator, ct_result: float) -> lx.AbstractLinearOperat return lx.TangentLinearOperator(p_op, _make_identity(t_op, ct_result)) -@_make_identity.register -def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: - operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) - in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = jnp.full(in_size, ct_result) - return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) - - @_make_identity.register def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) @@ -386,13 +396,6 @@ def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOpera return lx.DiagonalLinearOperator(diag) -@_make_identity.register -def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: - inner_op = _make_identity(op.operator, ct_result) - scalar = jnp.array(1.0) - return lx.MulLinearOperator(inner_op, scalar) - - @_make_identity.register def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) @@ -447,12 +450,13 @@ def _estimate_trace_jvp(primals, tangents): # tangent problem => tr(V) # TODO: should we reuse key or split? both seem confusing options key, t_key = rdm.split(key) - if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): - t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) - t_operator = lx.TangentLinearOperator(operator, t_operator) + + t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) + t_operator = lx.TangentLinearOperator(operator, t_operator) t_state = estimator.init(t_key, t_operator) t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) + # t_result = jnp.trace(t_operator.as_matrix()) t_out = ( t_result, jtu.tree_map(lambda _: None, stats), From ee1b3d7c1e467a23891de9d389a4c6bbcfabe8da Mon Sep 17 00:00:00 2001 From: nahid18 Date: Thu, 13 Mar 2025 02:58:35 -0700 Subject: [PATCH 3/3] fix: scale bug in XTrace --- src/traceax/_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index a396500..7e5688e 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -219,7 +219,7 @@ def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0) if self.improved: - scale = _get_scale(W, SW_d, n, k) + scale = _get_scale(W, SW_d, n, m) else: scale = 1