diff --git a/cosineknn.m b/cosineknn.m new file mode 100644 index 0000000..b0f7ed9 --- /dev/null +++ b/cosineknn.m @@ -0,0 +1,94 @@ +function S = cosineknn(A,K) +% COSINEKNN Compute k-nearest neighbors with a cosine distance metric +% +% S = cosineknn(A,k) computes a cosine similarity metric between the +% vertices of A or the upper half of a bipartite graph A +% + +% David F. Gleich +% Copyright, Stanford University, 2009 + +% History +% 2009-08-12: Initial version + +% TODO support struct input +% TODO support other similarity functions +% TODO allow option for diagonal similarity too +% TODO allow option for triplet output + + +if isstruct(A), + error('cosineknn:notSupported','struct input is not supported yet'); +else + [rp ci ai]=sparse_to_csr(A); check=1; + [rpt cit ait]=sparse_to_csr(A'); + [n,m] = size(A); +end + +% compute a row norm (which means we accumlate the column indices of At) +rn = accumarray(cit,ait.^2); +rn = sqrt(rn); +rn(rn>eps(1)) = 1./rn(rn>eps(1)); % do division once + +% the output is a n-by-n matrix, allocate storage for it +si = zeros(n*K,1); +sj = zeros(n*K,1); +sv = zeros(n*K,1); +nused = 0; + +dwork = zeros(n,1); +iwork = zeros(n,1); +iused = zeros(n,1); +for i=1:n + % for each row of A, compute an inner product against all rows of A. + % to do this efficiently, we know that the inner-product will + % only be non-zero if two rows of A overlap in a column. Thus, in + % row i of A, we look at all non-zero columns, and then look at all + % rows touched by those columns in the A' matrix. We compute + % the similarity metric, then sort and only store the top k. + + curused = 0; % track how many non-zeros we compute during this operation + + for ri=rp(i):rp(i+1)-1 + % find all columns j in row i + j = ci(ri); + aij = ai(ri)*rn(i); + for rit=rpt(j):rpt(j+1)-1 + % find all rows k in column j + k = cit(rit); + if k==i, continue; end % skip over the standard entries + akj = ait(rit)*rn(k); + if iwork(k)>0 + % we already have a non-zero between row i and row k + dwork(iwork(k)) = dwork(iwork(k)) + aij*akj; + else + % we need to add a non-zero betwen row i and row k + curused = curused + 1; + iwork(k) = curused; + dwork(curused) = aij*akj; + iused(curused) = k; + end + end + end + % don't sort if fewer than K elements + if curused < K, + sperm = 1:curused; + else + [sdwork sperm] = sort(dwork(1:curused),1,'descend'); + end + for k=1:min(K,length(sperm)) + nused = nused + 1; + si(nused) = i; + sj(nused) = iused(sperm(k)); + sv(nused) = dwork(sperm(k)); + end + % reset the iwork array, no need to reset dwork, as it's just + % overwritten + for k=1:curused + iwork(iused(k)) = 0; + end +end +si = si(1:nused); +sj = sj(1:nused); +sv = sv(1:nused); +S = sparse(si,sj,sv,n,n); \ No newline at end of file