In [12]:
from qiskit import QuantumCircuit, Aer, execute, transpile
import threading
from qiskit.visualization import plot_histogram
from datetime import datetime
import matplotlib.pyplot as plt

def identify_search_pairs(solutions, n):
    pairs = [
        # First segment
        (0, 1, [solution[0:2] for solution in solutions]),
        (2, 2, [solution[2] for solution in solutions]),
        
        # Second segment
        (3, 3, [solution[3] for solution in solutions]),
        (4, 5, [solution[4:6] for solution in solutions]),
        
        # Third segment
        (6, 7, [solution[6:8] for solution in solutions]),
        (8, 9, [solution[8:10] for solution in solutions]),

        # Fourth segment
        (10, 11, [solution[10:12] for solution in solutions]),
        (12, 13, [solution[12:14] for solution in solutions]),
        
        # Fifth segment
        (14, 15, [solution[14:16] for solution in solutions]),
        (16, 16, [solution[16] for solution in solutions]),
        
        # Sixth segment
        (17, 17, [solution[17] for solution in solutions]),
        (18, 19, [solution[18:20] for solution in solutions])
    ]

    return pairs

def setup_circuit(n, ancillas):
    circuit = QuantumCircuit(n + len(ancillas), n)
    for ancilla in ancillas:
        circuit.x(ancilla)
        circuit.h(ancilla)
    return circuit

def partial_hadamard_transform(circuit, qubits):
    for qubit in qubits:
        circuit.h(qubit)

def partial_oracle(circuit, qubits, ancilla, target_state):
    if len(qubits) > 1:
        for qubit, bit in zip(qubits, target_state):
            if bit == '0':
                circuit.x(qubit)
        circuit.mcx(qubits, ancilla)
        for qubit, bit in zip(qubits, target_state):
            if bit == '0':
                circuit.x(qubit)

def partial_diffuser(circuit, qubits):
    if len(qubits) > 1:
        for qubit in qubits:
            circuit.h(qubit)
            circuit.x(qubit)
        circuit.h(qubits[-1])
        circuit.mcx(qubits[:-1], qubits[-1])
        circuit.h(qubits[-1])
        for qubit in qubits:
            circuit.x(qubit)
            circuit.h(qubit)

def handle_odd_qubit(circuit, qubit, solution_bit):
    circuit.h(qubit)
    if solution_bit == '1':
        circuit.z(qubit)
    circuit.h(qubit)

def apply_grover_segment(circuit, start_qubit, end_qubit, solution, ancilla):
    for step in range(start_qubit, end_qubit, 2):
        if step + 1 < end_qubit:  # Pair of qubits
            qubits_to_superpose = list(range(step, step + 2))
            partial_hadamard_transform(circuit, qubits_to_superpose)
            partial_oracle(circuit, qubits_to_superpose, ancilla, solution[step:step+2])
            partial_diffuser(circuit, qubits_to_superpose)
        else:  # Single qubit
            handle_odd_qubit(circuit, step, solution[step])
        if end_qubit == 2:
            handle_odd_qubit(circuit, end_qubit, solution[start_qubit])

def apply_grover_segment_1(circuit, start_qubit, end_qubit, solution, ancilla):
    #print(f"Search 1 started at: {datetime.now()}")
    time1 = datetime.now()
    apply_grover_segment(circuit, start_qubit, end_qubit, solution, ancilla)
    time2 = datetime.now()
    print(time2-time1)
    #print(f"Search 1 ended at: {datetime.now()}")

def apply_grover_segment_2(circuit, start_qubit, end_qubit, solution, ancilla):
    #print(f"Search 2 started at: {datetime.now()}")
    time1 = datetime.now()
    for step in range(end_qubit - 2, start_qubit - 2, -2):
        if step >= start_qubit:  # Pair of qubits
            qubits_to_superpose = list(range(step, step + 2))
            partial_hadamard_transform(circuit, qubits_to_superpose)
            partial_oracle(circuit, qubits_to_superpose, ancilla, solution[step:step+2])
            partial_diffuser(circuit, qubits_to_superpose)
        else:  # Single qubit
            handle_odd_qubit(circuit, start_qubit, solution[start_qubit])
    time2 = datetime.now()
    print(time2-time1)
    #print(f"Search 2 ended at: {datetime.now()}")

def segmented_grover_forward(circuit, pairs, solution, ancilla, previous_solution):
    time1 = datetime.now()
    for start, end, target_states in pairs:
        if start == end:
            if solution[start] != previous_solution[start]:
                handle_odd_qubit(circuit, start, solution[start])
        elif solution[start:end+1] != previous_solution[start:end+1]:
            target_state = solution[start:end+1]
            qubits_to_reset = list(range(start, end + 1))
            partial_hadamard_transform(circuit, qubits_to_reset)
            partial_oracle(circuit, qubits_to_reset, ancilla, target_state)
            partial_diffuser(circuit, qubits_to_reset)
    time2 = datetime.now()
    print(time2-time1)

def segmented_grover_backward(circuit, pairs, solution, ancilla, previous_solution):
    time1 = datetime.now()
    for start, end, target_states in pairs[::-1]:
        if start == end:
            if solution[start] != previous_solution[start]:
                handle_odd_qubit(circuit, start, solution[start])
        elif solution[start:end+1] != previous_solution[start:end+1]:
            target_state = solution[start:end+1]
            qubits_to_reset = list(range(start, end + 1))
            partial_hadamard_transform(circuit, qubits_to_reset)
            partial_oracle(circuit, qubits_to_reset, ancilla, target_state)
            partial_diffuser(circuit, qubits_to_reset)
    time2 = datetime.now()
    print(time2-time1)

