Permalink
Please sign in to comment.
94
cosineknn.m
| @@ -0,0 +1,94 @@ | ||
| +function S = cosineknn(A,K) | ||
| +% COSINEKNN Compute k-nearest neighbors with a cosine distance metric | ||
| +% | ||
| +% S = cosineknn(A,k) computes a cosine similarity metric between the | ||
| +% vertices of A or the upper half of a bipartite graph A | ||
| +% | ||
| + | ||
| +% David F. Gleich | ||
| +% Copyright, Stanford University, 2009 | ||
| + | ||
| +% History | ||
| +% 2009-08-12: Initial version | ||
| + | ||
| +% TODO support struct input | ||
| +% TODO support other similarity functions | ||
| +% TODO allow option for diagonal similarity too | ||
| +% TODO allow option for triplet output | ||
| + | ||
| + | ||
| +if isstruct(A), | ||
| + error('cosineknn:notSupported','struct input is not supported yet'); | ||
| +else | ||
| + [rp ci ai]=sparse_to_csr(A); check=1; | ||
| + [rpt cit ait]=sparse_to_csr(A'); | ||
| + [n,m] = size(A); | ||
| +end | ||
| + | ||
| +% compute a row norm (which means we accumlate the column indices of At) | ||
| +rn = accumarray(cit,ait.^2); | ||
| +rn = sqrt(rn); | ||
| +rn(rn>eps(1)) = 1./rn(rn>eps(1)); % do division once | ||
| + | ||
| +% the output is a n-by-n matrix, allocate storage for it | ||
| +si = zeros(n*K,1); | ||
| +sj = zeros(n*K,1); | ||
| +sv = zeros(n*K,1); | ||
| +nused = 0; | ||
| + | ||
| +dwork = zeros(n,1); | ||
| +iwork = zeros(n,1); | ||
| +iused = zeros(n,1); | ||
| +for i=1:n | ||
| + % for each row of A, compute an inner product against all rows of A. | ||
| + % to do this efficiently, we know that the inner-product will | ||
| + % only be non-zero if two rows of A overlap in a column. Thus, in | ||
| + % row i of A, we look at all non-zero columns, and then look at all | ||
| + % rows touched by those columns in the A' matrix. We compute | ||
| + % the similarity metric, then sort and only store the top k. | ||
| + | ||
| + curused = 0; % track how many non-zeros we compute during this operation | ||
| + | ||
| + for ri=rp(i):rp(i+1)-1 | ||
| + % find all columns j in row i | ||
| + j = ci(ri); | ||
| + aij = ai(ri)*rn(i); | ||
| + for rit=rpt(j):rpt(j+1)-1 | ||
| + % find all rows k in column j | ||
| + k = cit(rit); | ||
| + if k==i, continue; end % skip over the standard entries | ||
| + akj = ait(rit)*rn(k); | ||
| + if iwork(k)>0 | ||
| + % we already have a non-zero between row i and row k | ||
| + dwork(iwork(k)) = dwork(iwork(k)) + aij*akj; | ||
| + else | ||
| + % we need to add a non-zero betwen row i and row k | ||
| + curused = curused + 1; | ||
| + iwork(k) = curused; | ||
| + dwork(curused) = aij*akj; | ||
| + iused(curused) = k; | ||
| + end | ||
| + end | ||
| + end | ||
| + % don't sort if fewer than K elements | ||
| + if curused < K, | ||
| + sperm = 1:curused; | ||
| + else | ||
| + [sdwork sperm] = sort(dwork(1:curused),1,'descend'); | ||
| + end | ||
| + for k=1:min(K,length(sperm)) | ||
| + nused = nused + 1; | ||
| + si(nused) = i; | ||
| + sj(nused) = iused(sperm(k)); | ||
| + sv(nused) = dwork(sperm(k)); | ||
| + end | ||
| + % reset the iwork array, no need to reset dwork, as it's just | ||
| + % overwritten | ||
| + for k=1:curused | ||
| + iwork(iused(k)) = 0; | ||
| + end | ||
| +end | ||
| +si = si(1:nused); | ||
| +sj = sj(1:nused); | ||
| +sv = sv(1:nused); | ||
| +S = sparse(si,sj,sv,n,n); |
0 comments on commit
ca8170d