In [None]:
import torch
from torch.autograd.functional import hvp as torch_hvp
from torch.autograd.functional import jvp as torch_jvp
from torch.autograd import grad as torch_grad

class ComputationMonitor:
    def __init__(self):
        self.reset_counts()

        # Define the custom autograd Function within the monitor
        # so it can access the monitor's instance variables.
        class _CounterAutograd(torch.autograd.Function):
            @staticmethod
            def forward(ctx, monitor_instance, scalar_tensor):
                # Store the monitor instance for the backward pass
                ctx.monitor_instance = monitor_instance
                # We don't need to save scalar_tensor for backward if we're just passing grad_output
                # Ensure the output has the same requires_grad status as the input
                # If scalar_tensor doesn't require grad, this output won't either,
                # and backward won't be called. This is usually desired.
                return scalar_tensor.clone() # Use clone to ensure a new node in graph

            @staticmethod
            def backward(ctx, grad_output):
                # Access the monitor instance stored in ctx
                monitor = ctx.monitor_instance
                monitor._increment_backward_passes()
                # Pass the gradient through
                return None, grad_output # Gradient for monitor_instance, gradient for scalar_tensor

        self._CounterAutogradFn = _CounterAutograd

    def _increment_backward_passes(self):
        self.backward_passes_count += 1

    def _increment_hvp_calls(self):
        self.hvp_calls_count += 1

    def _increment_jvp_calls(self):
        self.jvp_calls_count += 1

    def reset_counts(self):
        self.backward_passes_count = 0
        self.hvp_calls_count = 0
        self.jvp_calls_count = 0
        print("Counters reset.")

    def get_counts(self):
        return {
            "backward_passes": self.backward_passes_count,
            "hvp_calls": self.hvp_calls_count,
            "jvp_calls": self.jvp_calls_count,
        }

    def print_counts(self):
        counts = self.get_counts()
        print(f"Current Counts: Backward Pass = {counts['backward_passes']}, "
              f"HVP = {counts['hvp_calls']}, JVP = {counts['jvp_calls']}")

    def monitor_scalar(self, scalar_tensor: torch.Tensor) -> torch.Tensor:
        """
        Pass a scalar tensor through this function to count backward passes involving it.
        The tensor MUST be a scalar and typically requires gradients (e.g., a loss).
        """
        if not scalar_tensor.ndim == 0:
            raise ValueError("Monitored tensor must be a scalar.")
        if not scalar_tensor.requires_grad:
            # If it doesn't require grad, .backward() won't be called on it anyway.
            # We could warn or let it pass. For now, let it pass, as our autograd
            # function won't have its backward called either.
            pass
        return self._CounterAutogradFn.apply(self, scalar_tensor)

    def counted_hvp(self, func, inputs, v, *args, **kwargs):
        """
        Wrapper around torch.autograd.functional.hvp that increments HVP count.
        """
        self._increment_hvp_calls()
        return torch_hvp(func, inputs, v, *args, **kwargs)

    def counted_jvp(self, func, inputs, v, *args, **kwargs):
        """
        Wrapper around torch.autograd.functional.jvp that increments JVP count.
        """
        self._increment_jvp_calls()
        return torch_jvp(func, inputs, v, *args, **kwargs)

    def counted_vjp(self, func_output_scalar, inputs, v=None, *args, **kwargs):
        """
        Wrapper for VJP (which is essentially what .backward() or torch.autograd.grad does
        for a scalar output). This can be used to count VJPs explicitly if you are
        using torch.autograd.grad for what would be a backward pass.
        Note: If you use monitor_scalar() and then loss.backward(), the backward_passes_count
        will be incremented. This counted_vjp is for cases where you might compute
        VJP using torch.autograd.grad directly on a scalar output.
        """
        # This is a bit tricky. If func_output_scalar is the result of monitor_scalar,
        # then calling .backward() or torch.autograd.grad will trigger the
        # _CounterAutograd.backward.
        # This method is more for if you're calling torch.autograd.grad on a function's
        # scalar output *without* having passed it through monitor_scalar's autograd.Function.
        # However, to be consistent, let's assume it's a separate way to count something "like" a backward.
        # For now, let's count it as a backward pass, as VJP on a scalar output is dL/dx.
        self._increment_backward_passes() # Or a separate counter if VJP is distinct for you
        if v is None: # For scalar output, v defaults to torch.tensor(1.0)
            v = torch.tensor(1.0, dtype=func_output_scalar.dtype, device=func_output_scalar.device)
        return torch_grad(func_output_scalar, inputs, grad_outputs=v, *args, **kwargs)