def search_and_reset(n, solutions, pairs):
    results = {}
    backend = Aer.get_backend('aer_simulator')
    first_solution = True
    
    ancillas = [n, n + 1, n + 2, n + 3, n + 4, n + 5]
    
    segment_size = n // 6
    
    pairs_segment_1 = pairs[0:2]
    pairs_segment_2 = pairs[2:4]
    pairs_segment_3 = pairs[4:6]
    pairs_segment_4 = pairs[6:8]
    pairs_segment_5 = pairs[8:10]
    pairs_segment_6 = pairs[10:12]

    previous_solution = None

    for i, solution in enumerate(solutions):
        circuit = setup_circuit(n, ancillas)  # Reset the entire circuit for each solution
        
        if first_solution:
            thread_segment_1 = threading.Thread(target=apply_grover_segment_1, args=(circuit, 0, 2, solution, ancillas[0]))
            thread_segment_2 = threading.Thread(target=apply_grover_segment_2, args=(circuit, 3, 6, solution, ancillas[1]))
            thread_segment_3 = threading.Thread(target=apply_grover_segment_1, args=(circuit, 6, 10, solution, ancillas[2]))
            thread_segment_4 = threading.Thread(target=apply_grover_segment_2, args=(circuit, 10, 14, solution, ancillas[3]))
            thread_segment_5 = threading.Thread(target=apply_grover_segment_1, args=(circuit, 14, 17, solution, ancillas[4]))
            thread_segment_6 = threading.Thread(target=apply_grover_segment_2, args=(circuit, 17, 20, solution, ancillas[5]))

            thread_segment_1.start()
            thread_segment_2.start()
            thread_segment_3.start()
            thread_segment_4.start()
            thread_segment_5.start()
            thread_segment_6.start()

            thread_segment_1.join()
            thread_segment_2.join()
            thread_segment_3.join()
            thread_segment_4.join()
            thread_segment_5.join()
            thread_segment_6.join()
            
            print("\n")

            first_solution = False
        else:
            previous_solution = solutions[i - 1]
            thread_segment_1 = threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_1, solution, ancillas[0], previous_solution))
            thread_segment_2 = threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_2, solution, ancillas[1], previous_solution))
            thread_segment_3 = threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_3, solution, ancillas[2], previous_solution))
            thread_segment_4 = threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_4, solution, ancillas[3], previous_solution))
            thread_segment_5 = threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_5, solution, ancillas[4], previous_solution))
            thread_segment_6 = threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_6, solution, ancillas[5], previous_solution))

            thread_segment_1.start()
            thread_segment_2.start()
            thread_segment_3.start()
            thread_segment_4.start()
            thread_segment_5.start()
            thread_segment_6.start()

            thread_segment_1.join()
            thread_segment_2.join()
            thread_segment_3.join()
            thread_segment_4.join()
            thread_segment_5.join()
            thread_segment_6.join()
            
            print("\n")

            # Adjust common qubits to match the previous solution
            for start, end, target_states in pairs:
                if solution[start:end+1] == previous_solution[start:end+1]:
                    if start == end:  # Single qubit
                        if solution[start] == '1':
                            circuit.x(start)
                    else:  # Pair of qubits
                        for j in range(start, end + 1):
                            if solution[j] == '1':
                                circuit.x(j)

        #print(circuit)
        
        circuit.measure(list(range(n)), list(range(n)))
        transpiled_circuit = transpile(circuit, backend)
        result = execute(transpiled_circuit, backend, shots=1024).result()
        counts = result.get_counts()
        results[solution] = counts

        # Remove the measurement instructions to continue using the same circuit for the next iteration
        circuit.data = circuit.data[:-n]
    
    return results

n = 20
segment_size=3
#solutions = ['00000000000000000000', '10101010101010101010', '11111111111111111111']
solutions = ['00000000000000000000', '00101100111100001011', '11111111111111111111']
pairs = identify_search_pairs(solutions, segment_size)
result_counts = search_and_reset(n, solutions, pairs)
print(pairs)

for solution, counts in result_counts.items():
    print(f"Results for {solution}: {counts}")
    plot_histogram(counts)
    plt.show()


0:00:00.000555
0:00:00.000603
0:00:00.000893
0:00:00.000829
0:00:00.000474
0:00:00.000425


0:00:00.0001090:00:00.000784

0:00:00.000686
0:00:00.000716
0:00:00.000094
0:00:00.000665


0:00:00.000759
0:00:00.000098
0:00:00.000674
0:00:00.000633
0:00:00.000615
0:00:00.000074


[(0, 1, ['00', '00', '11']), (2, 2, ['0', '1', '1']), (3, 3, ['0', '0', '1']), (4, 5, ['00', '11', '11']), (6, 7, ['00', '00', '11']), (8, 9, ['00', '11', '11']), (10, 11, ['00', '11', '11']), (12, 13, ['00', '00', '11']), (14, 15, ['00', '00', '11']), (16, 16, ['0', '1', '1']), (17, 17, ['0', '0', '1']), (18, 19, ['00', '11', '11'])]
Results for 00000000000000000000: {'00000000000000000000': 1024}
Results for 00101100111100001011: {'11010000111100110100': 1024}
Results for 11111111111111111111: {'11111111111111111111': 1024}
