# Reinforcement Learning Functions

In this notebook the step, reset, and reward functions will be defined and tested. Every function created in this notebook will be in `main.py` and the documentation can be seen there.

## Update Fisher Information 

This function will update the fisher information of each parameter depending on the root choice and protocol choice.

| Protocol | Numeric Value |
|----------|----------|
| Multicast   | 0  |
| Root Independent    | 1  |
| Independent Encoding X basis   | 2  |
| Independent Encoding Z basis   | 3  |
| Back and Forth Z basis    | 4  |
| Back and Forth X basis   | 5  |

In [5]:
def updated_fisher_information(root,protocol,fisher_info,theta):

    # ----- Multi cast Protocol -----
    if protocol == 0:
        theta_0 = theta[root]
        theta_1 = theta[(root + 1) % len(theta)]
        theta_2 = theta[(root + 2) % len(theta)]

        # Root channel information
        fisher_info[root] += (
        (theta_1 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_1 + theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_1 - theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_1 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        # Other Channel information
        fisher_info[(root + 1) % len(fisher_info)] += (
        (theta_0 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_0 - theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        fisher_info[(root + 2) % len(fisher_info)] += (
        (1 - theta_0 - theta_1)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (theta_0 - theta_1)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_1)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_1)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )


    # ----- Root Independent Protocol -----
    # In this protocol, the parameter nearest to the root does not receive any information. The other parameters receive information based on their own value.
    elif protocol == 1: # Root Independent
        for i in range(len(fisher_info)): # Loop through all parameters
            if i != root: 
                fisher_info[i]+=1/(theta[i]-theta[i]**2)

    # ----- Independent Encoding Protocol X basis -----
    # In this protocol, the root parameter receives information based on its own value, while the other parameters dont receive an information.
    elif protocol == 2: # Independent Encoding X basis
        fisher_info[root]+= 1/(theta[root]-theta[root]**2) # The root parameter receives information based on its own value.

    # ----- Independent Encoding Protocol Z basis -----
    # In this protocol, all parameters receive information based on their own value. Very similarly to the Root Independent Protocol, but without the root parameter.
    elif protocol == 3: # Independent Encoding Z basis
        for i in range(len(fisher_info)): # Loop through all parameters
            fisher_info[i]+=1/(theta[i]-theta[i]**2)

    # ----- Back and Forth Protocol Z basis -----
    elif protocol == 4: 
        theta_0 = theta[root]
        theta_1 = theta[(root + 1) % len(theta)]
        theta_2 = theta[(root + 2) % len(theta)]

        # Root channel information
        fisher_info[root] += (
        -((theta_1 - theta_2)**2 / 
          (-theta_0 * theta_2 + theta_1 * (-1 + theta_0 + theta_2))) 

        - (theta_1 - theta_2)**2 / 
          ((-1 + theta_1) * theta_2 + theta_0 * (-theta_1 + theta_2)) 

        + (-1 + theta_1 + theta_2)**2 / 
          (theta_1 * theta_2 - theta_0 * (-1 + theta_1 + theta_2)) 

        + (-1 + theta_1 + theta_2)**2 / 
          ((-1 + theta_1) * (-1 + theta_2) + theta_0 * (-1 + theta_1 + theta_2))
        ) 

        # Other channel information
        fisher_info[(root + 1) % len(fisher_info)] += (
        (theta_0 - theta_2)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 

        (1 - theta_0 - theta_2)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 

        (-theta_0 + theta_2)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 

        (-1 + theta_0 + theta_2)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

        fisher_info[(root + 2) % len(fisher_info)] += (
        (1 - theta_0 - theta_1)**2 / 
        (theta_0 * theta_1 + theta_2 - theta_0 * theta_2 - theta_1 * theta_2) + 
    
        (theta_0 - theta_1)**2 / 
        (theta_1 - theta_0 * theta_1 + theta_0 * theta_2 - theta_1 * theta_2) + 
    
        (-theta_0 + theta_1)**2 / 
        (theta_0 - theta_0 * theta_1 - theta_0 * theta_2 + theta_1 * theta_2) + 
    
        (-1 + theta_0 + theta_1)**2 / 
        (1 - theta_0 - theta_1 + theta_0 * theta_1 - theta_2 + theta_0 * theta_2 + theta_1 * theta_2)
        )

    # ----- Back and Forth Protocol X basis -----
    # In this protocol, the root parameter receives information based on its own value, while the other parameters dont receive an information. This is exactly the same as the Independent Encoding Protocol X basis.
    elif protocol == 5:
        fisher_info[root]+= 1/(theta[root]-theta[root]**2) # The root parameter receives information based on its own value.

    return fisher_info

### Testing Update Fisher IInformation

In [18]:
## Testing Multicast Protocol
print("Testing updated_fisher_information function with different protocols...")
print("\n")
print("Testing Multicast Protocol...")
# Test 1
theta = [1, 2, 3]
fisher_info = [0, 0, 0]
root = 0
protocol = 0

fisher_info = updated_fisher_information(root, protocol, fisher_info, theta)

print(f"Test 1: {fisher_info}")  # Output: [121/12, -1/2,-1/6]
print("Should be approximately: [10.083333333333334, -0.5, -0.16666666666666666]")

# Test 2
theta = [1, 2, 3]
fisher_info = [0, 0, 0]
root = 1
protocol = 0

fisher_info = updated_fisher_information(root, protocol, fisher_info, theta)

print(f"Test 2: {fisher_info}")  # Output: [121/12, -1/2,-1/6]
print("Should be approximately: [10.083333333333334, -0.5, -0.16666666666666666]")

print("\n")
print("Testing Root Independent Protocol...")
# Test 3
theta = [1, 2, 3]
fisher_info = [0, 0, 0]
root = 0
protocol = 1
fisher_info = updated_fisher_information(root, protocol, fisher_info, theta)
print(f"Test 3: {fisher_info}")  # Output: [0, -1/2, -1/6]
print("Should be approximately: [0.0, -0.5, -0.16666666666666666]")

# Test 4
theta = [1, 2, 3]
fisher_info = [0, 0, 0]
root = 1
protocol = 1
fisher_info = updated_fisher_information(root, protocol, fisher_info, theta)
print(f"Test 4: {fisher_info}")  # Output: [0, -1/2, -1/6]
print("Should be approximately: [0.0, -0.5, -0.16666666666666666]")


Testing updated_fisher_information function with different protocols...


Testing Multicast Protocol...
Test 1: [10.083333333333334, -0.5, -0.16666666666666663]
Should be approximately: [10.083333333333334, -0.5, -0.16666666666666666]
Test 2: [10.083333333333334, -0.5, -0.16666666666666663]
Should be approximately: [10.083333333333334, -0.5, -0.16666666666666666]


Testing Root Independent Protocol...
Test 3: [0, -0.5, -0.16666666666666666]
Should be approximately: [0.0, -0.5, -0.16666666666666666]


ZeroDivisionError: division by zero