In [1]:
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib auto
matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False

Using matplotlib backend: Qt5Agg


In [2]:
def restore_image(cb,cluster,shape):
    row,col,dummy=shape
    image=np.empty((row,col,3))
    index=0
    for r in range(row):
        for c in range(col):
            image[r,c]=cb[cluster[index]]
            index+=1
    return image

In [3]:
def show_scatter(a):
    N=10
    print('original image:\n',a)
    density,edges=np.histogramdd(a,bins=[N,N,N],range=[(0,1),
            (0,1),(0,1)])
    density /=density.max()
    x=y=z=np.arange(N)
    d=np.meshgrid(x,y,z)
    
    fig=plt.figure(1,facecolor='w')
    ax=fig.add_subplot(111,projection='3d')
    ax.scatter(d[1],d[0],d[2],c='r',s=100*density,marker='o',
               depthshade=True)
    ax.set_xlabel(u'红色分量')
    ax.set_ylabel(u'绿色分量')
    ax.set_zlabel(u'蓝色分量')
    plt.title(u'图像颜色三维频数分布', fontsize=20)
    
    plt.figure(2, facecolor='w')
    den = density[density > 0]
    den = np.sort(den)[::-1]
    t = np.arange(len(den))
    plt.plot(t, den, 'r-', t, den, 'go', lw=2)
    plt.title(u'图像颜色频数分布', fontsize=18)
    plt.grid(True)
    
    plt.show()
    

In [6]:
num_vq =300
im = Image.open(
    r'F:\study\ml\DoctorZou\16KNN_II\16.代码\16.Clustering\wx.jpg'
)  # 16.son.bmp(100)/16.flower2.png(200)/16.son.png(60)/16.lena.png(50)
image = np.array(im).astype(np.float) / 255
image = image[:, :, :3]
image_v = image.reshape((-1, 3))
model = KMeans(num_vq)
show_scatter(image_v)

N = image_v.shape[0]
idx = np.random.randint(0, N, size=1000)
image_sample = image_v[idx]
model.fit(image_sample)
c = model.predict(image_v)
print('聚类结果：\n', c)
print('聚类中心：\n', model.cluster_centers_)

plt.figure(figsize=(15, 8), facecolor='w')
plt.subplot(121)
plt.axis('off')
plt.title(u'原始图片', fontsize=18)
plt.imshow(image)
# plt.savefig('1.png')

plt.subplot(122)
vq_image = restore_image(model.cluster_centers_, c, image.shape)
plt.axis('off')
plt.title(u'矢量量化后图片：%d色' % num_vq, fontsize=18)
plt.imshow(vq_image)
# Image.fromarray(vq_image).save(
#     r'F:\study\ml\DoctorZou\16KNN_II\16.代码\16.Clustering\3.jpg')
# plt.savefig(r'F:\study\ml\DoctorZou\16KNN_II\16.代码\16.Clustering\2.jpg')

plt.tight_layout(1.2)
plt.show()

original image:
 [[ 0.39215686  0.43529412  0.45098039]
 [ 0.39607843  0.43921569  0.45490196]
 [ 0.36078431  0.40392157  0.41960784]
 ..., 
 [ 0.22745098  0.35686275  0.49411765]
 [ 0.21176471  0.34117647  0.47843137]
 [ 0.23921569  0.36862745  0.49803922]]
聚类结果：
 [294 294 209 ...,  46 141  46]
聚类中心：
 [[ 0.36862745  0.45098039  0.53333333]
 [ 0.09281046  0.07189542  0.06405229]
 [ 0.77254902  0.7254902   0.67058824]
 [ 0.45490196  0.44392157  0.41490196]
 [ 0.64117647  0.6254902   0.59019608]
 [ 0.26960784  0.27156863  0.27058824]
 [ 0.90980392  0.94901961  0.98431373]
 [ 0.52941176  0.49411765  0.45882353]
 [ 0.78823529  0.8         0.83529412]
 [ 0.56694678  0.53893557  0.51204482]
 [ 0.15735294  0.15784314  0.16470588]
 [ 0.36078431  0.37254902  0.3372549 ]
 [ 0.50588235  0.3751634   0.2745098 ]
 [ 0.02436975  0.01988796  0.02268908]
 [ 0.65686275  0.60588235  0.42352941]
 [ 0.67581699  0.66339869  0.62875817]
 [ 0.43137255  0.31372549  0.20392157]
 [ 0.99926471  1.          0.9995

In [7]:
plt.subplot(111)
plt.imshow(vq_image)
plt.savefig(
    r'F:\study\ml\DoctorZou\16KNN_II\16.代码\16.Clustering\33.jpg')