In [20]:
from sklearn.cluster import DBSCAN
import numpy as np
import random
from typing import List, Dict, Callable
from collections import defaultdict
from easydict import EasyDict
import math
import torch
from math import atan2, sin
from tqdm import tqdm


class Segment(EasyDict):
    def __init__(self, segment_id, points, emb):
        super().__init__()
        self.id = segment_id
        self.points = points
        self.emb = emb
        self.neighbor_count = 0
        self.local_neighbor = []


class Cluster:
    def __init__(self, segments):
        self.items = segments
        self.size = len(segments)
        self.centroid = self._calculate_centroid()
        self.radius = self._calculate_radius()
        self.merged = False

    def _get_segment_midpoint(self, seg):
        start, end = seg.points[0], seg.points[-1]
        mid_x = (start[0] + end[0]) / 2
        mid_y = (start[1] + end[1]) / 2
        return (mid_x, mid_y)

    def _calculate_centroid(self):
        total_x = 0
        total_y = 0
        for seg in self.items:
            mid_x, mid_y = self._get_segment_midpoint(seg)
            total_x += mid_x
            total_y += mid_y
        centroid_x = total_x / self.size
        centroid_y = total_y / self.size
        return (centroid_x, centroid_y)

    def _calculate_radius(self):
        max_distance = 0
        for seg in self.items:
            midpoint = self._get_segment_midpoint(seg)
            distance = self.compute_point_distance(midpoint, self.centroid)
            if distance > max_distance:
                max_distance = distance
        return max_distance

    @staticmethod
    def compute_point_distance(p1, p2):
        return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)


def compute_angular_distance(seg1, seg2):
    start1, end1 = seg1.points[0], seg1.points[-1]
    start2, end2 = seg2.points[0], seg2.points[-1]
    vector1 = (end1[0] - start1[0], end1[1] - start1[1])
    vector2 = (end2[0] - start2[0], end2[1] - start2[1])
    angle1 = atan2(vector1[1], vector1[0])
    angle2 = atan2(vector2[1], vector2[0])
    angle_diff = abs(angle1 - angle2)
    if angle_diff > np.pi:
        angle_diff = 2 * np.pi - angle_diff
    len1 = Cluster.compute_point_distance(start1, end1)
    len2 = Cluster.compute_point_distance(start2, end2)
    return abs(sin(angle_diff)) * max(len1, len2)


def compute_vector_distance(r1, r2):
    r1 = torch.tensor(r1) if not isinstance(r1, torch.Tensor) else r1
    r2 = torch.tensor(r2) if not isinstance(r2, torch.Tensor) else r2
    sum_square = torch.sum((r1 - r2) ** 2)
    return torch.sqrt(sum_square).item()


def calculate_distance(seg1, seg2, alpha, beta, gamma):
    d1 = Cluster.compute_point_distance(
        seg1.points[0], seg2.points[0]
    ) + Cluster.compute_point_distance(seg1.points[-1], seg2.points[-1])
    # d2 = compute_angular_distance(seg1, seg2)
    d3 = compute_vector_distance(seg1.emb, seg2.emb)
    return alpha * d1 + gamma * d3


def dbscan(
    trajs: List[Segment], eps: float, min_pts: int, distance_func: Callable
) -> Dict:
    if not trajs:
        return {}

    # 创建 tqdm 进度条
    num_segments = len(trajs)
    total_iterations = num_segments * (num_segments - 1) // 2
    progress_bar = tqdm(total=total_iterations, desc="计算距离矩阵")
    count = 0
    distance_matrix = np.zeros((num_segments, num_segments))
    for i in range(num_segments):
        for j in range(i + 1, num_segments):
            dist = distance_func(trajs[i], trajs[j])
            distance_matrix[i, j] = dist
            distance_matrix[j, i] = dist
            # 更新进度条
            count += 1
            progress_bar.update(1)

    db = DBSCAN(eps=eps, min_samples=min_pts, metric="precomputed")
    labels = db.fit_predict(distance_matrix)

    result = {}
    for seg in trajs:
        seg.neighbor_count = 0
        seg.local_neighbor = []

    # 根据距离确定邻居点
    for i, seg in enumerate(trajs):
        # 只考虑 j > i 的情况，避免重复计算
        for j in range(i + 1, len(trajs)):
            other_seg = trajs[j]
            if distance_func(seg, other_seg) < eps:
                # 两个线段互为邻居，分别增加邻居计数
                seg.neighbor_count += 1
                other_seg.neighbor_count += 1
                # 分别将对方添加到自己的邻居列表中
                seg.local_neighbor.append(other_seg)
                other_seg.local_neighbor.append(seg)

    for seg, label in zip(trajs, labels):
        result[seg.id] = (seg.neighbor_count, label)

    return result


