In [8]:
import numpy as np
import itertools
from scipy.optimize import linprog
from functools import lru_cache



# 误差计算函数（显式传入 target，避免引用全局变量）
def error(distribution, target):
    return np.linalg.norm(distribution - target)

# ========== 方法1：暴力枚举 ==========
def brute_force(nodes, target, S):
    best_set, best_err = None, float("inf")
    for comb in itertools.combinations(range(len(nodes)), S):
        avg_dist = nodes[list(comb)].mean(axis=0)
        err = error(avg_dist, target)
        if err < best_err:
            best_err = err
            best_set = comb
    return best_set, best_err

# ========== 方法2：贪心选择 ==========
def greedy(nodes, target, S):
    chosen = []
    remaining = list(range(len(nodes)))
    current_sum = np.zeros(nodes.shape[1])
    for _ in range(S):
        best_node, best_err = None, float("inf")
        for j in remaining:
            avg = (current_sum + nodes[j]) / (len(chosen) + 1)
            err = error(avg, target)
            if err < best_err:
                best_err = err
                best_node = j
        chosen.append(best_node)
        remaining.remove(best_node)
        current_sum += nodes[best_node]
    return chosen, error(current_sum / S, target)

# ========== 方法3：动态规划 ==========
def dynamic_programming(nodes, target, S):
    N, F = nodes.shape

    @lru_cache(None)
    def dp(i, s, mask):
        if s == 0:
            return np.zeros(F), 0
        if i == N:
            return None
        best_err = float("inf")
        best_sol = None

        # 不选第i个
        res = dp(i + 1, s, mask)
        if res is not None:
            vec, _ = res
            err = error(vec / s, target) if s > 0 else float("inf")
            if err < best_err:
                best_err, best_sol = err, (vec, mask)

        # 选第i个
        if s > 0:
            res = dp(i + 1, s - 1, mask | (1 << i))
            if res is not None:
                vec, _ = res
                vec = vec + nodes[i]
                err = error(vec / s, target)
                if err < best_err:
                    best_err, best_sol = err, (vec, mask | (1 << i))

        return best_sol

    sol = dp(0, S, 0)
    if sol is None:
        return [], float("inf")
    vec, mask = sol
    chosen = [i for i in range(N) if mask & (1 << i)]
    return chosen, error(vec / S, target)

# 新单元格：使用 PuLP 的 0/1 整数规划（最小化 L1 偏差）
def integer_programming(nodes, target, S, silent=True):
    try:
        import pulp
    except ImportError as e:
        raise ImportError("未找到 PuLP。请安装：pip install pulp 或 conda install -c conda-forge pulp") from e

    N, F = nodes.shape
    if not (0 < S <= N):
        return [], float("inf")

    prob = pulp.LpProblem("select_S_min_L1", pulp.LpMinimize)

    # x_i ∈ {0,1} 表示是否选择第 i 个节点
    x = pulp.LpVariable.dicts("x", range(N), lowBound=0, upBound=1, cat="Binary")
    # v_f ≥ 0 为 S * |(1/S)∑ x_i*nodes[i,f] - target[f]| 的线性化变量
    v = pulp.LpVariable.dicts("v", range(F), lowBound=0, cat="Continuous")

    # 目标：min ∑_f v_f （等价于 min ∑_f S*|avg_f - target[f]|，常数因子 S 可忽略）
    prob += pulp.lpSum(v[f] for f in range(F))

    # 选择恰好 S 个
    prob += pulp.lpSum(x[i] for i in range(N)) == S

    # 线性化 L1 约束：令 y_f = ∑ x_i * nodes[i,f]，则
    # y_f - S*target[f] ≤ v_f
    # -(y_f - S*target[f]) ≤ v_f
    for f in range(F):
        y_f = pulp.lpSum(nodes[i, f] * x[i] for i in range(N))
        prob += y_f - S * float(target[f]) <= v[f]
        prob += -(y_f - S * float(target[f])) <= v[f]

    solver = pulp.PULP_CBC_CMD(msg=(not silent))
    status = prob.solve(solver)

    if pulp.LpStatus[status] != "Optimal":
        # 回退：贪心
        return greedy(nodes, target, S)

    chosen = [i for i in range(N) if pulp.value(x[i]) >= 0.5]
    avg = nodes[chosen].mean(axis=0) if len(chosen) > 0 else np.zeros(F)
    return chosen, error(avg, target)

