k_means算法的简单实现

In [29]:
import numpy as np
"""X是所有的训练数据"""
"""k是预计分类的个数"""
"""maxIt是中心点变化的最多次数"""
def kmeans(X,k,maxIt):
    numPoints,numDim = X.shape
    dataSet = np.zeros((numPoints,numDim+1))
    dataSet[:,:-1] = X
    centroids = dataSet[np.random.randint(numPoints,size = k),:]
    centroids[:,-1] = range(1,k+1)
    
    iterations = 0
    oldcentroids = None
    
    while not shouldStop(oldcentroids,centroids,iterations,maxIt):
        print("iterations:",str(iterations))
        print("dataSet:\n",dataSet)
        print("centroids:\n",centroids)
        
        #拷贝当前的中心点
        oldcentroids = np.copy(centroids)
        iterations += 1
        
        updateLabels(dataSet,centroids)
        centroids = getCentriods(dataSet,k)
        
    return dataSet


In [32]:
def shouldStop(oldcentroids,centroids,iterations,maxIt):
    """在中心点更替次数大于最大值或者中心点不在变化时返回True"""
    if iterations >  maxIt:
        return True
    return np.array_equal(oldcentroids,centroids)

def updateLabels(dataSet,centroids):
    """更新标签是根据点距离最近的中心点的种类"""
    numPointers,numDim = dataSet.shape
    for i in range(numPointers):
        dataSet[i,-1] = getLabelFromClosestCentroid(dataSet[i,:-1],centroids)
        
def getLabelFromClosestCentroid(dataSetRow,centroids):
    label = centroids[0,-1]
    minDist = np.linalg.norm(dataSetRow - centroids[0,:-1])
    for i in range(1,centroids.shape[0]):
        dist = np.linalg.norm(dataSetRow - centroids[i,:-1])
        if dist< minDist:
            minDist = dist
            label = centroids[i,-1]
    print("minDist:",minDist)
    return label

def getCentriods(dataSet,k):
    """计算出新的中心点"""
    result = np.zeros((k,dataSet.shape[1]))
    for i in range(1,k+1):
        oneCluster = dataSet[dataSet[:,-1] == i,:-1]
        result[i-1,:-1] = np.mean(oneCluster,axis = 0)
        result[i-1,-1] = i
    return result

In [33]:
x1 = np.array([1, 1])
x2 = np.array([2, 1])
x3 = np.array([4, 3])
x4 = np.array([5, 4])
testX = np.vstack((x1, x2, x3, x4))
result = kmeans(testX, 2, 10)
print ("final result:")
print (result)

iterations: 0
dataSet:
 [[1. 1. 0.]
 [2. 1. 0.]
 [4. 3. 0.]
 [5. 4. 0.]]
centroids:
 [[1. 1. 1.]
 [2. 1. 2.]]
minDist: 0.0
minDist: 0.0
minDist: 2.8284271247461903
minDist: 4.242640687119285
iterations: 1
dataSet:
 [[1. 1. 1.]
 [2. 1. 2.]
 [4. 3. 2.]
 [5. 4. 2.]]
centroids:
 [[1.         1.         1.        ]
 [3.66666667 2.66666667 2.        ]]
minDist: 0.0
minDist: 1.0
minDist: 0.4714045207910319
minDist: 1.885618083164127
iterations: 2
dataSet:
 [[1. 1. 1.]
 [2. 1. 1.]
 [4. 3. 2.]
 [5. 4. 2.]]
centroids:
 [[1.5 1.  1. ]
 [4.5 3.5 2. ]]
minDist: 0.5
minDist: 0.5
minDist: 0.7071067811865476
minDist: 0.7071067811865476
final result:
[[1. 1. 1.]
 [2. 1. 1.]
 [4. 3. 2.]
 [5. 4. 2.]]
