# Product distribution

In [1]:
#nbx --fname="../src/product_dist.jl"
using Gen
abstract type ProductDistribution{T} <: Distribution{Vector{T}} end

In [2]:
#nbx
function unsqueeze(a, d)
        if ndims(a) == 0
            return [a]
        end
        if d<0
            d = ndims(a) - d
        end
        return reshape(a, (size(a)[1:d-1]..., 1, size(a)[d:end]...))
end
        
function mycat(xs::Vector{T}; dims) where T
    d = dims
    if d<0
        d = ndims(xs[1]) + d + 1
    end   
    return cat(xs...; dims=d)
end

mycat (generic function with 1 method)

## Homogeneous Product

In [27]:
#nbx
struct HomogeneousProduct{T} <: ProductDistribution{T}
    dist::Distribution{T}
    n::Int
    slicedim::Int  # indicates the dimension specifying 
                   # the arguments for each dist
                   # Default is along 1st dimension - differs
                   # from Gen's mixtures.
end
ProductDistribution(dist::Distribution{T}, n::Int, s::Int) where T = HomogeneousProduct{T}(dist, n, s)
ProductDistribution(dist::Distribution{T}, n::Int)         where T = HomogeneousProduct{T}(dist, n, 1)


function slicedim(a, d::Int, i::Int) 
    return ndims(a) == 1 ? a[i] : selectdim(a, d > 0 ? d : ndims(a)+d+1, i)
end

function Gen.random(Q::HomogeneousProduct, args...)
    p = Q.dist
    n = Q.n
    d = Q.slicedim
    
    ys = [p((slicedim(a, d, i) for a in args)...) for i=1:n] 
    
    if d == 1 return ys end
        
    # This part slows it down quite a lot
    # but technically is necessary to logpdf eval 
    # actual samples 
    ys = [
        unsqueeze(y, d) for y in ys
    ]
    ys = mycat(ys, dims=d)
    return ys
end

(Q::HomogeneousProduct)(args...) = Gen.random(Q, args...)

function Gen.logpdf(Q::HomogeneousProduct{T}, xs, args...) where T
    p = Q.dist
    n = Q.n
    d = Q.slicedim
    return sum([
        Gen.logpdf(p, slicedim(xs, d, i), (slicedim(a, d, i) for a in args)...) for i=1:n
    ])
end
    
function Gen.logpdf_grad(Q::HomogeneousProduct{T}, xs, args...) where T

    p = Q.dist
    n = Q.n
    d = Q.slicedim
    k = length(args) + 1
    grads = [
        Gen.logpdf_grad(p, slicedim(xs, d, i), (slicedim(a, d, i) for a in args)...) for i=1:n
    ]
    grad_slices = [
        [unsqueeze(grads[i][j], d) for i=1:n] for j=1:k
    ]    
    rearranged_grads = [
        mycat(slice,dims=d) for slice in grad_slices
    ]

    return rearranged_grads
end

Gen.has_output_grad(Q::HomogeneousProduct)    = has_output_grad(Q.dist)
Gen.has_argument_grads(Q::HomogeneousProduct) = Tuple([Gen.has_argument_grads(Q.dist) for i=1:Q.n])

In [28]:
using ForwardDiff: gradient
import Distributions

n = 5
Q = ProductDistribution(normal, n)
mus  = rand(n)
stds = rand(n)

xs = Q(mus, stds)
log_p = logpdf(Q, xs, mus, stds)
log_p_grad = logpdf_grad(Q, xs, mus, stds)

println("sample: \n\t", xs);
println("log_p: \n\t", log_p);

# Gradient reality check
v = [xs;mus;stds]
func = v -> Gen.logpdf(Q, v[1:n],v[n+1:2*n],v[2*n+1:3*n])
gr = gradient(func, v)
gr = [gr[1:n],gr[n+1:2*n],gr[2*n+1:3*n]]

println("Gradient OK? \n\t"  ,isapprox(log_p_grad , gr))

sample: 
	[-0.5657392545685711, 0.4002711146745374, -0.01405309766505402, 0.39396560148883614, 0.010278143989848665]
log_p: 
	2.888574361097689
Gradient OK? 
	true


In [38]:
Q = ProductDistribution(mvnormal, 2, -1)
means = cat([-10.0, -10.0], [10.0, 10.0], dims=2)
covs  = cat([1.0 0.0; 0.0 1.0], [10.0 0.0; 0.0 10.0], dims=3)
samples = Q(means, covs)
gr = logpdf_grad(Q, samples, means, covs)

println(size(samples))
println([size(g) for g in gr]...)

(2, 2)
(2, 2)(2, 2)(2, 2, 2)


In [39]:
Q = ProductDistribution(mvnormal, 2)
means = [[-10.0, -10.0], [10.0, 10.0]]
covs  = [[1.0 0.0; 0.0 1.0], [10.0 0.0; 0.0 10.0]]
samples = Q(means, covs)
gr = logpdf_grad(Q, samples, means, covs)

