# Reinforcement Learning Functions

In this notebook the step, reset, and reward functions will be defined and tested. Every function created in this notebook will be in `main.py` and the documentation can be seen there.

## Update Fisher Information 

This function will update the fisher information of each parameter depending on the root choice and protocol choice.

| Protocol | Numeric Value |
|----------|----------|
| Multicast   | 0  |
| Root Independent    | 1  |
| Independent Encoding X basis   | 2  |
| Independent Encoding Z basis   | 3  |
| Back and Forth Z basis    | 4  |
| Back and Forth X basis   | 5  |

In [92]:
import numpy as np

In [None]:
def updated_fisher_information(root, protocol, theta):
    """
    Calculates the Fisher Information matrix for quantum network parameters based on the chosen protocol.
    
    This function computes Fisher Information contributions for each parameter depending on 
    the quantum measurement protocol being used in a 3-node quantum flip star network.
    
    Parameters:
    -----------
    root : int
        The index of the root node/channel (0, 1, or 2) that serves as the central measurement point
    protocol : int
        The quantum protocol to use:
        - 0: Multicast Protocol - All channels receive complex Fisher information updates
        - 1: Root Independent Protocol - All channels except root receive 1/(θ-θ²) information
        - 2: Independent Encoding X basis - Only root receives 1/(θ-θ²) information  
        - 3: Independent Encoding Z basis - All channels receive 1/(θ-θ²) information
        - 4: Back and Forth Z basis - Complex Fisher information updates for all channels
        - 5: Back and Forth X basis - Only root receives 1/(θ-θ²) information
    theta : list
        Quantum channel error parameters [θ₀, θ₁, θ₂] representing flip probabilities
        Must be in range (0,1) and not equal to 0 or 1 to avoid division by zero
        
    Returns:
    --------
    list
        New Fisher Information matrix [[f_00,f_01,f_02],
                                        [f_10,f_11,f_12], 
                                        [f_20,f_21,f_22]] containing the calculated contributions 
          for each parameter based on the specified protocol
        
    Notes:
    ------
    - Fisher Information quantifies how much information a measurement provides about parameters
    - Higher Fisher Information indicates better parameter estimation capability
    - The function uses modular arithmetic to handle circular indexing of the 3-node network
    - Protocols 0 and 4 use complex multivariate Fisher information formulas derived from 
      quantum network tomography theory
    - Protocols 1, 2, 3, 5 use the simpler univariate formula 1/(θᵢ(1-θᵢ))
    - The function initializes a new Fisher Information list with zeros for each call
    - The fisher information matrix is symmetric so only the upper triangle is filled
    - The function raises a ZeroDivisionError if any theta value equals 0 or 1 in protocols that use 1/(θ-θ²)
    Raises:
    -------
    ZeroDivisionError
        If any theta value equals 0 or 1 in protocols that use 1/(θ-θ²) which is not physically possible.
    """
  
    # Initialize Fisher Information list if not provided
    fisher_info = np.zeros((len(theta), len(theta)), dtype=float)


    # ----- Multi cast Protocol -----
    if protocol == 0:

        theta_0 = theta[root]
        theta_1 = theta[(root + 1) % len(theta)]
        theta_2 = theta[(root + 2) % len(theta)]

      
        # -------------------------------------------------------------------#
        ## Calculate Diagonals ##
        # -------------------------------------------------------------------#

        # Root channel information
        fisher_info[root,root] += (
        (theta_1 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_1 + theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_1 - theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_1 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )
        #print(f"Fisher info for root {root}: {fisher_info[root]}")

        # Other Channel information
        fisher_info[(root + 1) % len(fisher_info), (root + 1) % len(fisher_info)] += (
        (theta_0 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_0 - theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        fisher_info[(root + 2) % len(fisher_info), (root + 2) % len(fisher_info)] += (

        (1 - theta_0 - theta_1)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (theta_0 - theta_1)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_1)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_1)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)

        )

        # -------------------------------------------------------------------#
        ## Calculate Off-Diagonals ##
        # -------------------------------------------------------------------#

        # F_01 (Assuming root is 0)
        fisher_info[root, (root + 1) % len(fisher_info)] += (
          ((theta_0 - theta_2) * (theta_1 - theta_2)) / 
          (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

          ((1 - theta_0 - theta_2) * (-theta_1 + theta_2)) / 
          (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

          ((1 - theta_1 - theta_2) * (-theta_0 + theta_2)) / 
          (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

          ((-1 + theta_0 + theta_2) * (-1 + theta_1 + theta_2)) / 
          (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )   

        # F_02 (Assuming root is 0)
        fisher_info[root, (root + 2) % len(fisher_info)] += (
            ((1 - theta_0 - theta_1) * (theta_1 - theta_2)) / 
            (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

            ((theta_0 - theta_1) * (-theta_1 + theta_2)) / 
            (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

            ((-theta_0 + theta_1) * (1 - theta_1 - theta_2)) / 
            (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

            ((-1 + theta_0 + theta_1) * (-1 + theta_1 + theta_2)) / 
            (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
          )

        # F_12 (Assuming root is 0)
        fisher_info[(root + 1) % len(fisher_info), (root + 2) % len(fisher_info)] += (
            ((1 - theta_0 - theta_1) * (theta_0 - theta_2)) / 
            (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

            ((theta_0 - theta_1) * (1 - theta_0 - theta_2)) / 
            (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

            ((-theta_0 + theta_1) * (-theta_0 + theta_2)) / 
            (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

            ((-1 + theta_0 + theta_1) * (-1 + theta_0 + theta_2)) / 
            (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
          )

    # ----- Root Independent Protocol -----
    # In this protocol, the parameter nearest to the root does not receive any information. The other parameters receive information based on their own value. All off diagonal elements are zero.
    elif protocol == 1: # Root Independent
        for i in range(len(theta)): # Loop through all parameters
            if i != root: 
                fisher_info[i,i]+=1/(theta[i]-theta[i]**2)



    # ----- Independent Encoding Protocol X basis -----
    # In this protocol, the root parameter receives information based on its own value, while the other parameters dont receive an information. All off diagonal elements are zero.
    elif protocol == 2: # Independent Encoding X basis
        fisher_info[root, root] += 1/(theta[root]-theta[root]**2) # The root parameter receives information based on its own value.


    # ----- Independent Encoding Protocol Z basis -----
    # In this protocol, all parameters receive information based on their own value. Very similarly to the Root Independent Protocol, but without the root parameter. All off diagonal elements are zero.
    elif protocol == 3: # Independent Encoding Z basis
        for i in range(len(fisher_info)): # Loop through all parameters
            fisher_info[i,i]+=1/(theta[i]-theta[i]**2)



    # ----- Back and Forth Protocol Z basis -----
    # This protocol is complex similar to the multicast protocol; however, this protocol uses 3 qubits instead of 2.
    elif protocol == 4:
        theta_0 = theta[root]
        theta_1 = theta[(root + 1) % len(theta)]
        theta_2 = theta[(root + 2) % len(theta)]

        # -------------------------------------------------------------------#
        ## Calculate Diagonals ##
        # -------------------------------------------------------------------#

        # Root channel information
        fisher_info[root,root] += (
        -((theta_1 - theta_2)**2 / 
          (-theta_0 * theta_2 + theta_1 * (-1 + theta_0 + theta_2)))    

        - (theta_1 - theta_2)**2 / 
          ((-1 + theta_1) * theta_2 + theta_0 * (-theta_1 + theta_2)) 

        + (-1 + theta_1 + theta_2)**2 / 
          (theta_1 * theta_2 - theta_0 * (-1 + theta_1 + theta_2)) 

        + (-1 + theta_1 + theta_2)**2 / 
          ((-1 + theta_1) * (-1 + theta_2) + theta_0 * (-1 + theta_1 + theta_2))
        ) 

        # Other channel information
        fisher_info[(root + 1) % len(fisher_info), (root + 1) % len(fisher_info)] += (
        (theta_0 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_0 - theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        fisher_info[(root + 2) % len(fisher_info), (root + 2) % len(fisher_info)] += (
        (1 - theta_0 - theta_1)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 
    
        (theta_0 - theta_1)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 
    
        (-theta_0 + theta_1)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 
    
        (-1 + theta_0 + theta_1)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        # -------------------------------------------------------------------#
        ## Calculate Off- Diagonals ##
        # -------------------------------------------------------------------#

        # F_01 (Assuming root is 0)
        fisher_info[root, (root + 1) % len(fisher_info)] += (
            ((theta_0 - theta_2) * (theta_1 - theta_2)) / 
            (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

            ((1 - theta_0 - theta_2) * (-theta_1 + theta_2)) / 
            (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

            ((1 - theta_1 - theta_2) * (-theta_0 + theta_2)) / 
            (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

            ((-1 + theta_0 + theta_2) * (-1 + theta_1 + theta_2)) / 
            (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        # F_02 (Assuming root is 0)
        fisher_info[root, (root + 2) % len(fisher_info)] += (
            ((1 - theta_0 - theta_1) * (theta_1 - theta_2)) / 
            (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

            ((theta_0 - theta_1) * (-theta_1 + theta_2)) / 
            (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

            ((-theta_0 + theta_1) * (1 - theta_1 - theta_2)) / 
            (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

            ((-1 + theta_0 + theta_1) * (-1 + theta_1 + theta_2)) / 
            (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )
        
        # F_12 (Assuming root is 0)
        fisher_info[(root + 1) % len(fisher_info), (root + 2) % len(fisher_info)] += (
            ((1 - theta_0 - theta_1) * (theta_0 - theta_2)) / 
            (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

            ((theta_0 - theta_1) * (1 - theta_0 - theta_2)) / 
            (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

            ((-theta_0 + theta_1) * (-theta_0 + theta_2)) / 
            (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

            ((-1 + theta_0 + theta_1) * (-1 + theta_0 + theta_2)) / 
            (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )


    # ----- Back and Forth Protocol X basis -----
    # In this protocol, the root parameter receives information based on its own value, while the other parameters dont receive an information. This is exactly the same as the Independent Encoding Protocol X basis.
    elif protocol == 5:
        fisher_info[root, root] += 1/(theta[root]-theta[root]**2) # The root parameter receives information based on its own value.


    # Fill the lower triangle of the matrix to maintain symmetry
    for i in range(len(fisher_info)):
        for j in range(i + 1, len(fisher_info)):
            if fisher_info[i, j] != 0:
                # Copy upper triangle to lower triangle
                fisher_info[j, i] = fisher_info[i, j]
            else:
              fisher_info[i, j] = fisher_info[j, i]

    return fisher_info

### Testing Update Fisher IInformation

#### Multicast Protocol

In [106]:
## Testing Multicast Protocol
print("TESTING MULTICAST PROTOCOL (Protocol 0)")
print("=" * 60)
print("This protocol provides complex Fisher information updates for all channels")
print("Expected behavior: Fisher information distributed across all parameters\n")

# Test 1: General testing of the multicast protocol
print("Test 1: Basic Multicast Test")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 0

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root}")
print(f"  Protocol: {protocol} (Multicast)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [2.87307, 3.64343, 3.24879]")
print(f"Expected off-diagonal: F₀₁=1.98428, F₀₂=0.769805, F₁₂=-0.00334806")

# Check if results match expected values
expected = [2.87307, 3.64343, 3.24879]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]

expected_off_diagonal = [1.98428,0.769805,-0.00334806]
actual_off_diagonal = [fisher_info[0,1], fisher_info[0,2], fisher_info[1,2]]


tolerance = 0.001
off_diagonal_matches = all(abs(a - e) < tolerance for a, e in zip(actual_off_diagonal, expected_off_diagonal))
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))
print(f"PASS: Test 1 PASSED" if matches and off_diagonal_matches else f"FAIL: Test 1 FAILED")
print()

# -------------------------------------------------------------------

# Test 2: Testing with different theta values
print("Test 2: Different Theta Values")
print("-" * 30)
theta = [0.2, 0.1, 0.1]
root = 0
protocol = 0
fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root}")
print(f"  Protocol: {protocol} (Multicast)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [4.74932, 6.36823, 6.36823]")
print(f"Expected off-diagonal: F₀₁=0.39095, F₀₂=0.39095, F₁₂=2.36823")

expected = [4.74932, 6.36823, 6.36823]
expected_off_diagonal = [0.39095,0.39095, 2.36823]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))
actual_off_diagonal = [fisher_info[0,1], fisher_info[0,2], fisher_info[1,2]]
off_diagonal_matches = all(abs(a - e) < tolerance for a, e in zip(actual_off_diagonal, expected_off_diagonal))
print(f"PASS: Test 2 PASSED" if matches and off_diagonal_matches else f"FAIL: Test 2 FAILED")
print()

# -------------------------------------------------------------------

# Test 3: Testing root independence (changing root should give same pattern)
print("Test 3: Root Independence Verification")
print("-" * 30)
theta = [0.2, 0.1, 0.1]
root = 1
protocol = 0
fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (changed from 0)")
print(f"  Protocol: {protocol} (Multicast)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Note: Values should match Test 2 pattern due to multicast root independence")
print(f"  Expected diagonal: [4.74932, 6.36823, 6.36823]")
print(f"  Expected off-diagonal: F₀₁=0.39095, F₀₂=0.39095, F₁₂=2.36823")

expected = [4.74932, 6.36823, 6.36823]
expected_off_diagonal = [0.39095,0.39095, 2.36823]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))
actual_off_diagonal = [fisher_info[0,1], fisher_info[0,2], fisher_info[1,2]]
off_diagonal_matches = all(abs(a - e) < tolerance for a, e in zip(actual_off_diagonal, expected_off_diagonal))
print(f"PASS: Test 3 PASSED" if matches and off_diagonal_matches else f"FAIL: Test 3 FAILED")
print()


print("\n" + "=" * 60)
print("MULTICAST PROTOCOL TESTING COMPLETE")
print("Key observations:")
print("• All parameters receive Fisher information (no zeros)")
print("• Off-diagonal elements are non-zero (parameter correlations)")
print("• Protocol provides rich information structure for quantum sensing")
print("=" * 60)

TESTING MULTICAST PROTOCOL (Protocol 0)
This protocol provides complex Fisher information updates for all channels
Expected behavior: Fisher information distributed across all parameters

Test 1: Basic Multicast Test
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 0
  Protocol: 0 (Multicast)

Fisher Information Matrix (3x3):
  Diagonal:     [2.87307, 3.64343, 3.24879]
  Off-diagonal: F₀₁=1.98428, F₀₂=0.76981, F₁₂=-0.00335
Expected diagonal: [2.87307, 3.64343, 3.24879]
Expected off-diagonal: F₀₁=1.98428, F₀₂=0.769805, F₁₂=-0.00334806
PASS: Test 1 PASSED

Test 2: Different Theta Values
------------------------------
Input Parameters:
  θ = [0.2, 0.1, 0.1]
  Root node: 0
  Protocol: 0 (Multicast)

Fisher Information Matrix (3x3):
  Diagonal:     [4.74932, 6.36823, 6.36823]
  Off-diagonal: F₀₁=0.39095, F₀₂=0.39095, F₁₂=2.36823
Expected diagonal: [4.74932, 6.36823, 6.36823]
Expected off-diagonal: F₀₁=0.39095, F₀₂=0.39095, F₁₂=2.36823
PASS: Test 2 PASSED



#### Root Independent Protocol

In [107]:
## Testing Root Independent Protocol
print("TESTING ROOT INDEPENDENT PROTOCOL (Protocol 1)")
print("=" * 60)
print("This protocol excludes the root parameter from Fisher information updates")
print("Expected behavior: All parameters except root receive 1/(θ-θ²) information\n")

# Test 1: Basic root independent test
print("Test 1: Basic Root Independent Test")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 1

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root}")
print(f"  Protocol: {protocol} (Root Independent)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [0.00000, 6.25000, 4.76190]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

# Check if results match expected values
expected = [0.0, 6.25, 4.76190]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
tolerance = 0.001
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

# Check that off-diagonal elements are zero
off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 1 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 1 FAILED")
print()

# Test 2: Different root selection
print("Test 2: Root Selection Verification (root=1)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 1
protocol = 1

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (excludes θ₁)")
print(f"  Protocol: {protocol} (Root Independent)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [11.11111, 0.00000, 4.76190]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [11.11111, 0.0, 4.76190]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 2 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 2 FAILED")
print()

# Test 3: Third root selection
print("Test 3: Root Selection Verification (root=2)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 2
protocol = 1

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (excludes θ₂)")
print(f"  Protocol: {protocol} (Root Independent)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [11.11111, 6.25000, 0.00000]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [11.11111, 6.25, 0.0]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 3 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 3 FAILED")

print("\n" + "=" * 60)
print("ROOT INDEPENDENT PROTOCOL TESTING COMPLETE")
print("Key observations:")
print("• Root parameter always receives zero Fisher information")
print("• Non-root parameters receive 1/(θᵢ(1-θᵢ)) information")
print("• All off-diagonal elements are zero (no parameter correlations)")
print("• Protocol effectively isolates parameter estimation from root choice")
print("=" * 60)

TESTING ROOT INDEPENDENT PROTOCOL (Protocol 1)
This protocol excludes the root parameter from Fisher information updates
Expected behavior: All parameters except root receive 1/(θ-θ²) information

Test 1: Basic Root Independent Test
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 0
  Protocol: 1 (Root Independent)

Fisher Information Matrix (3x3):
  Diagonal:     [0.00000, 6.25000, 4.76190]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal: [0.00000, 6.25000, 4.76190]
Expected off-diagonal: All zeros (no parameter correlations)
PASS: Test 1 PASSED

Test 2: Root Selection Verification (root=1)
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 1 (excludes θ₁)
  Protocol: 1 (Root Independent)

Fisher Information Matrix (3x3):
  Diagonal:     [11.11111, 0.00000, 4.76190]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal: [11.11111, 0.00000, 4.76190]
Expected off-diagonal: All z

#### Independent Encoding X basis

In [108]:
## Testing Independent Encoding X basis Protocol
print("TESTING INDEPENDENT ENCODING X BASIS PROTOCOL (Protocol 2)")
print("=" * 60)
print("This protocol gives Fisher information only to the root parameter")
print("Expected behavior: Only root receives 1/(θ-θ²) information, others get zero\n")

# Test 1: Basic independent encoding X basis test
print("Test 1: Basic Independent Encoding X Test")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 2

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root}")
print(f"  Protocol: {protocol} (Independent Encoding X basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [11.11111, 0.00000, 0.00000]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

# Check if results match expected values
expected = [11.11111, 0.0, 0.0]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
tolerance = 0.001
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

# Check that off-diagonal elements are zero
off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 1 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 1 FAILED")
print()

# Test 2: Different root selection
print("Test 2: Root Selection Verification (root=1)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 1
protocol = 2

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (only θ₁ receives information)")
print(f"  Protocol: {protocol} (Independent Encoding X basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [0.00000, 6.25000, 0.00000]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [0.0, 6.25, 0.0]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 2 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 2 FAILED")
print()

# Test 3: Third root selection
print("Test 3: Root Selection Verification (root=2)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 2
protocol = 2

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (only θ₂ receives information)")
print(f"  Protocol: {protocol} (Independent Encoding X basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [0.00000, 0.00000, 4.76190]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [0.0, 0.0, 4.76190]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 3 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 3 FAILED")

print("\n" + "=" * 60)
print("INDEPENDENT ENCODING X BASIS PROTOCOL TESTING COMPLETE")
print("Key observations:")
print("• Only root parameter receives Fisher information")
print("• Root parameter receives 1/(θᵣₒₒₜ(1-θᵣₒₒₜ)) information")
print("• All non-root parameters receive zero information")
print("• All off-diagonal elements are zero (no parameter correlations)")
print("• Protocol provides highly focused parameter estimation")
print("=" * 60)

TESTING INDEPENDENT ENCODING X BASIS PROTOCOL (Protocol 2)
This protocol gives Fisher information only to the root parameter
Expected behavior: Only root receives 1/(θ-θ²) information, others get zero

Test 1: Basic Independent Encoding X Test
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 0
  Protocol: 2 (Independent Encoding X basis)

Fisher Information Matrix (3x3):
  Diagonal:     [11.11111, 0.00000, 0.00000]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal: [11.11111, 0.00000, 0.00000]
Expected off-diagonal: All zeros (no parameter correlations)
PASS: Test 1 PASSED

Test 2: Root Selection Verification (root=1)
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 1 (only θ₁ receives information)
  Protocol: 2 (Independent Encoding X basis)

Fisher Information Matrix (3x3):
  Diagonal:     [0.00000, 6.25000, 0.00000]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal: [0.0

#### Independent Encoding Z basis

In [114]:
## Testing Independent Encoding Z basis Protocol
print("TESTING INDEPENDENT ENCODING Z BASIS PROTOCOL (Protocol 3)")
print("=" * 60)
print("This protocol gives Fisher information to all parameters equally")
print("Expected behavior: All parameters receive 1/(θ-θ²) information, root-independent\n")

# Test 1: Basic independent encoding Z basis test
print("Test 1: Basic Independent Encoding Z Test")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 3

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root}")
print(f"  Protocol: {protocol} (Independent Encoding Z basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [11.11111, 6.25000, 4.76190]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

# Check if results match expected values
expected = [11.11111, 6.25, 4.76190]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
tolerance = 0.001
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

# Check that off-diagonal elements are zero
off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 1 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 1 FAILED")
print()

# Test 2: Root independence verification
print("Test 2: Root Independence Verification (root=1)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 1
protocol = 3

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (should not affect results)")
print(f"  Protocol: {protocol} (Independent Encoding Z basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [11.11111, 6.25000, 4.76190] (same as Test 1)")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [11.11111, 6.25, 4.76190]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 2 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 2 FAILED")
print()

# Test 3: Complete root independence verification
print("Test 3: Complete Root Independence (root=2)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 2
protocol = 3

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (third root option)")
print(f"  Protocol: {protocol} (Independent Encoding Z basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [11.11111, 6.25000, 4.76190] (same as Tests 1 & 2)")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [11.11111, 6.25, 4.76190]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 3 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 3 FAILED")

print("\n" + "=" * 60)
print("INDEPENDENT ENCODING Z BASIS PROTOCOL TESTING COMPLETE")
print("Key observations:")
print("• All parameters receive Fisher information (unlike X basis)")
print("• Each parameter receives 1/(θᵢ(1-θᵢ)) information based on its own value")
print("• Root selection has no effect on Fisher information distribution")
print("• All off-diagonal elements are zero (no parameter correlations)")
print("• Protocol is completely root-independent and parameter-isolated")
print("=" * 60)

TESTING INDEPENDENT ENCODING Z BASIS PROTOCOL (Protocol 3)
This protocol gives Fisher information to all parameters equally
Expected behavior: All parameters receive 1/(θ-θ²) information, root-independent

Test 1: Basic Independent Encoding Z Test
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 0
  Protocol: 3 (Independent Encoding Z basis)

Fisher Information Matrix (3x3):
  Diagonal:     [11.11111, 6.25000, 4.76190]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal: [11.11111, 6.25000, 4.76190]
Expected off-diagonal: All zeros (no parameter correlations)
PASS: Test 1 PASSED

Test 2: Root Independence Verification (root=1)
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 1 (should not affect results)
  Protocol: 3 (Independent Encoding Z basis)

Fisher Information Matrix (3x3):
  Diagonal:     [11.11111, 6.25000, 4.76190]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal:

#### Back and Forth Z basis

In [113]:
## Testing Back and Forth Z basis Protocol
print("TESTING BACK AND FORTH Z BASIS PROTOCOL (Protocol 4)")
print("=" * 60)
print("This protocol uses complex Fisher information formulas with 3 qubits")
print("Expected behavior: Complex multivariate Fisher information for all parameters\n")

# Test 1: Basic back and forth Z basis test
print("Test 1: Basic Back and Forth Z Test")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 4

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root}")
print(f"  Protocol: {protocol} (Back and Forth Z basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [2.87307, 3.64343, 3.24879]")
print(f"Expected off-diagonal: [1.98428,0.769805,-0.00334806]")

# Check if results match expected values
expected = [2.87307, 3.64343, 3.24879]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
tolerance = 0.001
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))
off_diagonal_expected = [1.98428, 0.769805, -0.00334806]
actual_off_diagonal = [fisher_info[0,1], fisher_info[0,2], fisher_info[1,2]]
off_diagonal_matches = all(abs(a - e) < tolerance for a, e in zip(actual_off_diagonal, off_diagonal_expected))



print(f"PASS: Test 1 PASSED" if matches and off_diagonal_matches else f"FAIL: Test 1 FAILED")
print()

# Test 2: Root behavior verification
print("Test 2: Root Behavior Verification (root=1)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 1
protocol = 4

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (different root selection)")
print(f"  Protocol: {protocol} (Back and Forth Z basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [2.87307, 3.64343, 3.24879]")
print(f"Expected off-diagonal: [1.98428,0.769805,-0.00334806]")

expected = [2.87307, 3.64343, 3.24879]
expected_off_diagonal = [1.98428, 0.769805, -0.00334806]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
actual_off_diagonal = [fisher_info[0,1], fisher_info[0,2], fisher_info[1,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_matches = all(abs(a - e) < tolerance for a, e in zip(actual_off_diagonal, expected_off_diagonal))

print(f"PASS: Test 2 PASSED" if matches and off_diagonal_matches else f"FAIL: Test 2 FAILED")
print()

# Test 3: Verify total Fisher information
print("Test 3: Total Fisher Information Analysis")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 2
protocol = 4

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (third root option)")
print(f"  Protocol: {protocol} (Back and Forth Z basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")

# Calculate total Fisher information (trace of matrix)
total_fisher_info = fisher_info[0,0] + fisher_info[1,1] + fisher_info[2,2]
print(f"Total Fisher Information (trace): {total_fisher_info:.5f}")

# Verify protocol produces meaningful Fisher information
meaningful_info = total_fisher_info > 1.0  # Reasonable threshold
off_diagonal_present = (abs(fisher_info[0,1]) > tolerance or 
                       abs(fisher_info[0,2]) > tolerance or 
                       abs(fisher_info[1,2]) > tolerance)

print(f"PASS: Test 3 PASSED" if meaningful_info and off_diagonal_present else f"FAIL: Test 3 FAILED")

print("\n" + "=" * 60)
print("BACK AND FORTH Z BASIS PROTOCOL TESTING COMPLETE")
print("Key observations:")
print("• Uses complex multivariate Fisher information formulas")
print("• All parameters receive Fisher information contributions")
print("• Off-diagonal elements present (parameter correlations)")
print("• Protocol behavior may vary with root selection")
print("• Uses 3 qubits instead of 2 (more complex than multicast)")
print("=" * 60)

TESTING BACK AND FORTH Z BASIS PROTOCOL (Protocol 4)
This protocol uses complex Fisher information formulas with 3 qubits
Expected behavior: Complex multivariate Fisher information for all parameters

Test 1: Basic Back and Forth Z Test
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 0
  Protocol: 4 (Back and Forth Z basis)

Fisher Information Matrix (3x3):
  Diagonal:     [2.87307, 3.64343, 3.24879]
  Off-diagonal: F₀₁=1.98428, F₀₂=0.76981, F₁₂=-0.00335
Expected diagonal: [2.87307, 3.64343, 3.24879]
Expected off-diagonal: [1.98428,0.769805,-0.00334806]
PASS: Test 1 PASSED

Test 2: Root Behavior Verification (root=1)
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 1 (different root selection)
  Protocol: 4 (Back and Forth Z basis)

Fisher Information Matrix (3x3):
  Diagonal:     [2.87307, 3.64343, 3.24879]
  Off-diagonal: F₀₁=1.98428, F₀₂=0.76981, F₁₂=-0.00335
Expected diagonal: [2.87307, 3.64343, 3.24879]
Expecte

#### Back and Forth X basis

In [117]:
## Testing Back and Forth X basis Protocol
print("TESTING BACK AND FORTH X BASIS PROTOCOL (Protocol 5)")
print("=" * 60)
print("This protocol behaves identically to Independent Encoding X basis")
print("Expected behavior: Only root receives 1/(θ-θ²) information, others get zero\n")

# Test 1: Basic back and forth X basis test
print("Test 1: Basic Back and Forth X Test")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 5

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root}")
print(f"  Protocol: {protocol} (Back and Forth X basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [11.11111, 0.00000, 0.00000]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

# Check if results match expected values
expected = [11.11111, 0.0, 0.0]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
tolerance = 0.001
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

# Check that off-diagonal elements are zero
off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 1 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 1 FAILED")
print()

# Test 2: Different root selection
print("Test 2: Root Selection Verification (root=1)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 1
protocol = 5

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (only θ₁ receives information)")
print(f"  Protocol: {protocol} (Back and Forth X basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [0.00000, 6.25000, 0.00000]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [0.0, 6.25, 0.0]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 2 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 2 FAILED")
print()

# Test 3: Third root selection
print("Test 3: Root Selection Verification (root=2)")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 2
protocol = 5

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Input Parameters:")
print(f"  θ = {theta}")
print(f"  Root node: {root} (only θ₂ receives information)")
print(f"  Protocol: {protocol} (Back and Forth X basis)")
print(f"\nFisher Information Matrix (3x3):")
print(f"  Diagonal:     [{fisher_info[0,0]:.5f}, {fisher_info[1,1]:.5f}, {fisher_info[2,2]:.5f}]")
print(f"  Off-diagonal: F₀₁={fisher_info[0,1]:.5f}, F₀₂={fisher_info[0,2]:.5f}, F₁₂={fisher_info[1,2]:.5f}")
print(f"Expected diagonal: [0.00000, 0.00000, 4.76190]")
print(f"Expected off-diagonal: All zeros (no parameter correlations)")

expected = [0.0, 0.0, 4.76190]
actual_diagonal = [fisher_info[0,0], fisher_info[1,1], fisher_info[2,2]]
matches = all(abs(a - e) < tolerance for a, e in zip(actual_diagonal, expected))

off_diagonal_zero = (abs(fisher_info[0,1]) < tolerance and 
                     abs(fisher_info[0,2]) < tolerance and 
                     abs(fisher_info[1,2]) < tolerance)

print(f"PASS: Test 3 PASSED" if matches and off_diagonal_zero else f"FAIL: Test 3 FAILED")

print("\n" + "=" * 60)
print("BACK AND FORTH X BASIS PROTOCOL TESTING COMPLETE")
print("Key observations:")
print("• Identical behavior to Independent Encoding X basis protocol")
print("• Only root parameter receives Fisher information")
print("• Root parameter receives 1/(θᵣₒₒₜ(1-θᵣₒₒₜ)) information")
print("• All non-root parameters receive zero information")
print("• All off-diagonal elements are zero (no parameter correlations)")
print("• Despite the name, this protocol uses the same simple formula as Protocol 2")
print("=" * 60)

TESTING BACK AND FORTH X BASIS PROTOCOL (Protocol 5)
This protocol behaves identically to Independent Encoding X basis
Expected behavior: Only root receives 1/(θ-θ²) information, others get zero

Test 1: Basic Back and Forth X Test
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 0
  Protocol: 5 (Back and Forth X basis)

Fisher Information Matrix (3x3):
  Diagonal:     [11.11111, 0.00000, 0.00000]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal: [11.11111, 0.00000, 0.00000]
Expected off-diagonal: All zeros (no parameter correlations)
PASS: Test 1 PASSED

Test 2: Root Selection Verification (root=1)
------------------------------
Input Parameters:
  θ = [0.1, 0.2, 0.3]
  Root node: 1 (only θ₁ receives information)
  Protocol: 5 (Back and Forth X basis)

Fisher Information Matrix (3x3):
  Diagonal:     [0.00000, 6.25000, 0.00000]
  Off-diagonal: F₀₁=0.00000, F₀₂=0.00000, F₁₂=0.00000
Expected diagonal: [0.00000, 6.25000, 0.00000]


## Reward Function

In [132]:
def reward_function(fisher_matrix):
    """
    Calculates the reward based on the Fisher Information matrix.
    This function computes a reward value that is inversely proportional to the trace 
    of the inverse Fisher Information matrix. The trace of the inverse Fisher matrix
    represents the sum of the Cramér-Rao lower bounds for parameter estimation,
    so minimizing this trace improves overall parameter estimation precision.
    
    Parameters:
    -----------
    fisher_matrix : numpy.ndarray
        The Fisher Information matrix (3x3) containing information contributions 
        for each parameter and their correlations
    
    Returns:
    --------
    float
        The calculated reward value inversely proportional to tr(F⁻¹).
        Higher Fisher information (lower estimation uncertainty) gives higher reward.
        
    Notes:
    ------
    - The trace of the inverse Fisher matrix equals the sum of variances in optimal estimation
    - Lower trace means better parameter estimation precision
    - Reward = 1 / (trace(F⁻¹) + ε) where ε prevents division by zero
    - If matrix is singular (non-invertible), returns a small penalty reward
    """
    
    try:
        # Compute the inverse of the Fisher Information matrix
        fisher_inverse = np.linalg.inv(fisher_matrix)
        
        # Calculate the trace of the inverse matrix
        trace_inverse = np.trace(fisher_inverse)
        
        # Small epsilon to prevent division by zero and ensure numerical stability
        epsilon = 1e-8
        
        # Reward is inversely proportional to the trace of the inverse
        # Higher Fisher information (lower trace) gives higher reward
        reward = 1.0 / (trace_inverse + epsilon)
        
        
        return np.log(reward + 1)
        
    except np.linalg.LinAlgError:
        # Matrix is singular (non-invertible), create submatrix by removing zero diagonal elements
        # This happens when Fisher matrix has zero determinant (e.g., some parameters have zero info)
        try:
            # Find indices of diagonal elements that are non-zero (above tolerance)
            tolerance = 1e-10
            diagonal = np.diag(fisher_matrix)
            non_zero_indices = np.where(np.abs(diagonal) > tolerance)[0]
            
            if len(non_zero_indices) == 0:
                # All diagonal elements are zero, return penalty
                return -1e-6
            
            # Create submatrix by selecting only rows and columns with non-zero diagonal elements
            submatrix = fisher_matrix[np.ix_(non_zero_indices, non_zero_indices)]
            
            # Compute inverse of the reduced submatrix
            submatrix_inverse = np.linalg.inv(submatrix)
            
            # Calculate trace of the submatrix inverse
            trace_inverse = np.trace(submatrix_inverse)
            
            # Small epsilon to prevent division by zero
            epsilon = 1e-8
            
            # Reward based on reduced submatrix (partial parameter estimation)
            reward = 1.0 / (trace_inverse + epsilon)
            
            return np.log(reward + 1)
            
        except np.linalg.LinAlgError:
            # Even reduced submatrix is singular, return small penalty reward
            return -1e-6


### Test Function

In [134]:
## Testing New Reward Function
print("TESTING REWARD FUNCTION")
print("=" * 50)
print("Testing reward function with Fisher Information matrices")
print("Reward = 1 / (trace(F⁻¹) + ε) - higher Fisher info gives higher reward\n")

# Test 1: Simple diagonal Fisher matrix
print("Test 1: Simple Diagonal Fisher Matrix")
print("-" * 30)
fisher_matrix_1 = np.array([[1.0, 0.0, 0.0],
                           [0.0, 2.0, 0.0], 
                           [0.0, 0.0, 3.0]])

reward_1 = reward_function(fisher_matrix_1)
print(f"Fisher Matrix 1 (diagonal):")
print(f"  {fisher_matrix_1}")
print(f"  Trace of inverse: {np.trace(np.linalg.inv(fisher_matrix_1)):.5f}")
print(f"  Reward: {reward_1:.5f}")
print()

# Test 2: Higher Fisher information matrix (should give higher reward)
print("Test 2: Higher Fisher Information Matrix")
print("-" * 30)
fisher_matrix_2 = np.array([[10.0, 0.0, 0.0],
                           [0.0, 20.0, 0.0], 
                           [0.0, 0.0, 30.0]])

reward_2 = reward_function(fisher_matrix_2)
print(f"Fisher Matrix 2 (higher values):")
print(f"  {fisher_matrix_2}")
print(f"  Trace of inverse: {np.trace(np.linalg.inv(fisher_matrix_2)):.5f}")
print(f"  Reward: {reward_2:.5f}")
print(f"  Reward ratio (2/1): {reward_2/reward_1:.2f} (higher Fisher info → higher reward)")
print()

# Test 3: Real Fisher matrix from quantum protocol
print("Test 3: Real Fisher Matrix from Multicast Protocol")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 0
fisher_matrix_3 = updated_fisher_information(root, protocol, theta)

reward_3 = reward_function(fisher_matrix_3)
print(f"Fisher Matrix 3 (Multicast Protocol):")
print(f"  Diagonal: [{fisher_matrix_3[0,0]:.3f}, {fisher_matrix_3[1,1]:.3f}, {fisher_matrix_3[2,2]:.3f}]")
print(f"  Off-diagonal: F₀₁={fisher_matrix_3[0,1]:.3f}, F₀₂={fisher_matrix_3[0,2]:.3f}, F₁₂={fisher_matrix_3[1,2]:.3f}")
print(f"  Trace of inverse: {np.trace(np.linalg.inv(fisher_matrix_3)):.5f}")
print(f"  Reward: {reward_3:.5f}")
print()

# Test 4: Comparison with Root Independent Protocol
print("Test 4: Root Independent Protocol Comparison")
print("-" * 30)
theta = [0.1, 0.2, 0.3]
root = 0
protocol = 1
fisher_matrix_4 = updated_fisher_information(root, protocol, theta)

reward_4 = reward_function(fisher_matrix_4)
print(f"Fisher Matrix 4 (Root Independent Protocol):")
print(f"  Diagonal: [{fisher_matrix_4[0,0]:.3f}, {fisher_matrix_4[1,1]:.3f}, {fisher_matrix_4[2,2]:.3f}]")
#print(f"  Trace of inverse: {np.trace(np.linalg.inv(fisher_matrix_4)):.5f}")
print(f"  Reward: {reward_4:.5f}")
print(f"  Reward comparison (Multicast vs Root Independent): {reward_3:.3f} vs {reward_4:.3f}")

# Test 5: Comprehensive Protocol Testing with Random Theta Values
print("\nTest 5: All Protocols with Random Theta Values")
print("-" * 50)
print("Testing all 6 quantum protocols with random θ values in [0.05, 0.4]")

# Set random seed for reproducible results
np.random.seed(42)

# Generate random theta values in the specified range
theta_random = np.random.uniform(0.05, 0.4, 3)
print(f"Random θ values: [{theta_random[0]:.3f}, {theta_random[1]:.3f}, {theta_random[2]:.3f}]")
print()

# Protocol names for reference
protocol_names = [
    "Multicast", 
    "Root Independent", 
    "Independent Encoding X basis",
    "Independent Encoding Z basis", 
    "Back and Forth Z basis",
    "Back and Forth X basis"
]

# Test all protocols with root = 0
root = 0
rewards = []

for protocol in range(6):
    print(f"Protocol {protocol}: {protocol_names[protocol]}")
    
    # Get Fisher matrix for this protocol
    fisher_matrix = updated_fisher_information(root, protocol, theta_random)
    
    # Calculate reward
    reward = reward_function(fisher_matrix)
    rewards.append(reward)
    
    # Display results
    diagonal = [fisher_matrix[i,i] for i in range(3)]
    print(f"  Diagonal: [{diagonal[0]:.3f}, {diagonal[1]:.3f}, {diagonal[2]:.3f}]")
    
    # Check if matrix has off-diagonal elements
    has_off_diagonal = (abs(fisher_matrix[0,1]) > 1e-8 or 
                       abs(fisher_matrix[0,2]) > 1e-8 or 
                       abs(fisher_matrix[1,2]) > 1e-8)
    
    if has_off_diagonal:
        print(f"  Off-diagonal: F₀₁={fisher_matrix[0,1]:.3f}, F₀₂={fisher_matrix[0,2]:.3f}, F₁₂={fisher_matrix[1,2]:.3f}")
    else:
        print(f"  Off-diagonal: All zeros (diagonal protocol)")
    
    print(f"  Reward: {reward:.5f}")
    
    # Calculate and display QCRB (Quantum Cramér-Rao Bound)
    # QCRB = trace of inverse Fisher matrix = sum of parameter estimation variances
    try:
        fisher_inverse = np.linalg.inv(fisher_matrix)
        qcrb_value = np.trace(fisher_inverse)
        print(f"  QCRB: {qcrb_value:.5f} (trace of F⁻¹)")
        matrix_status = "Full 3x3 invertible"
    except np.linalg.LinAlgError:
        # Matrix is singular, calculate QCRB using submatrix approach
        tolerance = 1e-10
        diagonal = np.diag(fisher_matrix)
        non_zero_indices = np.where(np.abs(diagonal) > tolerance)[0]
        diagonal_nonzero = len(non_zero_indices)
        
        if diagonal_nonzero > 0:
            # Use submatrix to calculate QCRB
            submatrix = fisher_matrix[np.ix_(non_zero_indices, non_zero_indices)]
            try:
                submatrix_inverse = np.linalg.inv(submatrix)
                qcrb_value = np.trace(submatrix_inverse)
                print(f"  QCRB: {qcrb_value:.5f} (trace of {diagonal_nonzero}x{diagonal_nonzero} submatrix F⁻¹)")
                matrix_status = f"Singular, using {diagonal_nonzero}x{diagonal_nonzero} submatrix"
            except np.linalg.LinAlgError:
                print(f"  QCRB: N/A (submatrix also singular)")
                matrix_status = f"Singular, {diagonal_nonzero}x{diagonal_nonzero} submatrix also singular"
        else:
            print(f"  QCRB: N/A (all diagonal elements zero)")
            matrix_status = "Completely singular (all zeros)"
    
    print(f"  Matrix: {matrix_status}")
    
    print()

# Summary comparison
print("=" * 50)
print("PROTOCOL REWARD COMPARISON:")
max_reward_idx = np.argmax(rewards)
min_reward_idx = np.argmin(rewards)

for i, (protocol_name, reward) in enumerate(zip(protocol_names, rewards)):
    marker = " ★" if i == max_reward_idx else " ☆" if i == min_reward_idx else ""
    print(f"  {i}: {protocol_name:25} → {reward:.5f}{marker}")

print(f"\nBest Protocol: {protocol_names[max_reward_idx]} (Reward: {rewards[max_reward_idx]:.5f})")
print(f"Worst Protocol: {protocol_names[min_reward_idx]} (Reward: {rewards[min_reward_idx]:.5f})")
print(f"Reward Range: {max(rewards)/min(rewards):.2f}x difference")

print("\n" + "=" * 50)
print("REWARD FUNCTION TESTING COMPLETE")
print("Key observations:")
print("• Higher Fisher information matrices give higher rewards")
print("• Reward is inversely proportional to trace(F⁻¹)")
print("• Lower estimation uncertainty → higher reward")
print("• Function handles both diagonal and non-diagonal matrices")
print("• Dynamic submatrix selection works for singular matrices")
print("• All 6 quantum protocols tested successfully")
print("=" * 50)

TESTING REWARD FUNCTION
Testing reward function with Fisher Information matrices
Reward = 1 / (trace(F⁻¹) + ε) - higher Fisher info gives higher reward

Test 1: Simple Diagonal Fisher Matrix
------------------------------
Fisher Matrix 1 (diagonal):
  [[1. 0. 0.]
 [0. 2. 0.]
 [0. 0. 3.]]
  Trace of inverse: 1.83333
  Reward: 0.43532

Test 2: Higher Fisher Information Matrix
------------------------------
Fisher Matrix 2 (higher values):
  [[10.  0.  0.]
 [ 0. 20.  0.]
 [ 0.  0. 30.]]
  Trace of inverse: 0.18333
  Reward: 1.86478
  Reward ratio (2/1): 4.28 (higher Fisher info → higher reward)

Test 3: Real Fisher Matrix from Multicast Protocol
------------------------------
Fisher Matrix 3 (Multicast Protocol):
  Diagonal: [2.873, 3.643, 3.249]
  Off-diagonal: F₀₁=1.984, F₀₂=0.770, F₁₂=-0.003
  Trace of inverse: 1.42327
  Reward: 0.53216

Test 4: Root Independent Protocol Comparison
------------------------------
Fisher Matrix 4 (Root Independent Protocol):
  Diagonal: [0.000, 6.250, 4.