# DBN Structure Inference

The idea is to infer a posterior for the *structure* of a Dynamic Bayesian Network (DBN), given some data.

We formulate this task with the following model:

$$ P(G | X) \propto P(X | G) \cdot P(G) $$

* $P(G)$ is a prior distribution over DBN structures. We'll assume it has the form
$$P(G) \propto \exp \left( -\lambda |G \setminus G^\prime| \right)$$
where $|G \setminus G^\prime|$ denotes the number of edges in the graph, which are not present in some reference graph $G^\prime$.
* $P(X | G)$ is the marginal likelihood of the DBN structure. That is, it's the likelihood of the DBN after the network parameters have been integrated out -- it scores network *structure*. 
* If we assume some reasonable priors for network parameters, $P(X|G)$ can be obtained in closed form. In this work, we'll use the following marginal likelihood:
    
    $$P(X | G) \propto \prod_{i=1}^p (1 + n)^{-(2^{|\pi(i)|} - 1)/2} \left( X_i^{+ T} X_i^+ - \frac{n}{n+1} X_i^{+ T} B_i (B_i^T B_i)^{-1} B_i^T X_i^+ \right)^{-\frac{n}{2}}$$ 
    where $X$ and $B$ are matrices obtained from data. This marginal likelihood results from an empirical prior over the regression coefficients, and an improper ($\propto 1/\sigma^2$) prior for the regression "noise" variables.

In [1]:
include("DiGraph.jl")
using Gen
using PyPlot
using .DiGraphs
using LinearAlgebra
using CSV
using DataFrames

┌ Info: Recompiling stale cache file /home/dmerrell/.julia/compiled/v1.2/CSV/HHBkp.ji for CSV [336ed68f-0bac-5ca0-87d4-7b16caf5d00b]
└ @ Base loading.jl:1240


## Get some data

For now, we'll work with some data used by Hill et al. in their 2007 paper, _Bayesian Inference of Signaling Network Topology in a Cancer Cell Line_.

It gives the differential phosphorylation levels of 20 proteins, in a cancer cell line perturbed by EGF. This is a well-studied signaling pathway; the goal is to produce a graph describing the dependencies between proteins in this pathway. 

In [None]:
protein_names = CSV.read("data/protein_names.csv");
reference_adjacency = CSV.read("data/prior_graph.csv");
timesteps = CSV.read("data/time.csv")
timeseries_data = CSV.read("data/mukherjee_data.csv")

## Build the model

Implement the graph prior distribution:

$$P(G) \propto \exp \left( -\lambda |G \setminus G^\prime| \right)$$

In [None]:
struct GraphPrior <: Gen.Distribution{DiGraph} end
const graphprior = GraphPrior()

Gen.random(gp::GraphPrior, lambda::Float64, reference_graph::DiGraph) = reference_graph

function graph_edge_diff(g::DiGraph, g_ref::DiGraph)
    e1 = Set([g.edges[i,:] for i=1:size(g.edges)[1]])
    e_ref = Set([g_ref.edges[i,:] for i=1:size(g_ref.edges)[1]])
    return length(setdiff(e1, e_ref))
end
    
Gen.logpdf(gp::GraphPrior, graph::DiGraph, lambda::Float64, reference_graph::DiGraph) = -lambda * graph_edge_diff(graph, reference_graph)

graphprior(lambda::Float64, reference_graph::DiGraph) = Gen.random(graphprior, lambda, reference_graph);

Implement the DBN's marginal distribution:

$$P(X | G) \propto \prod_{i=1}^p (1 + n)^{-(2^{|\pi(i)|} - 1)/2} \left( X_i^{+ T} X_i^+ - \frac{n}{n+1} X_i^{+ T} B_i (B_i^T B_i)^{-1} B_i^T X_i^+ \right)^{-\frac{n}{2}}$$

Some things to note:
* We're kind of shoe-horning this marginal likelihood into Gen. The probabilistic programming ethos entails modeling the entire data-generating process. This ought to provide better performance during inference, though.

In [None]:
struct DBNMarginal <: Gen.Distribution{Array{Float64,3}} end
const dbnmarginal = DBNMarginal()