def random_select_traj(lres_i: Dict, min_pts: int, num: int) -> List:
    candidates = [
        seg_id for seg_id, (count, _) in lres_i.items() if count >= min_pts - 1
    ]
    return random.sample(candidates, min(num, len(candidates))) if candidates else []


def update(
    traj: Segment,
    c: int,
    lres: Dict,
    all_local_segments: List[Segment],
    min_pts: int,
    eps: float,
    distance_func: Callable,
    updated_segments: set = None,
):
    if updated_segments is None:
        updated_segments = set()
    if traj.id in updated_segments:
        return
    updated_segments.add(traj.id)

    # 更新轨迹段所属簇编号
    lres[traj.id] = (lres[traj.id][0], c)

    # 更新簇内其他成员
    if lres[traj.id][1] is not None:
        same_cluster_segs = [
            seg_id for seg_id, (_, label) in lres.items() if label == lres[traj.id][1]
        ]
        for seg_id in same_cluster_segs:
            lres[seg_id] = (lres[seg_id][0], c)

    # 使用本地邻居信息进行加速
    for other_seg in traj.local_neighbor:
        if other_seg.neighbor_count >= min_pts:
            update(
                other_seg,
                c,
                lres,
                all_local_segments,
                min_pts,
                eps,
                distance_func,
                updated_segments,
            )


def federated_clustering(
    fed_trajs: List[List[Segment]],
    eps: float,
    min_pts: int,
    k: int,
    num: int,
    alpha: float,
    beta: float,
    gamma: float,
) -> List[Cluster]:
    from functools import partial

    distance_func = partial(calculate_distance, alpha=alpha, beta=beta, gamma=gamma)
    local_results = [
        dbscan(trajs, eps, min_pts, distance_func=distance_func) for trajs in fed_trajs
    ]

    no_merge_count = 0  # 初始化连续未合并的轮数为 0

    while no_merge_count < k:  # 当连续未合并的轮数小于 k 时继续迭代
        merged = False
        selected_segments = []
        for lres in local_results:
            if lres:
                selected_segments.append(random_select_traj(lres, min_pts, num))
            else:
                selected_segments.append([])

        merged_index = []
        for i in range(len(fed_trajs)):
            for j in range(i + 1, len(fed_trajs)):
                for p in range(len(selected_segments[i])):
                    for q in range(len(selected_segments[j])):
                        seg_i_id, seg_j_id = (
                            selected_segments[i][p],
                            selected_segments[j][q],
                        )
                        try:
                            seg_i = next(
                                seg for seg in fed_trajs[i] if seg.id == seg_i_id
                            )
                            seg_j = next(
                                seg for seg in fed_trajs[j] if seg.id == seg_j_id
                            )
                        except StopIteration:
                            print(f"no such id {i, j}")
                            continue
                        # 获取两个轨迹段的簇标签，跳过已经同簇了的轨迹段
                        label_i = local_results[i][seg_i_id][1]
                        label_j = local_results[j][seg_j_id][1]
                        if label_i == label_j:
                            continue

                        distance = distance_func(seg_i, seg_j)

                        # 判断簇标签不同且距离小于阈值
                        if distance < eps:
                            # 更新邻居数
                            seg_i.neighbor_count += 1
                            seg_j.neighbor_count += 1
                            local_results[i][seg_i_id] = (
                                seg_i.neighbor_count,
                                local_results[i][seg_i_id][1],
                            )
                            local_results[j][seg_j_id] = (
                                seg_j.neighbor_count,
                                local_results[j][seg_j_id][1],
                            )

                            # 根据三角关系更新两个轨迹段各自的本地邻居ntraj
                            for ntraj in seg_i.local_neighbor:
                                if distance_func(ntraj, seg_j) < eps:
                                    ntraj.neighbor_count += 1
                                    local_results[i][ntraj.id] = (
                                        ntraj.neighbor_count,
                                        local_results[i][ntraj.id][1],
                                    )

                            for ntraj in seg_j.local_neighbor:
                                if distance_func(ntraj, seg_i) < eps:
                                    ntraj.neighbor_count += 1
                                    local_results[j][ntraj.id] = (
                                        ntraj.neighbor_count,
                                        local_results[j][ntraj.id][1],
                                    )

                            # 递归更新各自的本地邻居的聚类标签
                            new_cluster_id = (
                                local_results[i][seg_i_id][1]
                                if local_results[i][seg_i_id][1] != -1
                                else local_results[j][seg_j_id][1]
                            )
                            update(
                                seg_i,
                                new_cluster_id,
                                local_results[i],
                                fed_trajs[i],
                                min_pts,
                                eps,
                                distance_func,
                            )
                            update(
                                seg_j,
                                new_cluster_id,
                                local_results[j],
                                fed_trajs[j],
                                min_pts,
                                eps,
                                distance_func,
                            )
                            merged = True
                            merged_index.append(((i, p), (j, p)))
        if merged:
            no_merge_count = 0  # 发生合并，重置计数器
        else:
            no_merge_count += 1  # 未发生合并，计数器加 1
        print(no_merge_count, merged_index)

    cluster_mapping = defaultdict(list)
    for i, lres in enumerate(local_results):
        for seg_id, (_, cluster_id) in lres.items():
            if cluster_id != -1:
                try:
                    seg = next(seg for seg in fed_trajs[i] if seg.id == seg_id)
                    cluster_mapping[cluster_id].append(seg)
                except StopIteration:
                    continue

    return [Cluster(segs) for segs in cluster_mapping.values()]

