### **BENCHMARKING**
**`_apply_matrix_to_single_state()`**: 
  1) only vmap in the core part;
  2) without vmap in the core part;
  3) jit everything (3a: same but just remove jit decorator);
  4) more straightforward implementation
  5) general case of implementation #4

For the one qubit case, the fastest implementation is #4. Also for two-qubit gates, it is better to "hard code" the for loops. However, the general case seem to scale well with the number of affected qubits.
The library currenctly uses implementation #4.

In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrand
from Qubitly.states import normalize_array
from Qubitly.states import WaveFunction, _NO_RANDOMNESS
from Qubitly.circuits import QuantumCircuit
from Qubitly.gates import Hadamard, CNOT

In [2]:
# IMPLEMENTATION 1
def _apply_matrix_to_single_site_1(matrix: jnp.ndarray, site: int, vector: jnp.ndarray):
    dim = len(vector)
    assert 0 <= site < dim

    result = jnp.zeros_like(vector)

    for s_out in range(2):
        for s_in in range(2):
            if matrix[s_out][s_in] != 0:
                # Compute the contribution of matrix element s_out, s_in (s_out, s_in ∈ {0, 1})
                def basis_state_matrix_element_contribution(p_in: int) -> (int, jnp.complex64):
                    # For basis state p, extract the bit representing the site
                    bit = (p_in >> site) & 1
                
                    def true_branch(p_in): # Works even if no argument is passed to the two branches
                        # Compute the basis state that the matrix element under scrutiny produces a contribution to (this is p_out)
                        mask = ~(1 << site)
                        p_out = (p_in & mask) | (s_out << site)
                        return p_out, matrix[s_out, s_in]*vector[p_in]
                        
                    def false_branch(p_in):
                        return 0, 0.0 + 0.0j

                    return jax.lax.cond(bit == s_in, true_branch, false_branch, operand=p_in)

                idxs, amps = jax.vmap(basis_state_matrix_element_contribution)(jnp.arange(dim))
                result = result.at[idxs].add(amps)

    return result



# IMPLEMENTATION 2
def _apply_matrix_to_single_site_2(matrix: jnp.ndarray, site: int, vector: jnp.ndarray):
    dim = len(vector)
    assert matrix.shape == (2, 2)
    assert 0 <= site < dim

    result = jnp.zeros_like(vector)

    for s_out in range(2):
        for s_in in range(2):
            if matrix[s_out][s_in] != 0:
                # Compute the contribution of matrix element s_out, s_in (s_out, s_in ∈ {0, 1})
                def basis_state_matrix_element_contribution(p_in: int) -> (int, jnp.complex64):
                    # For basis state p, extract the bit representing the site
                    bit = (p_in >> site) & 1
                    if bit == s_in:
                        # Compute the basis state that the matrix element under scrutiny produces a contribution to (this is p_out)
                        mask = ~(1 << site)
                        p_out = (p_in & mask) | (s_out << site)

                        return p_out, matrix[s_out, s_in]*vector[p_in]
                    else:
                        return 0, 0.0 + 0.0j

                for i in range(dim):
                    idx, amp = basis_state_matrix_element_contribution(i)
                    result = result.at[idx].add(amp)

    return result


# IMPLEMENTATION 3
@jax.jit
def _apply_matrix_to_single_site_3(matrix: jnp.ndarray, site: int, vector: jnp.ndarray):
    dim = len(vector)
    # assert 0 <= site < dim # Not allowed for jit

    result = jnp.zeros_like(vector)

    for s_out in range(2):
        for s_in in range(2):
            def handle_non_zero_matrix_elem(result) -> jnp.ndarray:
                # Compute the contribution of matrix element s_out, s_in (s_out, s_in ∈ {0, 1})
                def basis_state_matrix_element_contribution(p_in: int) -> (int, jnp.complex64):
                    # For basis state p, extract the bit representing the site
                    bit = (p_in >> site) & 1
                
                    def true_branch(_): # Works even if no argument is passed to the two branches
                        # Compute the basis state that the matrix element under scrutiny produces a contribution to (this is p_out)
                        mask = ~(1 << site)
                        p_out = (p_in & mask) | (s_out << site)
                        return p_out, matrix[s_out, s_in] * vector[p_in]
                        
                    def false_branch(_):
                        return 0, 0.0 + 0.0j

                    return jax.lax.cond(bit == s_in, true_branch, false_branch, operand=None)

                idxs, amps = jax.vmap(basis_state_matrix_element_contribution)(jnp.arange(dim))
                result = result.at[idxs].add(amps)

                return result

            result = jax.lax.cond(jnp.abs(matrix[s_out][s_in]) > 1e-12, handle_non_zero_matrix_elem, lambda x: x, operand=result)

    return result