# --- Example Usage ---
if __name__ == "__main__":
    monitor = ComputationMonitor()

    # --- Test 1: Backward Pass ---
    print("\n--- Test 1: Backward Pass ---")
    x = torch.randn(3, requires_grad=True)
    y = torch.randn(3, requires_grad=True)
    
    # Some computation leading to a scalar loss
    output = (x * y).sum() * 2
    loss = output.sin() # loss is a scalar

    # Monitor the loss
    monitored_loss = monitor.monitor_scalar(loss)
    
    print(f"Initial loss: {loss.item()}, Monitored loss: {monitored_loss.item()}")
    monitor.print_counts() # Expect 0 for all

    # Perform backward pass
    monitored_loss.backward()
    monitor.print_counts() # Expect backward_passes = 1

    # Perform another backward pass (e.g. with retain_graph=True if needed for loss)
    # For simplicity, let's create a new graph path
    x_new = torch.randn(3, requires_grad=True)
    loss_new = (x_new**2).sum()
    monitored_loss_new = monitor.monitor_scalar(loss_new)
    monitored_loss_new.backward()
    monitor.print_counts() # Expect backward_passes = 2

    # --- Test 2: HVP ---
    print("\n--- Test 2: HVP ---")
    monitor.reset_counts()

    def scalar_func_hvp(inp_x, inp_y):
        return ((inp_x * inp_y).sum() * 2).sin()

    inputs_hvp = (x.clone().detach().requires_grad_(True), y.clone().detach().requires_grad_(True))
    v_hvp = (torch.randn_like(x), torch.randn_like(y))

    # Call HVP using the monitor's wrapper
    hvp_result, grad_result = monitor.counted_hvp(scalar_func_hvp, inputs_hvp, v_hvp)
    monitor.print_counts() # Expect hvp_calls = 1

    monitor.counted_hvp(scalar_func_hvp, inputs_hvp, v_hvp)
    monitor.print_counts() # Expect hvp_calls = 2

    # --- Test 3: JVP ---
    print("\n--- Test 3: JVP ---")
    monitor.reset_counts()

    def func_jvp(inp_x): # JVP often used with functions returning non-scalars
        return inp_x * 2 + inp_x.sin()

    inputs_jvp = x.clone().detach().requires_grad_(True)
    v_jvp = torch.randn_like(x)

    # Call JVP using the monitor's wrapper
    jvp_result = monitor.counted_jvp(func_jvp, (inputs_jvp,), (v_jvp,))
    monitor.print_counts() # Expect jvp_calls = 1

    monitor.counted_jvp(func_jvp, (inputs_jvp,), (v_jvp,))
    monitor.print_counts() # Expect jvp_calls = 2

    # --- Test 4: Monitored scalar involved in HVP/JVP function (indirect) ---
    print("\n--- Test 4: Monitored scalar in HVP/JVP func ---")
    monitor.reset_counts()

    p = torch.tensor([2.0, 3.0], requires_grad=True)
    q = torch.tensor([0.5, 1.5], requires_grad=True)

    def outer_func_for_hvp(param_p, param_q):
        intermediate_scalar = (param_p * param_q).sum()
        monitored_intermediate = monitor.monitor_scalar(intermediate_scalar)
        # The HVP is on 'outer_func_for_hvp'.
        # If 'monitored_intermediate' itself is differentiated (e.g., if outer_func_for_hvp
        # was part of a larger graph and we called .backward() on its output),
        # then backward_passes_count would increment.
        # Here, HVP computes derivatives *through* monitored_intermediate.
        # The .backward() of _CounterAutograd will be called during HVP's internal backward passes.
        return monitored_intermediate**2

    v_hvp2 = (torch.randn_like(p), torch.randn_like(q))
    
    print("Before HVP call:")
    monitor.print_counts()

    # HVP involves second derivatives. The backward pass of _CounterAutogradFn
    # will be called when computing the gradient (first part of HVP),
    # and potentially again depending on how HVP is implemented (e.g. if it does grad of grad).
    # torch.autograd.functional.hvp typically uses one backward pass to get VJP (g),
    # then one JVP-like operation (grad(g @ v)) which might involve another backward pass.
    hvp_res, _ = monitor.counted_hvp(outer_func_for_hvp, (p, q), v_hvp2)
    
    print("After HVP call:")
    monitor.print_counts() # Expect hvp_calls = 1, backward_passes might be >0 (typically 2 for HVP)

    # --- Test 5: VJP via counted_vjp ---
    print("\n--- Test 5: VJP via counted_vjp ---")
    monitor.reset_counts()
    a = torch.tensor(2.0, requires_grad=True)
    b = torch.tensor(3.0, requires_grad=True)
    
    def my_scalar_output_func(in_a, in_b):
        return (in_a * in_b).sin()

    # Case 1: Func output is NOT monitored by monitor_scalar
    scalar_out = my_scalar_output_func(a, b)
    # We want to count this torch.autograd.grad call as a "backward pass" like event
    grads_ab = monitor.counted_vjp(scalar_out, (a,b))
    monitor.print_counts() # Expect backward_passes = 1 (due to counted_vjp)

    # Case 2: Func output IS monitored by monitor_scalar
    # This will double count if not careful: once by _CounterAutogradFn.backward,
    # once by counted_vjp.
    # The primary mechanism for .backward() style counting is monitor_scalar.
    # counted_vjp is for explicit torch.autograd.grad calls on non-monitored scalars.
    monitor.reset_counts()
    scalar_out_monitored = monitor.monitor_scalar(my_scalar_output_func(a,b))
    # Now, if we use torch.autograd.grad:
    # This will trigger _CounterAutogradFn.backward()
    grads_ab_monitored = torch_grad(scalar_out_monitored, (a,b))
    monitor.print_counts() # Expect backward_passes = 1 (due to monitor_scalar's effect)
                           # and JVP/HVP = 0.

    # If you were to use monitor.counted_vjp on scalar_out_monitored, it would add another
    # to backward_passes_count. So, use one or the other for counting a VJP.
    # monitor.counted_vjp(scalar_out_monitored, (a,b)) # This would make backward_passes = 2

Counters reset.

--- Test 1: Backward Pass ---
Initial loss: 0.10916007310152054, Monitored loss: 0.10916007310152054
Current Counts: Backward Pass = 0, HVP = 0, JVP = 0
Current Counts: Backward Pass = 1, HVP = 0, JVP = 0
Current Counts: Backward Pass = 2, HVP = 0, JVP = 0

--- Test 2: HVP ---
Counters reset.
Current Counts: Backward Pass = 0, HVP = 1, JVP = 0
Current Counts: Backward Pass = 0, HVP = 2, JVP = 0

--- Test 3: JVP ---
Counters reset.
Current Counts: Backward Pass = 0, HVP = 0, JVP = 1
Current Counts: Backward Pass = 0, HVP = 0, JVP = 2

--- Test 4: Monitored scalar in HVP/JVP func ---
Counters reset.
Before HVP call:
Current Counts: Backward Pass = 0, HVP = 0, JVP = 0
After HVP call:
Current Counts: Backward Pass = 2, HVP = 1, JVP = 0

--- Test 5: VJP via counted_vjp ---
Counters reset.
Current Counts: Backward Pass = 1, HVP = 0, JVP = 0
Counters reset.
Current Counts: Backward Pass = 1, HVP = 0, JVP = 0