"""
DBNMarginal's sampling method does nothing.
In our inference task, the Xs will always be observed.
"""
Gen.random(dbn::DBNMarginal, parents::Vec{Vec{Int}}, T::Int) = zeros(length(parents), T)

function compute_B(idx::Tuple{Vararg{Int64,N} where N}, dataset::Array{Array{Float64,2},1},
                   B_cache::Dict{Tuple{Vararg{Int64,N} where N}, Array{Float64,2}})

function compute_lp_term(X::Array{Float64,2}, dataset::Array{Array{Float64,2},1}, i::Int, T::Int, parents::Array{Int,1},
                         lp_cache::Dict{Tuple{Vararg{Int64,N} where N}, Float64},
                         B_cache::Dict{Tuple{Vararg{Int64,N} where N}})

    # Check whether the term has already been computed:
    k = (i, parents...)
    if k in keys(lp_cache)
        return lp_cache[k]
    else
        
        x_plus = X[i,2:T]
        x_minus = X[i,1:T-1]
        B = compute_B()
        
        matprod = pinv(dot(transpose(B), B))
        
        cache[k] = -0.5*size(B)[2]*log(1.0+N) -  0.5*N*log(dot(transpose(x_plus), x_plus) - (n/(n+1.0))*t
        
end
    
"""
DBNMarginal's log_pdf, in effect, returns a score for the 
network topology. We use a dictionary to cache precomputed terms of the sum.
"""
function Gen.log_pdf(dbn::DBNMarginal, X::Array{Float64,3}, parents::Vec{Vec{Int}}, 
                     cache::Dict{Tuple{Int,Vararg{Int64,N} where N},Float64})
    
    
end

Implement a DBN:

In [None]:
@gen function dbn_node(x_prev::Vec{Float64}, parents::Vec{Int}, weights::Vec{Float}, noise::Float64)
    x_new = @trace(Gen.normal(dot()), :x)
    return x_new
end

layer_nodes = Gen.Map(dbn_node)

@gen function dbn_layer(timestep::Int, x_prev::Vec{Float64}, parents::Vec{Vec{Int}}, weights::Vec{Vec{Float64}})
    x_prev_repeat = fill(x_prev, length(parents))
    x_new = @trace(layer_nodes(x_prev_repeat, parents, weights), :variables)
    return x_new
end
    
dbn = Gen.Unfold(dbn_layer)

independent_series = Gen.Map(unfolded_layers)

@gen function dbn(V::Vec{Int}, parents::Vec{Vec{Int}}, T::Int)
    
    x = @trace(unfolded_layers(), :timeseries)
    return x
end

Implement our overall model:

In [None]:
function get_parent_vecs(G)
    return [sort(in_neighbors(G, v)) for v in sort(G.vertices)]
end
    
@gen function coeff_prior()

@gen function data_generator(X::Vec{Array{Float64,2}}, reference_graph::DiGraph{Int}, Tvec::Vec{Int})
    
    lambda = @trace(Gen.gamma(1,1), :lambda)
    
    G = @trace(GraphPrior(lambda, reference_graph), :G)
    V = sort(G.vertices)
    parents = get_parent_vecs(G)
    
    x_init = @trace(Gen.mvnormal(), :x_init)
    regression_coeffs = @trace(, :beta)
    regression_noise = @trace(, :noise)
    
    x = @trace(dbn(V, parents, ))

## Inference

### Metropolis-Hastings over directed graphs

Proposal distribution:

In [2]:
"""
Proposal distribution for exploring the unconstrained space of
directed graphs.

`expected_indegree` guides the exploration -- if a vertex's in-degree
is higher than this, then we are much more likely to remove an edge
from one of its parents.
"""
@gen function digraph_proposal(tr, expected_indegree::Float64)
    
    G = copy(t[:G])
    ordered_vertices = sort(G.vertices)
    V = length(G.vertices)
    
    u_idx = @trace(Gen.categorical(V), :u_idx)
    u = ordered_vertices[u_idx]
    ordered_inneighbors = sort(in_neighbors(G,u))
    in_deg = length(ordered_inneighbors)
    
    prob_remove = (in_deg / V) ^ log2(1.0*V / expected_degree)
    remove_edge = @trace(Gen.bernoulli(prob_remove), :remove_edge)
    
    if remove_edge
        v_idx = @trace(Gen.categorical(in_deg), :v_idx)
        v = ordered_inneighbors[v_idx]
        remove_edge!(G, v, u)
    else
        
        ordered_outneighbors = sort(out_neighbors(G,u))
        out_deg = length(ordered_outneighbors)
        neighbors = union(ordered_inneighbors, ordered_outneighbors)
        deg = length(neighbors)
        
        prob_reverse = (out_deg / deg)
        reverse_edge = @trace(Gen.bernoulli(prob_remove), :reverse_edge)
        if reverse_edge
            
            v_idx = @trace(Gen.categorical(out_deg), :v_idx)
            v = ordered_outneighbors[v_idx]
            remove_edge!(G, u, v)
            add_edge!(G, v, u)
        else
            nonparents = sort(setdiff(G.vertices, inneighbors))
            
            v_idx = @trace(Gen.categorical(length(nonparents)), :v_idx)
            v = nonneighbors[v_idx]
            add_edge!(G, v, u)
        end
        
    end
    
    return G
    
end;

DynamicDSLFunction{Any}(Dict{Symbol,Any}(), Dict{Symbol,Any}(), Type[Any, Float64], ##digraph_proposal#371, Bool[0, 0], false)

Involution function:

In [None]:
function digraph_involution(cur_tr, fwd_choices, fwd_ret, prop_args)
    
    # Update the trace 
    new_G = fwd_ret
    old_G = cur_tr[:G]
    update_choices = Gen.choicemap()
    update_choices[:G] = new_G
    new_tr, weight, retdiff, discard = Gen.update(cur_tr, Gen.get_args(cur_tr), (), update_choices)
    
    # figure out what has changed
    fwd_u_idx = fwd_choices[:u_idx]
    sorted_vertices = sort(old_G.vertices)
    fwd_u = sorted_vertices[fwd_u_idx]
    fwd_v_idx = fwd_choices[:v_idx]
    
    # Deduce the correct backward choices
    bwd_choices = Gen.choicemap()
    if fwd_choices[:remove_edge] # an edge was removed -- we must add it back.
        
        fwd_parents = in_neighbors(old_G, fwd_u)
        fwd_v = sort(fwd_parents)[fwd_v_idx]
        bwd_nonparents = sort(setdiff(new_G.vertices,fwd_parents))
        bwd_v_idx = indexin(fwd_v, bwd_nonparents)[1]
        
        bwd_choices[:u_idx] = fwd_u_idx 
        bwd_choices[:remove_edge] = false
        bwd_choices[:reverse_edge] = false
        bwd_choices[:v_idx] = bwd_v_idx
        
    else
        if fwd_choices[:reverse_edge] # an edge was reversed -- reverse it back.
            
            
            
            bwd_choices[:u_idx] = indexin()[1]
            bwd_choices[:remove_edge] = false
            bwd_choices[:reverse_edge] = true
            bwd_choices[:v_idx] = 
            
        else # an edge was added -- remove it.
            
            fwd_parents = sort(in_neighbors(old_G, fwd_u))
            fwd_v = fwd_parents[fwd_v_idx]
            bwd_v_idx = indexin(fwd_v, fwd_parents)[1]
            
            bwd_choices[:u_idx] = fwd_u_idx
            bwd_choices[:remove_edge] = true
            bwd_choices[:v_idx] = bwd_v_idx
            
        end
    end
    
    return new_tr, bwd_choices, weight
end

In [None]:
i = 10
parents = [1;2;3;4;5]

In [None]:
(i, parents...)

In [None]:
d = Dict{Tuple{Vararg{Int64,N} where N}, Float64}()

In [None]:
d[(4,5)] = 1.23

In [None]:
d[()] = 48.281

In [None]:
d

In [None]:
div(length((1,2,3)),2)

In [5]:
a = union([1;2;3],[4;1;8])

5-element Array{Int64,1}:
 1
 2
 3
 4
 8

In [11]:
indexin(8,a)[1]

5