Skip to content

Commit

Permalink
Enforce type consistency (#43)
Browse files Browse the repository at this point in the history
* enforce type consistency
  • Loading branch information
getzze committed Jul 25, 2023
1 parent 3b03a0b commit f2c7d97
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/robustlinearmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ function StatsAPI.fit(
end

# Make sure X and y have the same float eltype
T = promote_type(float(eltype(X)), float(eltype(y)))
return fit(M, convert.(T, X), convert.(T, y), args...; kwargs...)
pX, py = promote_to_same_float(X, y)
return fit(M, pX, py, args...; kwargs...)
end

## Convert from formula-data to modelmatrix-response calling form
Expand All @@ -119,7 +119,10 @@ function StatsAPI.fit(
# Extract arrays from data using formula
f, y, X, extra = modelframe(f, data, contrasts, dropmissing, M; wts=wts)
# Call the `fit` method with arrays
return fit(M, X, y, args...; wts=extra.wts, contrasts=contrasts, __formula=f, kwargs...)
pX, py = promote_to_same_float(X, y)
return fit(
M, pX, py, args...; wts=extra.wts, contrasts=contrasts, __formula=f, kwargs...
)
end


Expand Down Expand Up @@ -1021,11 +1024,11 @@ function pirls!(
devold = deviance(m)
absdev = abs(devold)
dev = devold
Δdev = 0
Δdev = zero(T)

verbose && println("initial deviance: $(@sprintf("%.4g", devold))")
for i in 1:maxiter
f = 1.0 # line search factor
f = one(T) # line search factor
# local dev
absdev = abs(devold)

Expand Down Expand Up @@ -1124,12 +1127,12 @@ function pirls_Sestimate!(
sigold = scale(
setη!(m; updatescale=true, verbose=verbose, sigma0=sigma0, fallback=maxσ)
)
installbeta!(p, 1)
installbeta!(p, one(T))
r.σ = sigold

verbose && println("initial iteration scale: $(@sprintf("%.4g", sigold))")
for i in 1:maxiter
f = 1.0 # line search factor
f = one(T) # line search factor
local sig

# Compute the change to β, update μ and compute deviance
Expand Down Expand Up @@ -1242,11 +1245,11 @@ function pirls_τestimate!(

# Compute initial τ-scale
tauold = tauscale(setη!(m; updatescale=true); verbose=verbose)
installbeta!(p, 1)
installbeta!(p, one(T))

verbose && println("initial iteration τ-scale: $(@sprintf("%.4g", tauold))")
for i in 1:maxiter
f = 1.0 # line search factor
f = one(T) # line search factor
local tau

# Compute the change to β, update μ and compute deviance
Expand Down Expand Up @@ -1366,6 +1369,7 @@ function resampling_best_estimate(
## Hubert2015 - The DetS and DetMM estimators for multivariate location and scatter
## (https://www.sciencedirect.com/science/article/abs/pii/S0167947314002175)
M = length(coef(m))
T = eltype(coef(m))

if isnothing(Nsamples)
Nsamples = resampling_minN(M, 0.05, propoutliers)
Expand All @@ -1377,8 +1381,8 @@ function resampling_best_estimate(


verbose && println("Start $(Nsamples) subsamples...")
σis = zeros(Nsamples)
βis = zeros(M, Nsamples)
σis = zeros(T, Nsamples)
βis = zeros(T, M, Nsamples)
for i in 1:Nsamples
# TODO: to parallelize, make a deepcopy of m
inds = sample(rng, axes(response(m), 1), Npoints; replace=false, ordered=false)
Expand All @@ -1393,7 +1397,7 @@ function resampling_best_estimate(
# Initialize σ as mad(residuals)
setinitσ!(m)

σi = 0
σi = zero(T)
for k in 1:Nsteps_β
setη!(
m;
Expand Down
11 changes: 11 additions & 0 deletions src/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
## Missing values
################################################

function promote_to_same_float(X::AbstractMatrix, y::AbstractVector)
T = promote_type(float(eltype(X)), float(eltype(y)))
if !(T <: AbstractFloat)
msg = "promoting X and y arrays to float types"
throw(TypeError(:fit, msg, Type{<:AbstractFloat}, T))
end
MT = AbstractMatrix{T}
VT = AbstractVector{T}
return convert.(T, X)::MT, convert.(T, y)::VT
end

_missing_omit(x::AbstractArray{T}) where {T} = copyto!(similar(x, nonmissingtype(T)), x)

function StatsModels.missing_omit(X::AbstractMatrix, y::AbstractVector)
Expand Down

0 comments on commit f2c7d97

Please sign in to comment.