# Qubit Mover 1

## Import Libraries

In [2]:
# Regular Python Libraries
import numpy as np

# Quantum Libraries
import qutip as qt
from qutip import Qobj

# Qutip quantum circuit libraries
from qutip_qip.circuit import QubitCircuit
from qutip_qip.operations import (Gate,cnot,rx,ry,rz,snot,phasegate)


## Create Helper Functions

### Initialize Qubit, Gate, and Error Objects

#### Initialize Qubit Objects

In [4]:
# Quantum Gates
# ===========================
CNOT=cnot().copy() # 2 qubit CNOT gate

# Define gates manually to ensure they have correct global phase.
X = Qobj([[0, 1], [1, 0]])
Y = Qobj([[0, -1j], [1j, 0]])
Z = Qobj([[1, 0], [0, -1]])
H = Qobj([[1, 1], [1, -1]]) / np.sqrt(2)
S = Qobj([[1, 0], [0, 1j]])
I= Qobj([[1, 0], [0, 1]])

S_DAG = S.dag().copy() # S dagger gate

# Quantum States
# ===========================
# Standard basis states

# Z basis
RHO_ZERO=qt.basis(2,0).proj().copy() # |0><0|
RHO_ONE=qt.basis(2,1).proj().copy()  # |1><1|

# X basis
RHO_PLUS=(qt.basis(2,0)+qt.basis(2,1)).unit().proj().copy()  # |+><+|
RHO_MINUS=(qt.basis(2,0)-qt.basis(2,1)).unit().proj().copy() # |-><-|

# Y basis
RHO_R=(qt.basis(2,0)+1j*qt.basis(2,1)).unit().proj().copy()  # |R><R|
RHO_L=(qt.basis(2,0)-1j*qt.basis(2,1)).unit().proj().copy() # |L><L|

def get_all_gates()->dict[str: Qobj]:
    """ RETURNS:
        dict: A dict of all available gates for the circuit builder agent. The dict
        has the following keys: 'I', 'X', 'Y', 'Z', 'H', 'S', 'S_DAG', 'CNOT'.
    """
    return {
        'I': I.copy(),
        'X': X.copy(),
        'Y': Y.copy(),
        'Z': Z.copy(),
        'H': H.copy(),
        'S': S.copy(),
        'S_DAG': S_DAG.copy(),
        'CNOT': CNOT.copy()
    }

def get_all_basis_states()->dict[str: Qobj]:
    """ RETURNS:
        dict: A dict of all available basis states for the circuit builder agent. The dict
        has the following keys: 'RHO_ZERO', 'RHO_ONE', 'RHO_PLUS', 'RHO_MINUS', 'RHO_R', 'RHO_L'.
    """
    return {
        'RHO_ZERO': RHO_ZERO.copy(),
        'RHO_ONE': RHO_ONE.copy(),
        'RHO_PLUS': RHO_PLUS.copy(),
        'RHO_MINUS': RHO_MINUS.copy(),
        'RHO_R': RHO_R.copy(),
        'RHO_L': RHO_L.copy()
    }


def bit_flip_error(rho: Qobj ,p: float)-> Qobj:
    """
        Takes in a density matrix QOBJ and applies a bit flip error with probability p.
        The function then returns the resulting mixed state density matrix as a density matrix QOBJ.

        PARAMETERS:
        -----------
        rho: Qobj
            The input density matrix to apply the bit flip error to.
        p: float
            The probability of a bit flip error occurring (0 <= p <= 1).
    """

    x:Qobj =get_all_gates()['X']

    rho_error = (p)* x * rho * x + (1 - p) * rho

    return rho_error

### Calculate the QFIM 

Since we are only dealing with diagonal density matrices (Classical states), we can use the simplified formula for the QFIM:
$$F_{ij} = \sum_k \frac{1}{\lambda_k}\frac{\partial \lambda_k}{\partial \theta_i}\frac{\partial \lambda_k}{\partial \theta_j}$$

Where $\lambda_k$ are the eigenvalues of the density matrix and $\theta_i$ are the parameters we are estimating. In our case, the parameters are the bit flip probabilities at each node.

Since this formula requires differentiation, we will use finite differences to approximate the derivatives. This means that in addition to our density matrix representing the state, we will also need 6 additional density matrices representing small perturbations in each of the 3 parameters (positive and negative) i.e.

$$ \rho (\theta) \approx \left \{\rho (\theta+h),\rho (\theta-h)   \right \} $$

#### Create Perturbed Density Matrix Object

