In [19]:
import numpy as np

# -------------------------
# Block encoding (Halmos dilation)
# -------------------------

def qsp_U00_batch(xs, phases):
    W_list = [W_signal(x) for x in xs]
    res = []
    for W in W_list:
        U = Rz(phases[0])
        for phi in phases[1:]:
            U = W @ U
            U = Rz(phi) @ U
        res.append(U[0, 0])
    return np.array(res)

def psd_sqrt(M):
    M = np.array(M, dtype=np.complex128)
    M = (M + M.conj().T) / 2
    w, V = np.linalg.eigh(M)
    w = np.clip(w, 0.0, None)
    return V @ np.diag(np.sqrt(w)) @ V.conj().T

def block_encode_halmos(A, alpha=None):
    """
    Returns (alpha, U) where U is 2n x 2n unitary and top-left block is A/alpha.
    A need not be Hermitian/unitary.
    """
    A = np.array(A, dtype=np.complex128)
    n, m = A.shape
    if n != m:
        raise ValueError("A must be square for this simple Halmos block-encoding.")
    if alpha is None:
        alpha = float(np.linalg.norm(A, 2))
        if alpha == 0:
            alpha = 1.0

    Atil = A / alpha
    I = np.eye(n, dtype=np.complex128)

    B = psd_sqrt(I - Atil @ Atil.conj().T)
    C = psd_sqrt(I - Atil.conj().T @ Atil)

    U = np.block([[Atil, B],
                  [C,    -Atil.conj().T]])
    return alpha, U


# -------------------------
# QSP core (2x2 signal)
# -------------------------

def W_signal(x):
    """2x2 signal unitary for QSP; x in [-1,1]."""
    x = float(x)
    if abs(x) > 1 + 1e-12:
        raise ValueError("x must be in [-1,1]")
    s = np.sqrt(max(0.0, 1.0 - x*x))
    return np.array([[x, 1j*s],
                     [1j*s, x]], dtype=np.complex128)

def Rz(phi):
    return np.diag([np.exp(1j*phi), np.exp(-1j*phi)]).astype(np.complex128)

def qsp_U00(x, phases):
    """
    U(x) = Rz(phi0) W(x) Rz(phi1) W(x) ... W(x) Rz(phid)
    return (0,0) entry, which is the QSP polynomial value (complex in general).
    """
    U = Rz(phases[0])
    W = W_signal(x)
    for phi in phases[1:]:
        U = W @ U
        U = Rz(phi) @ U
    return U[0, 0]


# -------------------------
# Target function: truncated inverse on [-1,1]
# -------------------------

def truncated_inverse_target(x, kappa, x0=None, x1=None):
    """
    Odd extension target for QSP:
      - for |x| >= x1: g(x) = 1/(kappa*x)
      - for |x| <= x0: g(x) = 0
      - smooth cubic transition in between (C1-ish, bounded)

    Default: x1 = 1/kappa, x0 = 0.5/kappa
    (so we don't force behavior extremely near 0 where inverse is nasty)
    """
    x = float(x)
    sgn = 1.0 if x >= 0 else -1.0
    ax = abs(x)
    if x1 is None: x1 = 1.0 / kappa
    if x0 is None: x0 = 0.5 / kappa
    if x0 <= 0 or x1 <= x0:
        raise ValueError("Need 0 < x0 < x1.")
    if ax <= x0:
        return 0.0
    if ax >= x1:
        return sgn * (1.0 / (kappa * ax))

    # Smoothstep t in [0,1]
    t = (ax - x0) / (x1 - x0)
    smooth = t*t*(3 - 2*t)  # cubic smoothstep

    # Blend from 0 to 1/(kappa*ax); keep bounded by <= 1
    return sgn * smooth * (1.0 / (kappa * ax))

def clipped_inverse_target(x, kappa, x_min=None):
    """
    일부 방향을 0으로 버리지 않고, 작은 x에서는 inverse 값을 saturation 시켜서 사용.
    g(x) ~ 1/(kappa * x), 단 |g(x)| <= 1 이 되도록 잘라줌.
    """
    x = float(x)
    if x == 0.0:
        # 정확히 0은 피하기
        x = 1e-12
    sgn = 1.0 if x >= 0 else -1.0
    ax = abs(x)

    # 역수를 정의할 때 사용할 최소 x 크기
    if x_min is None:
        # 예: 이건 설계 선택 (추측입니다)
        x_min = 1.0 / kappa   # 혹은 더 작게 잡아도 됨

    # 너무 작으면 x_min으로 클리핑
    ax_eff = max(ax, x_min)

    val = 1.0 / (kappa * ax_eff)   # inverse 값
    # QSP용으로 |g(x)| <= 1 유지하려면 1로 saturation
    if val > 1.0:
        val = 1.0

    return sgn * val


