# Diagonal Gaussian - `diagnormal`

In [6]:
#nbx --fname=../src/diag_normal.jl
using Gen

struct DiagonalNormal <: Distribution{Vector{Float64}} end
const diagnormal = DiagonalNormal()

function random(::DiagonalNormal, mus::AbstractVector{U}, stds::AbstractVector{V}) where {U <: Real, V <: Real}
    return [normal(mu, std) for (mu,std) in zip(mus, stds)]
end

function random(::DiagonalNormal, mus::AbstractVector{U}, std::V) where {U <: Real, V <: Real}
    return [normal(mu, std) for mu in mus]
end

(::DiagonalNormal)(mus::AbstractVector{U}, stds::AbstractVector{V})  where {U <: Real, V <: Real} = random(DiagonalNormal(), mus, stds)
(::DiagonalNormal)(mus::AbstractVector{U}, std::V)  where {U <: Real, V <: Real} = random(DiagonalNormal(), mus, std)


function Gen.logpdf(::DiagonalNormal, xs::AbstractVector{T}, 
                mus::AbstractVector{U}, stds::AbstractVector{V}) where {T <: Real, U <: Real, V <: Real}
    log_p = 0.0
    for (x, mu, std) in zip(xs, mus, stds)
        log_p += Gen.logpdf(normal, x, mu, std)
    end
    return log_p
end

function Gen.logpdf(::DiagonalNormal, xs::AbstractVector{T}, 
                mus::AbstractVector{U}, std::V) where {T <: Real, U <: Real, V <: Real}
    log_p = 0.0
    for (x, mu) in zip(xs, mus)
        log_p += Gen.logpdf(normal, x, mu, std)
    end
    return log_p
end


has_output_grad(::DiagonalNormal) = true
has_argument_grads(::DiagonalNormal) = (true, true)

has_argument_grads (generic function with 1 method)

## Benchmarks

In [2]:
using BenchmarkTools

n    = 1_000
mus  = rand(n);
stds = ones(n);
xs   = rand(n);

@btime diagnormal($mus, $stds);
@btime Gen.logpdf($diagnormal, $xs, $mus, $stds);

  4.615 μs (1 allocation: 7.94 KiB)
  6.159 μs (0 allocations: 0 bytes)


In [49]:
using LinearAlgebra: diagm

D = diagm(stds)
@btime mvnormal($mus, $D);
@btime logpdf($mvnormal, $xs, $mus, $D);

  10.599 ms (6 allocations: 15.27 MiB)
  10.652 ms (7 allocations: 15.27 MiB)


In [50]:
@btime broadcasted_normal($mus, $stds)
@btime Gen.logpdf($broadcasted_normal, $xs, $mus, $stds);

  2.827 μs (2 allocations: 15.88 KiB)
  7.831 μs (5 allocations: 39.69 KiB)
