In [None]:
## Working version with 3 unknown cell quantu computing

from qiskit import QuantumCircuit, QuantumRegister
from qiskit.quantum_info import Statevector
import numpy as np

# Known cells (fixed values)
known_cells = {
    0: 2, 1: 0, 2: 1, 3: 3,
    4: 1, 5: 3, 6: 2, 7: 0,
    8: 0, 9: 1, 10: 3, 11: 2,
    12: 3
}
unknown_cells = [13, 14, 15]  # Three unknown cells

# Reduced RC groups for constraints
rcs = [
    [1,5,9,13], [2,6,10,14], [3,7,11,15]
]

#rcs = [
#    [0,1,2,3], [4,5,6,7], [8,9,10,11], [12,13,14,15],
#    [0,4,8,12], [1,5,9,13], [2,6,10,14], [3,7,11,15],
#    [0,1,4,5], [2,3,6,7], [8,9,12,13], [10,11,14,15]
#]

num_vars = 2 * len(unknown_cells)  # 6 qubits (2 per unknown cell)
var_qubits = QuantumRegister(num_vars, 'v')
rc_ancillas = QuantumRegister(len(rcs), 'anc_rc')  # 3 ancillas for each RC
temp_ancillas = QuantumRegister(4, 'anc_temp')  # small pool of temp ancillas
output = QuantumRegister(1, 'out')

qc = QuantumCircuit(var_qubits, rc_ancillas, temp_ancillas, output)

def cell_to_qubits(cell_idx):
    """Map a cell index to its var_qubits indices."""
    pos = unknown_cells.index(cell_idx)
    return [2*pos, 2*pos + 1]

def known_cell_value_qubits(cell_idx):
    """Return fixed bit pattern for known cells."""
    val = known_cells[cell_idx]
    return [(val >> 1) & 1, val & 1]

def mark_equal_2qubits(qc, var_qubits, cellA, cellB, target, anc1, anc2):
    """
    Mark 'target' ancilla if cellA == cellB in their 2 qubit values.
    Uses anc1 and anc2 as helper ancillas for intermediate steps.
    """
    def get_qubits(cell):
        if cell in unknown_cells:
            return [var_qubits[i] for i in cell_to_qubits(cell)]
        else:
            return known_cell_value_qubits(cell)
    
    qa = get_qubits(cellA)
    qb = get_qubits(cellB)
    
    # Step 1: Check bit0 equality
    if isinstance(qa[0], int) and isinstance(qb[0], int):
        bit0_equal = (qa[0] == qb[0])
        if bit0_equal:
            qc.x(target)
        else:
            return
    else:
        if isinstance(qa[0], int):
            if qa[0] == 0:
                qc.cx(qb[0], anc1)
            else:
                qc.x(qb[0])
                qc.cx(qb[0], anc1)
                qc.x(qb[0])
        elif isinstance(qb[0], int):
            if qb[0] == 0:
                qc.cx(qa[0], anc1)
            else:
                qc.x(qa[0])
                qc.cx(qa[0], anc1)
                qc.x(qa[0])
        else:
            qc.cx(qa[0], anc1)
            qc.cx(qb[0], anc1)
        qc.x(anc1)

    # Step 2: Check bit1 equality
    if isinstance(qa[1], int) and isinstance(qb[1], int):
        bit1_equal = (qa[1] == qb[1])
        if bit1_equal:
            qc.x(target)
        else:
            if isinstance(qa[0], int) and isinstance(qb[0], int) and (qa[0] == qb[0]):
                qc.x(target)  # undo extra flips
            return
    else:
        if isinstance(qa[1], int):
            if qa[1] == 0:
                qc.cx(qb[1], anc2)
            else:
                qc.x(qb[1])
                qc.cx(qb[1], anc2)
                qc.x(qb[1])
        elif isinstance(qb[1], int):
            if qb[1] == 0:
                qc.cx(qa[1], anc2)
            else:
                qc.x(qa[1])
                qc.cx(qa[1], anc2)
                qc.x(qa[1])
        else:
            qc.cx(qa[1], anc2)
            qc.cx(qb[1], anc2)
        qc.x(anc2)

    # Step 3: Combine bit0 and bit1 equality
    qc.ccx(anc1, anc2, target)

    # Uncompute ancillas
    qc.x(anc1)
    qc.x(anc2)
    if not (isinstance(qa[0], int) and isinstance(qb[0], int)):
        qc.cx(qa[0], anc1)
        qc.cx(qb[0], anc1)
    if not (isinstance(qa[1], int) and isinstance(qb[1], int)):
        qc.cx(qa[1], anc2)
        qc.cx(qb[1], anc2)

def mark_conflict_rc(qc, var_qubits, rc_cells, ancilla, ancillas_pool):
    pairs = []
    for i in range(len(rc_cells)):
        for j in range(i+1, len(rc_cells)):
            pairs.append((rc_cells[i], rc_cells[j]))

    temp_anc = ancillas_pool[0]

    for cA, cB in pairs:
        qc.reset(temp_anc)
        mark_equal_2qubits(qc, var_qubits, cA, cB, temp_anc, ancillas_pool[1], ancillas_pool[2])
        qc.cx(temp_anc, ancilla)
        mark_equal_2qubits(qc, var_qubits, cA, cB, temp_anc, ancillas_pool[1], ancillas_pool[2])

def apply_oracle(qc, var_qubits, output, rc_ancillas, temp_ancillas, rcs):
    for anc in rc_ancillas:
        qc.reset(anc)
    for i, rc in enumerate(rcs):
        mark_conflict_rc(qc, var_qubits, rc, rc_ancillas[i], temp_ancillas)
    for anc in rc_ancillas:
        qc.x(anc)
    qc.mcx(rc_ancillas, output[0])
    for anc in rc_ancillas:
        qc.x(anc)

def diffuser(nqubits):
    qc = QuantumCircuit(nqubits)
    qc.h(range(nqubits))
    qc.x(range(nqubits))
    qc.h(nqubits - 1)
    qc.mcx(list(range(nqubits - 1)), nqubits - 1)
    qc.h(nqubits - 1)
    qc.x(range(nqubits))
    qc.h(range(nqubits))
    return qc

# Initialize superposition on variable qubits
qc.h(var_qubits)

# Initialize output qubit to |-> = (|0>-|1>)/sqrt(2)
qc.initialize([1/np.sqrt(2), -1/np.sqrt(2)], output)

# Apply Grover iterations
N = 4**3
num_iterations = int(np.floor(np.pi/4 * np.sqrt(N)))  # ~6 iterations
for _ in range(num_iterations):
    apply_oracle(qc, var_qubits, output, rc_ancillas, temp_ancillas, rcs)
    qc.append(diffuser(len(var_qubits)), var_qubits[:])

# Final statevector and probabilities
final_state = Statevector.from_instruction(qc)
probs = final_state.probabilities(qargs=[0,1,2,3,4,5])  # indices of var_qubits

# --- Plotting ---
num_var_qubits = len(var_qubits)
labels = [format(i, f'0{num_var_qubits}b') for i in range(2**num_var_qubits)]
plt.bar(labels, probs)
plt.xlabel('Binary State (var_qubits)')
plt.ylabel('Probability')
plt.title('Measurement Probabilities (No Classical Bits)')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()