In [9]:
from qiskit import QuantumCircuit, Aer, execute, transpile
import threading
from qiskit.visualization import plot_histogram
from datetime import datetime
import matplotlib.pyplot as plt

def identify_search_pairs(solutions, n):
    pairs = [
        # Segments of 2 qubits each
        (0, 1, [solution[0:2] for solution in solutions]),
        (2, 3, [solution[2:4] for solution in solutions]),
        (4, 5, [solution[4:6] for solution in solutions]),
        (6, 7, [solution[6:8] for solution in solutions]),
        (8, 9, [solution[8:10] for solution in solutions]),
        (10, 11, [solution[10:12] for solution in solutions]),
        (12, 13, [solution[12:14] for solution in solutions]),
        (14, 15, [solution[14:16] for solution in solutions]),
        (16, 17, [solution[16:18] for solution in solutions]),
        (18, 19, [solution[18:20] for solution in solutions])
    ]
    return pairs

def setup_circuit(n, ancillas):
    circuit = QuantumCircuit(n + len(ancillas), n)
    for ancilla in ancillas:
        circuit.x(ancilla)
        circuit.h(ancilla)
    return circuit

def partial_hadamard_transform(circuit, qubits):
    for qubit in qubits:
        circuit.h(qubit)

def partial_oracle(circuit, qubits, ancilla, target_state):
    for qubit, bit in zip(qubits, target_state):
        if bit == '0':
            circuit.x(qubit)
    circuit.mcx(qubits, ancilla)
    for qubit, bit in zip(qubits, target_state):
        if bit == '0':
            circuit.x(qubit)

def partial_diffuser(circuit, qubits):
    for qubit in qubits:
        circuit.h(qubit)
        circuit.x(qubit)
    circuit.h(qubits[-1])
    circuit.mcx(qubits[:-1], qubits[-1])
    circuit.h(qubits[-1])
    for qubit in qubits:
        circuit.x(qubit)
        circuit.h(qubit)

def apply_grover_segment(circuit, start_qubit, end_qubit, solution, ancilla):
    time1 = datetime.now()
    for step in range(start_qubit, end_qubit, 2):
        qubits_to_superpose = list(range(step, step + 2))
        partial_hadamard_transform(circuit, qubits_to_superpose)
        partial_oracle(circuit, qubits_to_superpose, ancilla, solution[step:step+2])
        partial_diffuser(circuit, qubits_to_superpose)
    time2 = datetime.now()
    print(time2-time1)

def segmented_grover_forward(circuit, pairs, solution, ancilla, previous_solution):
    time1 = datetime.now()
    for start, end, target_states in pairs:
        if solution[start:end+1] != previous_solution[start:end+1]:
            target_state = solution[start:end+1]
            qubits_to_reset = list(range(start, end + 1))
            partial_hadamard_transform(circuit, qubits_to_reset)
            partial_oracle(circuit, qubits_to_reset, ancilla, target_state)
            partial_diffuser(circuit, qubits_to_reset)
    time2 = datetime.now()
    print(time2-time1)

def segmented_grover_backward(circuit, pairs, solution, ancilla, previous_solution):
    time1 = datetime.now()
    for start, end, target_states in pairs[::-1]:
        if solution[start:end+1] != previous_solution[start:end+1]:
            target_state = solution[start:end+1]
            qubits_to_reset = list(range(start, end + 1))
            partial_hadamard_transform(circuit, qubits_to_reset)
            partial_oracle(circuit, qubits_to_reset, ancilla, target_state)
            partial_diffuser(circuit, qubits_to_reset)
    time2 = datetime.now()
    print(time2-time1)

