In [None]:
using BayesNets
using QI
using LightGraphs
using GraphPlot

In [None]:
star(A, B) = sqrtm(B)*A*sqrtm(B) # operator z (#2)

In [None]:
nodes(bn) = [n.target for n in bn.cpds]

In [None]:
roA = HermitianMatrix([1/3, 1/3, 1/3])
roB = HermitianMatrix([1/3, 1/3, 1/3])
roC = HermitianMatrix([
        0,1/2,1/2, #A=0, B=0
        0,0,1, #A=0, B=1
        0,1,0, #A=0, B=2
        0,0,1, #A=1, B=0
        1/2,0,1/2, #A=1, B=1
        1,0,0, #A=1, B=2
        0,1,0, #A=2, B=0
        1,0,0, #A=2, B=1
        1/2,1/2,0 #A=2, B=2
        ]); #

a_cpd = DiscreteQCPD(:a, [], [], 3, roA)
b_cpd = DiscreteQCPD(:b, [], [], 3, roB)
c_cpd = DiscreteQCPD(:c, [:a, :b], [3,3], 3, roC)
monty_bn = AcausalStructure()
push!(monty_bn, a_cpd)
push!(monty_bn, b_cpd)
push!(monty_bn, c_cpd)

gplot(Graph(monty_bn.dag), nodelabel=nodes(monty_bn))


In [None]:
roAB =1/3*(ket(0,9)+ket(4,9)+ket(8,9))* (bra(0,9)+bra(4,9)+bra(8,9)) 

roC = HermitianMatrix([
        0,1/2,1/2, #A=0, B=0
        0,0,1, #A=0, B=1
        0,1,0, #A=0, B=2
        0,0,1, #A=1, B=0
        1/2,0,1/2, #A=1, B=1
        1,0,0, #A=1, B=2
        0,1,0, #A=2, B=0
        1,0,0, #A=2, B=1
        1/2,1/2,0 #A=2, B=2
        ]); #

ab_cpd = DiscreteQCPD(:ab, [], [], 9, roAB)
c_cpd = DiscreteQCPD(:c, [:ab], [9], 3, roC)
quantum_bn = AcausalStructure()
push!(quantum_bn, ab_cpd)
push!(quantum_bn, c_cpd)

gplot(Graph(quantum_bn.dag), nodelabel=nodes(quantum_bn))

In [None]:
roA = HermitianMatrix([.5, .5])

roBwA = HermitianMatrix([.5, .5, .4, .6])

roCwA = HermitianMatrix([.7, .3, .2, .8])

roDwB = HermitianMatrix([.9, .1, .5, .5])

roEwC = HermitianMatrix([.3, .7, .6, .4])

roFwDE = HermitianMatrix([.01, .99, .01, .99, .01, .99, .99, .01])

roGwC = HermitianMatrix([.8, .2, .1, .9])

roHwEG = HermitianMatrix([.05, .95, .95, .05, .95, .05, .95, .05])

a_cpd = DiscreteQCPD(:a, [], [], 2, roA)
b_cpd = DiscreteQCPD(:b, [:a], [2],2, roBwA)
c_cpd = DiscreteQCPD(:c, [:a], [2],2, roCwA)
d_cpd = DiscreteQCPD(:d, [:b], [2],2, roDwB)
e_cpd = DiscreteQCPD(:e, [:c], [2],2, roEwC)
f_cpd = DiscreteQCPD(:f, [:d, :e], [2,2],2, roFwDE)
g_cpd = DiscreteQCPD(:g, [:c], [2],2, roGwC)
h_cpd = DiscreteQCPD(:h, [:e, :g], [2,2],2, roHwEG)

example_bn = AcausalStructure()
push!(example_bn, a_cpd)
push!(example_bn, b_cpd)
push!(example_bn, c_cpd)
push!(example_bn, d_cpd)
push!(example_bn, e_cpd)
push!(example_bn, f_cpd)
push!(example_bn, g_cpd)
push!(example_bn, h_cpd)

gplot(example_bn.dag, nodelabel=nodes(example_bn))


In [None]:
function moral_graph(as::BayesNet)
    result = Graph(deepcopy(as.dag))
    for node in as.cpds
        for p1 in node.parents
            for p2 in node.parents
                p1_ind = as.name_to_index[p1]
                p2_ind = as.name_to_index[p2]
                if(p1_ind!=p2_ind)
                    add_edge!(result, p1_ind, p2_ind)
                end
            end
        end
    end
    return result
end

In [None]:
bn = example_bn
moral_bn = moral_graph(bn)
print(ne(moral_bn))
gplot(moral_bn, nodelabel=nodes(bn))

In [None]:
is_subset(s1::Set, s2::Set) = intersect(s1, s2) == s1