In [34]:
# ========== 数据构造 ==========
def generate_data(N, F, alpha=0.1):
    return np.random.dirichlet([alpha] * F, size=N)

N, F, S = 30, 5, 10  # 小规模方便暴力枚举 & 动规运行
alpha = 0.1
np.random.seed(42)

nodes = generate_data(N, F, alpha)
target = np.random.dirichlet([1.0] * F)
# ========== 测试 ==========
print("目标分布:", target)

目标分布: [0.61296198 0.12209785 0.2234295  0.03158455 0.00992612]


In [14]:
# 1. 暴力枚举（规模过大时跳过）
import math
comb_cnt = math.comb(N, S)
if comb_cnt > 2_000_000:
    print(f"\n方法1: 暴力枚举\n组合数 {comb_cnt:,} 过大，已跳过")
else:
    chosen, err = brute_force(nodes, target, S)
    print("\n方法1: 暴力枚举")
    print("误差:", err)


方法1: 暴力枚举
组合数 30,045,015 过大，已跳过


In [15]:
chosen, err = brute_force(nodes, target, S)
print("\n方法1: 暴力枚举")
print("误差:", err)


方法1: 暴力枚举
误差: 0.13217582792730395


In [35]:
# 2. 贪心
chosen, err = greedy(nodes, target, S)
print("\n方法2: 贪心选择")
print("误差:", err)


方法2: 贪心选择
误差: 0.13238864143049828


In [None]:
# 3. 动态规划
chosen, err = dynamic_programming(nodes, target, S)
print("\n方法3: 动态规划")
print("误差:", err)


方法3: 动态规划
误差: 0.14660495376944568


In [36]:
# 4. 整数规划
chosen, err = integer_programming(nodes, target, S)
print("\n方法4b: 整数规划")
print("误差:", err)


方法4b: 整数规划
误差: 0.1324293348366095


In [37]:
# 新单元格：四种算法的时间复杂度基准测试
import time, math
import numpy as np

