# Sampling from a distribution over graphs

We will define a distribution over directed graphs, under the assumption that the vertices are known a-priori.

We'll focus on a class of distributions proposed by Mukherjee et al:

$$ P(G) \propto \exp \left( - \beta \cdot D(G, G^\prime) \right) $$

i.e., probability decreases exponentially with distance from some reference graph, $G^\prime$. 

The distance, $D$, could be any nonnegative notion of difference; but a useful default is Hamming distance between the graphs' edge sets.

In [1]:
include("DiGraph.jl")
using Gen
using .DiGraphs

## Define a density over graphs

In [2]:
# in Gen, distributions are defined as singleton types:
struct GraphPriorDist <: Distribution{DiGraph} end
const graphpriordist = GraphPriorDist()

GraphPriorDist()

In [31]:
# define a hamming distance between graphs
function edge_hamming(g1::DiGraph{T}, g2::DiGraph{T}) where T
    return length(symdiff([g1.edges[i,:] for i=1:size(g1.edges)[1]],
                          [g2.edges[i,:] for i=1:size(g2.edges)[1]]))
end

# Define sampling methods for the graph distribution.
function Gen.random(gpd::GraphPriorDist, reference_graph::DiGraph{T}, beta::Number, distance_func::Function) where T
    return reference_graph
end

# if distance isn't specified, use hamming distance
function Gen.random(gpd::GraphPriorDist, reference_graph::DiGraph{T}, beta::Number) where T
    return Gen.random(gpd, reference_graph, beta, edge_hamming)
end

graphpriordist(reference_graph::DiGraph{T}, beta::Number) where T = Gen.random(graphpriordist, reference_graph, beta)

# Define a log-probability for this distribution.
function Gen.logpdf(gpd::GraphPriorDist, dg::DiGraph{T}, reference_graph::DiGraph{T}, beta::Number, distance_func::Function) where T
    return exp( -beta * distance_func(dg, reference_graph) )
end

# Again: use hamming distance by default
function Gen.logpdf(gpd::GraphPriorDist, dg::DiGraph{T}, reference_graph::DiGraph{T}, beta::Number) where T
    return Gen.logpdf(gpd, dg, reference_graph, beta, edge_hamming)
end

There are some issues with our construction:
1. This distribution is _unnormalized_. 
2. This distribution is _not generative_. It doesn't prescribe a sampling method.


However, our implementation should allow Gen's Metropolis-Hastings sampler to 'play nicely' with this distribution.
1. The distribution is unnormalized -- this means that the `logpdf` will not return correct values. However, the thing that _really_ matters in MH sampling is **change** in log-probability. And our logpdf implementation will correctly compute those.
2. The distribution isn't generative, but MH won't require us to explicitly sample from the distribution -- the whole point of MH is that it will allow us to sample from (unnormalized) log densities that have no generative description.

## Write a probabilistic program

Practically a one-liner. It samples from our distribution -- that's all for now.

In [38]:
@gen function graph_generator(reference_graph::DiGraph{String}, beta::Float64)
    g = @trace(graphpriordist(reference_graph, beta), :graph)
    return g
end;

