In [None]:
# Variational Recovery Map for 5 qubits [ONLY VAR PART]

# beta[0] -> RX(red); beta[1] -> RZ(blue); beta[2] -> RZZ(black)
# L = num repetition -> 3 specified in paper

def decode_variational_5(beta, L = 3, wires=5):
    
    rzz_pairs = [
        (0, 1), (1, 2), (2, 3), (3, 4),
        (0, 2), (2, 4),
        (1, 3), 
        (0, 4), 
        (0, 3),
        (1, 4)
    ]

    for i in range(wires):
        qml.RZ(beta[1], wires=i)

    for l in range(L):
        for i in range(wires):
            qml.RX(beta[0], wires=i)
            qml.RZ(beta[1], wires=i)        

        for (i, j) in rzz_pairs:
            qml.IsingZZ(beta[2], wires=[i, j])

    for i in range(wires):
        qml.RX(beta[0], wires=i)
        qml.RZ(beta[1], wires=i)           

        
        
# Variational Recovery Map for 3 qubits [ONLY VAR PART]

# beta[0] -> RX(red); beta[1] -> RZ(blue); beta[2] -> CZ(black)
# L = num repetition -> 2 specified in paper

def decode_variational_3(beta, L = 2, wires=3):
    
    for i in range(wires):
        qml.RZ(beta[1], wires=i)

    for l in range(L):
        for i in range(wires):
            qml.RX(beta[0], wires=i)
            qml.RZ(beta[1], wires=i)        
    
        qml.CZ(beta[2], wires=[0, 1])
        qml.CZ(beta[2], wires=[1, 2])
        qml.CZ(beta[2], wires=[0, 2])
        

    for i in range(wires):
        qml.RX(beta[0], wires=i)
        qml.RZ(beta[1], wires=i)

In [None]:
## - - WORK IN PROGRESS -- ##

In [67]:
import pennylane as qml
import numpy as np

# --- DEVICES ---
dev_analytic = qml.device("default.qubit", wires=9, shots=None)  # For qml.state()
dev_sampling = qml.device("default.mixed", wires=9, shots=1)  # Changed to None shots for deterministic sampling

# --- ENCODING CIRCUIT ---
def encoding_circuit(logical_bit=0):
    if logical_bit == 1:
        qml.PauliX(wires=0)

    qml.Hadamard(0)
    qml.CNOT(wires=[0, 1])
    qml.CNOT(wires=[0, 2])
    qml.CNOT(wires=[0, 3])
    qml.CNOT(wires=[0, 4])

    qml.Hadamard(1)
    qml.CNOT(wires=[1, 2])
    qml.CNOT(wires=[1, 3])
    qml.CNOT(wires=[1, 4])

    qml.Hadamard(2)
    qml.CNOT(wires=[2, 3])
    qml.CNOT(wires=[2, 4])

    qml.Hadamard(3)
    qml.CNOT(wires=[3, 4])

# --- NOISE LAYER ---
def apply_noise_layer(gamma=0.1, noise_type="bit_flip"):
    for wire in range(5):
        if noise_type == "amplitude_damping":
            qml.AmplitudeDamping(gamma, wires=wire)
        elif noise_type == "bit_flip":
            qml.BitFlip(gamma, wires=wire)
        elif noise_type == "depolarizing":
            qml.DepolarizingChannel(gamma, wires=wire)

