In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from sklearn.decomposition import KernelPCA
import numpy as np
import plotly.graph_objects as go
from scipy.spatial import KDTree, cKDTree
from dash import Dash, dcc, html, Input, Output
from sklearn.metrics.pairwise import euclidean_distances

# 生成球面和椭球面上的点
np.random.seed(42)
num_points_per_cluster = 300
radius = 1  # 球半径

a, b, c = 1.5, 1, 0.8  # 椭球半轴长度

def generate_spherical_points(center, r, num_points):
    points = []
    for _ in range(num_points):
        theta = np.random.uniform(center[0] - r, center[0] + r)  # 角度 theta
        phi = np.random.uniform(center[1] - r, center[1] + r)    # 角度 phi
        
        # 计算球面点
        x = radius * np.sin(theta) * np.cos(phi)
        y = radius * np.sin(theta) * np.sin(phi)
        z = radius * np.cos(theta)
        points.append((x, y, z))
    return np.array(points)

def convert_to_ellipsoid(points):
    return np.array([(a*x, b*y, c*z) for x, y, z in points])

# 生成螺旋形流形上的点
def generate_spiral_points(num_points, turns=3, radius=1, height=5):
    points = []
    for _ in range(num_points):
        t = np.random.uniform(0, 2 * np.pi * turns)
        r = np.random.uniform(0, radius)
        x = r * np.cos(t)
        y = r * np.sin(t)
        z = t / (2 * np.pi) * height
        points.append((x, y, z))
    return np.array(points)

# 设定四个簇的中心 (theta, phi)
centers = [
    (np.pi / 3, np.pi / 3),
    (np.pi / 3, 2 * np.pi / 3),
    (2 * np.pi / 3, np.pi / 3),
    (2 * np.pi / 3, 2 * np.pi / 3)
]

# 生成簇点（同一数据用于球面和椭球面）
colors = ['red', 'blue', 'green', 'orange']
all_points = []
all_indices = []
point_colors = []

for i, center in enumerate(centers):
    points = generate_spherical_points(center, 0.6, num_points_per_cluster)
    all_points.append(points)
    all_indices.extend(range(len(all_indices), len(all_indices) + len(points)))
    point_colors.extend([colors[i]] * len(points))

all_points = np.vstack(all_points)
ellipsoid_points = convert_to_ellipsoid(all_points)
spiral_points = generate_spiral_points(num_points_per_cluster * len(centers))

# 构建 KDTree 以找到最近邻居（分别在球体和椭球体上）
sphere_tree = cKDTree(all_points)
ellipsoid_tree = cKDTree(ellipsoid_points)
spiral_tree = cKDTree(spiral_points)

def find_neighbors(points, tree, index, k=5, metric='euclidean', gamma=0.1):
    """
    找到指定点的最近邻点，支持欧几里得距离和基于 KernelPCA 的扩散距离。

    参数:
        points: 数据点的数组，形状为 (n_samples, n_features)
        tree: KDTree 或 cKDTree 对象，用于加速欧几里得距离的查询
        index: 查询点的索引
        k: 返回的最近邻点数量
        metric: 距离度量方式，可选 'euclidean' 或 'diffusion'
        gamma: KernelPCA 中 RBF 核的参数
    返回:
        indices: 最近邻点的索引
    """
    if metric == 'euclidean':
        # 使用 KDTree 快速查找欧几里得距离的最近邻
        _, indices = tree.query(points[index], k=k + 1)  # 包括自身，所以取 k+1
        return indices[1:]  # 排除自身
    elif metric == 'diffusion':
        # 使用 KernelPCA 计算扩散距离
        diffusion_distances = compute_diffusion_distances(points, gamma=gamma)
        dist_row = diffusion_distances[index]
        indices = np.argsort(dist_row)[1:k + 1]  # 排除自身
        return indices
    else:
        raise ValueError("Unsupported metric. Choose 'euclidean' or 'diffusion'.")

def compute_diffusion_distances(points, n_components=2, kernel='rbf', gamma=0.1):
    """
    使用 KernelPCA 计算扩散距离。

    参数:
        points: 数据点的数组，形状为 (n_samples, n_features)
        n_components: KernelPCA 降维后的目标维度
        kernel: 核函数类型（默认为 'rbf'）
        gamma: RBF 核的参数
    返回:
        diffusion_distances: 扩散距离矩阵
    """
    # 使用 KernelPCA 进行降维
    kpca = KernelPCA(n_components=n_components, kernel=kernel, gamma=gamma, fit_inverse_transform=True)
    transformed_points = kpca.fit_transform(points)
    
    # 在降维后的空间中计算欧几里得距离
    diffusion_distances = euclidean_distances(transformed_points)
    return diffusion_distances
# Dash 应用
app = Dash(__name__)