DynamicDSLFunction{Any}(Dict{Symbol,Any}(), Dict{Symbol,Any}(), Type[DiGraph{String}, Float64], ##graph_generator#371, Bool[0, 0], false)

In order to sample from this, we'll need to define a proposal distribution that explores a space of directed graphs.

## Give a proposal distribution

As a first pass, we'll define operations that allow us to explore the space of directed graphs in a mostly unconstrained fashion. 

That is, we won't enforce most graph properties: connectedness, acyclicity, in-degree, etc.
The one property we will enforce is no 2-cycles (mostly for aesthetic reasons).

### First (naive) proposal distribution
Our first proposal distribution is sort of naive.
It samples uniformly over the $N(N-1)$ possible directed edges.
If the edge already exists then it flips a coin, deciding whether to remove or reverse that edge.
If the edge doesn't exist but its reverse _does_, then we do similarly for that edge.
If neither the edge nor its reverse exists, then we propose an edge there.

We can expect that this proposal will yield inefficient sampling. For a sparse graph, it will tend to propose adding edges; assuming the reference graph is itself sparse, these proposals will probably be rejected.

In [51]:
# First (inefficient) proposal distribution
@gen function inefficient_digraph_proposal(cur_trace)
    
    G = cur_trace[:graph]
    N = length(G.vertices)
    
    v1 = uniform_discrete(1, N)
    v2 = uniform_discrete(1, N-1) # won't allow self-loops
    if v2 >= v1
        v2 += 1
    end
    
    # If an edge already exists: remove or reverse it
    if length(G.edges[(G.edges[:,1] .== v1) .& (G.edges[:,2] .== v2)]) > 0
        remove = bernoulli(0.5)
        if remove
            remove_edge!(G, v1, v2) # remove
        else
            remove_edge!(G, v1, v2) # reverse
            add_edge!(G, v2, v1)
        end
    # If the edge's reverse exists: remove or reverse that.    
    elseif length(G.edges[(G.edges[:,1] .== v2) .& (G.edges[:,2] .== v1)]) > 0
        remove = bernoulli(0.5)
        if remove
            remove_edge!(G, v2, v1) # remove
        else
            remove_edge!(G, v2, v1) # reverse
            add_edge!(G, v1, v2)
        end
    # If neither the edge nor its reverse exists, add the edge!
    else
        add_edge!(G, v1, v2)
    end
    
    cur_trace[:graph] = G
    
    return cur_trace
end;

DynamicDSLFunction{Any}(Dict{Symbol,Any}(), Dict{Symbol,Any}(), Type[Any], ##digraph_unconstrained_proposal#372, Bool[0], false)

### An alternative (smarter) proposal distribution

In [52]:
#TODO

# MCMC sampling

In [11]:
ref_graph = DiGraph([["cat" "mammal"];
        ["dog" "mammal"];
        ["lizard" "reptile"];
        ["mammal" "animal"];
        ["reptile" "animal"]])
ref_graph_p = copy(ref_graph)
remove_edge!(ref_graph_p, "reptile", "animal")

4×2 Array{String,2}:
 "cat"     "mammal" 
 "dog"     "mammal" 
 "lizard"  "reptile"
 "mammal"  "animal" 

In [7]:
random(graphpriordist, ref_graph, 1.0)

DiGraph{String}(Set(["mammal", "reptile", "animal", "cat", "lizard", "dog"]), ["cat" "mammal"; "dog" "mammal"; … ; "mammal" "animal"; "reptile" "animal"])

In [23]:
Set([ ref_graph_p.edges[i,:] for i=1:size(ref_graph_p.edges)[1] ])

Set(Array{String,1}[["dog", "mammal"], ["lizard", "reptile"], ["mammal", "animal"], ["cat", "mammal"]])

In [17]:
ref_graph_p.edges

4×2 Array{String,2}:
 "cat"     "mammal" 
 "dog"     "mammal" 
 "lizard"  "reptile"
 "mammal"  "animal" 

In [25]:
symdiff([ref_graph_p.edges[i,:] for i=1:size(ref_graph_p.edges)[1]],
        [ref_graph.edges[i,:] for i=1:size(ref_graph.edges)[1]])

1-element Array{Array{String,1},1}:
 ["reptile", "animal"]

In [32]:
graphpriordist(ref_graph, 1.0)

DiGraph{String}(Set(["mammal", "reptile", "animal", "cat", "lizard", "dog"]), ["cat" "mammal"; "dog" "mammal"; … ; "mammal" "animal"; "reptile" "animal"])

true

In [44]:
ref_graph_p.edges

4×2 Array{String,2}:
 "cat"     "mammal" 
 "dog"     "mammal" 
 "lizard"  "reptile"
 "mammal"  "animal" 