In [3]:
from qiskit import QuantumCircuit, Aer, execute, transpile
import threading
from qiskit.visualization import plot_histogram
from datetime import datetime
import matplotlib.pyplot as plt

def identify_search_pairs(solutions):
    num_qubits = len(solutions[0])
    pairs = []
    
    for i in range(0, num_qubits, 2):
        if i + 1 < num_qubits:  # Ensure we don't go out of bounds
            if any(solution[i] != solutions[0][i] or solution[i+1] != solutions[0][i+1] for solution in solutions):
                pairs.append((i, i+1, [solution[i:i+2] for solution in solutions]))
    
    return pairs

def setup_circuit(n, ancilla, ancilla2):
    circuit = QuantumCircuit(n + 2, n)
    circuit.x(ancilla)
    circuit.h(ancilla)
    circuit.x(ancilla2)
    circuit.h(ancilla2)
    return circuit

def partial_hadamard_transform(circuit, qubits):
    for qubit in qubits:
        circuit.h(qubit)

def partial_oracle(circuit, qubits, ancilla, target_state):
    for qubit, bit in zip(qubits, target_state):
        if bit == '0':
            circuit.x(qubit)
    circuit.mcx(qubits, ancilla)
    for qubit, bit in zip(qubits, target_state):
        if bit == '0':
            circuit.x(qubit)

def partial_diffuser(circuit, qubits):
    for qubit in qubits:
        circuit.h(qubit)
        circuit.x(qubit)
    circuit.h(qubits[-1])
    circuit.mcx(qubits[:-1], qubits[-1])
    circuit.h(qubits[-1])
    for qubit in qubits:
        circuit.x(qubit)
        circuit.h(qubit)

def apply_grover_to_segment_forward(circuit, start_qubit, mid_qubit, solution, ancilla, n):
    #print(f"Forward search started at: {datetime.now()}")
    time1 = datetime.now()
    if (n / 2) % 2 == 1:
        circuit.h(start_qubit)
        if solution[start_qubit] == '1':
            circuit.z(start_qubit)
        circuit.h(start_qubit)
        start_qubit += 1
    
    for step in range(start_qubit, mid_qubit - 1, 2):
        qubits_to_superpose = list(range(step, min(step + 2, mid_qubit)))
        partial_hadamard_transform(circuit, qubits_to_superpose)
        partial_oracle(circuit, qubits_to_superpose, ancilla, solution[step:step+2])
        partial_diffuser(circuit, qubits_to_superpose)
    time2 = datetime.now()
    print(time2-time1)
    #print(f"Forward search ended at: {datetime.now()}")

def apply_grover_to_segment_backward(circuit, mid_qubit, end_qubit, solution, ancilla2, n):
    #print(f"Backward search started at: {datetime.now()}")
    time1 = datetime.now()
    if n % 2 == 1 or (n / 2) % 2 == 1:
        circuit.h(end_qubit-1)
        if solution[end_qubit-1] == '1':
            circuit.z(end_qubit-1)
        circuit.h(end_qubit-1)
        end_qubit -= 1
    
    for step in range(end_qubit - 2, mid_qubit - 2, -2):
        qubits_to_superpose = list(range(step, step + 2))
        partial_hadamard_transform(circuit, qubits_to_superpose)
        partial_oracle(circuit, qubits_to_superpose, ancilla2, solution[step:step+2])
        partial_diffuser(circuit, qubits_to_superpose)
    time2 = datetime.now()
    print(time2-time1)
    #print(f"Backward search ended at: {datetime.now()}")

def segmented_grover_forward(circuit, pairs, solution, ancilla, previous_solution):
    time1 = datetime.now()
    for start, end, target_states in pairs:
        if solution[start:end+1] != previous_solution[start:end+1]:
            target_state = solution[start:end+1]
            qubits_to_superpose = list(range(start, end + 1))
            partial_hadamard_transform(circuit, qubits_to_superpose)
            partial_oracle(circuit, qubits_to_superpose, ancilla, target_state)
            partial_diffuser(circuit, qubits_to_superpose)
    time2 = datetime.now()
    print(time2-time1)