def search_and_reset(n, solutions, pairs):
    results = {}
    backend = Aer.get_backend('aer_simulator')
    first_solution = True
    
    ancillas = [n, n + 1, n + 2, n + 3, n + 4, n + 5, n + 6, n + 7, n + 8, n + 9]
    
    pairs_segment_1 = pairs[0:1]
    pairs_segment_2 = pairs[1:2]
    pairs_segment_3 = pairs[2:3]
    pairs_segment_4 = pairs[3:4]
    pairs_segment_5 = pairs[4:5]
    pairs_segment_6 = pairs[5:6]
    pairs_segment_7 = pairs[6:7]
    pairs_segment_8 = pairs[7:8]
    pairs_segment_9 = pairs[8:9]
    pairs_segment_10 = pairs[9:10]

    previous_solution = None

    for i, solution in enumerate(solutions):
        circuit = setup_circuit(n, ancillas)  # Reset the entire circuit for each solution
        
        if first_solution:
            threads = [
                threading.Thread(target=apply_grover_segment, args=(circuit, 0, 2, solution, ancillas[0])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 2, 4, solution, ancillas[1])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 4, 6, solution, ancillas[2])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 6, 8, solution, ancillas[3])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 8, 10, solution, ancillas[4])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 10, 12, solution, ancillas[5])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 12, 14, solution, ancillas[6])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 14, 16, solution, ancillas[7])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 16, 18, solution, ancillas[8])),
                threading.Thread(target=apply_grover_segment, args=(circuit, 18, 20, solution, ancillas[9]))
            ]

            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            
            print("\n")
            
            first_solution = False
        else:
            previous_solution = solutions[i - 1]
            threads = [
                threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_1, solution, ancillas[0], previous_solution)),
                threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_2, solution, ancillas[1], previous_solution)),
                threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_3, solution, ancillas[2], previous_solution)),
                threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_4, solution, ancillas[3], previous_solution)),
                threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_5, solution, ancillas[4], previous_solution)),
                threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_6, solution, ancillas[5], previous_solution)),
                threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_7, solution, ancillas[6], previous_solution)),
                threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_8, solution, ancillas[7], previous_solution)),
                threading.Thread(target=segmented_grover_forward, args=(circuit, pairs_segment_9, solution, ancillas[8], previous_solution)),
                threading.Thread(target=segmented_grover_backward, args=(circuit, pairs_segment_10, solution, ancillas[9], previous_solution))
            ]

            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            
            print("\n")
            
            # Adjust common qubits to match the previous solution
            for start, end, target_states in pairs:
                if solution[start:end+1] == previous_solution[start:end+1]:
                    for j in range(start, end + 1):
                        if solution[j] == '1':
                            circuit.x(j)

        #print(circuit)
        
        circuit.measure(list(range(n)), list(range(n)))
        transpiled_circuit = transpile(circuit, backend)
        result = execute(transpiled_circuit, backend, shots=1024).result()
        counts = result.get_counts()
        results[solution] = counts

        # Remove the measurement instructions to continue using the same circuit for the next iteration
        circuit.data = circuit.data[:-n]
    
    return results

n = 20
#solutions = ['00000000000000000000', '10101010101010101010', '11111111111111111111']
solutions = ['00000000000000000000', '00110011001100110011', '11111111111111111111']
pairs = identify_search_pairs(solutions, n)
result_counts = search_and_reset(n, solutions, pairs)
print(pairs)

for solution, counts in result_counts.items():
    print(f"Results for {solution}: {counts}")
    plot_histogram(counts)
    plt.show()


0:00:00.0004980:00:00.000485

0:00:00.000484
0:00:00.000481
0:00:00.000431
0:00:00.000416
0:00:00.000540
0:00:00.000411
0:00:00.000406
0:00:00.000446


0:00:00.000005
0:00:00.000407
0:00:00.000002
0:00:00.000384
0:00:00.000003
0:00:00.000366
0:00:00.000002
0:00:00.000450
0:00:00.000002
0:00:00.000375


0:00:00.000740
0:00:00.000005
0:00:00.000655
0:00:00.000006
0:00:00.000676
0:00:00.000004
0:00:00.000612
0:00:00.000004
0:00:00.000648
0:00:00.000008


[(0, 1, ['00', '00', '11']), (2, 3, ['00', '11', '11']), (4, 5, ['00', '00', '11']), (6, 7, ['00', '11', '11']), (8, 9, ['00', '00', '11']), (10, 11, ['00', '11', '11']), (12, 13, ['00', '00', '11']), (14, 15, ['00', '11', '11']), (16, 17, ['00', '00', '11']), (18, 19, ['00', '11', '11'])]
Results for 00000000000000000000: {'00000000000000000000': 1024}
Results for 00110011001100110011: {'11001100110011001100': 1024}
Results for 11111111111111111111: {'11111111111111111111': 1024}