def _median(xs):
    xs = sorted(xs)
    return xs[len(xs)//2]

def _bench_once(fn, nodes, target, S, repeat=3):
    # 预热
    fn(nodes, target, S)
    ts = []
    for _ in range(repeat):
        t0 = time.perf_counter()
        fn(nodes, target, S)
        ts.append(time.perf_counter() - t0)
    return _median(ts)

def _safe_comb(n, k):
    try:
        return math.comb(n, k)
    except Exception:
        # 粗略上界
        return float('inf')

def _fit_loglog(xs, ts):
    xs_fit, ts_fit = [], []
    for x, t in zip(xs, ts):
        if t is not None and t > 0 and x > 0:
            xs_fit.append(x); ts_fit.append(t)
    if len(xs_fit) >= 2:
        coef = np.polyfit(np.log(xs_fit), np.log(ts_fit), 1)
        slope = coef[0]
        return slope
    return None

def _print_table(header, rows):
    colw = [max(len(str(h)), max(len(str(r[i])) for r in rows)) for i, h in enumerate(header)]
    fmt = "  ".join("{:<" + str(w) + "}" for w in colw)
    print(fmt.format(*header))
    for r in rows:
        print(fmt.format(*r))

def bench_scale_N(F=5, S=10, N_list=(10,12,14,16,18,20,22,24,26,28,30), seed=42, repeat=3):
    print("\n=== 随 N 变化（固定 F={}, S={}）===".format(F, S))
    header = ["N", "brute_force(ms)", "greedy(ms)", "dp(ms)", "ip(ms)"]
    rows = []
    tN_bf, tN_gr, tN_dp, tN_ip = [], [], [], []
    for N in N_list:
        if S > N:
            continue
        np.random.seed(seed)
        nodes = generate_data(N, F, alpha=0.1)
        target = np.random.dirichlet([1.0] * F)

        # 暴力：组合数过大时跳过
        comb_cnt = _safe_comb(N, S)
        if comb_cnt <= 2_000_000:
            try:
                t = _bench_once(brute_force, nodes, target, S, repeat)
                bf_ms = f"{t*1000:.2f}"
                tN_bf.append((N, t))
            except Exception as e:
                bf_ms = "err"
        else:
            bf_ms = "skip"
        # 贪心
        try:
            t = _bench_once(greedy, nodes, target, S, repeat)
            gr_ms = f"{t*1000:.2f}"; tN_gr.append((N, t))
        except Exception as e:
            gr_ms = "err"
        # 动规：N 大时指数膨胀，设置保守上限
        if N <= 22:
            try:
                t = _bench_once(dynamic_programming, nodes, target, S, max(1, repeat//2))
                dp_ms = f"{t*1000:.2f}"; tN_dp.append((N, t))
            except Exception as e:
                dp_ms = "err"
        else:
            dp_ms = "skip"
        # 整数规划（PuLP）：规模适中一般很快
        try:
            t = _bench_once(integer_programming, nodes, target, S, repeat)
            ip_ms = f"{t*1000:.2f}"; tN_ip.append((N, t))
        except Exception as e:
            ip_ms = "err"
        rows.append([N, bf_ms, gr_ms, dp_ms, ip_ms])

    _print_table(header, rows)

    # 估计斜率（log-log）
    for name, data in [("greedy vs N", tN_gr), ("ip vs N", tN_ip)]:
        xs, ts = zip(*data) if data else ([], [])
        slope = _fit_loglog(xs, ts) if xs else None
        if slope is not None:
            print(f"估计增长阶 {name}: ~ N^{slope:.2f}")
        else:
            print(f"估计增长阶 {name}: 数据不足")

def bench_scale_S(N=30, F=5, S_list=(2,4,6,8,10,12), seed=42, repeat=3):
    print("\n=== 随 S 变化（固定 N={}, F={}）===".format(N, F))
    header = ["S", "brute_force(ms)", "greedy(ms)", "dp(ms)", "ip(ms)"]
    rows = []
    tS_gr, tS_ip = [], []
    for S in S_list:
        if S <= 0 or S > N:
            continue
        np.random.seed(seed)
        nodes = generate_data(N, F, alpha=0.1)
        target = np.random.dirichlet([1.0] * F)

        comb_cnt = _safe_comb(N, S)
        if comb_cnt <= 2_000_000:
            try:
                t = _bench_once(brute_force, nodes, target, S, repeat)
                bf_ms = f"{t*1000:.2f}"
            except Exception:
                bf_ms = "err"
        else:
            bf_ms = "skip"
        # 贪心
        try:
            t = _bench_once(greedy, nodes, target, S, repeat)
            gr_ms = f"{t*1000:.2f}"; tS_gr.append((S, t))
        except Exception:
            gr_ms = "err"
        # 动规：N 固定，S 增大时也会加重；仍限制为 N<=22 才跑
        if N <= 22:
            try:
                t = _bench_once(dynamic_programming, nodes, target, S, max(1, repeat//2))
                dp_ms = f"{t*1000:.2f}"
            except Exception:
                dp_ms = "err"
        else:
            dp_ms = "skip"
        # 整数规划
        try:
            t = _bench_once(integer_programming, nodes, target, S, repeat)
            ip_ms = f"{t*1000:.2f}"; tS_ip.append((S, t))
        except Exception:
            ip_ms = "err"

        rows.append([S, bf_ms, gr_ms, dp_ms, ip_ms])

    _print_table(header, rows)
    for name, data in [("greedy vs S", tS_gr), ("ip vs S", tS_ip)]:
        xs, ts = zip(*data) if data else ([], [])
        slope = _fit_loglog(xs, ts) if xs else None
        if slope is not None:
            print(f"估计增长阶 {name}: ~ S^{slope:.2f}")
        else:
            print(f"估计增长阶 {name}: 数据不足")

def bench_scale_F(N=30, S=10, F_list=(3,4,5,6,8,10), seed=42, repeat=3):
    print("\n=== 随 F 变化（固定 N={}, S={}）===".format(N, S))
    header = ["F", "brute_force(ms)", "greedy(ms)", "dp(ms)", "ip(ms)"]
    rows = []
    tF_gr, tF_ip = [], []
    for F in F_list:
        np.random.seed(seed)
        nodes = generate_data(N, F, alpha=0.1)
        target = np.random.dirichlet([1.0] * F)

        # 暴力与 DP 的主要爆炸在 N/S 维度，F 只影响常数因子；这里直接跳过暴力避免重复开销
        bf_ms = "skip"

        # 贪心
        try:
            t = _bench_once(greedy, nodes, target, S, repeat)
            gr_ms = f"{t*1000:.2f}"; tF_gr.append((F, t))
        except Exception:
            gr_ms = "err"

        # 动规：仅在 N 较小（≤22）时测
        if N <= 22:
            try:
                t = _bench_once(dynamic_programming, nodes, target, S, max(1, repeat//2))
                dp_ms = f"{t*1000:.2f}"
            except Exception:
                dp_ms = "err"
        else:
            dp_ms = "skip"

        # 整数规划
        try:
            t = _bench_once(integer_programming, nodes, target, S, repeat)
            ip_ms = f"{t*1000:.2f}"; tF_ip.append((F, t))
        except Exception:
            ip_ms = "err"

        rows.append([F, bf_ms, gr_ms, dp_ms, ip_ms])

    _print_table(header, rows)
    for name, data in [("greedy vs F", tF_gr), ("ip vs F", tF_ip)]:
        xs, ts = zip(*data) if data else ([], [])
        slope = _fit_loglog(xs, ts) if xs else None
        if slope is not None:
            print(f"估计增长阶 {name}: ~ F^{slope:.2f}")
        else:
            print(f"估计增长阶 {name}: 数据不足")

# 运行三组基准
bench_scale_N(F=5, S=10, N_list=(10,12,14,16,18,20,22,24,26,28,30), repeat=3)
bench_scale_S(N=30, F=5, S_list=(2,4,6,8,10,12), repeat=3)
bench_scale_F(N=30, S=10, F_list=(3,4,5,6,8,10), repeat=3)


=== 随 N 变化（固定 F=5, S=10）===
N   brute_force(ms)  greedy(ms)  dp(ms)    ip(ms)
10  0.04             0.32        1.12      182.23
12  1.17             0.44        8.19      50.29 
14  17.64            0.55        40.40     65.20 
16  233.07           1.35        347.37    51.59 
18  782.39           0.76        1099.28   57.16 
20  3179.62          0.81        4066.21   86.57 
22  11244.01         0.97        13508.41  422.15
24  33704.89         1.10        skip      343.67
26  skip             1.20        skip      225.93
28  skip             1.33        skip      117.23
30  skip             1.55        skip      123.64
估计增长阶 greedy vs N: ~ N^1.27
估计增长阶 ip vs N: ~ N^0.87

=== 随 S 变化（固定 N=30, F=5）===
S   brute_force(ms)  greedy(ms)  dp(ms)  ip(ms)
2   7.28             0.34        skip    120.06
4   453.57           0.64        skip    116.01
6   9991.59          0.94        skip    115.46
8   skip             1.21        skip    123.81
10  skip             1.44        skip    122.82
12