# Reinforcement Learning Functions

In this notebook the step, reset, and reward functions will be defined and tested. Every function created in this notebook will be in `main.py` and the documentation can be seen there.

## Update Fisher Information 

This function will update the fisher information of each parameter depending on the root choice and protocol choice.

| Protocol | Numeric Value |
|----------|----------|
| Multicast   | 0  |
| Root Independent    | 1  |
| Independent Encoding X basis   | 2  |
| Independent Encoding Z basis   | 3  |
| Back and Forth Z basis    | 4  |
| Back and Forth X basis   | 5  |

In [67]:
def updated_fisher_information(root, protocol, theta):
    """
    Calculates the Fisher Information matrix for quantum network parameters based on the chosen protocol.
    
    This function computes Fisher Information contributions for each parameter depending on 
    the quantum measurement protocol being used in a 3-node quantum flip star network.
    
    Parameters:
    -----------
    root : int
        The index of the root node/channel (0, 1, or 2) that serves as the central measurement point
    protocol : int
        The quantum protocol to use:
        - 0: Multicast Protocol - All channels receive complex Fisher information updates
        - 1: Root Independent Protocol - All channels except root receive 1/(θ-θ²) information
        - 2: Independent Encoding X basis - Only root receives 1/(θ-θ²) information  
        - 3: Independent Encoding Z basis - All channels receive 1/(θ-θ²) information
        - 4: Back and Forth Z basis - Complex Fisher information updates for all channels
        - 5: Back and Forth X basis - Only root receives 1/(θ-θ²) information
    theta : list
        Quantum channel error parameters [θ₀, θ₁, θ₂] representing flip probabilities
        Must be in range (0,1) and not equal to 0 or 1 to avoid division by zero
        
    Returns:
    --------
    list
        New Fisher Information matrix [F₀, F₁, F₂] containing the calculated contributions
        for each parameter based on the specified protocol
        
    Notes:
    ------
    - Fisher Information quantifies how much information a measurement provides about parameters
    - Higher Fisher Information indicates better parameter estimation capability
    - The function uses modular arithmetic to handle circular indexing of the 3-node network
    - Protocols 0 and 4 use complex multivariate Fisher information formulas derived from 
      quantum network tomography theory
    - Protocols 1, 2, 3, 5 use the simpler univariate formula 1/(θᵢ(1-θᵢ))
    - The function initializes a new Fisher Information list with zeros for each call
    
    Raises:
    -------
    ZeroDivisionError
        If any theta value equals 0 or 1 in protocols that use 1/(θ-θ²) which is not physically possible.
    """
  
    # Initialize Fisher Information list if not provided
    fisher_info = [0.0] * len(theta) 
    # ----- Multi cast Protocol -----
    if protocol == 0:

        theta_0 = theta[root]
        theta_1 = theta[(root + 1) % len(theta)]
        theta_2 = theta[(root + 2) % len(theta)]

        #print(f"Root: {root}, Theta values: {theta_0}, {theta_1}, {theta_2}")

        # Root channel information
        fisher_info[root] += (
        (theta_1 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_1 + theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_1 - theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_1 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )
        #print(f"Fisher info for root {root}: {fisher_info[root]}")

        # Other Channel information
        fisher_info[(root + 1) % len(fisher_info)] += (
        (theta_0 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_0 - theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        fisher_info[(root + 2) % len(fisher_info)] += (
        (1 - theta_0 - theta_1)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (theta_0 - theta_1)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_1)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_1)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )


    # ----- Root Independent Protocol -----
    # In this protocol, the parameter nearest to the root does not receive any information. The other parameters receive information based on their own value.
    elif protocol == 1: # Root Independent
        for i in range(len(fisher_info)): # Loop through all parameters
            if i != root: 
                fisher_info[i]+=1/(theta[i]-theta[i]**2)

    # ----- Independent Encoding Protocol X basis -----
    # In this protocol, the root parameter receives information based on its own value, while the other parameters dont receive an information.
    elif protocol == 2: # Independent Encoding X basis
        fisher_info[root]+= 1/(theta[root]-theta[root]**2) # The root parameter receives information based on its own value.

    # ----- Independent Encoding Protocol Z basis -----
    # In this protocol, all parameters receive information based on their own value. Very similarly to the Root Independent Protocol, but without the root parameter.
    elif protocol == 3: # Independent Encoding Z basis
        for i in range(len(fisher_info)): # Loop through all parameters
            fisher_info[i]+=1/(theta[i]-theta[i]**2)

    # ----- Back and Forth Protocol Z basis -----
    elif protocol == 4: 
        theta_0 = theta[root]
        theta_1 = theta[(root + 1) % len(theta)]
        theta_2 = theta[(root + 2) % len(theta)]

        # Root channel information
        fisher_info[root] += (
        -((theta_1 - theta_2)**2 / 
          (-theta_0 * theta_2 + theta_1 * (-1 + theta_0 + theta_2)))    

        - (theta_1 - theta_2)**2 / 
          ((-1 + theta_1) * theta_2 + theta_0 * (-theta_1 + theta_2)) 

        + (-1 + theta_1 + theta_2)**2 / 
          (theta_1 * theta_2 - theta_0 * (-1 + theta_1 + theta_2)) 

        + (-1 + theta_1 + theta_2)**2 / 
          ((-1 + theta_1) * (-1 + theta_2) + theta_0 * (-1 + theta_1 + theta_2))
        ) 

        # Other channel information
        fisher_info[(root + 1) % len(fisher_info)] += (
        (theta_0 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_0 - theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        fisher_info[(root + 2) % len(fisher_info)] += (
        (1 - theta_0 - theta_1)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 
    
        (theta_0 - theta_1)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 
    
        (-theta_0 + theta_1)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 
    
        (-1 + theta_0 + theta_1)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

    # ----- Back and Forth Protocol X basis -----
    # In this protocol, the root parameter receives information based on its own value, while the other parameters dont receive an information. This is exactly the same as the Independent Encoding Protocol X basis.
    elif protocol == 5:
        fisher_info[root]+= 1/(theta[root]-theta[root]**2) # The root parameter receives information based on its own value.

    return fisher_info

### Testing Update Fisher IInformation

#### Multicast Protocol

In [61]:
## Testing Multicast Protocol
print("Testing updated_fisher_information function with different protocols...")
print("\n")
print("Testing Multicast Protocol...")
print("----------------------")

# Test 1
# General testing of the multicast protocol
theta = [.1, .2, .3]
root = 0
protocol = 0

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Test 1: {fisher_info}")  # Output: [2.87307, 3.64343,3.24879]
print("Should be approximately: [2.87307, 3.64343,3.24879]")
print("\n")

# Test 2
# Testing the indexing and modular arithmetic of the multicast protocol
theta = [.2, .1, .1]
root = 0
protocol = 0
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 2: {fisher_info}")  # Output: [4.74932, 6.36823,6.36823]
print("Should be approximately: [4.74932, 6.36823,6.36823]")
print("\n")

# Test 3
# Testing the indexing and modular arithmetic of the multicast protocol with different root value
theta = [.2, .1, .1]
root = 1
protocol = 0
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 3: {fisher_info}")  # Output: [6.36823, 4.74932,6.36823]
print("Should be approximately: [4.74932, 6.36823,6.36823]")
print("\n")

# This has been verified with the original equations and it shows that the multicast protocol is root independent.
# The output should be the same regardless of the root value, as the multicast protocol is designed
# to provide the same information regardless of which node is the root.



Testing updated_fisher_information function with different protocols...


Testing Multicast Protocol...
----------------------
Test 1: [2.8730682786948774, 3.643431760055802, 3.2487948539099425]
Should be approximately: [2.87307, 3.64343,3.24879]


Test 2: [4.749321266968327, 6.368225238813475, 6.368225238813475]
Should be approximately: [4.74932, 6.36823,6.36823]


Test 3: [4.749321266968327, 6.368225238813474, 6.368225238813473]
Should be approximately: [4.74932, 6.36823,6.36823]




#### Root Independent Protocol

In [62]:
## Testing Root Independent Protocol
print("Testing Root Independent Protocol...")
print("----------------------")

# Test 1
# General testing of the multicast protocol
theta = [.1, .2, .3]
root = 0
protocol = 1  # Root Independent Protocol

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Test 1: {fisher_info}") 
print("Should be approximately: [0, 6.25,4.7619]")
print("\n")

# Test 2
# Testing the indexing and modular arithmetic of the root independent protocol
theta=[.1,.2,.3]
root = 1
protocol = 1  # Root Independent Protocol

fisher_info = updated_fisher_information(root, protocol, theta)

print(f"Test 2: {fisher_info}") 
print("Should be approximately: [11.1111, 0,4.7619]")
print("\n")

# Test 3
# Testing changing the root again
theta=[.1,.2,.3]
root = 2
protocol = 1  # Root Independent Protocol
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 3: {fisher_info}")
print("Should be approximately: [11.1111,6.25,0]")
print("\n")


Testing Root Independent Protocol...
----------------------
Test 1: [0.0, 6.25, 4.761904761904762]
Should be approximately: [0, 6.25,4.7619]


Test 2: [11.11111111111111, 0.0, 4.761904761904762]
Should be approximately: [11.1111, 0,4.7619]


Test 3: [11.11111111111111, 6.25, 0.0]
Should be approximately: [11.1111,6.25,0]




#### Independent Encoding X basis

In [63]:
## Testing Independent Encoding X basis Protocol
print("Testing Independent Encoding X basis Protocol...")
print("----------------------")
# Test 1
# General testing of the independent encoding X basis protocol
theta = [.1, .2, .3]
root = 0
protocol = 2  # Independent Encoding X basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 1: {fisher_info}")
print("Should be approximately: [11.1111,0,0]")
print("\n")

# Test 2
# Testing the indexing and modular arithmetic of the independent encoding X basis protocol
theta = [.1, .2, .3]
root = 1
protocol = 2  # Independent Encoding X basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 2: {fisher_info}")
print("Should be approximately: [0,6.25,0]")
print("\n")

# Test 3
# Testing changing the root again
theta = [.1, .2, .3]
root = 2
protocol = 2  # Independent Encoding X basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 3: {fisher_info}")
print("Should be approximately: [0,0,4.7619]")
print("\n")



Testing Independent Encoding X basis Protocol...
----------------------
Test 1: [11.11111111111111, 0.0, 0.0]
Should be approximately: [11.1111,0,0]


Test 2: [0.0, 6.25, 0.0]
Should be approximately: [0,6.25,0]


Test 3: [0.0, 0.0, 4.761904761904762]
Should be approximately: [0,0,4.7619]




#### Independent Encoding Z basis

In [64]:
## Testing Independent Encoding Z basis Protocol
print("Testing Independent Encoding Z basis Protocol...")
print("----------------------")

# Test 1
# General testing of the independent encoding Z basis protocol
theta = [.1, .2, .3]
root = 0
protocol = 3  # Independent Encoding Z basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 1: {fisher_info}")
print("Should be approximately: [11.1111, 6.25,4.7619]")
print("\n")

# Test 2
# Testing the indexing and modular arithmetic of the independent encoding Z basis protocol
theta = [.1, .2, .3]
root = 1
protocol = 3  # Independent Encoding Z basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 2: {fisher_info}")
print("Should be approximately: [11.1111, 6.25,4.7619]")
print("\n")

# Test 3
# Testing changing the root again
theta = [.1, .2, .3]
root = 2
protocol = 3  # Independent Encoding Z basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 3: {fisher_info}")
print("Should be approximately: [11.1111, 6.25,4.7619]")
print("\n")


Testing Independent Encoding Z basis Protocol...
----------------------
Test 1: [11.11111111111111, 6.25, 4.761904761904762]
Should be approximately: [11.1111, 6.25,4.7619]


Test 2: [11.11111111111111, 6.25, 4.761904761904762]
Should be approximately: [11.1111, 6.25,4.7619]


Test 3: [11.11111111111111, 6.25, 4.761904761904762]
Should be approximately: [11.1111, 6.25,4.7619]




#### Back and Forth Z basis

In [65]:
## Testing Back and Forth Z basis Protocol
print("Testing Back and Forth Z basis Protocol...")
print("----------------------")
# Test 1
# General testing of the back and forth Z basis protocol
theta = [.1, .2, .3]
root = 0
protocol = 4  # Back and Forth Z basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 1: {fisher_info}")
print("Should be approximately: [2.87307, 3.64343,3.24879]")
print("\n")

# Test 2
# Testing the indexing and modular arithmetic of the back and forth Z basis protocol
theta = [.1, .2, .3]
root = 1
protocol = 4  # Back and Forth Z basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 2: {fisher_info}")
print("Should be approximately: [2.87307, 3.64343,3.24879]")
print("\n")



Testing Back and Forth Z basis Protocol...
----------------------
Test 1: [2.8730682786948774, 3.643431760055802, 3.2487948539099425]
Should be approximately: [2.87307, 3.64343,3.24879]


Test 2: [2.873068278694877, 3.643431760055801, 3.248794853909944]
Should be approximately: [2.87307, 3.64343,3.24879]




#### Back and Forth X basis

In [66]:
## Back and Forth X basis Protocol
print("Testing Back and Forth X basis Protocol...")
print("----------------------")
# Test 1
# General testing of the back and forth X basis protocol
theta = [.1, .2, .3]
root = 0
protocol = 5  # Back and Forth X basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 1: {fisher_info}")
print("Should be approximately: [11.1111,0,0]")
print("\n")

# Test 2
# Testing the indexing and modular arithmetic of the back and forth X basis protocol
theta = [.1, .2, .3]
root = 1
protocol = 5  # Back and Forth X basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 2: {fisher_info}")
print("Should be approximately: [0,6.25,0]")
print("\n")

# Test 3
# Testing changing the root again
theta = [.1, .2, .3]
root = 2
protocol = 5  # Back and Forth X basis
fisher_info = updated_fisher_information(root, protocol, theta)
print(f"Test 3: {fisher_info}")
print("Should be approximately: [0,0,4.7619]")
print("\n")

Testing Back and Forth X basis Protocol...
----------------------
Test 1: [11.11111111111111, 0.0, 0.0]
Should be approximately: [11.1111,0,0]


Test 2: [0.0, 6.25, 0.0]
Should be approximately: [0,6.25,0]


Test 3: [0.0, 0.0, 4.761904761904762]
Should be approximately: [0,0,4.7619]




## Reward Function

In [85]:
def reward_function(old_fisher_info,new_fisher_info):
    """
    Calculates the reward based on the Fisher Information and theta values.
    This function computes a reward value based on the Fisher Information matrix diagonal values
    from the previous and current states. This reward function is designed to encourage
    exploration of quantum network parameters by balancing the mean Fisher Information
    against the variance of the Fisher Information contributions.
    
    Parameters:
    -----------
    fisher_info : list
        The Fisher Information matrix diagonal values[ F₀, F₁, F₂] containing contributions for each parameter
    
    Returns:
    --------
    float
        The calculated reward value based on the Fisher Information.
    """

    p= .05  # Example penalty factor, adjust as needed

    mean_info=sum(new_fisher_info)/len(new_fisher_info)

    
    # Element-wise addition of the two lists
    combined_info = [old + new for old, new in zip(old_fisher_info, new_fisher_info)]
    total_mean = sum(combined_info) / len(combined_info)
    
    # Calculate the variance among the values in the fisher information matrix diagonal
    variance = sum((x - total_mean) ** 2 for x in combined_info) / len(combined_info)

    # Add your reward calculation logic here
    # For example:
    reward = mean_info - p * variance  # Just an example
    
    return reward


### Test Function

In [86]:
old_fisher_info = [0.0, 0.0, 0.0]
new_fisher_info = [1.0, 2.0, 3.0]
reward = reward_function(old_fisher_info, new_fisher_info)
print(f"Reward: {reward}")  # Output: Reward based on the Fisher Information and theta values


old_fisher_info = [200.0, 200.0, 205.0]
new_fisher_info = [1.0, 2.0, 3.0]
reward = reward_function(old_fisher_info, new_fisher_info)
print(f"Reward: {reward}")  # Output: Reward based on the Fisher Information and

Reward: 1.9666666666666666
Reward: 1.5222222222222221