# IMPLEMENTATION 4
@jax.jit
def _apply_matrix_to_single_site_4(matrix: jnp.ndarray, site: int, vector: jnp.ndarray):

    def calculate_basis_state_contribution(p: int):
        site_bit = (p >> site) & 1

        mask = ~(1 << site)
        masked_p = p & mask
        
        def handle_matrix_element(i: int):
            idx = masked_p | (i << site)
            amp = matrix[i, site_bit] * vector[p]
            return idx, amp

        idxs_p, amps_p = jax.vmap(handle_matrix_element)(jnp.arange(2))
        return idxs_p, amps_p


    idxs, amps = jax.vmap(calculate_basis_state_contribution)(jnp.arange(len(vector)))
    
    result = jnp.zeros_like(vector)
    result = result.at[idxs].add(amps)
    return result  


# IMPLEMENTATION 5 (which is for the general case)
@jax.jit
def _apply_matrix_to_sites(matrix: jnp.ndarray, sites: jnp.array, vector: jnp.ndarray):
    # NOTE that argument "sites" cannot be a regular list. In fact, index "i" of fori_loops is a tracer and cannot be used to access elements of a list instead of a concrete integer: TracerIntegerConversionError is raised.

    def calculate_basis_state_contribution(p: int):
        affected_substate_number = jax.lax.fori_loop(0, len(sites), 
                                      lambda i, state: state | ((p >> sites[i]) & 1) << i, 
                                      0)

        mask = ~ jax.lax.fori_loop(0, len(sites), 
                                   lambda i, state: state | (1 << sites[i]), 
                                   0)
        masked_p = p & mask
        
        def handle_matrix_element(s: int):
            bitfiled_update = jax.lax.fori_loop(0, len(sites), 
                                                lambda i, state: state | ((s >> i) & 1) << sites[i], 
                                                0)
            idx = masked_p | bitfiled_update
            amp = matrix[s, affected_substate_number] * vector[p]
            return idx, amp

        idxs_p, amps_p = jax.vmap(handle_matrix_element)(jnp.arange(2**len(sites)))
        return idxs_p, amps_p


    idxs, amps = jax.vmap(calculate_basis_state_contribution)(jnp.arange(len(vector)))
    
    result = jnp.zeros_like(vector)
    result = result.at[idxs].add(amps)
    return result

In [3]:
sigma_x_arr = jnp.array([[0.0, 1.0], 
                     [1.0, 0.0]], dtype=jnp.complex64)

_low_dim_state = jnp.ones(2**3, dtype=jnp.complex64)
_low_dim_state = normalize_array(_low_dim_state)

_mid_dim_state = jnp.ones(2**5, dtype=jnp.complex64)
_mid_dim_state = normalize_array(_mid_dim_state)

_high_dim_state = jnp.ones(2**10, dtype=jnp.complex64)
_high_dim_state = normalize_array(_high_dim_state)

**Comparison between implementation #1 and implementation #2 for vectors of increasing size.**
Turns out that the naive implementation performs better for up to two spins. For 10 spins, the vectorized implementation is around 50 times faster.

In [4]:
print("Implementation 1, low dimensional vector")
%timeit tr_low_dim_state = _apply_matrix_to_single_site_1(sigma_x_arr, 2, _low_dim_state)
print("Implementation 1, mid dimensional vector")
%timeit tr_mid_dim_state = _apply_matrix_to_single_site_1(sigma_x_arr, 4, _mid_dim_state)
print("Implementation 1, high dimensional vector")
%timeit tr_high_dim_state = _apply_matrix_to_single_site_1(sigma_x_arr, 9, _high_dim_state)

print()

print("Implementation 2, low dimensional vector")
%timeit tr_low_dim_state = _apply_matrix_to_single_site_2(sigma_x_arr, 2, _low_dim_state)
print("Implementation 2, mid dimensional vector")
%timeit tr_mid_dim_state = _apply_matrix_to_single_site_2(sigma_x_arr, 4, _mid_dim_state)
print("Implementation 2, high dimensional vector")
%timeit tr_high_dim_state = _apply_matrix_to_single_site_2(sigma_x_arr, 9, _high_dim_state)

