In [None]:
import jax
import jax.numpy as jnp
from mmdagg import *

def categorical_sample(probs, key, num_samples):
    # probs: (d,)  - 概率向量
    # 返回: (num_samples, d) - one-hot采样
    idxs = jax.random.choice(key, len(probs), shape=(num_samples,), p=probs)
    return jnp.eye(len(probs))[idxs]

def main():
    # 设定离散概率分布
    d = 8
    p_true = jnp.array([0.15, 0.1, 0.05, 0.1, 0.25, 0.1, 0.15, 0.1])
    p_fake = jnp.array([0.2, 0.05, 0.1, 0.05, 0.25, 0.05, 0.2, 0.1])

    # 从这两个分布各采1000个样本
    key = jax.random.PRNGKey(42)
    X = categorical_sample(p_true, key, 1000)  # shape (1000, 8)
    key, subkey = jax.random.split(key)
    Y = categorical_sample(p_fake, subkey, 1000)  # shape (1000, 8)

    # 调用mmdagg，注意 kernel="laplace_gaussian" 是你的主流程
    results, _ = mmdagg(X, Y, return_dictionary=True)
    print("--- MMD stats for each kernel/bandwidth ---")

    # 1. 提取MMD和权重，并打印核信息
    mmd_values = []
    kernel_names = []
    for k, v in results.items():
        if k.startswith("Single test"):
            kernel_name = [key for key in v if key.startswith("Kernel")][0]
            kernel_names.append(kernel_name)
            print(f"{k}:  {kernel_name},  bw={v['Bandwidth']:.3f},  MMD={v['MMD']:.6f}")
            mmd_values.append(v["MMD"])
    mmd_values = jnp.array(mmd_values)
    

    # 2. 构造权重（权重分配方式与mmdagg主流程保持一致！）
    number_kernels = len(set(kernel_names))  # 例如 laplace, gaussian 就是2
    number_bandwidths = len(mmd_values) // number_kernels

    # create_weights
    weights = create_weights(number_bandwidths, "uniform") / number_kernels
    # 展开成和 mmd_values 一样长（每个核都重复同一组带宽权重）
    weights = jnp.tile(weights, number_kernels)

    # 3. 加权聚合
    mmd_loss = jnp.sum(weights * mmd_values)
    print("Aggregate MMD (weighted sum over all kernels and bandwidths):", float(mmd_loss))

if __name__ == "__main__":
    main()

E0716 15:24:31.504529   28488 buffer_comparator.cc:145] Difference at 16: 0, expected 513.414
E0716 15:24:31.504554   28488 buffer_comparator.cc:145] Difference at 17: 0, expected 516.857
E0716 15:24:31.504557   28488 buffer_comparator.cc:145] Difference at 18: 0, expected 506.085
E0716 15:24:31.504559   28488 buffer_comparator.cc:145] Difference at 19: 0, expected 500.791
E0716 15:24:31.504561   28488 buffer_comparator.cc:145] Difference at 20: 0, expected 500.779
E0716 15:24:31.504562   28488 buffer_comparator.cc:145] Difference at 21: 0, expected 511.443
E0716 15:24:31.504564   28488 buffer_comparator.cc:145] Difference at 22: 0, expected 512.05
E0716 15:24:31.504566   28488 buffer_comparator.cc:145] Difference at 23: 0, expected 516.21
E0716 15:24:31.504567   28488 buffer_comparator.cc:145] Difference at 24: 0, expected 504.591
E0716 15:24:31.504569   28488 buffer_comparator.cc:145] Difference at 25: 0, expected 510.732
2025-07-16 15:24:31.505822: E external/xla/xla/service/gpu/aut

--- MMD stats for each kernel/bandwidth ---


AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'items'

: 