In [None]:
function triangulate(g::Graph, as::AcausalStructure)
    g_copy = [false for _ in vertices(g)]
    g = deepcopy(g)
    nl = nodes(as)
    cliques = Vector{Set}()
    while(!all(g_copy))
        least_edges_to_be_added = Inf
        chosen_vertex = 0
        chosen_cluster = Set()
        for v=1:length(g_copy)
            if g_copy[v]
                continue
            else
                cluster = Set()
                for e in edges(g)
                    if (v==src(e) || v==dst(e)) && !g_copy[src(e)] && !g_copy[dst(e)]
                        push!(cluster, src(e))
                        push!(cluster, dst(e))
                    end
                end
                edges_todo = 0
                for v1 in cluster
                    for v2 in cluster
                        if v1 != v2 && !in((v1, v2), edges(g))
                            edges_todo +=1
                        end
                    end
                end
                edges_todo /= 2
                
                if edges_todo < least_edges_to_be_added || 
                    ((edges_todo == least_edges_to_be_added) && 
                        (prod([as.cpds[n].ncategories for n in cluster]) <= prod([as.cpds[n].ncategories for n in chosen_cluster])))                      
                    least_edges_to_be_added = edges_todo
                    chosen_vertex = v
                    chosen_cluster = cluster
                end
            end       
        end
        chosen_nodes = Set([nl[n] for n in chosen_cluster])
        if !any([is_subset(chosen_nodes, clique) for clique in cliques])
            push!(cliques, chosen_nodes)
        end
        g_copy[chosen_vertex] = true
        for v1 in chosen_cluster
            for v2 in chosen_cluster
                if v1 != v2 && !in((v1, v2), edges(g))
                    add_edge!(g, v1, v2)
                end
            end
        end     
    end
    return g, cliques
end

In [None]:
println(ne(moral_bn))

tri_moral_bn, cliques = triangulate(moral_bn,bn)
println(ne(tri_moral_bn))

# before, cliques were sets
# they are now transformed to vectors and their ordering is the same 
# as in their bayes_net DAG
cliques = [sort([c for c in clique], by=c -> bn.name_to_index[c]) for clique in cliques]
println(cliques)
gplot(tri_moral_bn ,nodelabel=nodes(bn))


In [None]:
struct JoinTree
    graph::Graph
    clusters::Vector{Vector}
    vertex_to_num::Dict{Int64, Union{Float64, Complex, Matrix}}
    edge_to_num::Dict{Set{Int}, Union{Float64, Complex, Matrix}}
    
end

In [None]:
function sepset_cost(sepset::Set, as::AcausalStructure)
    weights = [as.cpds[as.name_to_index[v]].ncategories for v in sepset]
    
    return 0
end

In [None]:
function sepset_comparator(c1, c2)
    sepset = intersect(c1, c2)
    -length(sepset)
end

In [None]:
cluster_size(cluster::Vector{Symbol}, as::AcausalStructure) = prod([as.cpds[as.name_to_index[v]].ncategories for v in cluster])


In [None]:
function make_join_tree(clusters::Vector, as::AcausalStructure)
    candidate_sepsets = []
    trees = Dict([c => c for c in clusters])
    chosen_sepsets = Set()
    Dict([v => 1.0 for v=1:length(clusters)])
    
    result = JoinTree(
        Graph(length(clusters)), 
        clusters, 
        Dict([v => eye(cluster_size(clusters[v], as)) for v=1:length(clusters)]),
        Dict()
    )
    for c1 = 1:length(clusters)
        for c2= 1:length(clusters)
            if c1 != c2
                push!(candidate_sepsets,(c1, c2))
            end
        end
    end
    candidate_sepsets = sort(candidate_sepsets, by=c -> sepset_comparator(clusters[c[1]], clusters[c[2]]))
    i = 1
    n = length(clusters)

    while length(chosen_sepsets) < n-1
        i1, i2 = candidate_sepsets[i]
        c1, c2 = clusters[i1], clusters[i2]
        sepset = intersect(c1, c2)
        if (trees[c1] != trees[c2]) && !any([sepset==s for s in chosen_sepsets])
            push!(chosen_sepsets, sepset)
            trees[c1] = trees[c2] = union(c1, c2)
#             println(c1, " ", c2, " ", sepset)
            add_edge!(result.graph, i1, i2)
            push!(result.edge_to_num, Set([i1, i2]) => eye(cluster_size(sepset, as)))
        end
        i += 1
    end
    result
end

In [None]:
jt = make_join_tree(cliques, bn)
jt.clusters

In [None]:
gplot(jt.graph, nodelabel=[join([string(v) for v in c]) for c in jt.clusters])

In [None]:
function family(v_cpd::DiscreteQCPD)
    union(Set([v_cpd.target]), Set(v_cpd.parents))