def segmented_grover_backward(circuit, pairs, solution, ancilla2, previous_solution):
    time1 = datetime.now()
    for start, end, target_states in pairs[::-1]:  # Process pairs in reverse order for backward search
        if solution[start:end+1] != previous_solution[start:end+1]:
            target_state = solution[start:end+1]
            qubits_to_superpose = list(range(start, end + 1))  # Reverse the qubit order for backward search
            partial_hadamard_transform(circuit, qubits_to_superpose)
            partial_oracle(circuit, qubits_to_superpose, ancilla2, target_state)
            partial_diffuser(circuit, qubits_to_superpose)
    time2 = datetime.now()
    print(time2-time1)
            
def search_and_reset(n, solutions, pairs):
    results = {}
    backend = Aer.get_backend('aer_simulator')
    first_solution = True
    
    ancilla = n
    ancilla2 = n + 1
    
    mid_qubit = n // 2 if n % 2 == 0 else (n // 2) + 1
    
    forward_pairs = [(start, end, states) for start, end, states in pairs if start < mid_qubit]
    backward_pairs = [(start, end, states) for start, end, states in pairs if start >= mid_qubit]

    for i, solution in enumerate(solutions):
        circuit = setup_circuit(n, ancilla, ancilla2)  # Reset the entire circuit for each solution
        
        if first_solution:
            thread_first_half = threading.Thread(target=apply_grover_to_segment_forward, args=(circuit, 0, mid_qubit, solution, ancilla, n))
            thread_second_half = threading.Thread(target=apply_grover_to_segment_backward, args=(circuit, mid_qubit, n, solution, ancilla2, n))

            thread_first_half.start()
            thread_second_half.start()

            thread_first_half.join()
            thread_second_half.join()

            first_solution = False
        else:
            previous_solution = solutions[i - 1]
            thread_first_half = threading.Thread(target=segmented_grover_forward, args=(circuit, forward_pairs, solution, ancilla, previous_solution))
            thread_second_half = threading.Thread(target=segmented_grover_backward, args=(circuit, backward_pairs, solution, ancilla2, previous_solution))

            thread_first_half.start()
            thread_second_half.start()

            thread_first_half.join()
            thread_second_half.join()

            # Manually adjust the common qubits to match the previous solution
            for j in range(0, n, 2):
                if j + 1 < n and solution[j:j+2] == previous_solution[j:j+2]:
                    if solution[j] == '1':
                        circuit.x(j)
                    if solution[j+1] == '1':
                        circuit.x(j+1)

        circuit.measure(list(range(n)), list(range(n)))
        transpiled_circuit = transpile(circuit, backend)
        result = execute(transpiled_circuit, backend, shots=1024).result()
        counts = result.get_counts()
        results[solution] = counts
        
        #print(circuit)
        
        # Remove the measurement instructions to continue using the same circuit for the next iteration
        circuit.data = circuit.data[:-n]
    
    return results

n = 20
solutions = ['00000000000000000000', '00000011111111000000', '11111111111111111111']
pairs = identify_search_pairs(solutions)
result_counts = search_and_reset(n, solutions, pairs)
print(pairs)

for solution, counts in result_counts.items():
    print(f"Results for {solution}: {counts}")
    plot_histogram(counts)
    plt.show()



0:00:00.002594
0:00:00.002255
0:00:00.001280
0:00:00.001314
0:00:00.001897
0:00:00.001789
[(0, 1, ['00', '00', '11']), (2, 3, ['00', '00', '11']), (4, 5, ['00', '00', '11']), (6, 7, ['00', '11', '11']), (8, 9, ['00', '11', '11']), (10, 11, ['00', '11', '11']), (12, 13, ['00', '11', '11']), (14, 15, ['00', '00', '11']), (16, 17, ['00', '00', '11']), (18, 19, ['00', '00', '11'])]
Results for 00000000000000000000: {'00000000000000000000': 1024}
Results for 00000011111111000000: {'00000011111111000000': 1024}
Results for 11111111111111111111: {'11111111111111111111': 1024}
