# Example usage of rank_preserving_calibration

This notebook demonstrates how to use the `admm_rank_preserving_simplex_marginals` function from the `rank_preserving_calibration` package to adjust multiclass probabilities while preserving ranking and matching target class totals.

In [7]:
import numpy as np
from rank_preserving_calibration import admm_rank_preserving_simplex_marginals

# -----------------------------
# Demo on synthetic data
# -----------------------------


def softmax(logits: np.ndarray, axis: int = 1) -> np.ndarray:
    z = logits - logits.max(axis=axis, keepdims=True)
    e = np.exp(z)
    return e / e.sum(axis=axis, keepdims=True)

np.random.seed(42)
N, J = 150, 3

# Generate base probabilities P
logits = np.random.randn(N, J) * 0.7 + np.array([0.0, 0.4, -0.2])  # induce a mild prior
P = softmax(logits, axis=1)

# Target marginals M: shift mass to class 1 while keeping total N
base = P.sum(axis=0)
target_frac = np.array([0.50, 0.30, 0.20])  # desired share across classes
M = target_frac * N
M = M.astype(float)
# Ensure sums match N
assert np.isclose(M.sum(), N)

# Run ADMM solver
Q, info = admm_rank_preserving_simplex_marginals(P, M, max_iters=3000, verbose=False)

Q, info

(array([[ 5.41877997e-01,  2.29787947e-01,  2.28337201e-01],
        [ 7.69876853e-01,  1.62522725e-01,  6.76005344e-02],
        [ 6.73100449e-01,  3.10881529e-01,  1.60181544e-02],
        [ 6.52501263e-01,  2.43732910e-01,  1.03765827e-01],
        [ 8.34798404e-01,  1.17291931e-01,  4.79096889e-02],
        [ 4.62046038e-01,  2.07144313e-01,  3.30810523e-01],
        [ 3.39957608e-01,  6.55364245e-02,  5.94505990e-01],
        [ 4.99554241e-01,  4.74359199e-01,  2.60865596e-02],
        [ 4.41301518e-01,  5.05677318e-01,  5.30211124e-02],
        [ 6.27687256e-01,  2.29787947e-01,  1.42524662e-01],
        [ 2.95272983e-01,  6.60043227e-01,  4.46830964e-02],
        [ 3.23941038e-01,  6.60041831e-01,  1.60181544e-02],
        [ 7.91728892e-01,  1.08551427e-01,  9.97184807e-02],
        [ 4.36651106e-01,  4.46418828e-01,  1.16930066e-01],
        [ 5.63427526e-01,  3.91889488e-01,  4.46830964e-02],
        [ 3.61474831e-01,  2.08273309e-01,  4.30252131e-01],
        [ 6.51083962e-01