Implementation 1, low dimensional vector
28.5 ms ± 1.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Implementation 1, mid dimensional vector
26.1 ms ± 894 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Implementation 1, high dimensional vector
32.9 ms ± 662 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Implementation 2, low dimensional vector
19.1 ms ± 292 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Implementation 2, mid dimensional vector
68.3 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Implementation 2, high dimensional vector
2.17 s ± 198 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


**Performance of implementation #3**


In [5]:
print("Implementation 3, low dimensional vector")
%timeit tr_low_dim_state = _apply_matrix_to_single_site_3(sigma_x_arr, 2, _low_dim_state)
print("Implementation 3, mid dimensional vector")
%timeit tr_mid_dim_state = _apply_matrix_to_single_site_3(sigma_x_arr, 4, _mid_dim_state)
print("Implementation 3, high dimensional vector")
%timeit tr_high_dim_state = _apply_matrix_to_single_site_3(sigma_x_arr, 8, _high_dim_state)

# WITHOUT JITTING
# Implementation 3, low dimensional vector
# 184 ms ± 1.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Implementation 3, mid dimensional vector
# 192 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Implementation 3, high dimensional vector
# 193 ms ± 905 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Implementation 3, low dimensional vector
5.25 μs ± 174 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Implementation 3, mid dimensional vector
5.39 μs ± 198 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Implementation 3, high dimensional vector
19 μs ± 827 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


**Performance of implementation #4**

In [6]:
print("Implementation 4, low dimensional vector")
%timeit tr_low_dim_state = _apply_matrix_to_single_site_4(sigma_x_arr, 2, _low_dim_state)
print("Implementation 4, mid dimensional vector")
%timeit tr_mid_dim_state = _apply_matrix_to_single_site_4(sigma_x_arr, 4, _mid_dim_state)
print("Implementation 4, high dimensional vector")
%timeit tr_high_dim_state = _apply_matrix_to_single_site_4(sigma_x_arr, 8, _high_dim_state)

# WITHOUT JITTING
# Implementation 4, low dimensional vector
# 7.63 ms ± 322 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Implementation 4, mid dimensional vector
# 7.58 ms ± 421 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Implementation 4, high dimensional vector
# 7.44 ms ± 318 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Implementation 4, low dimensional vector
4.61 μs ± 293 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Implementation 4, mid dimensional vector
4.63 μs ± 22.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Implementation 4, high dimensional vector
12.9 μs ± 183 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


**Performance of implementation #5 (general case)**

In [7]:
print("Implementation 5, low dimensional vector")
%timeit tr_low_dim_state = _apply_matrix_to_sites(sigma_x_arr, jnp.array([2]), _low_dim_state)
print("Implementation 5, mid dimensional vector")
%timeit tr_mid_dim_state = _apply_matrix_to_sites(sigma_x_arr, jnp.array([4]), _mid_dim_state)
print("Implementation 5, high dimensional vector")
%timeit tr_high_dim_state = _apply_matrix_to_sites(sigma_x_arr, jnp.array([8]), _high_dim_state)

# WITHOUT JITTING (NO DIFFERENCE IN THIS CASE!!!)
# Implementation 5, low dimensional vector
# 62.8 μs ± 288 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# Implementation 5, mid dimensional vector
# 62.7 μs ± 413 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# Implementation 5, high dimensional vector
# 71.7 μs ± 662 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Implementation 5, low dimensional vector
63.3 μs ± 111 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Implementation 5, mid dimensional vector
63.6 μs ± 73 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Implementation 5, high dimensional vector
72 μs ± 110 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


**`_apply_matrix_to_two_sites()`**: 
  1) adapted from single site implementation #3
  2) adapted from single site implementation #4
  3) generalization of single site implementation #4 (same function as single site case #5)

The library currently uses implementation #2

