-
Notifications
You must be signed in to change notification settings - Fork 35
/
kmeans.py
72 lines (59 loc) · 2.58 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
# @Date : 2020/5/21
# @Author: Luokun
# @Email : olooook@outlook.com
import random
import numpy as np
from matplotlib import pyplot as plt
from numpy import linalg as LA
class KMeans:
"""
K-means clustering(K均值聚类)
"""
def __init__(self, k: int, iterations=100, eps=1e-3):
"""
Args:
k (int): 聚类类别数
iterations (int, optional): 迭代最大次数. Defaults to 100.
eps (int, optional): 中心点最小更新量. Defaults to 1e-3.
"""
self.k, self.eps, self.iterations = k, eps, iterations
self.centers = None # 中心点
def fit(self, X: np.ndarray):
self.centers = X[random.sample(range(len(X)), self.k)] # 随机选择k个点作为中心点
for _ in range(self.iterations): # 达到最大迭代次数iterations退出迭代
Y = self(X) # 更新节点类别
means = np.stack([
# 存在元素属于类别i则计算类别i所有点的均值,否则随机选择一个点作为类别i的均值
np.mean(X[Y == i], axis=0) if np.any(Y == i) else random.choice(X) for i in range(self.k)
]) # 各类别的均值
if np.max(np.abs(self.centers - means)) < self.eps: # 中心点最大更新值小于eps
break # 退出迭代
self.centers = means # 将更新后的均值作为各类别中心点
def __call__(self, X: np.ndarray):
return np.array([np.argmin(LA.norm(self.centers - x, axis=1)) for x in X]) # 每一点类别为最近的中心点类别
def load_data():
x = np.stack([np.random.randn(200, 2) + np.array([2, 2]),
np.random.randn(200, 2),
np.random.randn(200, 2) + np.array([2, -2])])
return x
if __name__ == '__main__':
x = load_data()
plt.figure(figsize=[12, 6])
plt.subplot(1, 2, 1)
plt.title('Truth')
plt.scatter(x[0, :, 0], x[0, :, 1], color='r', marker='.')
plt.scatter(x[1, :, 0], x[1, :, 1], color='g', marker='.')
plt.scatter(x[2, :, 0], x[2, :, 1], color='b', marker='.')
x = x.reshape(-1, 2)
kmeans = KMeans(3)
kmeans.fit(x)
pred = kmeans(x)
x0, x1, x2 = x[pred == 0], x[pred == 1], x[pred == 2]
plt.subplot(1, 2, 2)
plt.title('Prediction')
plt.scatter(x0[:, 0], x0[:, 1], color='r', marker='.')
plt.scatter(x1[:, 0], x1[:, 1], color='g', marker='.')
plt.scatter(x2[:, 0], x2[:, 1], color='b', marker='.')
plt.scatter(kmeans.centers[:, 0], kmeans.centers[:, 1], color=['r', 'g', 'b'], marker='*', s=100)
plt.show()