Skip to content

Commit

Permalink
work on wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
joshday committed Oct 4, 2021
1 parent 4b17c09 commit 982f05a
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 75 deletions.
55 changes: 29 additions & 26 deletions src/OnlineStatsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ export
# Weights
EqualWeight, ExponentialWeight, LearningRate, LearningRate2, HarmonicWeight, McclainWeight,
# Stats
CircBuff, Counter, CountMap, CountMissing, CovMatrix, Extrema, FTSeries, Group, GroupBy, Mean,
Moments, Series, SkipMissing, Sum, Variance
CircBuff, Counter, CountMap, CountMissing, CovMatrix, Extrema, FilterTransform, FTSeries,
Group, GroupBy, Mean, Moments, Series, SkipMissing, Sum, TryCatch, Variance

@static if VERSION < v"1.1.0"
eachrow(A::AbstractVecOrMat) = (view(A, i, :) for i in axes(A, 1))
Expand All @@ -28,13 +28,6 @@ nobs(o::OnlineStat) = o.n

Broadcast.broadcastable(o::OnlineStat) = Ref(o)

# Stats that hold a single stat
abstract type StatWrapper{T} <: OnlineStat{T} end
nobs(o::StatWrapper) = nobs(o.stat)
value(o::StatWrapper) = value(o.stat)
_merge!(a::StatWrapper{T}, b::StatWrapper{T}) where {T} = _merge!(a.stat, b.stat)
name(o::T, args...) where {T<:StatWrapper} = name(typeof(o), args...) * "($(name(o.stat, args...)))"

# Stats that hold a collection of other stats
abstract type StatCollection{T} <: OnlineStat{T} end
Base.show(io::IO, o::StatCollection) = AbstractTrees.print_tree(io, o)
Expand Down Expand Up @@ -79,10 +72,13 @@ end
_merge!(o, o2) = @warn("Merging $(name(o2)) into $(name(o)) is not well-defined. No merging occurred.")
Base.merge(o::OnlineStat, o2::OnlineStat) = merge!(copy(o), o2)

#-----------------------------------------------------------------------# Show
#-----------------------------------------------------------------------# Base.show
function Base.show(io::IO, o::OnlineStat)
print(io, name(o, false, false), ": ")
print(io, "n=", nobs(o))
for (k,v) in pairs(additional_info(o))
print(io, " | $k=$v")
end
print(io, " | value=")
show(IOContext(io, :compact => true), value(o))
end
Expand All @@ -98,6 +94,8 @@ function name(T::Type, withmodule = false, withparams = true)
end
name(o, args...) = name(typeof(o), args...)

additional_info(o) = ()

#-----------------------------------------------------------------------# fit!
"""
fit!(stat::OnlineStat, data)
Expand All @@ -107,12 +105,15 @@ the type of a single observation for the provided `stat`, `fit!` will attempt to
through and `fit!` each item in `data`. Therefore, `fit!(Mean(), 1:10)` translates
roughly to:
```
o = Mean()
for x in 1:10
fit!(o, x)
end
```
# Example

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

I believe the code block was meant to illustrate how fit iterates for non single observations, rather than to illustrate an example. Note the text right before the code block: "Therefore, fit!(Mean(), 1:10) translates roughly to:"

This comment has been minimized.

Copy link
@joshday

joshday Oct 5, 2021

Author Owner

Oh, yep! Wasn't reading my own words...

o = Mean()
for x in 1:10
fit!(o, x)
end
fit!(o, 11:20)
"""
fit!(o::OnlineStat{T}, yi::T) where {T} = (_fit!(o, yi); return o)

Expand All @@ -124,16 +125,15 @@ Alias for `merge!`. Merges `stat2` into `stat1`.
Useful for reductions of OnlineStats using `fit!`.
# Example
```
julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3]
3-element Vector{Mean{Float64, EqualWeight}}:
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
julia> reduce(fit!, v, init=Mean())
Mean: n=9 | value=2.0
```
julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3]
3-element Vector{Mean{Float64, EqualWeight}}:
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
julia> reduce(fit!, v, init=Mean())
Mean: n=9 | value=2.0
"""
fit!(o::OnlineStat, o2::OnlineStat) = merge!(o, o2)

Expand All @@ -150,6 +150,8 @@ end
smooth(a, b, γ)
Weighted average of `a` and `b` with weight `γ`.
``(1 - γ) * a + γ * b``
"""
smooth(a, b, γ) = a + γ * (b - a)