In [8]:
# IMPLEMENTATION 1
@jax.jit
def _apply_matrix_to_two_sites_1(matrix: jnp.ndarray, sites: list[int], vector: jnp.ndarray):
    dim = len(vector)
    # assert 0 <= site < dim # Not allowed for jit

    result = jnp.zeros_like(vector)

    for ss_out in range(4):
        for ss_in in range(4):
            def handle_non_zero_matrix_elem(result) -> jnp.ndarray:
                # Compute the contribution of matrix element s_out, s_in (s_out, s_in ∈ {0, 1})
                def basis_state_matrix_element_contribution(p_in: int) -> (int, jnp.complex64):
                    
                    def extract_bit(site: int):
                        # For basis state p, extract the bit representing site
                        return (p_in >> site) & 1
                    bits = jax.vmap(extract_bit)(jnp.array(sites))
                
                    def true_branch(_): # Variables are captured from the scope
                        # Compute the basis state that the matrix element under scrutiny produces a contribution to (this is p_out)
                        mask = ~( (1 << sites[0]) | (1 << sites[1]) )
                        p_out = (p_in & mask) | ((ss_out & 1) << sites[0]) | ((ss_out >> 1) << sites[1])
                        # p_out = (p_in & mask) | (((ss_out >> 0) & 1) << sites[1]) | (((ss_out >> 1) & 1) << sites[0])
                        return p_out, matrix[ss_out, ss_in]*vector[p_in]

                    return jax.lax.cond(jnp.all(bits == jnp.array([ss_in & 1, ss_in >> 1])), true_branch, lambda _: (0, 0.0 + 0.0j), operand=None)

                idxs, amps = jax.vmap(basis_state_matrix_element_contribution)(jnp.arange(dim))
                result = result.at[idxs].add(amps)

                return result

            result = jax.lax.cond(jnp.abs(matrix[ss_out][ss_in]) > 1e-12, handle_non_zero_matrix_elem, lambda x: x, operand=result)

    return result


# IMPLEMENTATION 2
@jax.jit
def _apply_matrix_to_two_sites_2(matrix: jnp.ndarray, sites: list[int], vector: jnp.ndarray):

    def calculate_basis_state_contribution(p: int):
        least_important_bit = (p >> sites[0]) & 1
        most_important_bit = (p >> sites[1]) & 1
        sites_number = (most_important_bit << 1) | least_important_bit

        mask = ~( (1 << sites[0]) | (1 << sites[1]) )
        masked_p = p & mask
        
        def handle_matrix_element(i: int):
            idx = masked_p | ((i & 1) << sites[0]) | ((i >> 1) << sites[1])
            amp = matrix[i, sites_number] * vector[p]
            return idx, amp

        idxs_p, amps_p = jax.vmap(handle_matrix_element)(jnp.arange(4))
        return idxs_p, amps_p


    idxs, amps = jax.vmap(calculate_basis_state_contribution)(jnp.arange(len(vector)))
    
    result = jnp.zeros_like(vector)
    result = result.at[idxs].add(amps)
    return result  

In [9]:
CNOT_arr = jnp.array([[1.0, 0.0, 0.0, 0.0], 
                  [0.0, 0.0, 0.0, 1.0], 
                  [0.0, 0.0, 1.0, 0.0], 
                  [0.0, 1.0, 0.0, 0.0]], dtype=jnp.complex64)

_low_dim_state = jnp.ones(2**3, dtype=jnp.complex64)
_low_dim_state = normalize_array(_low_dim_state)

_mid_dim_state = jnp.ones(2**5, dtype=jnp.complex64)
_mid_dim_state = normalize_array(_mid_dim_state)

_high_dim_state = jnp.ones(2**10, dtype=jnp.complex64)
_high_dim_state = normalize_array(_high_dim_state)

In [10]:
_state = jnp.array([0, 1, 0, 0], dtype=jnp.complex64)

_result = _apply_matrix_to_two_sites_1(CNOT_arr, [0,1], _state)

print(_result)

# NOTE that the first element in the list is the control bit. By inspecting the matrix above you can deduce that the control bit is the least important of the two. 
# Hence the first element of the list corresponds to the least important (i. e. the rightmost) bit in the matrix.

[0.+0.j 0.+0.j 0.+0.j 1.+0.j]


**Performance of implementation #1 (adapted from single-site implementation #3)**

In [11]:
print("Low dimensional vector")
%timeit _apply_matrix_to_two_sites_1(CNOT_arr, [0,2], _low_dim_state)
print("Mid dimensional vector")
%timeit _apply_matrix_to_two_sites_1(CNOT_arr, [1,4], _mid_dim_state)
print("High dimensional vector")
%timeit _apply_matrix_to_two_sites_1(CNOT_arr, [2,9], _high_dim_state)

