In [19]:
import numpy as np

In [155]:
class LBG:
    def __init__(self, n_clusters, epsilon=0.00001):
        self.n_clusters=n_clusters
        self.epsilon=epsilon

    def _get_init_cb(self, data):
        return np.array([np.mean(data, axis=0)])

    def _create_codebook(self, data, e):
        return data * (1.0 + e)

    def _average_distortion(self, data, codebook, idx):
        return np.mean([np.linalg.norm(data - codebook[i], axis=1) ** 2 for i in idx])

    def _split_codebook(self, data, codebook, init_distortion):
        tmp_cbs = np.empty((0, self.n_features))
        for i in range(2):
            tmp_cbs = np.append(tmp_cbs, self._create_codebook(codebook, (-1)**i * self.epsilon), axis=0)

        codebook = np.copy(tmp_cbs)
        distortion = init_distortion
        err = 1.0 + self.epsilon;

        while err > self.epsilon:
            norms = np.array([np.linalg.norm(data - c, axis=1)**2 for c in codebook])
            idx_min = np.argmin(norms, axis=0)
            idx_map = np.zeros((self.N, len(codebook)))
            idx_map[range(len(data)), idx_min] = 1

            s = (np.sum(idx_map, axis=0, dtype=float)+1e-16).reshape(len(codebook), -1)
            codebook = np.dot(idx_map.T, data) / s

            prev_distortion = distortion
            distortion = self._average_distortion(data, codebook, idx_min)
            err = (prev_distortion - distortion) / distortion

        return codebook, idx_min

    def fit(self, data):
        shape = data.shape
        self.N = shape[0]
        self.n_features=shape[1]
        
        codebook = self._get_init_cb(data)
        distortion = self._average_distortion(data, codebook, np.array([[0]]))

        while len(codebook) < self.n_clusters:
            codebook, labels = self._split_codebook(data, codebook, distortion)
            
        self.cluster_centers=codebook

        return labels

    def predict(self, data):
        norms = np.array([np.linalg.norm(data - c, axis=1)**2 for c in self.cluster_centers])
        return np.argmin(norms, axis=0)

In [156]:
from skimage.io import imread, imsave
from sklearn.utils import shuffle
import time

In [157]:
img_path = 'face.png'
img = imread(img_path)
n_random=1000
n_clusters = 64

In [158]:
rows, cols, depth = img.shape
img_vect = img.reshape(rows * cols, depth)
img_train = shuffle(img_vect)[:n_random]

In [159]:
lbg = LBG(n_clusters=n_clusters)

In [160]:
start = time.time()
lbg.fit(img_train)
labels = lbg.predict(img_vect)
labels = labels.reshape(rows, cols)
end = time.time()

In [165]:
imsave('cmprs_lbls.png', labels)
np.save('cmprs_cntrs.npy', lbg.cluster_centers)

In [166]:
centers = np.load('cmprs_cntrs.npy')
labels = imread('cmprs_lbls.png')

image = np.zeros((labels.shape[0], labels.shape[1], centers.shape[1]), dtype=np.uint8)
for i in range(labels.shape[0]):
    for j in range(labels.shape[1]):
        image[i, j, :] = centers[labels[i, j], :]

imsave('reconstructed.png', image)