In [22]:
import numpy as np
from itertools import combinations

# ==== 你的混淆矩阵 ====

clean_matrix = np.array([
    [793, 5, 23, 4, 18, 8, 18, 11, 97, 23],
    [4, 904, 4, 5, 1, 5, 5, 1, 21, 50],
    [55, 5, 550, 64, 81, 70, 115, 28, 17, 15],
    [26, 11, 72, 365, 48, 234, 154, 42, 18, 30],
    [16, 1, 80, 38, 568, 48, 133, 90, 18, 8],
    [14, 3, 52, 86, 47, 662, 65, 53, 9, 9],
    [10, 5, 34, 31, 28, 28, 842, 5, 10, 7],
    [18, 0, 20, 23, 42, 47, 19, 809, 8, 14],
    [37, 14, 5, 5, 5, 4, 5, 2, 900, 23],
    [20, 63, 6, 10, 7, 5, 7, 13, 33, 836]
])

adv_matrix = np.array([
    [492, 39, 72, 12, 36, 18, 30, 32, 214, 55],
    [21, 611, 16, 19, 8, 21, 23, 14, 87, 180],
    [92, 22, 217, 65, 116, 117, 204, 87, 38, 42],
    [50, 36, 92, 96, 83, 206, 190, 120, 54, 73],
    [42, 21, 127, 41, 145, 99, 288, 161, 39, 37],
    [31, 19, 95, 104, 50, 367, 137, 108, 49, 40],
    [26, 25, 92, 59, 134, 83, 479, 56, 16, 30],
    [35, 11, 36, 48, 91, 108, 84, 500, 27, 60],
    [165, 58, 24, 17, 19, 22, 22, 10, 587, 76],
    [50, 189, 20, 18, 10, 21, 37, 56, 105, 494]
])

def compute_group_metrics(matrix, group1):
    group1 = set(group1)
    group2 = set(range(10)) - group1

    acc = []
    for i in range(10):
        total = matrix[i].sum()
        correct = matrix[i, i]
        acc.append(correct / total if total > 0 else 0)

    # root coarse binary acc
    pred_labels = []
    true_labels = []
    for i in range(10):
        for j in range(10):
            count = matrix[i][j]
            if count == 0:
                continue
            true_group = 0 if i in group1 else 1
            pred_group = 0 if j in group1 else 1
            true_labels += [true_group] * count
            pred_labels += [pred_group] * count
    root_binary_acc = np.mean(np.array(true_labels) == np.array(pred_labels))

    # average group acc
    acc_group1 = np.mean([acc[i] for i in group1])
    acc_group2 = np.mean([acc[i] for i in group2])

    # 组合权重：根据组的比例调整
    w1 = len(group1) / 10
    w2 = 1 - w1

    score = root_binary_acc * (w1 * acc_group1 + w2 * acc_group2)
    return score, root_binary_acc, acc_group1, acc_group2, w1, w2

results_clean = []
results_adv = []

# 枚举所有非空子集（1~9个类作为 group1）
for k in range(3, 7):
    for group1 in combinations(range(10), k):
        score_c, root_c, acc4_c, acc6_c, w1, w2 = compute_group_metrics(clean_matrix, group1)
        score_a, root_a, acc4_a, acc6_a, _, _ = compute_group_metrics(adv_matrix, group1)
        results_clean.append((group1, score_c, root_c, acc4_c, acc6_c, w1, w2))
        results_adv.append((group1, score_a, root_a, acc4_a, acc6_a, w1, w2))

# Clean 最佳组合
best_clean = max(results_clean, key=lambda x: x[1])
best_adv = max(results_adv, key=lambda x: x[1])

print("==== Best Group (Clean) ====")
print(f"group1: {best_clean[0]} ({len(best_clean[0])}+{10-len(best_clean[0])})")
print(f"score: {best_clean[1]:.4f} | root_acc: {best_clean[2]:.4f} | acc1: {best_clean[3]:.4f} | acc2: {best_clean[4]:.4f} | w1: {best_clean[5]:.2f} | w2: {best_clean[6]:.2f}")

print("\n==== Best Group (Adversarial) ====")
print(f"group1: {best_adv[0]} ({len(best_adv[0])}+{10-len(best_adv[0])})")
print(f"score: {best_adv[1]:.4f} | root_acc: {best_adv[2]:.4f} | acc1: {best_adv[3]:.4f} | acc2: {best_adv[4]:.4f} | w1: {best_adv[5]:.2f} | w2: {best_adv[6]:.2f}")


==== Best Group (Clean) ====
group1: (1, 8, 9) (3+7)
score: 0.6890 | root_acc: 0.9531 | acc1: 0.8800 | acc2: 0.6556 | w1: 0.30 | w2: 0.70

==== Best Group (Adversarial) ====
group1: (0, 1, 8, 9) (4+6)
score: 0.3393 | root_acc: 0.8508 | acc1: 0.5460 | acc2: 0.3007 | w1: 0.40 | w2: 0.60