# WITHOUT JITTING
# Low dimensional vector
# 855 ms ± 53.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Mid dimensional vector
# 870 ms ± 9.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# High dimensional vector
# 845 ms ± 20.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Low dimensional vector
13.6 μs ± 7.8 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Mid dimensional vector
14.6 μs ± 7.12 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
High dimensional vector
55.3 μs ± 6.73 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


**Performance of implementation #2 (adapted from single-site implementation #4)**

In [12]:
print("Low dimensional vector")
%timeit _apply_matrix_to_two_sites_2(CNOT_arr, [0,2], _low_dim_state)
print("Mid dimensional vector")
%timeit _apply_matrix_to_two_sites_2(CNOT_arr, [1,4], _mid_dim_state)
print("High dimensional vector")
%timeit _apply_matrix_to_two_sites_2(CNOT_arr, [2,9], _high_dim_state)

# WITHOUT JITTING
# Low dimensional vector
# 9.61 ms ± 562 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Mid dimensional vector
# 9.57 ms ± 273 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# High dimensional vector
# 9.62 ms ± 398 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Low dimensional vector
6.05 μs ± 95.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Mid dimensional vector
5.92 μs ± 15.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
High dimensional vector
22.7 μs ± 46.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


**Performance of the general function (implementation #3 - see single-site implementation #5)**

In [13]:
print("Low dimensional vector")
%timeit _apply_matrix_to_sites(CNOT_arr, jnp.array([0,2]), _low_dim_state)
print("Mid dimensional vector")
%timeit _apply_matrix_to_sites(CNOT_arr, jnp.array([1,4]), _mid_dim_state)
print("High dimensional vector")
%timeit _apply_matrix_to_sites(CNOT_arr, jnp.array([2,9]), _high_dim_state)

Low dimensional vector
70.1 μs ± 177 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Mid dimensional vector
70.7 μs ± 246 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
High dimensional vector
92.9 μs ± 7.68 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


**`_measure_computational_basis()`**: 
  1) keeps an array of the original size masking out bits that are irrelevant for the measurement (1b: same but without jitting)
  2) computes probabilities extracting only relevant bits into a smalller array (current implementation)

In [14]:
@jax.jit
def _measure_computational_basis_1(key, amplitudes: jnp.ndarray, qubits_to_measure: jnp.ndarray):
    
    def get_single_quibit_mask(qubit: int):
        return 1 << qubit
    single_qubit_masks = jax.vmap(get_single_quibit_mask)(qubits_to_measure)
    measure_mask = jax.lax.fori_loop(0, len(single_qubit_masks), lambda i, measure_mask: measure_mask | single_qubit_masks[i], 0)
    
    def apply_measure_mask(p: int) -> (int, jnp.complex64):
        idx = p & measure_mask
        amp = amplitudes[p]
        return idx, amp
    idxs, amps = jax.vmap(apply_measure_mask)(jnp.arange(amplitudes.shape[0]))
    prob_of_masked_basis_states = jnp.zeros_like(amplitudes)
    prob_of_masked_basis_states = prob_of_masked_basis_states.at[idxs].add(jnp.abs(amps)**2)

    # Sample masked basis state
    key, subkey = jrand.split(key)
    r = jrand.uniform(subkey)

    cum_of_masked_basis_states = jnp.cumsum(prob_of_masked_basis_states)
    sampled_masked_basis_state_idx = jnp.searchsorted(cum_of_masked_basis_states, r)

    # Build state from those amplitudes whose corresponding basis state index equals the sampled one when masked
    # We call those indices "compliant"
    def is_basis_state_idx_compliant(p: int):
        return (p & measure_mask) == sampled_masked_basis_state_idx
    compliant_idxs_mask = jax.vmap(is_basis_state_idx_compliant)(jnp.arange(amplitudes.shape[0]))

    measured_amplitudes = jnp.where(compliant_idxs_mask, amplitudes, 0.0)
    measured_amplitudes = normalize_array(measured_amplitudes)

    return key, measured_amplitudes