Expand Down Expand Up @@ -189,4 +191,5 @@ neighbors(x) = @inbounds ((x[i], x[i+1]) for i in eachindex(x)[1:end-1])
#-----------------------------------------------------------------------# includes
include("weight.jl")
include("stats.jl")
include("wrappers.jl")
end
55 changes: 7 additions & 48 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,6 @@ function Base.delete!(o::CountMap, level)
o
end

#-----------------------------------------------------------------------------# CountMissing
"""
CountMissing(stat)
Calculate a `stat` along with the count of `missing` values.
# Example
o = CountMissing(Mean())
fit!(o, [1, missing, 3])
"""
mutable struct CountMissing{T, O<:OnlineStat{T}} <: StatWrapper{Union{Missing,T}}
stat::O
nmissing::Int
end
CountMissing(stat::OnlineStat) = CountMissing(stat, 0)
value(o::CountMissing) = (nmissing=o.nmissing, stat=o.stat)
nobs(o::CountMissing) = nobs(o.stat) + o.nmissing

_fit!(o::CountMissing, x) = _fit!(o.stat, x)
_fit!(o::CountMissing, ::Missing) = (o.nmissing += 1)

_merge!(a::CountMissing, b::CountMissing) = (merge!(a.stat, b.stat); a.nmissing += b.nmissing)

#-----------------------------------------------------------------------# CovMatrix
"""
CovMatrix(p=0; weight=EqualWeight())
Expand Down Expand Up @@ -536,6 +512,8 @@ _merge!(o::Series, o2::Series) = map(_merge!, o.stats, o2.stats)

#-----------------------------------------------------------------------# FTSeries
"""
Deprecated! See [`FilterTransform`](@ref).
FTSeries(stats...; filter=x->true, transform=identity)
Track multiple stats for one data stream that is filtered and transformed before being
Expand Down Expand Up @@ -570,11 +548,13 @@ mutable struct FTSeries{IN, OS, F, T} <: StatCollection{IN}
transform::T
nfiltered::Int
end
function FTSeries(stats::OnlineStat...; filter=x->true, transform=identity)
IN, OS = Union{map(input, stats)...}, typeof(stats)
FTSeries{IN, OS, typeof(filter), typeof(transform)}(stats, filter, transform, 0)
function FTSeries(stats::OnlineStat...; kw...)
IN = Union{map(input, stats)...}
FTSeries(IN, stats...; kw...)
end
function FTSeries(T::Type, stats::OnlineStat...; filter=x->true, transform=identity)
Base.depwarn("`FTSeries(args...; kw...)` is deprecated. Use `FilterTransform(Series(args...; kw...))` instead.",
:FTSeries; force=true)
FTSeries{T, typeof(stats), typeof(filter), typeof(transform)}(stats, filter, transform, 0)
end
value(o::FTSeries) = value.(o.stats)
Expand All @@ -596,24 +576,3 @@ function _merge!(o::FTSeries, o2::FTSeries)
o.nfiltered += o2.nfiltered
_merge!.(o.stats, o2.stats)
end


#-----------------------------------------------------------------------------# SkipMissing
"""
SkipMissing(stat)
Wrapper around an OnlineStat that will skip over `missing` values.
# Example
o = SkipMissing(Mean())
fit!(o, [1, missing, 3])
"""
struct SkipMissing{T, O<:OnlineStat{T}} <: StatWrapper{Union{Missing,T}}
stat::O
SkipMissing(stat::OnlineStat{T}) where {T} = new{T, typeof(stat)}(stat)
end
_fit!(o::SkipMissing, x::Missing) = nothing
_fit!(o::SkipMissing, x) = _fit!(o.stat, x)
Base.skipmissing(o::OnlineStat) = SkipMissing(o)
151 changes: 151 additions & 0 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#-----------------------------------------------------------------------------# StatWrapper
abstract type StatWrapper{T} <: OnlineStat{T} end
nobs(o::StatWrapper) = nobs(o.stat)
value(o::StatWrapper) = value(o.stat)
_merge!(a::StatWrapper{T}, b::StatWrapper{T}) where {T} = _merge!(a.stat, b.stat)
name(o::T, args...) where {T<:StatWrapper} = name(typeof(o), args...) * "($(name(o.stat, args...)))"