println(size(samples))
println([size(g) for g in gr]...)

(2,)
(2, 2)(2, 2)(2, 2, 2)


In [31]:
using BenchmarkTools

n = 1_000
Q = ProductDistribution(normal, n)
mus  = rand(n)
stds = rand(n)
xs   = rand(n)

@btime Q($mus, $stds);
@btime Gen.logpdf($Q, $xs,  $mus, $stds);

  5.370 μs (8 allocations: 8.17 KiB)
  8.992 μs (6 allocations: 8.12 KiB)


In [32]:
@btime normal($mus[1], $stds[1]);
@btime Gen.logpdf($normal, $xs[1],  $mus[1], $stds[1]);

  5.949 ns (0 allocations: 0 bytes)
  8.546 ns (0 allocations: 0 bytes)


## Heterogeneous Product

In [40]:
struct HeterogeneousProduct{T} <: ProductDistribution{T}
    dists::Vector{D where D <: Distribution{T}}
    n::Int
    slicedim::Int
end
ProductDistribution(ds::Vector{D}, s::Int) where {T, D <: Distribution{T}} = HeterogeneousProduct{T}(ds, length(ds), s)
ProductDistribution(ds::Vector{D})         where {T, D <: Distribution{T}} = HeterogeneousProduct{T}(ds, length(ds), 1)


slicedim(a, d::Int, i::Int) = selectdim(a, d > 0 ? d : ndims(a)+d+1, i)
function Gen.random(Q::HeterogeneousProduct{T}, args...) where T
    n = Q.n
    s = Q.slicedim
    return [Q.dists[i]((ndims(a) == 1 ? a[i] : slicedim(a,s, i) for a in args)...) for i=1:n]
end

(Q::HeterogeneousProduct)(args...) = Gen.random(Q, args...)


function Gen.logpdf(Q::HeterogeneousProduct{T}, xs::Vector{T}, args...) where T
    p = Q.dists
    n = Q.n
    d = Q.slicedim
    return sum([
        Gen.logpdf(p[i], xs[i], (ndims(a) == 1 ? a[i] : slicedim(a, d, i) for a in args)...) for i=1:n
    ])
end

    
function Gen.logpdf_grad(Q::HeterogeneousProduct{T}, xs::Vector{T}, args...) where T
    p = Q.dists
    n = Q.n
    d = Q.slicedim
    return sum([
        Gen.logpdf_grad(p[i], xs[i], (ndims(a) == 1 ? a[i] : slicedim(a, d, i) for a in args)...) for i=1:n
    ])
end
        
Gen.has_output_grad(Q::HeterogeneousProduct)    = has_output_grad(Q.dist)
Gen.has_argument_grads(Q::HeterogeneousProduct) = Tuple([Gen.has_argument_grads(Q.dist) for i=1:Q.n])

In [11]:
using BenchmarkTools

n = 1_000
Q = ProductDistribution([normal for i=1:n])
mus  = rand(n)
stds = rand(n)
xs   = rand(n)

@btime Q($mus, $stds);
@btime Gen.logpdf($Q, $xs,  $mus, $stds);

  151.080 μs (6006 allocations: 164.38 KiB)
  45.289 μs (4004 allocations: 70.53 KiB)


## Old Version

In [1]:
using Gen

struct ProductDistribution{T} <: Distribution{Vector{T}}
    dist::Distribution{T}
    n::Int
end


# Todo: Specify that args needs to be "zippable"
function Gen.random(Q::ProductDistribution, args_vec...)
    p = Q.dist
    return [p(args...) for args in zip(args_vec...)]
end


(Q::ProductDistribution)(args_vec::AbstractVector...) = Gen.random(Q::ProductDistribution, args_vec...)

function Gen.logpdf(Q::ProductDistribution{T}, xs::Vector{T}, args_vec::AbstractVector...) where T
    p = Q.dist
    # "for loop" implementation seems slower, according to benchmarktools
    # Don't know why really, doesn't seem to be that way in `diagnorm`
    return sum([Gen.logpdf(p, x, args...) for (x, args...) in zip(xs, args_vec...)])
end



Gen.has_output_grad(Q::ProductDistribution)    = has_output_grad(Q.dist)
Gen.has_argument_grads(Q::ProductDistribution) = Tuple([Gen.has_argument_grads(Q.dist) for i=1:Q.n])

In [2]:
using BenchmarkTools

n = 1_000
Q = ProductDistribution(normal, n)
mus  = rand(n)
stds = rand(n)
xs   = rand(n)

@btime Q($mus, $stds);
@btime Gen.logpdf($Q, $xs,  $mus, $stds);

  5.037 μs (6 allocations: 8.09 KiB)
  9.204 μs (4 allocations: 8.02 KiB)