@jax.jit
def _measure_computational_basis_2(key, amplitudes: jnp.ndarray, qubits_to_measure: jnp.ndarray):
    
    # Build probability vector for the measured subspace
    def extract_amp_and_state_num(p: int):
        amp = amplitudes[p]
        
        def add_qubit_contribution_to_state_num(i, old_state_num):
            bit_value = (p >> qubits_to_measure[i]) & 1
            new_state_num = old_state_num | (bit_value << i)
            return new_state_num
        state_num = jax.lax.fori_loop(0, len(qubits_to_measure), add_qubit_contribution_to_state_num, 0)

        return amp, state_num

    amps, state_nums = jax.vmap(extract_amp_and_state_num)(jnp.arange(amplitudes.shape[0]))
    probabilities = jnp.zeros(2 ** qubits_to_measure.shape[0])
    probabilities = probabilities.at[state_nums].add(jnp.abs(amps)**2)

    # Sample state in measured subspace
    key, subkey = jrand.split(key)
    r = jrand.uniform(subkey)

    cumulative = jnp.cumsum(probabilities)
    sampled_state_num = jnp.searchsorted(cumulative, r)

    # Compute amplitutudes after the measurement by retaining only amplitudes that are tied to the measure outcome
    result_amplitudes = jnp.where(state_nums == sampled_state_num, amplitudes, 0.0)
    result_amplitudes = normalize_array(result_amplitudes)

    return key, result_amplitudes

@jax.jit
def _measure_all_computational_basis(key, amplitudes: jnp.ndarray):
    key, subkey = jrand.split(key)
    r = jrand.uniform(subkey)
    
    probabilities = jnp.abs(amplitudes) ** 2
    cumulative = jnp.cumsum(probabilities)
    sampled_basis_state = jnp.searchsorted(cumulative, r) # Find the index of the smallest element of the array among those that are grater than r
    
    measured_amplitudes = jnp.zeros_like(amplitudes, dtype=jnp.complex64)
    measured_amplitudes = measured_amplitudes.at[sampled_basis_state].set(1.0)

    return key, sampled_basis_state, measured_amplitudes

In [15]:
amplitudes = jnp.array([1, 1, 0, 1], dtype=jnp.complex64)
# i. e. |00> + |01> + |11>
key = jrand.key(2)


# Measure qubit 0
%timeit _measure_computational_basis_1(key, amplitudes, jnp.array([0]))
# Possible outcomes: |00> with probability 1/3, |01> + |11> with probability 2/3
# It would make more sense to use a different random key at every iteration

# WITHOUT JITTING
# 19.5 ms ± 86 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Measure qubit 0
%timeit _measure_computational_basis_2(key, amplitudes, jnp.array([0]))
# Possible outcomes: |00> with probability 1/3, |01> + |11> with probability 2/3
# It would make more sense to use a different random key at every iteration

# WITHOUT JITTING
# 25.2 ms ± 183 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

386 μs ± 80 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
320 μs ± 1.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


**`_measure_all_computational_basis()`**: 
  1) current implemetation (1b: same but without jitting)

In [16]:
amplitudes = jnp.array([1, 1, 0, 1], dtype=jnp.complex64)
# i. e. |00> + |01> + |11>
key = jrand.key(2)


# Measure qubit 0
%timeit _measure_all_computational_basis(key, amplitudes)
# It would make more sense to use a different random key at every iteration

# WITHOUT JITTING
# 1.46 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

226 μs ± 1.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


**Jitting QuantumCircuit or not?**

Jitting QuantumCircuit seems indeed a good idea. The speedup is bigger for longer circuits, smaller for big n_qubits.
Jitting an instance of QuantumCircuit results in slightly better performance than using method `jit_call()`, which in turn calls a jitted version of the call function.

In [2]:
low_dim = 3
mid_dim = 8
high_dim = 15

for dim in [low_dim, mid_dim, high_dim]:
    _0 = WaveFunction.from_string("0"*dim)
    GHZ_preparation = QuantumCircuit(
        Hadamard(0),
        *[CNOT(control=i-1, target=i) for i in range(1, dim)],
    )
    %timeit GHZ_preparation(_0)
    %timeit GHZ_preparation.jit_call(_0)
    
    GHZ_preparation_jit = jax.jit(GHZ_preparation)
    %timeit GHZ_preparation_jit(_0)
    
    print()

34.6 μs ± 8.45 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
5.83 μs ± 1.02 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
4.76 μs ± 83.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

103 μs ± 264 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.2 μs ± 39.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
26.7 μs ± 96.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

10.8 ms ± 7.83 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.9 ms ± 119 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.95 ms ± 189 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