end

In [None]:
print(jt.clusters[1])
jt.vertex_to_num[1]

In [None]:
function initialize(jt::JoinTree, as::AcausalStructure)
    jt = deepcopy(jt)
    for v1 in reverse(as.cpds)
        parent_cluster_ind = [c for c=1:length(jt.clusters) if  is_subset(family(v1), Set(jt.clusters[c]))][1]
        previous_init = jt.vertex_to_num[parent_cluster_ind]
        mul_elem = eye(1)
        println(v1.target, " chooses ", jt.clusters[parent_cluster_ind])
#         print("(")
        for v2 in as.cpds
            if v2.target in jt.clusters[parent_cluster_ind]
                if v2.target == v1.target
                    mul_elem = kron(mul_elem, v2.conditional_distribution.p)
#                     print(" x ro", v2.target)
                elseif !in(v2.target, v1.parents)
                    mul_elem = kron(mul_elem, eye(v2.ncategories))
#                     print(" x I", v2.target)
                end
            end
        end
#         print(") * ")

#         println(mul_elem)
        jt.vertex_to_num[parent_cluster_ind] = star(previous_init, mul_elem)
#         println(size(v.conditional_distribution.p), " ", size(previous_init))
#         jt.vertex_to_num[parent_cluster_ind] = kron(previous_init, v.conditional_distribution.p)
#         println(v.target, " ", jt.cliques[parent_cluster_ind]," ", in(v.target, jt.cliques[parent_cluster_ind]))
    end
    println()
    jt
end

In [None]:
init_jt = initialize(jt, bn)
# print(init_jt.cliques)
# real(init_jt.vertex_to_num[5])

In [None]:
init_jt.clusters

In [None]:
ABC = init_jt.vertex_to_num[1]
real(ptrace(ABC, [3,3,3], [1,2]))

In [None]:
ADC = init_jt.vertex_to_num[5]
ptrace(ADC, [2,2,2], [2,3])

In [None]:
init_jt.vertex_to_num[5]

In [None]:
function single_message_pass(from_ind::Int, to_ind::Int, jt::JoinTree, as::AcausalStructure)
    jt = deepcopy(jt)
    if (from_ind, to_ind) in edges(jt.graph)
        cluster_from = jt.clusters[from_ind]
        cluster_to = jt.clusters[to_ind]
        sepset = intersect(cluster_from, cluster_to)
        println(cluster_from, " ", sepset, " ", cluster_to)
        to_trace_out_sym = setdiff(cluster_from, sepset)
        to_trace_out_ind = [findfirst(cluster_from, s) for s in to_trace_out_sym]
        println(to_trace_out_ind)
        from_variables_sizes = [as.cpds[as.name_to_index[v]].ncategories for v in cluster_from]
        println(from_variables_sizes)
        cluster_from_num = jt.vertex_to_num[from_ind]
        old_sepset_num = jt.edge_to_num[Set([from_ind, to_ind])]
        new_sepset_num = ptrace(cluster_from_num, from_variables_sizes, to_trace_out_ind)
        println(new_sepset_num)
        println(old_sepset_num)
        
        jt.edge_to_num[Set([from_ind, to_ind])] = new_sepset_num
        
        cluster_to_num = jt.vertex_to_num[to_ind]
        
        message = new_sepset_num / old_sepset_num
        
        message_sym = Vector(sepset)
        println(message_sym)
        for v in cluster_to
            if !in(v, message_sym)
                push!(message_sym, v)
                message = kron(message, eye(as.cpds[as.name_to_index[v]].ncategories))
            end
        end
        println(message_sym)
        message_sorted_inds = [findfirst(message_sym, s.target) for s in as.cpds if s.target in message_sym]
        println(message_sorted_inds) 
        message_dims = [as.cpds[as.name_to_index[s]].ncategories for s in message_sym]
        message_sorted = permute_systems(message, message_dims, message_sorted_inds )
        jt.vertex_to_num[to_ind] = star(cluster_to_num, message_sorted) # event(clique_to_num, message_sorted)        
    end
    return jt
end

In [None]:
init_jt.edge_to_num
init_jt.vertex_to_num[5]

In [None]:
init_jt.vertex_to_num[6]

In [None]:
passed = single_message_pass(5, 6, init_jt, bn);

In [None]:
passed.vertex_to_num[5]

In [None]:
passed.vertex_to_num[6]

In [None]:
ptrace(passed.vertex_to_num[6], [2,2,2], [2, 3])

In [None]:
function global_propagation(jt::JoinTree, as::AcausalStructure)
    jt = deepcopy(jt)
    cluster_marks = [true for c in jt.clusters]
    jt
end

In [None]:
global_propagation(jt, bn)