Skip to content

Commit

Permalink
fixing test
Browse files Browse the repository at this point in the history
  • Loading branch information
jumutc committed Sep 19, 2015
1 parent 7f5dfd5 commit 2d1fef3
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/algorithms/stochastic_rk_means.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@ run_algorithm(::Type{DROP_OUT}, X, Y, dfunc::Function, alg_params::Vector, k::In
run_algorithm(::Type{ADA_L1RDA}, X, Y, dfunc::Function, alg_params::Vector, k::Int, max_iter::Int, tolerance::Float64, online_pass, train_idx) =
adaptive_l1rda_alg(dfunc, X, Y, alg_params..., k, max_iter, tolerance, online_pass, train_idx)

function At_mul_B!(C::Array{Float64,2}, A::SparseMatrixCSC, B::SparseMatrixCSC)
assert(size(A,1) == size(B,1))
C[:,:] = At_mul_B(A,B)
end

function At_mul_B!(C::Array{Float64,2}, A::SparseMatrixCSC, B::Array{Float64,2})
At_mul_B!(C,full(A),B)
end
At_mul_B!(C::Array{Float64,2}, A::SparseMatrixCSC, B::SparseMatrixCSC) = begin C[:,:] = At_mul_B(A,B) end
At_mul_B!(C::Array{Float64,2}, A::SparseMatrixCSC, B::Array{Float64,2}) = At_mul_B!(C,full(A),B)
At_mul_B!(C::Array{Float64,2}, A::Array{Float64,2}, B::SparseMatrixCSC) = At_mul_B!(C,A,full(B))

# core algorithmic part
function stochastic_rk_means{A <: Algorithm}(X, rk_means::RK_MEANS{A}, alg_params::Vector, k::Int, max_iter::Int,
Expand Down Expand Up @@ -64,14 +59,15 @@ function stochastic_rk_means{A <: Algorithm}(X, rk_means::RK_MEANS{A}, alg_param

result = @parallel (hcat) for cluster_id in unique(mappings)
cluster_idx = find(mappings .== cluster_id)
run_algorithm(rk_means.support_alg,X,Y,dfunc,alg_params,k,max_iter,tolerance,online_pass,cluster_idx)[1]
r = run_algorithm(rk_means.support_alg,X,Y,dfunc,alg_params,k,max_iter,tolerance,online_pass,cluster_idx)[1]
size(r) == (d,) ? r'' : r
end

# assign and check the result of parallel execution
if all(result .== 0) || all(isnan(result))
failed_mapping = true
w = rand(d,rk_means.k_clusters)
elseif size(result,2) != size(w,2)
elseif size(result) != size(w)
failed_mapping = true
diff = size(w,2) - size(result,2)
w = [result rand(d,diff)]
Expand Down

0 comments on commit 2d1fef3

Please sign in to comment.