In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import manify

In [12]:
"""Product‑Space SVM (Tabaghi et al. 2021)
------------------------------------------------
Robust one‑vs‑rest soft‑margin SVM for product manifolds.

Changes in this revision
* **Auto‑regularises** all quadratic‑form matrices to be PSD by shifting them by ‑λ_min + δ.
* Skips constraints whose matrix is (near‑)zero, preventing useless DCP checks.
* Keeps the streamlined predict path from last iteration.
"""

from __future__ import annotations

import math
from beartype.typing import Dict, List, Optional, Tuple

import cvxpy as cp
import numpy as np
import torch
from jaxtyping import Float
from sklearn.base import BaseEstimator, ClassifierMixin

# ----------------------------------------------------------------------------
#  Helper utilities
# ----------------------------------------------------------------------------

def _sym(a: np.ndarray) -> np.ndarray:
    return 0.5 * (a + a.T)


def _make_psd(a: np.ndarray, delta: float = 1e-8) -> np.ndarray:
    """Return a Hermitian PSD matrix by shifting the spectrum."""
    a_sym = _sym(a)
    w, _ = np.linalg.eigh(a_sym)
    lam_min = w.min()
    if lam_min < 0:
        a_sym += (-lam_min + delta) * np.eye(a_sym.shape[0])
    return a_sym


def _psd_split(a: np.ndarray, tol: float = 1e-8) -> Tuple[np.ndarray, np.ndarray]:
    w, V = np.linalg.eigh(_sym(a))
    w_pos = np.clip(w, tol, None)
    w_neg = np.clip(-w, tol, None)
    K_pos = _make_psd(V @ np.diag(w_pos) @ V.T, tol)
    K_neg = _make_psd(V @ np.diag(w_neg) @ V.T, tol)
    return K_pos, K_neg

# ----------------------------------------------------------------------------
#  Kernels
# ----------------------------------------------------------------------------

def _kernel_manifold(manifold, xs: torch.Tensor, xt: torch.Tensor):
    ip = manifold.inner(xs, xt) * manifold.scale
    if manifold.type == "E":
        k = ip
        norm = 1.0
    elif manifold.type == "S":
        c = manifold.curvature
        k = torch.asin(torch.clamp(ip * c, -0.999999, 0.999999)) * math.sqrt(c)
        norm = math.sqrt(c)
    elif manifold.type == "H":
        c = abs(manifold.curvature)
        r2 = torch.max(torch.diag(manifold.inner(xs, xs))).item()
        k = torch.asinh(torch.clamp(ip / (r2 + 1e-7), -0.999999, 0.999999)) * math.sqrt(c)
        norm = math.asinh(-r2 * c)
    else:
        raise ValueError(manifold.type)
    return k, norm


def product_kernel(pm, Xs: torch.Tensor, Xt: Optional[torch.Tensor] = None):
    if Xt is None:
        Xt = Xs
    Ks, norms = [], []
    for M, xs, xt in zip(pm.P, pm.factorize(Xs), pm.factorize(Xt)):
        k, n = _kernel_manifold(M, xs, xt)
        Ks.append(k)
        norms.append(n)
    return Ks, norms

# ----------------------------------------------------------------------------
#  SVM model
# ----------------------------------------------------------------------------