# --- Syndrome ---
@qml.qnode(dev_sampling)
def syndrome_only(logical_bit=0, gamma=0.1):
    encoding_circuit(logical_bit)
    apply_noise_layer(gamma, "bit_flip")

    # Stabilizer S1 = X Z Z X I
    qml.Hadamard(5)
    qml.Hadamard(0); qml.Hadamard(3)
    for i in [0, 1, 2, 3]:
        qml.CNOT(wires=[i, 5])
    qml.Hadamard(0); qml.Hadamard(3)

    # Stabilizer S2 = I X Z Z X
    qml.Hadamard(6)
    qml.Hadamard(1); qml.Hadamard(4)
    for i in [1, 2, 3, 4]:
        qml.CNOT(wires=[i, 6])
    qml.Hadamard(1); qml.Hadamard(4)

    # Stabilizer S3 = X I X Z Z
    qml.Hadamard(7)
    qml.Hadamard(0); qml.Hadamard(2)
    for i in [0, 2, 3, 4]:
        qml.CNOT(wires=[i, 7])
    qml.Hadamard(0); qml.Hadamard(2)

    # Stabilizer S4 = Z X I X Z
    qml.Hadamard(8)
    qml.Hadamard(1); qml.Hadamard(3)
    for i in [0, 1, 3, 4]:
        qml.CNOT(wires=[i, 8])
    qml.Hadamard(1); qml.Hadamard(3)

    return qml.sample(wires=[5, 6, 7, 8])

@qml.qnode(dev_sampling)
def noisy_rho_only(logical_bit=0, gamma=0.1):
    encoding_circuit(logical_bit)
    apply_noise_layer(gamma, "bit_flip")
    return qml.density_matrix(wires=[0,1,2,3,4])


# --- SYNDROME DECODER ---
def decode_and_correct(syndrome):
    syndrome_map = {
        (0, 0, 0, 0): None,
        (0, 0, 0, 1): (0, 'X'),
        (1, 0, 0, 0): (1, 'X'),
        (1, 1, 0, 0): (2, 'X'),
        (0, 1, 1, 0): (3, 'X'),
        (0, 0, 1, 1): (4, 'X'),

        (1, 0, 1, 0): (0, 'Z'),
        (0, 1, 0, 1): (1, 'Z'),
        (0, 0, 1, 0): (2, 'Z'),
        (1, 0, 0, 1): (3, 'Z'),
        (0, 1, 1, 1): (4, 'Z'),

        (1, 0, 1, 1): (0, 'Y'),
        (1, 1, 0, 1): (1, 'Y'),
        (1, 1, 1, 0): (2, 'Y'),
        (1, 1, 1, 1): (3, 'Y'),
        (0, 1, 0, 0): (4, 'Y'),
    }

    key = tuple(syndrome)
    if key not in syndrome_map or syndrome_map[key] is None:
        return []
    qubit, op = syndrome_map[key]
    if op == 'X':
        return [qml.PauliX(wires=qubit)]
    elif op == 'Z':
        return [qml.PauliZ(wires=qubit)]
    elif op == 'Y':
        return [qml.PauliY(wires=qubit)]
    return []


# Test
if __name__ == "__main__":
    gamma = 0.1
    logical_bit = 0

    # Step 1: Measure syndrome (same noise path)
    syndrome_bits = list(syndrome_only(logical_bit, gamma))
    print("Syndrome bits:", syndrome_bits)

    # Step 2: Get corresponding correction ops
    correction_ops = decode_and_correct(syndrome_bits)
    print("Correction operations:", correction_ops)

    # Step 3: Get noisy state from same encoding + noise path
    noisy_state = noisy_rho_only(logical_bit, gamma)

    # Step 4: Apply correction
    @qml.qnode(dev_sampling)
    def apply_correction(state, correction_ops):
        qml.QubitDensityMatrix(state, wires=[0,1,2,3,4])
        for op in correction_ops:
            op.queue()
        return qml.density_matrix(wires=[0,1,2,3,4])

    corrected_rho = apply_correction(noisy_state, correction_ops)

    # Step 5: Compare to ideal encoded state
    ideal_state = get_encoded_state(logical_bit=logical_bit)[:32]
    ideal_rho = np.outer(ideal_state, np.conj(ideal_state))
    fidelity_val = qml.math.fidelity(ideal_rho, corrected_rho)

    print(f"Fidelity after recovery: {fidelity_val:.6f}")


Syndrome bits: [0, 1, 0, 1]
Correction operations: [Z(1)]
Fidelity after recovery: 0.002593