#-----------------------------------------------------------------------------# CountMissing
"""
CountMissing(stat)
Calculate a `stat` along with the count of `missing` values.
# Example
o = CountMissing(Mean())
fit!(o, [1, missing, 3])
"""
mutable struct CountMissing{T, O<:OnlineStat{T}} <: StatWrapper{Union{Missing,T}}
stat::O
nmissing::Int
end
CountMissing(stat::OnlineStat) = CountMissing(stat, 0)
value(o::CountMissing) = (nmissing=o.nmissing, stat=o.stat)
nobs(o::CountMissing) = nobs(o.stat) + o.nmissing

_fit!(o::CountMissing, x) = _fit!(o.stat, x)
_fit!(o::CountMissing, ::Missing) = (o.nmissing += 1)

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

An additional_info method for CountMissing would be nice:

additional_info(o::CountMissing) = (; nmissing=o.nmissing)
_merge!(a::CountMissing, b::CountMissing) = (merge!(a.stat, b.stat); a.nmissing += b.nmissing)

#-----------------------------------------------------------------------------# FilterTransform
"""
FilterTransform(stat::OnlineStat{S}, T = S; filter = x->true, transform = identity)
FilterTransform(T => filter => transform => stat)
Wrapper around an OnlineStat that the filters and transforms its input. Note that, depending on
your transformation, you may need to specify the type of a single observation (`T`).
# Examples
o = FilterTransform(Mean(), Union{Missing,Number}, filter=!ismissing)
fit!(o, [1, missing, 3])
o = FilterTransform(String => (x->true) => (x->parse(Int,x)) => Mean())
fit!(o, "1")
"""
struct FilterTransform{S, T, O<:OnlineStat{T},F,F2} <: StatWrapper{S}
stat::O
filter::F
transform::F2

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

I think the nfiltered field of FTSeries was nice. Any reason to remove it here?

This comment has been minimized.

Copy link
@joshday

joshday Oct 5, 2021

Author Owner

No reason (other than forgetting about it). I'll add it back.

end
FilterTransform(intype::DataType, stat::OnlineStat; kw...) = FilterTransform(stat, intype; kw...)
function FilterTransform(stat::OnlineStat{T}, intype=T; filter=always_true, transform=identity) where {T}
FilterTransform{intype, T, typeof(stat), typeof(filter), typeof(transform)}(stat, filter, transform)
end
function FilterTransform(p::Pair{DataType, <:Pair{<:Function, <:Pair{<:Function, <:OnlineStat}}})
FilterTransform(p[1], p[2][2][2]; filter=p[2][1], transform=p[2][2][1])
end

_fit!(o::FilterTransform, y) = o.filter(y) && _fit!(o.stat, o.transform(y))

additional_info(o::FilterTransform) = (; filter=o.filter, transform=o.transform)

always_true(x) = true


#-----------------------------------------------------------------------------# SkipMissing
"""
SkipMissing(stat)
Wrapper around an OnlineStat that will skip over `missing` values.
# Example
o = SkipMissing(Mean())
fit!(o, [1, missing, 3])
"""
struct SkipMissing{T, O<:OnlineStat{T}} <: StatWrapper{Union{Missing,T}}
stat::O
SkipMissing(stat::OnlineStat{T}) where {T} = new{T, typeof(stat)}(stat)
end
_fit!(o::SkipMissing, x::Missing) = nothing
_fit!(o::SkipMissing, x) = _fit!(o.stat, x)
Base.skipmissing(o::OnlineStat) = SkipMissing(o)

#-----------------------------------------------------------------------------# TryCatch
"""
TryCatch(stat; error_limit=1000, error_message_limit=90)
Wrap each call to `fit!` in a `try`-`catch` block and track the errors encountered (via [`CountMap`](@ref)). Errors will stop
being tracked after `error_limit` unique errors are encountered. Only the first `error_message_limit`
characters of each error message will be recorded.
# Example
o = TryCatch(Mean())
fit!(o, [1, missing, 3])
OnlineStatsBase.errors(o)
"""
struct TryCatch{T, O<:OnlineStat{T}} <: StatWrapper{T}
stat::O
errors::CountMap{String}

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