# -------------------------
# Phase fitter (robust-ish)
# -------------------------

class QSPPhaseFitter:
    """
    Fits phases to approximate a target function g(x) on [-1,1] (or subset),
    with penalties to enforce |p(x)| <= 1 (required for QSP feasibility).
    """

    def __init__(self, degree, kappa, seed=0):
        self.degree = int(degree)
        self.kappa = float(kappa)
        self.rng = np.random.default_rng(seed)

    def _make_grids(self, n_cheb=128, n_dense=1024, focus_min=None):
        # Chebyshev nodes on [-1,1]
        j = np.arange(n_cheb)
        x_cheb = np.cos((2*j + 1) * np.pi / (2*n_cheb))

        # Dense uniform grid for safety checks
        x_dense = np.linspace(-1.0, 1.0, n_dense)

        if focus_min is None:
            # Optionally focus loss on |x| >= focus_min (e.g. 1/kappa zone)
            focus_min = 1.0 / self.kappa
        x_cheb = x_cheb[np.abs(x_cheb) >= focus_min]
        x_dense = x_dense[np.abs(x_dense) >= focus_min]

        return x_cheb, x_dense

    def fit(self,
            target_fn,
            n_cheb=16,
            n_dense=16,
            focus_min=None,
            iters=10,
            lr=0.05,
            penalty_w=50.0,
            bound_eps=1e-6,
            multi_start=6):
        """
        Returns best_phases, info dict.
        - penalty enforces |p(x)|<=1 on dense grid (soft).
        - multi_start tries multiple random initializations and picks the best verified.
        """
        d = self.degree
        best = None
        best_info = None

        x_cheb, x_dense = self._make_grids(n_cheb=n_cheb, n_dense=n_dense, focus_min=focus_min)
        y_cheb = np.array([target_fn(x) for x in x_cheb], dtype=np.float64)
        y_dense = np.array([target_fn(x) for x in x_dense], dtype=np.float64)

        def loss(phases):
            # Primary fit on Chebyshev nodes
            vals = qsp_U00_batch(x_cheb, phases)
            # We want real target; take real part but also penalize imag leakage
            diff_re = (vals.real - y_cheb)
            diff_im = vals.imag
            mse = np.mean(diff_re*diff_re) + 0.2*np.mean(diff_im*diff_im)

            # Bound penalty on dense grid
            vd = qsp_U00_batch(x_dense, phases)
            abs_v = np.abs(vd)
            viol = np.clip(abs_v - (1.0 + bound_eps), 0.0, None)
            pen = penalty_w * np.mean(viol*viol)

            return mse + pen

        def finite_diff_grad(phases, eps=1e-4):
            g = np.zeros_like(phases)
            base = loss(phases)
            for i in range(len(phases)):
                p1 = phases.copy(); p1[i] += eps
                p2 = phases.copy(); p2[i] -= eps
                g[i] = (loss(p1) - loss(p2)) / (2*eps)
            return base, g

        def optimize_from(init):
            phases = init.copy()
            # simple Adam + backtracking-ish
            m = np.zeros_like(phases)
            v = np.zeros_like(phases)
            b1, b2 = 0.9, 0.999
            eps_adam = 1e-8

            cur = loss(phases)
            for t in range(1, iters+1):
                cur, g = finite_diff_grad(phases)
                m = b1*m + (1-b1)*g
                v = b2*v + (1-b2)*(g*g)
                mhat = m / (1 - b1**t)
                vhat = v / (1 - b2**t)
                step = lr * mhat / (np.sqrt(vhat) + eps_adam)

                # backtracking to avoid explosions
                new_phases = phases - step
                new_loss = loss(new_phases)
                if new_loss <= cur:
                    phases = new_phases
                    cur = new_loss
                else:
                    # shrink step
                    phases = phases - 0.3*step
                    cur = loss(phases)

                # keep phases in a reasonable range (wrap)
                phases = (phases + np.pi) % (2*np.pi) - np.pi

            return phases, cur

        def verify(phases):
            # Check max |p(x)| on full dense [-1,1]
            xs = np.linspace(-1.0, 1.0, 2048)
            vals = np.array([qsp_U00(x, phases) for x in xs])
            max_abs = float(np.max(np.abs(vals)))
            # Error on inverse region |x| >= 1/kappa
            mask = np.abs(xs) >= (1.0 / self.kappa)
            err = float(np.max(np.abs(vals.real[mask] - np.array([target_fn(x) for x in xs[mask]]))))
            imag_max = float(np.max(np.abs(vals.imag)))
            return max_abs, err, imag_max

        for s in range(multi_start):
            init = self.rng.normal(scale=0.2, size=d+1)
            phases, L = optimize_from(init)
            max_abs, err, imag_max = verify(phases)

            # Prefer feasible-ish and low error
            score = L + 10.0*max(0.0, max_abs-1.0) + err + 0.1*imag_max
            info = {"loss": L, "max_abs": max_abs, "max_err_on_[1/k,1]": err, "max_imag": imag_max}

            if best is None or score < best_info["score"]:
                best = phases
                best_info = {"score": score, **info}

        return best, best_info


