# Product distribution

In [None]:
#nbx --fname="../src/product_dist.jl"
using Gen
abstract type ProductDistribution{T} <: Distribution{Vector{T}} end

## Homogeneous Product

In [8]:
#nbx
struct HomogeneousProduct{T} <: ProductDistribution{T}
    dist::Distribution{T}
    n::Int
    slicedim::Int  # indicates the dimension specifying 
                   # the arguments for each dist
                   # Default is along 1st dimension - differs
                   # from Gen's mixtures.
end
ProductDistribution(dist::Distribution{T}, n::Int, s::Int) where T = HomogeneousProduct{T}(dist, n, s)
ProductDistribution(dist::Distribution{T}, n::Int)         where T = HomogeneousProduct{T}(dist, n, 1)


slicedim(a, d::Int, i::Int) = selectdim(a, d > 0 ? d : ndims(a)+d+1, i)
function Gen.random(Q::HomogeneousProduct, args...)
    p = Q.dist
    n = Q.n
    d = Q.slicedim
    return [p((ndims(a) == 1 ? a[i] : slicedim(a, d, i) for a in args)...) for i=1:n] 
end

(Q::HomogeneousProduct)(args...) = Gen.random(Q, args...)


function Gen.logpdf(Q::HomogeneousProduct{T}, xs::Vector{T}, args...) where T
    p = Q.dist
    n = Q.n
    d = Q.slicedim
    return sum([
        Gen.logpdf(p, xs[i], (ndims(a) == 1 ? a[i] : slicedim(a, d, i) for a in args)...) for i=1:n
    ])
end

Gen.has_output_grad(Q::HomogeneousProduct)    = has_output_grad(Q.dist)
Gen.has_argument_grads(Q::HomogeneousProduct) = Tuple([Gen.has_argument_grads(Q.dist) for i=1:Q.n])

In [4]:
using BenchmarkTools

n = 1_000
Q = ProductDistribution(normal, n)
mus  = rand(n)
stds = rand(n)
xs   = rand(n)

@btime Q($mus, $stds);
@btime Gen.logpdf($Q, $xs,  $mus, $stds);

  6.168 μs (8 allocations: 8.17 KiB)
  9.450 μs (6 allocations: 8.12 KiB)


In [13]:
@btime normal($mus[1], $stds[1]);
@btime Gen.logpdf($normal, $xs[1],  $mus[1], $stds[1]);

  5.748 ns (0 allocations: 0 bytes)
  8.804 ns (0 allocations: 0 bytes)


In [7]:
Q = ProductDistribution(mvnormal, 2, -1)
means = [0.0 1.0; 0.0 1.0] # or, cat([0.0, 0.0], [1.0, 1.0], dims=2)
covs = cat([1.0 0.0; 0.0 1.0], [10.0 0.0; 0.0 10.0], dims=3)
size(means), size(covs)
Q(means, covs)

2-element Vector{Vector{Float64}}:
 [-1.6943002714596043, -0.2264648547757579]
 [0.14278251405722397, -1.2397686587242691]

## Heterogeneous Product

In [9]:
#nbx
struct HeterogeneousProduct{T} <: ProductDistribution{T}
    dists::Vector{D where D <: Distribution{T}}
    n::Int
    slicedim::Int
end
ProductDistribution(ds::Vector{D}, s::Int) where {T, D <: Distribution{T}} = HeterogeneousProduct{T}(ds, length(ds), s)
ProductDistribution(ds::Vector{D})         where {T, D <: Distribution{T}} = HeterogeneousProduct{T}(ds, length(ds), 1)


slicedim(a, d::Int, i::Int) = selectdim(a, d > 0 ? d : ndims(a)+d+1, i)
function Gen.random(Q::HeterogeneousProduct{T}, args...) where T
    n = Q.n
    s = Q.slicedim
    return [Q.dists[i]((ndims(a) == 1 ? a[i] : slicedim(a,s, i) for a in args)...) for i=1:n]
end

(Q::HeterogeneousProduct)(args...) = Gen.random(Q, args...)


function Gen.logpdf(Q::HeterogeneousProduct{T}, xs::Vector{T}, args...) where T
    p = Q.dists
    n = Q.n
    d = Q.slicedim
    return sum([
        Gen.logpdf(p[i], xs[i], (ndims(a) == 1 ? a[i] : slicedim(a, d, i) for a in args)...) for i=1:n
    ])
end

Gen.has_output_grad(Q::HeterogeneousProduct)    = has_output_grad(Q.dist)
Gen.has_argument_grads(Q::HeterogeneousProduct) = Tuple([Gen.has_argument_grads(Q.dist) for i=1:Q.n])

In [11]:
using BenchmarkTools

n = 1_000
Q = ProductDistribution([normal for i=1:n])
mus  = rand(n)
stds = rand(n)
xs   = rand(n)

@btime Q($mus, $stds);
@btime Gen.logpdf($Q, $xs,  $mus, $stds);

  151.080 μs (6006 allocations: 164.38 KiB)
  45.289 μs (4004 allocations: 70.53 KiB)


## Old Version

In [1]:
using Gen

struct ProductDistribution{T} <: Distribution{Vector{T}}
    dist::Distribution{T}
    n::Int
end


# Todo: Specify that args needs to be "zippable"
function Gen.random(Q::ProductDistribution, args_vec...)
    p = Q.dist
    return [p(args...) for args in zip(args_vec...)]
end


(Q::ProductDistribution)(args_vec::AbstractVector...) = Gen.random(Q::ProductDistribution, args_vec...)

function Gen.logpdf(Q::ProductDistribution{T}, xs::Vector{T}, args_vec::AbstractVector...) where T
    p = Q.dist
    # "for loop" implementation seems slower, according to benchmarktools
    # Don't know why really, doesn't seem to be that way in `diagnorm`
    return sum([Gen.logpdf(p, x, args...) for (x, args...) in zip(xs, args_vec...)])
end



Gen.has_output_grad(Q::ProductDistribution)    = has_output_grad(Q.dist)
Gen.has_argument_grads(Q::ProductDistribution) = Tuple([Gen.has_argument_grads(Q.dist) for i=1:Q.n])

In [2]:
using BenchmarkTools

n = 1_000
Q = ProductDistribution(normal, n)
mus  = rand(n)
stds = rand(n)
xs   = rand(n)

@btime Q($mus, $stds);
@btime Gen.logpdf($Q, $xs,  $mus, $stds);

  5.037 μs (6 allocations: 8.09 KiB)
  9.204 μs (4 allocations: 8.02 KiB)