In [21]:
# 生成模拟数据
def generate_synthetic_data(num_clients, num_segments_per_client, embedding_dim):
    fed_trajs = []
    segment_id_counter = 0
    for client in range(num_clients):
        client_trajs = []
        for _ in range(num_segments_per_client):
            # 随机生成轨迹段的起点和终点
            start_point = (random.random(), random.random())
            end_point = (random.random(), random.random())
            points = [start_point, end_point]
            # 随机生成嵌入向量
            emb = torch.tensor([random.random() for _ in range(embedding_dim)])
            segment = Segment(segment_id_counter, points, emb)
            client_trajs.append(segment)
            segment_id_counter += 1
        fed_trajs.append(client_trajs)
    return fed_trajs


# 设置参数
num_clients = 3
num_segments_per_client = 100
embedding_dim = 5
eps = 1.0
min_pts = 2
k = 5
num = 3
alpha = 0.5
beta = 0.
gamma = 0.5

# 生成模拟数据
fed_trajs = generate_synthetic_data(num_clients, num_segments_per_client, embedding_dim)

# 进行联邦聚类
clusters = federated_clustering(fed_trajs, eps, min_pts, k, num, alpha, beta, gamma)

# 打印聚类结果
for i, cluster in enumerate(clusters):
    print(f"Cluster {i}:")
    print(f"  Size: {cluster.size}")
    print(f"  Centroid: {cluster.centroid}")
    print(f"  Radius: {cluster.radius}")
    print("  Segment IDs:", [seg.id for seg in cluster.items])
    print()

计算距离矩阵:   0%|          | 0/4950 [00:00<?, ?it/s]

计算距离矩阵: 100%|██████████| 4950/4950 [00:00<00:00, 22380.57it/s]
计算距离矩阵: 100%|██████████| 4950/4950 [00:00<00:00, 22761.99it/s]
计算距离矩阵: 100%|██████████| 4950/4950 [00:00<00:00, 23090.76it/s]

1 []
2 []
3 []
4 []
5 []
Cluster 0:
  Size: 300
  Centroid: (0.4878262172993846, 0.4866084334946415)
  Radius: 0.589437510464947
  Segment IDs: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 


