In [50]:
push!(LOAD_PATH, "../src");
using Gen
push!(LOAD_PATH, ENV["probcomp"]*"/Gen-Distribution-Zoo/src")
using GenDistributionZoo: ProductDistribution, diagnormal
using BenchmarkTools
using MyUtils

In [43]:
#nbx
# Todo: Maybe call it PerformanceWrapper?
struct PerformanceWrapper{T,A,B,C} <: Distribution{T}
    d::Distribution{T}
    argtransform::A
    logpdf::B
    random::C
end
PerformanceWrapper(d) = PerformanceWrapper(d, nothing, nothing, nothing)

function Gen.random(Q::PerformanceWrapper, args...)
    args = Q.argtransform === nothing ?  args : Q.argtransform(args...);
    return Q.random === nothing ? Gen.random(Q.d, args...) : Q.random(args...);
end
(Q::PerformanceWrapper)(args...) = Gen.random(Q, args...)

function Gen.logpdf(Q::PerformanceWrapper, x, args...)
    args = Q.argtransform === nothing ?  args : Q.argtransform(args...);
    return Q.logpdf === nothing ? Gen.logpdf(x, args...) : Q.logpdf(x, args...);
end
#
# Todo: `logpdf_grad` has to be implemented correctly, 
#       applying the Jacobian of argtransform...
#
function Gen.logpdf_grad(Q::PerformanceWrapper, x, args...)
    args = Q.argtransform === nothing ?  args : Q.argtransform(args...);
    return Gen.logpdf_grad(Q.d, x, Q.argtransform(args)...)
end
Gen.has_output_grad(Q::PerformanceWrapper)    = Gen.has_output_grad(Q.d)
Gen.has_argument_grads(Q::PerformanceWrapper) = Tuple(false for _ in Gen.has_argument_grads(Q.d))

In [27]:
@btime Gen.logpdf($normal, 0.0, 0, 1) samples=5 evals=5;

d = PerformanceWrapper(normal);
@btime Gen.logpdf($d, 0.0, 0, 1)      samples=5 evals=5;

d = PerformanceWrapper(normal, (mu, std) -> (mu - 1.0, 2*std), (args...) -> - Inf, nothing);
@btime Gen.logpdf($d, 0.0, 0, 1)      samples=5 evals=5;

  124.400 ns (0 allocations: 0 bytes)
  185.200 ns (2 allocations: 32 bytes)
  123.600 ns (0 allocations: 0 bytes)


In [39]:
dnormal = ProductDistribution(normal)

@btime diagnormal([0;0],[1,1]) samples=5 evals=5;
@btime dnormal([0;0],[1,1]) samples=5 evals=5;

@btime logpdf(diagnormal, [0.0;0.0], [0;0],[1,1]) samples=5 evals=5;
@btime logpdf(dnormal, [0.0;0.0], [0;0],[1,1])    samples=5 evals=5;

  409.600 ns (3 allocations: 240 bytes)
  1.151 μs (10 allocations: 480 bytes)
  334.400 ns (3 allocations: 240 bytes)
  1.059 μs (9 allocations: 512 bytes)


In [52]:
gm             = HomogeneousMixture(diagnormal, [1, 1])
outlier_dist   = diagnormal
sensor_mix     = HeterogeneousMixture([gm, outlier_dist])
sensor_product = ProductDistribution(sensor_mix)

GenDistributionZoo.HomogeneousProduct{Vector{Float64}}(HeterogeneousMixture{Vector{Float64}}(2, Distribution{Vector{Float64}}[HomogeneousMixture{Vector{Float64}}(GenDistributionZoo.DiagonalNormal(), [1, 1]), GenDistributionZoo.DiagonalNormal()], true, (true, true, true, true, true, true), false, [3, 2], [1, 4]), 1)

In [62]:
include("src/sensor_distribution.jl")

function performant_logpdf(x, ỹ, sig, outlier, outlier_vol, zmax)
    x_ = CuArray(stack(x))

    ỹ_ = CuArray(ỹ)
    n = size(ỹ_, 1)
    m = size(ỹ_, 2)
    ỹ_ = reshape(ỹ_, 1, n, m, 2)

    log_p, = sensor_logpdf(x_, ỹ_, sig, outlier, outlier_vol) # CuArray of length 1
    return CUDA.@allowscalar log_p[1]
end

function sensor_product_args(ỹ, sig, outlier, outlier_vol, zmax) 
    n,m, = size(ỹ)
    ỹ_perm = permutedims(ỹ, (1,3,2))
    args = (
        fill([1-outlier, outlier], n), 
        fill(1/m, n, m), 
        ỹ_perm, 
        fill(sig, n, 2,m), 
        fill(0.0, n,2), 
        fill(zmax, n,2)
    )
    return args
end

performant_sensor_product = PerformanceWrapper(ProductDistribution(sensor_mix), nothing, performant_logpdf, nothing)


ỹ = rand(361, 21, 2)
sig         = 0.1
outlier     = 0.1
outlier_vol = 100.
zmax        = 100.
args = (ỹ, sig, outlier, outlier_vol, zmax) 
x = sensor_product(sensor_product_args(args...)...)

trans_formed_args = sensor_product_args(args...)

@btime logpdf(sensor_product, x, trans_formed_args...) samples=3 evals=3;
@btime logpdf(performant_sensor_product, x, args...) samples=3 evals=3;

@btime performant_logpdf(x, ỹ, sig, outlier, outlier_vol, zmax) samples=3 evals=3;

  16.075 ms (163182 allocations: 5.77 MiB)
  440.169 μs (556 allocations: 48.62 KiB)
  452.092 μs (550 allocations: 48.52 KiB)


In [42]:


ỹ = rand(361, 21, 2)
sig         = 0.1
outlier     = 0.1
outlier_vol = 100.
zmax        = 100.
args = sensor_product_args(ỹ, sig, outlier, outlier_vol, zmax) 
x = sensor_product(args...)

@btime logpdf(sensor_product, x, args...) samples=3 evals=3;

  16.374 ms (163185 allocations: 5.77 MiB)