In [None]:
class PQOBJ: 
    """
    PQOBJ 
    ------------------

    A class used to represent a perturbed density matrix object. This class holds a central density matrix and
    six perturbed density matrices corresponding to small perturbations in three parameters (positive and negative).
    """

    def __init__(self, central: Qobj, perturbed: list[Qobj] =[],h=1e-6):
        """
        Initializes the PQOBJ with a central density matrix and a list of six perturbed density matrices.

        PARAMETERS:
        -----------
        central: Qobj
            The central density matrix.
        perturbed: list[Qobj]
            A list of six perturbed density matrices in the order:
            [rho(theta1 + h), rho(theta1 - h), rho(theta2 + h), rho(theta2 - h), rho(theta3 + h), rho(theta3 - h)]
        h: float
            The perturbation size (default is 1e-6).
        """
        self._h=h
        self._base = central # Central density matrix        
        if len(perturbed)==0:
            self._1_plus = central
            self._1_minus = central
            self._2_plus = central
            self._2_minus = central
            self._3_plus = central
            self._3_minus = central
        else:
            self._1_plus = perturbed[0]
            self._1_minus = perturbed[1]
            self._2_plus = perturbed[2]
            self._2_minus = perturbed[3]
            self._3_plus = perturbed[4]
            self._3_minus = perturbed[5]

    # Define Getters
    # ----------------------------
    @property
    def h(self) -> float:
        """ Returns the perturbation size. """
        return self._h

    @property
    def base(self) -> Qobj:
        """ Returns the central density matrix. """
        return self._base
    
    @property
    def one_plus(self) -> Qobj:
        """ Returns the density matrix perturbed positively in parameter 1. """
        return self._1_plus
    
    @property
    def one_minus(self) -> Qobj:
        """ Returns the density matrix perturbed negatively in parameter 1. """
        return self._1_minus
    
    @property
    def two_plus(self) -> Qobj:
        """ Returns the density matrix perturbed positively in parameter 2. """
        return self._2_plus
    
    @property
    def two_minus(self) -> Qobj:
        """ Returns the density matrix perturbed negatively in parameter 2. """
        return self._2_minus
    
    @property
    def three_plus(self) -> Qobj:
        """ Returns the density matrix perturbed positively in parameter 3. """
        return self._3_plus
    
    @property
    def three_minus(self) -> Qobj:
        """ Returns the density matrix perturbed negatively in parameter 3. """
        return self._3_minus
    
    # Define Setters
    # ----------------------------
    @h.setter
    def h(self, value: float):
        self._h = value

    @base.setter
    def base(self, value: Qobj):
        self._base = value

    @one_plus.setter
    def one_plus(self, value: Qobj):
        self._1_plus = value

    @one_minus.setter
    def one_minus(self, value: Qobj):
        self._1_minus = value

    @two_plus.setter    
    def two_plus(self, value: Qobj):
        self._2_plus = value

    @two_minus.setter    
    def two_minus(self, value: Qobj):
        self._2_minus = value

    @three_plus.setter    
    def three_plus(self, value: Qobj):
        self._3_plus = value

    @three_minus.setter    
    def three_minus(self, value: Qobj):
        self._3_minus = value



    # Create an apply error method

    def apply_error(self, params:list[float], channel:float):
        """
        apply_error
        ------------------
        Applies a quantum error channel to all density matrices in the PQOBJ.
        PARAMETERS:
        -----------
        params: list[float]
            The error probability for each channel.
        channel: float
            The quantum error channel function to apply (1,2, or 3).
        """

        p = params[channel-1] # Get the probability of error for the specified channel

        self._base = bit_flip_error(self._base, p) # Normal bit flip error for the base

        # Perturbed density matrices
        if channel==1:
            self._1_plus = bit_flip_error(self._1_plus, p+self._h)
            self._1_minus = bit_flip_error(self._1_minus, p-self._h)

            self._2_plus = bit_flip_error(self._2_plus, p)
            self._2_minus = bit_flip_error(self._2_minus, p)
            self._3_plus = bit_flip_error(self._3_plus, p)
            self._3_minus = bit_flip_error(self._3_minus, p)

        elif channel==2:
            self._2_plus = bit_flip_error(self._2_plus, p+self._h)
            self._2_minus = bit_flip_error(self._2_minus, p-self._h)

            self._1_plus = bit_flip_error(self._1_plus, p)
            self._1_minus = bit_flip_error(self._1_minus, p)
            self._3_plus = bit_flip_error(self._3_plus, p)
            self._3_minus = bit_flip_error(self._3_minus, p)

        elif channel==3:
            self._3_plus = bit_flip_error(self._3_plus, p+self._h)
            self._3_minus = bit_flip_error(self._3_minus, p-self._h)

            self._1_plus = bit_flip_error(self._1_plus, p)
            self._1_minus = bit_flip_error(self._1_minus, p)
            self._2_plus = bit_flip_error(self._2_plus, p)
            self._2_minus = bit_flip_error(self._2_minus, p)
        




    

## Create RLA and Training Environment

## Benchmark RLA Against Fixed Protocols (QFIM Calculation)

## Create Estimator Alogorithm

## Benchmark RLA Against Fixed Protocols (Estimator Accuracy)