class ProductSpaceSVM(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        pm,
        weights: Optional[Float[torch.Tensor, "n_manifolds"]] = None,
        *,
        alpha_E: float = 1.0,
        theta_S: float = math.pi / 2,
        r_H: float = 1.0,
        epsilon_slack: float = 1e-6,
        solver: str = "ECOS",
    ) -> None:
        self.pm = pm
        self.alpha_E = alpha_E
        self.theta_S = theta_S
        self.r_H = r_H
        self.eps = epsilon_slack
        self.solver = solver
        self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
        if len(self.weights) != len(pm.P):
            raise ValueError("weights length mismatch with product factors")

    # ------------------------------------------------------------------
    #  Fit
    # ------------------------------------------------------------------

    def fit(self, X: torch.Tensor, y: torch.Tensor):
        Ks, norms = product_kernel(self.pm, X)
        K_full = torch.ones_like(Ks[0])
        for k_m, w in zip(Ks, self.weights):
            K_full += w * k_m
        K_np = K_full.cpu().numpy()

        self.classes_ = torch.unique(y).tolist()
        self.beta_: Dict[int, np.ndarray] = {}

        for c in self.classes_:
            y_bin = torch.where(y == c, 1, -1).cpu().numpy()
            Y = np.diag(y_bin)
            beta = cp.Variable(K_np.shape[0])
            zeta = cp.Variable(K_np.shape[0])
            eps_m = cp.Variable()

            cons: List[cp.Constraint] = [eps_m >= 0, zeta >= 0, Y @ (K_np @ beta) >= eps_m - zeta]

            for M, k_m, w, n in zip(self.pm.P, Ks, self.weights, norms):
                mat = _sym((w * k_m).cpu().numpy())
                if np.allclose(mat, 0, atol=1e-10):
                    continue  # nothing to constrain
                Kp, Kn = _psd_split(mat)
                qf = lambda m: cp.quad_form(beta, cp.Constant(m), assume_PSD=True)

                if M.type == "E":
                    cons.append(qf(Kp) <= self.alpha_E ** 2)
                elif M.type == "S":
                    cons.append(qf(Kp) <= self.theta_S)
                elif M.type == "H":
                    cons.append(qf(Kp) <= self.r_H + n)
                    cons.append(qf(Kn) <= self.r_H)
                else:
                    raise ValueError(M.type)

            prob = cp.Problem(cp.Minimize(-eps_m + cp.sum(zeta)), cons)
            prob.solve(solver=self.solver, eps_abs=self.eps, eps_rel=self.eps, verbose=False)
            if prob.status not in (cp.OPTIMAL, cp.OPTIMAL_INACCURATE):
                raise RuntimeError(f"Solver failed for class {c}: {prob.status}")
            self.beta_[c] = beta.value.astype(float)

        self.X_train_ = X.clone()
        return self

    # ------------------------------------------------------------------
    #  Inference
    # ------------------------------------------------------------------

    def _kernel_test(self, X: torch.Tensor) -> np.ndarray:
        Ks_test, _ = product_kernel(self.pm, self.X_train_, X)
        K_test = torch.ones_like(Ks_test[0])
        for k_m, w in zip(Ks_test, self.weights):
            K_test += w * k_m
        return K_test.cpu().numpy()

    def decision_function(self, X: torch.Tensor):
        K_test = self._kernel_test(X)
        return np.column_stack([K_test.T @ self.beta_[c] for c in self.classes_])

    def predict(self, X: torch.Tensor):
        return np.take(self.classes_, np.argmax(self.decision_function(X), axis=1))

    def predict_proba(self, X: torch.Tensor):
        scores = self.decision_function(X)
        exp_s = np.exp(scores - scores.max(axis=1, keepdims=True))
        return exp_s / exp_s.sum(axis=1, keepdims=True)


In [13]:
pm = manify.manifolds.ProductManifold(signature=[(-1, 2), (0, 2), (1, 2)])
X, y = pm.gaussian_mixture(100)

svm = manify.predictors.svm.ProductSpaceSVM(pm=pm)
svm.fit(X, y)

DCPError: Problem does not follow DCP rules. Specifically:
The following constraints are not DCP:
QuadForm(var170, [[-0.00 -0.00 ... -0.00 -0.00]
 [-0.00 -0.00 ... -0.00 -0.00]
 ...
 [-0.00 -0.00 ... -0.00 -0.00]
 [-0.00 -0.00 ... -0.00 -0.00]]) <= 1e-05 , because the following subexpressions are not:
|--  QuadForm(var170, [[-0.00 -0.00 ... -0.00 -0.00]
 [-0.00 -0.00 ... -0.00 -0.00]
 ...
 [-0.00 -0.00 ... -0.00 -0.00]
 [-0.00 -0.00 ... -0.00 -0.00]])
QuadForm(var170, [[0.11 0.45 ... -0.45 0.60]
 [0.45 1.87 ... -1.71 2.52]
 ...
 [-0.45 -1.71 ... 3.22 -2.07]
 [0.60 2.52 ... -2.07 3.44]]) <= 1.0 , because the following subexpressions are not:
|--  QuadForm(var170, [[0.11 0.45 ... -0.45 0.60]
 [0.45 1.87 ... -1.71 2.52]
 ...
 [-0.45 -1.71 ... 3.22 -2.07]
 [0.60 2.52 ... -2.07 3.44]])