app.layout = html.Div([
    dcc.Input(id='index-input', type='number', placeholder='Enter Point Index', debounce=True),
    dcc.Slider(
        id='neighbor-slider',
        min=1,
        max=10,
        step=1,
        value=5,
        marks={i: str(i) for i in range(1, 11)},
        tooltip={'placement': 'bottom'}
    ),
    dcc.Dropdown(
        id='distance-metric',
        options=[
            {'label': 'Euclidean Distance', 'value': 'euclidean'},
            {'label': 'Diffusion Distance', 'value': 'diffusion'}
        ],
        value='euclidean',
        clearable=False
    ),
    html.Div([
        html.Div([
            dcc.Graph(id='sphere-plot', style={'width': '30vw', 'height': '600px'}),
            html.Ul(id='sphere-neighbors')
        ], style={'flex-grow': '1', 'padding': '10px'}),
        html.Div([
            dcc.Graph(id='ellipsoid-plot', style={'width': '30vw', 'height': '600px'}),
            html.Ul(id='ellipsoid-neighbors')
        ], style={'flex-grow': '1', 'padding': '10px'}),
        html.Div([
            dcc.Graph(id='spiral-plot', style={'width': '30vw', 'height': '600px'}),
            html.Ul(id='spiral-neighbors')
        ], style={'flex-grow': '1', 'padding': '10px'}),
    ], style={'display': 'flex', 'justify-content': 'space-between', 'align-items': 'center'})
])

@app.callback(
    [Output('sphere-plot', 'figure'), Output('sphere-neighbors', 'children'),
     Output('ellipsoid-plot', 'figure'), Output('ellipsoid-neighbors', 'children'),
     Output('spiral-plot', 'figure'), Output('spiral-neighbors', 'children')],
    [Input('neighbor-slider', 'value'), Input('index-input', 'value'), Input('distance-metric', 'value')]
)
def update_plots(neighbor_count, selected_index, metric):
    sphere_neighbors_list = []
    ellipsoid_neighbors_list = []
    spiral_neighbors_list = []

    sphere_colors = point_colors[:]
    ellipsoid_colors = point_colors[:]
    spiral_colors = point_colors[:len(spiral_points)]

    if selected_index is not None:
        # 球面邻居
        sphere_neighbors = find_neighbors(all_points, sphere_tree, selected_index, k=neighbor_count, metric=metric)
        sphere_colors = ['lightgray'] * len(all_points)
        for idx in [selected_index] + list(sphere_neighbors):
            sphere_colors[idx] = point_colors[idx]
        sphere_neighbors_list = [html.Li(f'Sphere Neighbor {i}: {n}') for i, n in enumerate(sphere_neighbors)]

        # 椭球面邻居
        ellipsoid_neighbors = find_neighbors(ellipsoid_points, ellipsoid_tree, selected_index, k=neighbor_count, metric=metric)
        ellipsoid_colors = ['lightgray'] * len(ellipsoid_points)
        for idx in [selected_index] + list(ellipsoid_neighbors):
            ellipsoid_colors[idx] = point_colors[idx]
        ellipsoid_neighbors_list = [html.Li(f'Ellipsoid Neighbor {i}: {n}') for i, n in enumerate(ellipsoid_neighbors)]

        # 螺旋形流形邻居
        spiral_neighbors = find_neighbors(spiral_points, spiral_tree, selected_index % len(spiral_points), k=neighbor_count, metric=metric)
        spiral_colors = ['lightgray'] * len(spiral_points)
        for idx in [selected_index % len(spiral_points)] + list(spiral_neighbors):
            spiral_colors[idx] = point_colors[idx % len(point_colors)]
        spiral_neighbors_list = [html.Li(f'Spiral Neighbor {i}: {n}') for i, n in enumerate(spiral_neighbors)]

    # 创建球面图
    sphere_scatter = go.Scatter3d(
        x=all_points[:, 0], y=all_points[:, 1], z=all_points[:, 2],
        mode='markers',
        marker=dict(size=6, color=sphere_colors, opacity=1.0),
        text=[str(i) for i in range(len(all_points))],
        hoverinfo='text',
        name='Sphere Points'
    )
    sphere_fig = go.Figure(data=[sphere_scatter])

    # 创建椭球面图
    ellipsoid_scatter = go.Scatter3d(
        x=ellipsoid_points[:, 0], y=ellipsoid_points[:, 1], z=ellipsoid_points[:, 2],
        mode='markers',
        marker=dict(size=6, color=ellipsoid_colors, opacity=1.0),
        text=[str(i) for i in range(len(ellipsoid_points))],
        hoverinfo='text',
        name='Ellipsoid Points'
    )
    ellipsoid_fig = go.Figure(data=[ellipsoid_scatter])

    # 创建螺旋形流形图
    spiral_scatter = go.Scatter3d(
        x=spiral_points[:, 0], y=spiral_points[:, 1], z=spiral_points[:, 2],
        mode='markers',
        marker=dict(size=6, color=spiral_colors, opacity=1.0),
        text=[str(i) for i in range(len(spiral_points))],
        hoverinfo='text',
        name='Spiral Points'
    )
    spiral_fig = go.Figure(data=[spiral_scatter])

    return (
        sphere_fig, sphere_neighbors_list,
        ellipsoid_fig, ellipsoid_neighbors_list,
        spiral_fig, spiral_neighbors_list
    )

if __name__ == '__main__':
    app.run_server(debug=True)