# Timing and Profiling in NetKet

NetKet provides a built-in timing system in {ref}`netket_utils_api` that helps profile your code and understand where time is being spent. This is particularly useful when optimizing performance or debugging slow computations.

The timing system is designed to work seamlessly with JAX and provides hierarchical timing information with a beautiful output format.

## Basic Usage

### Context Managers

The simplest way to time code is using the {class}`~netket.utils.timing.Timer` class as a context manager:

In [1]:
import netket as nk
from netket.utils import timing
import time
import jax
import jax.numpy as jnp

# Basic timing with Timer
with timing.Timer() as timer:
    time.sleep(0.1)  # Simulate some work
    
    # Nested timing with timed_scope
    with timing.timed_scope("matrix multiplication"):
        a = jnp.ones((100, 100))
        b = jnp.ones((100, 100))
        result = a @ b
        # Important for JAX: block until computation is done
        timer.block_until_ready(result)
    
    with timing.timed_scope("more work"):
        time.sleep(0.05)

print(timer)

╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.230                                                                                                    │
│ ├── (30.5%) | matrix multiplication : 0.070 s                                                                   │
│ └── (23.9%) | more work : 0.055 s                                                                               │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯



### Hierarchical Timing

You can create nested timing structures by combining multiple timing scopes:

In [2]:
# Nested timing example
with timing.Timer() as timer:
    with timing.timed_scope("setup"):
        time.sleep(0.02)
    
    with timing.timed_scope("computation"):
        with timing.timed_scope("part 1"):
            time.sleep(0.01)
        with timing.timed_scope("part 2"):
            time.sleep(0.01)

print(timer)

╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.050                                                                                                    │
│ ├── (50.0%) | setup : 0.025 s                                                                                   │
│ └── (50.0%) | computation : 0.025 s                                                                             │
│     ├── (50.0%) | part 1 : 0.013 s                                                                              │
│     └── (50.0%) | part 2 : 0.013 s                                                                              │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯



### Using `timed_scope` with Force

The {func}`~netket.utils.timing.timed_scope` context manager is perfect for timing specific sections of code within a larger timing context:

In [3]:
# Using timed_scope with force=True to enable timing even without a parent timer
with timing.timed_scope("main computation", force=True) as timer:
    # Some initial setup
    key = jax.random.key(42)
    data = jax.random.normal(key, (1000, 1000))
    timer.block_until_ready(data)
    
    with timing.timed_scope("eigenvalue decomposition"):
        eigenvals = jnp.linalg.eigvals(data)
        timer.block_until_ready(eigenvals)
    
    with timing.timed_scope("statistical analysis"):
        mean_val = jnp.mean(eigenvals)
        std_val = jnp.std(eigenvals)
        timer.block_until_ready((mean_val, std_val))

print(timer)

╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.538                                                                                                    │
│ ├── (67.2%) | eigenvalue decomposition : 0.362 s                                                                │
│ └── (11.9%) | statistical analysis : 0.064 s                                                                    │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯



### Using the `@timed` Decorator

The {func}`~netket.utils.timing.timed` decorator allows you to automatically time function calls. This is especially useful for timing functions that are called multiple times:

In [4]:
@timing.timed(name="expensive_computation")
def expensive_function(x):
    """A function that does some expensive computation."""
    result = jnp.sin(x) * jnp.cos(x) + jnp.exp(-x**2)
    return jnp.sum(result)

@timing.timed(name="data_processing")
def process_data(data):
    """Process some data."""
    return jnp.fft.fft(data)

# The decorated functions will only be timed when inside a timing context
with timing.Timer() as timer:
    # Call the functions multiple times
    for i in range(3):
        x = jnp.linspace(0, 10, 1000)
        result1 = expensive_function(x)
        
        key = jax.random.key(i)
        data = jax.random.normal(key, (512,))
        result2 = process_data(data)
        
        # Block until JAX computations are complete
        timer.block_until_ready((result1, result2))

print(timer)

╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.269                                                                                                    │
│ ├── (53.7%) | expensive_computation : 0.144 s                                                                   │
│ └── (16.2%) | data_processing : 0.043 s                                                                         │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯



## JAX Gotchas: JIT and Block-Until-Ready

### JIT Compilation

Timing decorators and context managers will not work inside JIT-compiled functions. JAX compilation strips out Python operations that aren't compatible with compilation, so timing calls inside JIT functions will not report any information and may cause errors.

In [5]:
# DON'T: Use timing inside a JIT function - this won't work
@jax.jit
def bad_jitted_function(x):
    with timing.timed_scope("inside jit"):  # This will be ignored/cause errors
        return jnp.sum(x**2)

# DO: Time the JIT function from outside
@timing.timed(name="jitted_computation")
@jax.jit
def good_jitted_function(x):
    return jnp.sum(x**2)

# Or time the call to the JIT function
@jax.jit
def my_jitted_function(x):
    return jnp.sum(x**2)

x = jnp.ones((1000, 1000))

# This will show no timing information from inside the JIT function
print("Bad example (no timing info):")
with timing.Timer() as timer_bad:
    result = bad_jitted_function(x)
    timer_bad.block_until_ready(result)
print(timer_bad)

# This will properly time the JIT function
print("\nGood example (proper timing):")
with timing.Timer() as timer_good:
    with timing.timed_scope("jitted function call"):
        result = my_jitted_function(x)
        timer_good.block_until_ready(result)
print(timer_good)

Bad example (no timing info):


╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.030                                                                                                    │
│ └── (1.5%) | inside jit : 0.000 s                                                                               │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯


Good example (proper timing):


╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.031                                                                                                    │
│ └── (99.9%) | jitted function call : 0.031 s                                                                    │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯



### Block-Until-Ready

When timing JAX functions, it's crucial to use `timer.block_until_ready()` to get accurate timing results. JAX uses lazy evaluation, so computations might not happen immediately when the function is called.

In [6]:
# Demonstration of why block_until_ready is important
print("Without block_until_ready (inaccurate timing):")
with timing.Timer() as timer_bad:
    with timing.timed_scope("jax computation"):
        large_matrix = jnp.ones((2000, 2000))
        result = jnp.linalg.inv(large_matrix)
        # Not blocking - timing will be inaccurate!

print(timer_bad)
print("\nWith block_until_ready (accurate timing):")
with timing.Timer() as timer_good:
    with timing.timed_scope("jax computation"):
        large_matrix = jnp.ones((2000, 2000))
        result = jnp.linalg.inv(large_matrix)
        # Properly blocking until computation is done
        timer_good.block_until_ready(result)

print(timer_good)

Without block_until_ready (inaccurate timing):


╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.068                                                                                                    │
│ └── (100.0%) | jax computation : 0.068 s                                                                        │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯


With block_until_ready (accurate timing):


╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.205                                                                                                    │
│ └── (100.0%) | jax computation : 0.205 s                                                                        │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯



## Interaction with NetKet Drivers

### Real-World Example: Timing a NetKet Simulation

Here's how the timing system is used in actual NetKet simulations:

In [7]:
# Create a simple quantum system
L = 8
g = nk.graph.Chain(length=L, pbc=True)
hi = nk.hilbert.Spin(s=0.5, N=g.n_nodes)

# Define the Hamiltonian
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)

# Define the variational ansatz
model = nk.models.RBM(alpha=1)
sampler = nk.sampler.MetropolisLocal(hi)
optimizer = nk.optimizer.Sgd(learning_rate=0.1)

# Create the variational state
vs = nk.vqs.MCState(sampler, model, n_samples=1000)

# Time the creation and execution of a VMC driver
with timing.Timer() as timer:
    with timing.timed_scope("driver setup"):
        driver = nk.VMC(ha, optimizer, variational_state=vs)
    
    with timing.timed_scope("optimization"):
        # Run a few optimization steps with timing enabled
        driver.run(n_iter=5, timeit=True)

