In [8]:
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from qiskit_ibm_runtime import QiskitRuntimeService, Session, Sampler
from collections import Counter
import numpy as np

class DynamicConfigRecovery:
    def __init__(self, n, prob_threshold, max_distance):
        self.n = n
        self.prob_threshold = prob_threshold
        self.max_distance = max_distance
        self.valid_pairs = set()

    # (i, j)를 카운트하고, invalid sample을 repair해서 closest valid로 map
    # 정규화해서 확률분포 계산, 확률 임계점 넘는 값만 valid pair로 업데이트
    # valid_pairs 반환
    def repair_and_update(self, raw_counts):
        # 1) raw_counts → (i,j) 카운트
        counts = Counter()
        for bs, cnt in raw_counts.items():
            bits = bs.replace(" ", "")
            i, j = int(bits[:self.n],2), int(bits[self.n:],2)
            counts[(i,j)] += cnt

        # 2) invalid repair
        if self.valid_pairs:
            repaired = Counter()
            for (i,j), cnt in counts.items():
                if (i,j) in self.valid_pairs:
                    repaired[(i,j)] += cnt
                else:
                    # 가장 가까운 valid로 reroute
                    best, md = None, float('inf')
                    for vp in self.valid_pairs:
                        d = abs(i-vp[0]) + abs(j-vp[1])
                        if d < md:
                            best, md = vp, d
                    if best and md <= self.max_distance:
                        repaired[best] += cnt
            counts = repaired

        # 3) normalize → probs
        total = sum(counts.values())
        if total == 0:
            probs = {}
        else:
            probs = {p: cnt/total for p, cnt in counts.items()}

        # 4) threshold 에 따라 valid 갱신
        self.valid_pairs = {p for p, pr in probs.items() if pr >= self.prob_threshold}
        return self.valid_pairs, probs

# Given OD distributon을 amplitued imbedding -> initialize
def build_od_embedding_circuit(n, distribution):
    qr = QuantumRegister(2*n, 'q')
    cr = ClassicalRegister(2*n, 'c')
    qc = QuantumCircuit(qr, cr)

    # amplitude initialize
    dim = 2**(2*n)
    amps = np.zeros(dim, complex)
    for (i,j), p in distribution.items():
        idx = (i<<n)|j
        amps[idx] = np.sqrt(p)
    amps /= np.linalg.norm(amps)

    qc.initialize(amps, qr)
    qc.measure(qr, cr)
    return qc

def sample_qpu_once(qc, backend, shots):
    qc_t = transpile(qc, backend=backend, optimization_level=3)
    with Session(backend=backend) as sess:
        sampler = Sampler(backend)
        job = sampler.run([qc_t], shots=shots)
        return job.result()[0].join_data().get_counts()

# QPU sampling을 계속할 수 X --> hybrid로 
def sample_local(distribution, shots, n):
    pairs = list(distribution.keys())
    probs  = np.array([distribution[p] for p in pairs])
    counts = np.random.multinomial(shots, probs)
    raw = {}
    for (i,j), cnt in zip(pairs, counts):
        bs = format(i, f'0{n}b') + format(j, f'0{n}b')
        raw[bs] = cnt
    return raw

def main():
    # 1) OD 수요 (3×3 예시)
    raw_demand = {
        (0,0):5,(0,1):3,(0,2):2,
        (1,0):1,(1,1):4,(1,2):5,
        (2,0):2,(2,1):2,(2,2):6
    }
    total = sum(raw_demand.values())
    od_dist = {p: d/total for p,d in raw_demand.items()}

    # 2) 파라미터
    n = (3-1).bit_length()   # =2
    prob_threshold = 0.05
    max_distance = 1
    shots = 512

    # 3) QPU 한 번 샘플링
    service = QiskitRuntimeService()
    backend = service.backend(name="ibm_aachen")
    emb_qc = build_od_embedding_circuit(n, od_dist)
    raw_counts = sample_qpu_once(emb_qc, backend, shots)

    # 4) Configuration Recovery + Adaptive Sampling
    dcr = DynamicConfigRecovery(n, prob_threshold, max_distance)
    current_counts = raw_counts

    for itr in range(1, 6):
        valid_set, probs = dcr.repair_and_update(current_counts)
        print(f"[Iteration {itr}] Valid pairs: {valid_set}")

        if not probs:
            print("  더 이상 살아남은 쌍이 없습니다.")
            break

        # 5) 분포 변형: p^γ
        gamma = 1 + 0.5 * itr
        mod_dist = {p: pr**gamma for p, pr in probs.items()}
        # 정규화
        s = sum(mod_dist.values())
        current_dist = {p: v/s for p, v in mod_dist.items()}

        # 6) 로컬 샘플링
        current_counts = sample_local(current_dist, shots, n)

    print("\n▶ 최종 Valid Set:", dcr.valid_pairs)

if __name__ == "__main__":
    main()




[Iteration 1] Valid pairs: {(0, 1), (1, 2), (2, 1), (0, 0), (1, 1), (2, 0), (0, 2), (2, 2), (1, 0)}
[Iteration 2] Valid pairs: {(0, 1), (1, 2), (2, 1), (0, 0), (1, 1), (2, 0), (2, 2)}
[Iteration 3] Valid pairs: {(0, 1), (1, 2), (0, 0), (1, 1), (2, 0), (2, 2)}
[Iteration 4] Valid pairs: {(2, 2), (0, 0)}
[Iteration 5] Valid pairs: {(0, 0)}

▶ 최종 Valid Set: {(0, 0)}
