In [9]:
import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import matplotlib.colors
from sklearn.cluster import AffinityPropagation
from sklearn.metrics import euclidean_distances
%matplotlib auto

Using matplotlib backend: Qt5Agg


In [5]:
N = 400
centers = [[1, 2], [-1, -1], [1, -1], [-1, 1]]
data, y = ds.make_blobs(N,
                        n_features=2,
                        centers=centers,
                        cluster_std=[0.5, 0.25, 0.7, 0.5],
                        random_state=0)
m=euclidean_distances(data,squared=True)
preference=-np.median(m)
print('preference: ',preference)

preference:  -5.29914553034


In [46]:
matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(12, 9), facecolor='w')
for i, mul in enumerate(np.linspace(1, 4, 9)):
    print(mul)
    p = mul * preference
    model = AffinityPropagation(affinity='euclidean', preference=p)
    af = model.fit(data)
    center_indices = af.cluster_centers_indices_
    n_clusters = len(center_indices)
    print('p=%1.f' % (p), 'qua of cluster :', n_clusters)
    y_hat = af.labels_

    plt.subplot(3, 3, i + 1)
    plt.title('preference:%.2f,qua of cluster: %d' % (p, n_clusters))
    clrs = []

    for c in np.linspace(16711680, 255, n_clusters):
        clrs.append('#%06x' % int(c))

    for k, clr in enumerate(clrs):
        cur = (y_hat == k)
        plt.scatter(data[cur, 0], data[cur, 1], c=clr, edgecolors='none')
        center = data[center_indices[k]]
        for x in data[cur]:
            plt.plot([x[0], center[0]], [x[1], center[1]],
                     color=clr,
                     zorder=1)
    plt.scatter(data[center_indices, 0],
                data[center_indices, 1],
                s=100,
                c=clrs,
                marker='*',
                edgecolors='k',
                zorder=2)
plt.tight_layout()
plt.suptitle(u'AP聚类', fontsize=20)
plt.subplots_adjust(top=0.92)
plt.show()

1.0
p=-5 qua of cluster : 16
1.375
p=-7 qua of cluster : 12
1.75
p=-9 qua of cluster : 11
2.125
p=-11 qua of cluster : 10
2.5
p=-13 qua of cluster : 9
2.875
p=-15 qua of cluster : 8
3.25
p=-17 qua of cluster : 55
3.625
p=-19 qua of cluster : 107
4.0
p=-21 qua of cluster : 7


In [14]:
for i,mul in enumerate(np.linspace(1,4,9)):
    print(mul)

1.0
1.375
1.75
2.125
2.5
2.875
3.25
3.625
4.0


In [15]:
preference*2

-10.598291060678285

In [30]:
model1 = AffinityPropagation(affinity='euclidean', preference=preference)
af1 = model1.fit(data)
center_indices1 = af1.cluster_centers_indices_
center_indices1

array([ 20,  29,  77,  83,  93, 134, 136, 148, 238, 257, 262, 263, 312,
       326, 362, 376], dtype=int64)

In [31]:
y_hat1=af1.labels_

In [32]:
len(af1.labels_)

400

In [33]:
data[:5]

array([[-0.69743996,  1.44777799],
       [-1.97285154,  0.54360825],
       [ 1.16948257, -1.62210418],
       [ 0.75098377,  2.96476603],
       [ 0.48287858,  2.34079726]])

In [34]:
clrs1=[]
for c in np.linspace(16711680,255,len(center_indices1)):
    clrs1.append('#%06x' % int(c))

In [39]:
clrs1

['#ff0000',
 '#ee0011',
 '#dd0022',
 '#cc0033',
 '#bb0044',
 '#aa0055',
 '#990066',
 '#880077',
 '#770088',
 '#660099',
 '#5500aa',
 '#4400bb',
 '#3300cc',
 '#2200dd',
 '#1100ee',
 '#0000ff']

In [49]:
for k,clr in enumerate(clrs1):
    cur1=(y_hat1==k)
    plt.scatter(data[cur1,0],data[cur1,1],c=clr,edgecolors='none')
    center1=data[center_indices1[k]]
    for x in data[cur1]:
        plt.plot([x[0],center1[0]],[x[1],center1[1]],color=clr,zorder=1)
    plt.scatter(data[center_indices1,0],
                data[center_indices1,1],
                s=100,c=clr,
               marker='*',edgecolors='k',zorder=2)
    
plt.title('p=%.1f ,clusters qua: %d' % (preference,len(center_indices1)))
plt.show()
        

In [81]:
for x in data[cur1]:
    print([x[0],center1[0]],[x[1],center1[1]])

[-1.0117115525107261, -1.2663513959897772] [1.5395973640562446, 1.54537486721725]
[-1.2256515185512631, -1.2663513959897772] [1.132843987483118, 1.54537486721725]
[-1.4663945207532192, -1.2663513959897772] [1.6216596922275774, 1.54537486721725]
[-1.3003287788288942, -1.2663513959897772] [1.7761215900242804, 1.54537486721725]
[-1.0451910036384735, -1.2663513959897772] [1.6837986199033568, 1.54537486721725]
[-0.834711621862813, -1.2663513959897772] [1.4746232367791179, 1.54537486721725]
[-0.90101635502751631, -1.2663513959897772] [1.5409676092386326, 1.54537486721725]
[-1.394334627254683, -1.2663513959897772] [1.5473191873560457, 1.54537486721725]
[-1.2957013339040553, -1.2663513959897772] [1.5622095922551842, 1.54537486721725]
[-1.2495083189970915, -1.2663513959897772] [1.0106756119217744, 1.54537486721725]
[-1.0178840359301105, -1.2663513959897772] [2.1903726756098751, 1.54537486721725]
[-0.77590235778834371, -1.2663513959897772] [1.8480907864140803, 1.54537486721725]
[-1.0074288516773

In [85]:
for x in data[cur1]:
    plt.plot([x[0],center1[0]],[x[1],center1[1]])
#     plt.plot([x[0],center1[0]])
plt.show()

In [75]:
plt.plot([[1,2],[0,2]],[[3,6],[7,6]])
plt.show()
