In [1]:
using LightGraphs, SimpleWeightedGraphs
using CPUTime
using DelimitedFiles
using Base
using DataStructures
using Base.Threads


In [2]:
nthreads()

4

# Input

In [29]:
# Run this only for diagnostic purposes
numVertex=6
g = SimpleWeightedGraph(numVertex)
Edges = [(1,2,1), (1,3,3), (2,3,1), (2,4,7), (3,5,1), (2,5,3), (4,5,1), (4,6,1), (5,6,6)]; # Manual edges input for initial sanity check, format is (src, dst, weight)
for edge in Edges
    add_edge!(g, edge[1], edge[2], edge[3])
end

In [13]:
function readDataFromFile(filename="in7.txt")
    open(filename, "r") do f 
        n, m = split(readline(f))
        n = parse(Int, n)
#         cnt=0
        sources = Vector{Int}()
        destinations = Vector{Int}()
        weights = Vector{Float64}()
        for ln in eachline(f)
#             a,b,c = split(ln)
            a,b,c = readdlm(IOBuffer(ln), Int)
            if(a==0)
                a=n
            end
            if(b==0)
                b=n
            end
#             cnt+=1
#             if(cnt%10000==0)
#                 println(cnt)
#             end
            if(c ==0)
                push!(sources,a)
                push!(destinations,b)
                push!(weights,0.01)
            else
                push!(sources,a)
                push!(destinations,b)
                push!(weights,c)
            end
        end
        g = SimpleWeightedGraph(sources, destinations, weights)
        return g
    end
end
# for in_n_10e5_m_5e6.in, generate the file from graph_generator.ipynb
g = readDataFromFile("in_n_10e5_m_5e6.in")
# g = readDataFromFile("in7.txt")

{100000, 5000000} undirected simple Int64 graph with Float64 weights

# Initialize Internal Variables

In [143]:
connected_vs = IntDisjointSets(nv(g))

IntDisjointSets([1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  99991, 99992, 99993, 99994, 99995, 99996, 99997, 99998, 99999, 100000], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 100000)

In [144]:
joined_nodes = Dict{Int, Vector{Int}}()
for j in 1:nv(g)
    i = find_root(connected_vs,j)
    if(haskey(joined_nodes,i))
        push!(joined_nodes[i],j)
    else
        joined_nodes[i] = [j]
    end
end

In [145]:
mst = Vector{edgetype(g)}()
sizehint!(mst, nv(g) - 1)
res =0

0

In [146]:
MAX_INT=2000000000;
minCost = fill(MAX_INT, nv(g)+1);
minNodeTgt = Vector{Int}(1:nv(g)+1);
minNodeSrc = Vector{Int}(1:nv(g)+1);

# Internal Functions

In [109]:
function initMinCostArray()
    for i = 1:nv(g)
        minCost[i] = MAX_INT
        minNodeTgt[i] = i
        minNodeSrc[i] = i
    end
end

initMinCostArray (generic function with 1 method)

In [135]:
function findMinCostVertex()
    sets = Vector{Int}(first.(keys(joined_nodes)))
    @threads for i in sets
        # println("Accessing set ",i, " with sources ", joined_nodes[i])
        for src in joined_nodes[i]
            for dst in neighbors(g, src)
                weight = floor(Int,get_weight(g,src,dst))
                # println(src," -> ", dst, "=", weight, " ",  !in_same_set(connected_vs, src, dst) )
                root_src = find_root(connected_vs, src)
                root_dst = find_root(connected_vs, dst)
                if root_src != root_dst
                    if(minCost[root_src] > weight )
                        minCost[root_src] = weight
                        minNodeTgt[root_src] = dst
                        minNodeSrc[root_src] = src
                    end
                end
            end
            
        end
    end
end

findMinCostVertex (generic function with 1 method)

In [99]:
function contractVertex()
    for i in 1:nv(g)
        if(minCost[i]!= MAX_INT && !in_same_set(connected_vs, minNodeSrc[i], minNodeTgt[i]))
            # Connect the vertices, add mst to answer
            set1 = find_root(connected_vs, minNodeSrc[i])
            set2 = find_root(connected_vs, minNodeTgt[i])
            union!(connected_vs, minNodeSrc[i], minNodeTgt[i])
            global res+=minCost[i]
            push!(mst, SimpleWeightedEdge(minNodeSrc[i], minNodeTgt[i], Float64(minCost[i])))
            # Merge Vertices that has been connected together
            merge_target = find_root(connected_vs, minNodeSrc[i])
#             println("Set ",set1,"and set ",set2," will be joined to", merge_target)
            if merge_target!=set1
                for j in joined_nodes[set1]
                    push!(joined_nodes[merge_target],j)
                end
                delete!(joined_nodes,set1)
            end
            if merge_target!=set2
                for j in joined_nodes[set2]
                    push!(joined_nodes[merge_target],j)
                end
                delete!(joined_nodes,set2)
            end
        end
    end
end

contractVertex (generic function with 1 method)

In [100]:
function contractVertex()
    for i in 1:nv(g)
        if(minCost[i]!= MAX_INT && !in_same_set(connected_vs, minNodeSrc[i], minNodeTgt[i]))
            # Connect the vertices, add mst to answer
            set1 = find_root(connected_vs, minNodeSrc[i])
            set2 = find_root(connected_vs, minNodeTgt[i])
            union!(connected_vs, minNodeSrc[i], minNodeTgt[i])
            global res+=minCost[i]
            push!(mst, SimpleWeightedEdge(minNodeSrc[i], minNodeTgt[i], Float64(minCost[i])))
            # println(minNodeSrc[i]," ", minNodeTgt[i]," ", minCost[i])
            # Merge Vertices that has been connected together
            merge_target = find_root(connected_vs, minNodeSrc[i])
#             println("Set ",set1,"and set ",set2," will be joined to", merge_target)
            if merge_target!=set1
                for j in joined_nodes[set1]
                    push!(joined_nodes[merge_target],j)
                end
                delete!(joined_nodes,set1)
            end
            if merge_target!=set2
                for j in joined_nodes[set2]
                    push!(joined_nodes[merge_target],j)
                end
                delete!(joined_nodes,set2)
            end
        end
    end
end

contractVertex (generic function with 1 method)

# Main Function

In [141]:
function boruvka_MST(maxItr = convert(Int64, round(log2(nv(g))+1, digits=0)))
    i=1
    println("Max iteration: ", maxItr)
    while(i<maxItr && length(mst)< nv(g)-1)
        i+=1
        println(i, " ", length(mst))
        initMinCostArray()
        findMinCostVertex()
        contractVertex()
#         println(res)
#         println(length(mst))
    end
end

boruvka_MST (generic function with 2 methods)

In [None]:
@time @CPUtime boruvka_MST()

Max iteration: 18
2 0
3 75092
4 94903
5 99130
6 99868


In [None]:
res

In [None]:
length(mst)

In [None]:
6018132840