diff --git a/src/robustlinearmodel.jl b/src/robustlinearmodel.jl index a45b13c..09a2733 100644 --- a/src/robustlinearmodel.jl +++ b/src/robustlinearmodel.jl @@ -98,14 +98,8 @@ function StatsAPI.fit( end # Make sure X and y have the same float eltype - T = promote_type(float(eltype(X)), float(eltype(y))) - if !(T <: AbstractFloat) - msg = "promoting X and y arrays to float types" - throw(TypeError(:fit, msg, Type{<:AbstractFloat}, T)) - end - MT = AbstractMatrix{T} - VT = AbstractVector{T} - return fit(M, convert.(T, X)::MT, convert.(T, y)::VT, args...; kwargs...) + pX, py = promote_to_same_float(X, y) + return fit(M, pX, py, args...; kwargs...) end ## Convert from formula-data to modelmatrix-response calling form @@ -125,14 +119,8 @@ function StatsAPI.fit( # Extract arrays from data using formula f, y, X, extra = modelframe(f, data, contrasts, dropmissing, M; wts=wts) # Call the `fit` method with arrays - T = promote_type(float(eltype(X)), float(eltype(y))) - if !(T <: AbstractFloat) - msg = "promoting X and y arrays to float types" - throw(TypeError(:fit, msg, Type{<:AbstractFloat}, T)) - end - MT = AbstractMatrix{T} - VT = AbstractVector{T} - return fit(M, X::MT, y::VT, args...; wts=extra.wts, contrasts=contrasts, __formula=f, kwargs...) + pX, py = promote_to_same_float(X, y) + return fit(M, pX, py, args...; wts=extra.wts, contrasts=contrasts, __formula=f, kwargs...) end diff --git a/src/tools.jl b/src/tools.jl index cc99859..0356d02 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -4,6 +4,20 @@ ## Missing values ################################################ +function promote_to_same_float( + X::AbstractMatrix, + y::AbstractVector, +) + T = promote_type(float(eltype(X)), float(eltype(y))) + if !(T <: AbstractFloat) + msg = "promoting X and y arrays to float types" + throw(TypeError(:fit, msg, Type{<:AbstractFloat}, T)) + end + MT = AbstractMatrix{T} + VT = AbstractVector{T} + return convert.(T, X)::MT, convert.(T, y)::VT +end + _missing_omit(x::AbstractArray{T}) where {T} = copyto!(similar(x, nonmissingtype(T)), x) function StatsModels.missing_omit(X::AbstractMatrix, y::AbstractVector)