In [None]:
import json
import pennylane as qml
import pennylane.numpy as np

# Write any helper functions you need here

def GHZ_perfect(n_qubits):
    qml.Hadamard(wires=0)
    for i in range(1,n_qubits):
        qml.CNOT(wires=[0,i])
        
    return qml.state()

def Hadamard(wire):
    qml.RY(np.pi/2, wire)
    qml.RX(np.pi, wire)

def GHZ_circuit(noise_param, n_qubits):

    """
    Quantum circuit that prepares an imperfect GHZ state using gates native to a neutral atom device.

    Args:
        - noise_param (float): Parameter that quantifies the noise in the CZ gate, modelled as a 
        depolarizing channel on the target qubit. noise_param is the parameter of the depolarizing channel
        following the PennyLane convention.
        - n_qubits (int): The number of qubits in the prepared GHZ state.
    Returns:
        - (np.tensor): A density matrix, as returned by `qml.state`, representing the imperfect GHZ state.
    
    """
    Hadamard(0)

    for wire in range(1, n_qubits):
        Hadamard(wire)
        qml.CZ([0, wire])
        qml.DepolarizingChannel(noise_param, wire)
        Hadamard(wire)

    return qml.state()

def GHZ_fidelity(noise_param, n_qubits):

    """
    Calculates the fidelity between the imperfect GHZ state returned by GHZ_circuit and the ideal GHZ state.

    Args:
        - noise_param (float): Parameter that quantifies the noise in the CZ gate, modelled as a 
        depolarizing channel on the target qubit. noise_param is the parameter of the depolarizing channel
        following the PennyLane convention.
        - n_qubits (int): The number of qubits in the GHZ state.
    Returns:
        - (float): The fidelity between the noisy and ideal GHZ states.
    """
    
    dev = qml.device('default.mixed', wires=n_qubits)
    
    GHZ_QNode = qml.QNode(GHZ_circuit,dev)
    noisy = GHZ_QNode(noise_param, n_qubits)

    N = 2**n_qubits

    # Calculating the fidelity with respect to a perfect state prepared
    # using a separate circuit gives incorrect result (at the 7th decimal place)
    dev2 = qml.device('default.mixed', wires=n_qubits)
    GHZ_QNode_perfect = qml.QNode(GHZ_perfect,dev2)
                              
    # Creating ideal GHZ state gives slightly incorrect fidelity too                          
    state0 = np.zeros((2**n_qubits, 2**n_qubits), dtype='complex')
    state0[0,0] = 0.5#+0j
    state0[0,-1] = 0.5#+0j
    state0[-1,0] = 0.5#+0j
    state0[-1,-1] = 0.5#+0j
    
    # Use GHZ_QNode to find the fidelity between
    # the noisy GHZ state and an ideal GHZ state
    perfect = GHZ_QNode(0, n_qubits)
    state1 = GHZ_QNode(noise_param, n_qubits)

    print(qml.math.fidelity(perfect, state1, check_state=True, c_dtype='complex128'))
    print(qml.math.fidelity(state0, state1, check_state=True, c_dtype='complex128'))
    return qml.math.fidelity(perfect, state1)


# These functions are responsible for testing the solution.

def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    output = GHZ_fidelity(*ins)

    return str(output)

def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    
    dev = qml.device('default.mixed', wires=4)
    qnode = qml.QNode(GHZ_circuit, dev)
    u = qnode(0.05,3)
    
    for op in qnode.tape.operations:
        assert (isinstance(op, qml.RX) or isinstance(op, qml.RY) or isinstance(op, qml.CZ) or isinstance(op, qml.DepolarizingChannel)), "You are using forbidden gates!"

    assert np.isclose(solution_output, expected_output, rtol = 1e-4)


# These are the public test cases
test_cases = [
    ('[0.05, 3]', '0.9027779255467782'),
    ('[0.01, 5]', '0.9606614879634601')
]

# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

## Circuit design
https://algassert.com/quirk#circuit={%22cols%22:[[%22H%22,1,1,%22%E2%80%A6%22,%22%E2%80%A6%22,%22%E2%80%A6%22],[%22%E2%80%A2%22,%22X%22],[1,%22%E2%80%A2%22,%22X%22],[1,1,1,%22~4l1a%22,%22~4l1a%22,%22~4l1a%22],[1,1,1,%22~7uh8%22],[1,1,1,%22%E2%80%A2%22,1,%22Z%22],[1,1,1,%22%E2%80%A2%22,%22Z%22],[1,1,1,1,%22~hoja%22,%22~hoja%22],[%22Density3%22,1,1,%22Density3%22]],%22gates%22:[{%22id%22:%22~7uh8%22,%22name%22:%22RX90%22,%22matrix%22:%22{{%E2%88%9A%C2%BD,-%E2%88%9A%C2%BDi},{-%E2%88%9A%C2%BDi,%E2%88%9A%C2%BD}}%22},{%22id%22:%22~4l1a%22,%22name%22:%22RY90%22,%22matrix%22:%22{{%E2%88%9A%C2%BD,-%E2%88%9A%C2%BD},{%E2%88%9A%C2%BD,%E2%88%9A%C2%BD}}%22},{%22id%22:%22~hoja%22,%22name%22:%22RY270%22,%22matrix%22:%22{{-%E2%88%9A%C2%BD,-%E2%88%9A%C2%BD},{%E2%88%9A%C2%BD,-%E2%88%9A%C2%BD}}%22},{%22id%22:%22~ou3s%22,%22name%22:%22RX270%22,%22matrix%22:%22{{-%E2%88%9A%C2%BD,-%E2%88%9A%C2%BDi},{-%E2%88%9A%C2%BDi,-%E2%88%9A%C2%BD}}%22}]}