Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible type instability in OnlineStatsBase.jl #265

Closed
nic-barbara opened this issue Aug 1, 2023 · 1 comment
Closed

Possible type instability in OnlineStatsBase.jl #265

nic-barbara opened this issue Aug 1, 2023 · 1 comment

Comments

@nic-barbara
Copy link

In OnlineStatsBase.jl, why are some statistics types subtyped with OnlineStat{Number}? For example:

mutable struct Mean{T,W} <: OnlineStat{Number}
    μ::T
    weight::W
    n::Int
end
Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight, 0)

Is there a reason we can't have mutable struct Mean{T,W} <: OnlineStat{T} instead? This means that when input() is called on statistics like Mean() it will always return Number instead of the actual input type (eg: Float32). The issue appears to affect Mean, Moments, Sum, and variance.


I noticed this while playing around with a Mean/Stdev filter. My original code is as follows (and feel free to offer any suggestions on better/more efficient ways to do this, I'm new to this package).

using BenchmarkTools
using OnlineStatsBase

mutable struct MeanStdFilter{T}
    nu::Int
    tracker::OnlineStat
end

function MeanStdFilter(nu::Int; T::DataType=Float32)
    s = [Series(Mean(T), Variance(T)) for _ in 1:nu]
    return MeanStdFilter{T}(nu, Group(s...))
end

function _get_mean_var(m::MeanStdFilter{T}) where T
    vals = value.(value(m.tracker))
    return reinterpret(reshape, T, collect(vals))
end

function (m::MeanStdFilter)(x::AbstractVector)
    fit!(m.tracker, x)
    μσ2 = _get_mean_var(m)
    return (x .- μσ2[1,:]) ./ sqrt.(μσ2[2,:])
end

# Test runtime
nu = 4
T = Float32
m = MeanStdFilter(nu; T)

# @btime m(randn(T,nu));
@btime _get_mean_var(m);

Running with T = Float32 I get:

1.014 μs (18 allocations: 608 bytes)

and with T = Float64 it increases to:

549.342 ns (6 allocations: 480 bytes)

I suspect this is to do with having to convert Float64 to Float32 at some point in the pipeline because of the issue raised above.

Thanks in advance for any help!

@nic-barbara
Copy link
Author

Actually given this is an issue to OnlineStatsBase.jl I'll move the discussion over to there. Apologies for the inconvenience.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant