-
Notifications
You must be signed in to change notification settings - Fork 227
/
perform_kmeans.m
534 lines (472 loc) · 15.5 KB
/
perform_kmeans.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
function [B,seeds,E] = perform_kmeans(X,nbCluster,options)
% perform_kmeans - perform the k-means clustering algorithm.
%
% [B,seeds] = perform_kmeans(X,nbCluster,options);
%
% 'X' is a [d,n] matrix where d is the dimension of the space
% and n is the number of points (that live in R^d).
% 'nbCluster' is the number of wanted clusters.
%
% 'B' is a vector that contain the class membership
% of each point.
% 'seeds' is the current center of each region.
% 'E' is the energy of the segmentation (the lowest, the better).
%
% You can provide optional parameter in a structure options:
% - options.seeds : You can provide initial guess for the centers of the regions
% via a [d,nbCluster] matrix 'seeds', where seeds(:,i)
% is the ith center point in R^d.
% - options.intialization : if you don't provide 'seeds', then
% the algorithm will use a initialization depending on 'options.intialization':
% * If options.intialization='random' : random choice of centers.
% * If options.intialization='greedy' : greedy intialization via
% farthest point sampling.
% - options.nb_iter is the number of iterations of the Lloyd algorithm.
%
% Copyright (c) 2004 Gabriel Peyr?
if size(X,1)>size(X,2)
X = X';
end
if nargin<2
nbCluster = 4;
end
options.null = 1;
if isfield(options,'nb_iter')
nb_iter = options.nb_iter;
else
nb_iter = 5;
end
if isfield(options,'kmeans_code')
kmeans_code = options.kmeans_code;
else
kmeans_code = 2;
end
if isfield(options,'etol')
etol = options.etol;
else
etol = 0.02;
end
if kmeans_code==1
[B,seeds,E] = perform_kmeans_old(X,nbCluster,options);
elseif kmeans_code==2
[B,seeds,E] = kmeansML(nbCluster,X,'maxiter',nb_iter, 'etol', etol);
else
[B,seeds,E] = kmeans_light(X', nbCluster, etol);
seeds = seeds';
end
[membership,E] = computeMembership(X,seeds);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [B,seeds,E] = perform_kmeans_old(X,nbCluster,options);
options.null = 1;
if isfield(options,'nb_iter')
nb_iter = options.nb_iter;
else
nb_iter = 5;
end
if isfield(options,'initialization')
initialization = options.initialization;
else
initialization = 'random';
end
n = size(X,2);
d = size(X,1); % = (2*k+1)^2
if isfield(options,'seeds') && not(isempty(options.seeds))
seeds = options.seeds;
else
seeds = zeros(d,nbCluster);
if strcmp(initialization,'random')
% original centers (selected at random)
seed_num = floor(rand(nbCluster,1)*n)+1;
seeds = X(:,seed_num);
elseif strcmp(initialization,'greedy')
% select first point at random
seeds(:,1) = X(:,floor(rand*n)+1);
% replace by farthest point
D = compute_distance_to_points(X,seeds(:,1));
[tmp,I] = max(D); I = I(1);
seeds(:,1) = X(:,I);
for i=2:nbCluster
D = compute_distance_to_points(X,seeds(:,1:i-1));
if i>2
D = sum(D);
end
[tmp,I] = max(D); I = I(1);
seeds(:,i) = X(:,I);
end
else
error('Unknown intialization type (should be random/greedy).');
end
end
D = compute_distance_to_points(X,seeds);
% compute region of influence
[tmp,B] = min(D);
for i=1:nb_iter
% compute region center
for k=1:nbCluster
I = find(B==k); % points belonging to cluster k
% geometric barycenter
Xk = X(:,I);
if ~isempty(I)
seeds(:,k) = sum(Xk,2)/length(I);
else
% warning('Empty cluster created.');
end
end
% update distance to seed
D = compute_distance_to_points(X,seeds);
% compute region of influence
[tmp,B] = min(D);
end
if nargout==3
% compute the energy
E = 0;
for k=1:nbCluster
I = find(B==k); % points belonging to cluster k
D = compute_distance_to_points(X(:,I),seeds(:,k));
E = E+sum(D);
end
E = E/( n*d );
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Auxiliary function
function D = compute_distance_to_points(X,seeds)
nbCluster = size(seeds,2);
n = size(X,2);
D = zeros(nbCluster,n);
d = size(X,1);
for k=1:nbCluster
% distance to seed
D(k,:) = sum( (X - repmat(seeds(:,k),1,n)).^2 );
end
% kmeans_light: K-means clustering using euclid distance.
%
% [dataCluster codebook] = kmeans_light(data, K, stopIter)
%
% Input and output arguments ([]'s are optional):
% data (matrix) of size NxD. N is the number of data (classifiee)
% vectors,and D is the dimension of each vector.
% K (scalar) The number of clusters.
% stopIter (scalar) The threshold [0, 1] to stop learning iterations.
% Default is .05, and smaller makes continue interations.
% dataCluster (matrix) of size Nx1: integers indicating the cluster indicies.
% dataCluster(i) is the cluster id for data item i.
% codebook (matrix) of size KxD: set of cluster centroids, VQ codewords.
%
% See also: autolabel.m
%
% Author : Naotoshi Seo
% Date : April, 2006
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [dataCluster codebook distortion] = kmeans_light(data, K, stopIter)
if nargin < 3,
stopIter = .05;
end
[N dim] = size(data);
if K > N,
error('K must be less than or equal to the # of data');
end
% Initial codebook
codebook = data(randsample(N, K), :);
improvedRatio = Inf;
distortion = Inf;
iter = 0;
while true
% Calculate euclidean distances between each sample and each codeword
d = eucdist2(data, codebook);
% Assign each sample to the nearest codeword (centroid)
[dataNearClusterDist, dataCluster] = min(d, [], 2);
% distortion. If centroids are unchanged, distortion is also unchanged.
% smaller distortion is better
old_distortion = distortion;
distortion = mean(dataNearClusterDist);
% If no more improved, break;
improvedRatio = 1 - (distortion / old_distortion);
% fprintf('%d: improved ratio = %f\n', iter, improvedRatio); iter = iter + 1;
if improvedRatio <= stopIter, break, end;
% Renew codebook
for i=1:K
% Get the id of samples which were clusterd into cluster i.
idx = find(dataCluster == i);
% Calculate centroid of each cluter, and replace codebook
codebook(i, :) = mean(data(idx, :));
end
end
%%%% Euclidean distance matrix between row vectors in X and Y %%%%%%%
% Input and output arguments
% X: NxDim matrix
% Y: PxDim matrix
% d: distance matrix of size NxP
function d=eucdist2(X,Y);
U=~isnan(Y); Y(~U)=0;
V=~isnan(X); X(~V)=0;
d=abs(X.^2*U'+V*Y'.^2-2*X*Y');
function y = randsample(n, k, replace, w)
%RANDSAMPLE Random sample, with or without replacement.
% Y = RANDSAMPLE(N,K) returns Y as a 1-by-K vector of values sampled
% uniformly at random, without replacement, from the integers 1:N.
%
% Y = RANDSAMPLE(POPULATION,K) returns K values sampled uniformly at
% random, without replacement, from the values in the vector POPULATION.
%
% Y = RANDSAMPLE(...,REPLACE) returns a sample taken with replacement if
% REPLACE is true, or without replacement if REPLACE is false (the default).
%
% Y = RANDSAMPLE(...,true,W) returns a weighted sample, using positive
% weights W, taken with replacement. W is often a vector of probabilities.
% This function does not support weighted sampling without replacement.
%
% Example: Generate a random sequence of the characters ACGT, with
% replacement, according to specified probabilities.
%
% R = randsample('ACGT',48,true,[0.15 0.35 0.35 0.15])
%
% See also RAND, RANDPERM.
% Copyright 1993-2004 The MathWorks, Inc.
% $Revision: 1.1.4.1 $ $Date: 2003/11/01 04:28:51 $
if nargin < 2
error('stats:randsample:TooFewInputs','Requires two input arguments.');
elseif numel(n) == 1
population = [];
else
population = n;
n = numel(population);
if length(population)~=n
error('stats:randsample:BadPopulation','POPULATION must be a vector.');
end
end
if nargin < 3
replace = false;
end
if nargin < 4
w = [];
elseif ~isempty(w)
if length(w) ~= n
if isempty(population)
error('stats:randsample:InputSizeMismatch',...
'W must have length equal to N.');
else
error('stats:randsample:InputSizeMismatch',...
'W must have the same length as the population.');
end
else
p = w(:)' / sum(w);
end
end
switch replace
% Sample with replacement
case {true, 'true', 1}
if isempty(w)
y = ceil(n .* rand(k,1));
else
[dum, y] = histc(rand(k,1),[0 cumsum(p)]);
end
% Sample without replacement
case {false, 'false', 0}
if k > n
if isempty(population)
error('stats:randsample:SampleTooLarge',...
'K must be less than or equal to N for sampling without replacement.');
else
error('stats:randsample:SampleTooLarge',...
'K must be less than or equal to the population size.');
end
end
if isempty(w)
% If the sample is a sizeable fraction of the population,
% just randomize the whole population (which involves a full
% sort of n random values), and take the first k.
if 4*k > n
rp = randperm(n);
y = rp(1:k);
% If the sample is a small fraction of the population, a full sort
% is wasteful. Repeatedly sample with replacement until there are
% k unique values.
else
x = zeros(1,n); % flags
sumx = 0;
while sumx < k
x(ceil(n * rand(1,k-sumx))) = 1; % sample w/replacement
sumx = sum(x); % count how many unique elements so far
end
y = find(x > 0);
y = y(randperm(k));
end
else
error('stats:randsample:NoWeighting',...
'Weighted sampling without replacement is not supported.');
end
otherwise
error('stats:randsample:BadReplaceValue',...
'REPLACE must be either true or false.');
end
if ~isempty(population)
y = population(y);
else
y = y(:);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [membership,means,rms] = kmeansML(k,data,varargin)
% [membership,means,rms] = kmeansML(k,data,...)
%
% Multi-level kmeans.
% Tries very hard to always return k clusters.
%
% INPUT
% k Number of clusters
% data dxn matrix of data points
% 'maxiter' Max number of iterations. [30]
% 'dtol' Min change in center locations. [0]
% 'etol' Min percent change in RMS error. [0]
% 'ml' Multi-level? [true]
% 'verbose' Verbose level. [0]
% 0 = none
% 1 = textual
% 2 = visual
%
% OUTPUT
% membership 1xn cluster membership vector
% means dxk matrix of cluster centroids
% rms RMS error of model
%
% October 2002
% David R. Martin <dmartin@eecs.berkeley.edu>
% process options
maxIter = 30;
dtol = 0;
etol = 0;
ml = true;
verbose = 0;
for i = 1:2:numel(varargin),
opt = varargin{i};
if ~ischar(opt), error('option names not a string'); end
if i==numel(varargin), error(sprintf('option ''%s'' has no value',opt)); end
val = varargin{i+1};
switch opt,
case 'maxiter', maxIter = max(1,val);
case 'dtol', dtol = max(0,val);
case 'etol', etol = max(0,val);
case 'ml', ml = val;
case 'verbose', verbose = val;
otherwise, error(sprintf('invalid option ''%s''',opt));
end
end
[membership,means,rms] = ...
kmeansInternal(k,data,maxIter,dtol,etol,ml,verbose,1);
function [membership,means,rms] = kmeansInternal(...
k,data,maxIter,dtol,etol,ml,verbose,retry)
[d,n] = size(data);
perm = randperm(n);
% compute initial means
rate = 3;
minN = 50;
coarseN = round(n/rate);
if ~ml | coarseN < k | coarseN < minN,
% pick random points as means
means = data(:,perm(1:k));
else
% recurse on random subsample to get means
coarseData = data(:,perm(1:coarseN));
[coarseMem,means] = ...
kmeansInternal(k,coarseData,maxIter,dtol,etol,ml,verbose,0);
end
% Iterate.
iter = 0;
rms = inf;
if verbose>0, fwrite(2,sprintf('kmeansML: n=%d d=%d k=%d [',n,d,k)); end
while iter < maxIter,
if verbose>0, fwrite(2,'.'); end
iter = iter + 1;
% Compute cluster membership and RMS error.
rmsPrev = rms;
[membership,rms] = computeMembership(data,means);
% Compute new means and cluster counts.
prevMeans = means;
[means,counts] = computeMeans(k,data,membership);
% The error should always decrease.
if rms > rmsPrev, error('bug: rms > rmsPrev'); end
% Check for convergence.
rmsPctChange = 2 * (rmsPrev - rms) / (rmsPrev + rms + eps);
maxMoved = sqrt(max(sum((prevMeans-means).^2)));
if rmsPctChange <= etol & maxMoved <= dtol,
break;
end
% Visualize.
if verbose>1,
kmeansVis(data,membership,means);
end
end
[membership,rms] = computeMembership(data,means);
if verbose>0, fwrite(2,sprintf('] rms=%.3g\n',rms)); end
% If there's an empty cluster, then re-run kmeans.
% Retry a fixed number of times.
maxRetries = 3;
if find(counts==0),
if retry < maxRetries,
disp('Warning: Re-runing kmeans due to empty cluster.');
[membership,means] = kmeansInternal( ...
k,data,maxIter,dtol,etol,ml,verbose,retry+1);
else
disp('Warning: There is an empty cluster.');
end
end
function [membership,rms] = computeMembership(data,means)
z = distSqr(data,means);
[d2,membership] = min(z,[],2);
rms = sqrt(mean(d2));
function [means,counts] = computeMeans(k,data,membership)
[d,n] = size(data);
means = zeros(d,k);
counts = zeros(1,k);
for i = 1:k,
ind = find(membership==i);
counts(i) = length(ind);
means(:,i) = sum(data(:,ind),2) / max(1,counts(i));
end
% for i = 1:n,
% j = membership(i);
% means(:,j) = means(:,j) + data(:,i);
% counts(j) = counts(j) + 1;
% end
% for j = 1:k,
% means(:,j) = means(:,j) / max(1,counts(j));
% end
function z = distSqr(x,y)
% function z = distSqr(x,y)
%
% Return matrix of all-pairs squared distances between the vectors
% in the columns of x and y.
%
% INPUTS
% x dxn matrix of vectors
% y dxm matrix of vectors
%
% OUTPUTS
% z nxm matrix of squared distances
%
% This routine is faster when m<n than when m>n.
%
% David Martin <dmartin@eecs.berkeley.edu>
% March 2003
% Based on dist2.m code,
% Copyright (c) Christopher M Bishop, Ian T Nabney (1996, 1997)
if size(x,1)~=size(y,1),
error('size(x,1)~=size(y,1)');
end
[d,n] = size(x);
[d,m] = size(y);
% z = repmat(sum(x.^2)',1,m) ...
% + repmat(sum(y.^2),n,1) ...
% - 2*x'*y;
z = x'*y;
x2 = sum(x.^2)';
y2 = sum(y.^2);
for i = 1:m,
z(:,i) = x2 + y2(i) - 2*z(:,i);
end