print("\nTotal timing breakdown:")
print(timer)

  self.n_samples = n_samples


  0%|          | 0/5 [00:00<?, ?it/s]

╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.728                                                                                                    │
│ └── (85.1%) | VMC._forward_and_backward : 0.619 s                                                               │
│     └── (97.6%) | MCState.expect_and_grad : 0.604 s                                                             │
│         └── (69.3%) | MCState.sample : 0.419 s                                                                  │
│             └── (46.2%) | sampling n_discarded samples : 0.193 s                                                │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯


Total timing breakdown:


╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.843                                                                                                    │
│ └── (100.0%) | optimization : 0.843 s                                                                           │
│     └── (86.3%) |                                                                                               │
│         /Users/filippo.vicentini/Nextcloud/Codes/Python/netket/netket/driver/abstract_variational_driver.py:336 │
│         : 0.728 s                                                                                               │
│         └── (85.1%) | VMC._forward_and_backward : 0.619 s                                                       │
│             └── (97.6%) | MCState.expect_and_grad : 0.604 s                                                     │
│                 └── (69.3%) | MCState.sample : 0.419 s                

### Timing Custom Observable Calculations

You can also use the timing system to profile custom observable calculations within your NetKet workflows:

In [8]:
# Example of how timing is used internally (simplified version)
@timing.timed(name="estimate observables")
def estimate_observables(state, observables):
    """This mimics how NetKet drivers time observable estimation."""
    results = {}
    for name, obs in observables.items():
        with timing.timed_scope(f"observable: {name}"):
            results[name] = state.expect(obs)
    return results

# Demonstrate the pattern
observables = {"energy": ha, "magnetization": nk.operator.spin.sigmax(hi, 0)}

with timing.Timer() as timer:
    for i in range(3):
        with timing.timed_scope(f"iteration {i}"):
            estimates = estimate_observables(vs, observables)

print(timer)

╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.516                                                                                                    │
│ └── (99.1%) | iteration 0 : 0.511 s                                                                             │
│     └── (100.0%) | estimate observables : 0.511 s                                                               │
│         ├── (19.0%) | observable: energy : 0.097 s                                                              │
│         │   └── (100.0%) | MCState.expect : 0.097 s                                                             │
│         │       └── (3.2%) | MCState.sample : 0.003 s                                                           │
│         │           └── (26.0%) | sampling n_discarded samples : 0.001 s                                        │
│         └── (81.0%) | observable: magnetization : 0.414 s             

In [9]:
# The driver.run() method with timeit=True will automatically show timing information
# This is implemented using the same timing system we've been exploring

# Example of how timing is used internally (simplified version)
@timing.timed(name="estimate observables")
def estimate_observables(state, observables):
    """This mimics how NetKet drivers time observable estimation."""
    results = {}
    for name, obs in observables.items():
        with timing.timed_scope(f"observable: {name}"):
            results[name] = state.expect(obs)
    return results

# Demonstrate the pattern
observables = {"energy": ha, "magnetization": nk.operator.spin.sigmax(hi, 0)}

with timing.Timer() as timer:
    for i in range(3):
        with timing.timed_scope(f"iteration {i}"):
            estimates = estimate_observables(vs, observables)

print(timer)

╭────────────────────────────────────────────── Timing Information ───────────────────────────────────────────────╮
│ Total: 0.007                                                                                                    │
│ ├── (39.0%) | iteration 0 : 0.003 s                                                                             │
│ │   └── (97.9%) | estimate observables : 0.003 s                                                                │
│ │       ├── (42.3%) | observable: energy : 0.001 s                                                              │
│ │       │   └── (99.3%) | MCState.expect : 0.001 s                                                              │
│ │       └── (55.7%) | observable: magnetization : 0.002 s                                                       │
│ │           └── (99.4%) | MCState.expect : 0.002 s                                                              │
│ ├── (29.4%) | iteration 1 : 0.002 s                                   