Just an idea to consider. The number of existing exceptions is finite and low, while we can have a large number of different error messages. Would it make sense to have something like the following?

errors::typeof(GroupBy(Exception, CountMap(String)))

where it keeps track of all exceptions encountered, but would stop adding new errors messages after error_limit distinct message are found (and after that counts them as "other" as proposed in my other suggestion).

error_limit::Int
error_message_limit::Int
end
function TryCatch(stat::OnlineStat; error_limit=1000, error_message_limit=90)
TryCatch(stat, CountMap(String), error_limit, error_message_limit)
end

errors(o::TryCatch) = value(o.errors)

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

An additional nerrors method would be cool:

nerrors(o::TryCatch) = sum(values(value(o.errors)))

function additional_info(o::TryCatch)
ex = errors(o)
nex = length(ex)
msg = length(ex) o.error_limit ? "$nex (limit reached)" : nex
nex == 0 ? () : (; errors=msg)

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

This message could easily become huge. A suggestion is to consider replacing by (; nerrors=nerrors(o)) (based on the nerrors method proposed in my previous comment)

end

function handle_error!(o::TryCatch, ex)
io = IOBuffer()
Base.showerror(io, ex)
s = String(take!(io))
lim = o.error_message_limit
s = length(s) > lim ? s[1:lim] * "..." : s
length(value(o.errors)) < o.error_limit && _fit!(o.errors, s)

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

I really like the overall logic of this approach 👍. However, I see two low hanging fruit limitations:

  1. If the current error s was already found, this wouldn't increment its counter
  2. This doesn't keep track of the total number of errors found

Suggestion:

if length(value(o.errors)) < o.error_limit || s in keys(value(o.errors))
    _fit!(o.errors, s)
else
   _fit!(o.errors, "other errors")
end

This comment has been minimized.

Copy link
@joshday

joshday Oct 5, 2021

Author Owner

Ah, right. Good catch on 1. I'll add your second suggestion as well!

end

function fit!(o::TryCatch{T}, y::T) where {T}
try
_fit!(o.stat, y)
catch ex
handle_error!(o, ex)
end
o
end

function fit!(o::TryCatch{I}, y::T) where {I, T}
try
T == eltype(y) && error("The input for $(name(o,false,false)) is $I. Found $T.")
for yi in y
fit!(o, yi)
end
catch ex
handle_error!(o, ex)
end
o
end
21 changes: 20 additions & 1 deletion test/test_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ println(" > Extrema")
@test o.nmin == length(x) - sum(x)
@test o.nmax == sum(x)
end
#-----------------------------------------------------------------------------# FilterTransform
println(" > FilterTransform")
@testset "FilterTransform" begin
o = FilterTransform(String => (x->true) => (x -> parse(Int,x)) => Mean())
fit!(o, ["1", "3", "5"])
@test value(o) 3

This comment has been minimized.

Copy link
@brucala

brucala Oct 5, 2021

Contributor

Simple suggestion to test that the filter is behaving as expected:

    o = FilterTransform(String => (x -> x != "1") => (x -> parse(Int,x)) => Mean())
    fit!(o, ["1", "3", "5"])
    @test value(o)  4
end

#-----------------------------------------------------------------------# Group
println(" > Group")
@testset "Group" begin
Expand Down Expand Up @@ -255,6 +263,17 @@ println(" > Sum")
@test (mergevals(Sum(), y, y2)...)
@test ==(mergevals(Sum(Int), z, z2)...)
end

#-----------------------------------------------------------------------------# TryCatch
println(" > TryCatch")
@testset "TryCatch" begin
o = TryCatch(Mean())
fit!(o, [1, missing, 3])
@test value(o) 2
merge!(o, fit!(TryCatch(Mean()), [missing, 5, missing]))
@test value(o) 3
end

#-----------------------------------------------------------------------# Variance
println(" > Variance")
@testset "Variance" begin
Expand All @@ -275,4 +294,4 @@ println(" > Variance")
@test value(fit!(Variance(Float32), randn(Float32, 10))) isa Float32
end

end # end "Test Stats"
end # end "Test Stats"

0 comments on commit 982f05a

Please sign in to comment.