# -------------------------
# Reliable-ish solver: build block encoding + fit QSP phases + apply via SVD (trustworthy SVT action)
# -------------------------

class ReliableQSVTInverseApprox:
    """
    - Halmos block encoding으로 U_block (2n x 2n) 생성
    - QSP로 inverse용 다항식 p(x) 피팅
    - p(x)를 block-encoded A에 QSVT 방식으로 적용해서
      top-left 블록으로 A^{-1} 근사를 만드는 시뮬레이터
    """

    def __init__(self, A, kappa, degree=9, seed=0):
        self.A = np.array(A, dtype=np.complex128)
        n, m = self.A.shape
        if n != m:
            raise ValueError("This minimal solver assumes square A.")
        self.n = n
        self.kappa = float(kappa)
        if self.kappa < 1:
            raise ValueError("kappa must be >= 1")

        # Block encoding: U_block은 2n x 2n 유니터리, top-left가 A/alpha
        self.alpha, self.U_block = block_encode_halmos(self.A)

        # QSP phase fitter (스칼라용)
        self.fitter = QSPPhaseFitter(degree=degree, kappa=self.kappa, seed=seed)

        # target 함수: clipped inverse
        self.target_fn = lambda x: clipped_inverse_target(x, self.kappa)

        self.phases = None
        self.fit_info = None

        # 행렬 레벨에서 쓸 Z, W_A 준비용 캐시
        self._Z_big = None
        self._W_A = None

    # -------------------------
    # Phase fitting (스칼라 QSP)
    # -------------------------
    def fit_phases(self,
                   n_cheb=64,
                   n_dense=128,
                   iters=600,
                   lr=0.05,
                   penalty_w=80.0,
                   multi_start=8):
        phases, info = self.fitter.fit(
            target_fn=self.target_fn,
            n_cheb=n_cheb,
            n_dense=n_dense,
            focus_min=None,
            iters=iters,
            lr=lr,
            penalty_w=penalty_w,
            multi_start=multi_start
        )
        self.phases = phases
        self.fit_info = info
        return phases, info

    def p(self, x):
        """QSP polynomial value (scalar)."""
        if self.phases is None:
            raise RuntimeError("Call fit_phases() first.")
        return qsp_U00(x, self.phases)

    # -------------------------
    # 행렬 레벨 QSVT용 빌더들
    # -------------------------
    def _build_Z_big(self):
        """
        Z_big = 2P - I, where P = diag(I_n, 0_n).
        즉, diag(I_n, -I_n) (2n x 2n).
        """
        if self._Z_big is not None:
            return self._Z_big
        n = self.n
        Z = np.zeros((2*n, 2*n), dtype=np.complex128)
        Z[:n, :n] = np.eye(n, dtype=np.complex128)
        Z[n:, n:] = -np.eye(n, dtype=np.complex128)
        self._Z_big = Z
        return Z

    def _build_W_A(self):
        """
        행렬 레벨 signal unitary:
            W_A = Z_big @ U_block
        QSVT 이론에서 (2P - I) U 가 scalar QSP의 W(x)에 해당.
        """
        if self._W_A is not None:
            return self._W_A
        Z = self._build_Z_big()
        self._W_A = Z @ self.U_block
        return self._W_A

    def _R_big(self, phi):
        """
        R_Z(phi) = exp(i * phi * Z_big)
                  = blockdiag(e^{i phi} I_n, e^{-i phi} I_n)
        """
        n = self.n
        R = np.zeros((2*n, 2*n), dtype=np.complex128)
        e_p = np.exp(1j * phi)
        e_m = np.exp(-1j * phi)
        R[:n, :n] = e_p * np.eye(n, dtype=np.complex128)
        R[n:, n:] = e_m * np.eye(n, dtype=np.complex128)
        return R

    def _build_qsvt_unitary(self):
        """
        스칼라 QSP 구조를 그대로 2n x 2n에 올린 QSVT 유니터리:
            U = R(phi_0)
            for phi in phases[1:]:
                U = W_A @ U
                U = R(phi) @ U
        """
        if self.phases is None:
            raise RuntimeError("Call fit_phases() first.")
        phases = self.phases
        W_A = self._build_W_A()

        # 초기 R(phi_0)
        U = self._R_big(phases[0])

        # 나머지 phase들 적용
        for phi in phases[1:]:
            U = W_A @ U
            U = self._R_big(phi) @ U

        return U

    # -------------------------
    # QSVT로 A^{-1} 근사 적용
    # -------------------------
    def apply_inverse_approx(self, b):
        """
        QSVT 유니터리에서 top-left 블록 Q를 꺼내서
        A^{-1} b ≈ (kappa/alpha) * Q @ b 로 근사.

        (주의) 완전한 QSVT 이론 기준으로는 부호/phase가 약간 다를 수 있지만
        우리가 p(x)를 1/(kappa x)로 피팅했기 때문에
        (kappa/alpha) * Q 가 대략 A^{-1} 형태가 됨.
        """
        if self.phases is None:
            raise RuntimeError("Call fit_phases() first.")

        b = np.array(b, dtype=np.complex128).reshape(-1)
        if b.shape[0] != self.n:
            raise ValueError("b has wrong dimension")

        U_qsvt = self._build_qsvt_unitary()

        # top-left n x n 블록이 p(A/alpha) 역할
        n = self.n
        Q = U_qsvt[:n, :n]

        # p(x) ≈ 1/(kappa x) 이므로 (kappa/alpha) * p(A/alpha) ≈ A^{-1}
        scale = self.kappa / self.alpha
        return scale * (Q @ b)

    # (선택 사항) SVD 기준 정답과 비교해보는 진단용 메서드
    def apply_inverse_approx_via_svd(self, b):
        """
        예전 방식: SVD로 직접 singular value transform.
        디버깅/비교용.
        """
        if self.phases is None:
            raise RuntimeError("Call fit_phases() first.")

        b = np.array(b, dtype=np.complex128).reshape(-1)
        if b.shape[0] != self.n:
            raise ValueError("b has wrong dimension")

        U, s, Vh = np.linalg.svd(self.A, full_matrices=False)
        x = s / self.alpha
        px = np.array([self.p(xi).real for xi in x], dtype=np.float64)
        inv_diag = (self.kappa / self.alpha) * px

        return Vh.conj().T @ (inv_diag * (U.conj().T @ b))

    def diagnostics(self):
        return {
            "alpha": self.alpha,
            "fit_info": self.fit_info,
        }




In [None]:
A = np.array([[1,0,3],
              [0,1,0],
              [3,0,1]], dtype=float)

condA = np.linalg.cond(A, 2)
solver = ReliableQSVTInverseApprox(A, kappa=condA*1.1, degree=15, seed=0)

phases, info = solver.fit_phases(n_cheb=32, n_dense=32,
                                 iters=100, lr=0.05,
                                 penalty_w=60.0, multi_start=3)

b = np.random.randn(3)

x_hat_qsvt = solver.apply_inverse_approx(b)           # block encoding + QSVT 버전
x_hat_svd  = solver.apply_inverse_approx_via_svd(b)   # SVD 기반 SVT (비교용)
x_true     = np.linalg.solve(A, b)

print(x_hat_qsvt)
print(x_true)
print("||x_true - x_hat_qsvt|| =", np.linalg.norm(x_true - x_hat_qsvt))
print("||x_true - x_hat_svd || =", np.linalg.norm(x_true - x_hat_svd))


