Skip to content

Commit

Permalink
Patch GLM (#41)
Browse files Browse the repository at this point in the history
* ensure same eltype for X and y (JuliaStats/GLM.jl#369)

* correctly use QR

* use :cholesky and dropcollinear like in GLM

* test method :cholesky

* create own DensePredQR

* cleanup
  • Loading branch information
getzze committed Jun 2, 2023
1 parent 70b6ecb commit 35f270c
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 86 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
16 changes: 11 additions & 5 deletions src/RobustModels.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module RobustModels

using Pkg: Pkg

include("compat.jl")

# Use README as the docstring of the module and doctest README
Expand All @@ -10,10 +12,10 @@ end RobustModels

# Import with `using` to use the module names to prefix the methods
# that are extended from these modules
using GLM
using StatsAPI
using StatsBase
using StatsModels
using GLM: GLM
using StatsAPI: StatsAPI
using StatsBase: StatsBase
using StatsModels: StatsModels

## Import to implement new methods
import Base: show, broadcastable, convert, ==
Expand Down Expand Up @@ -73,18 +75,21 @@ using LinearAlgebra:
inv,
diag,
diagm,
rank,
ldiv!

using Random: AbstractRNG, GLOBAL_RNG
using Printf: @printf, @sprintf
using GLM: FPVector, lm, SparsePredChol, DensePredChol, DensePredQR
using GLM: FPVector, lm, SparsePredChol, DensePredChol
using StatsBase:
AbstractWeights, CoefTable, ConvergenceException, median, mad, mad_constant, sample
using StatsModels:
@delegate,
@formula,
formula,
RegressionModel,
FormulaTerm,
InterceptTerm,
ModelFrame,
modelcols,
apply_schema,
Expand Down Expand Up @@ -238,6 +243,7 @@ abstract type AbstractRegularizedPred{T} end
Base.broadcastable(m::T) where {T<:AbstractEstimator} = Ref(m)
Base.broadcastable(m::T) where {T<:LossFunction} = Ref(m)


include("tools.jl")
include("losses.jl")
include("estimators.jl")
Expand Down
16 changes: 15 additions & 1 deletion src/compat.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
using LinearAlgebra: cholesky!
using LinearAlgebra: cholesky!, qr!

function get_pkg_version(m::Module)
toml = Pkg.TOML.parsefile(joinpath(pkgdir(m), "Project.toml"))
return VersionNumber(toml["version"])
end


## Compatibility layers

# https://github.com/JuliaStats/GLM.jl/pull/459
@static if VERSION < v"1.8.0-DEV.1139"
pivoted_cholesky!(A; kwargs...) = cholesky!(A, Val(true); kwargs...)
else
using LinearAlgebra: RowMaximum
pivoted_cholesky!(A; kwargs...) = cholesky!(A, RowMaximum(); kwargs...)
end

@static if VERSION < v"1.7.0"
pivoted_qr!(A; kwargs...) = qr!(A, Val(true); kwargs...)
else
using LinearAlgebra: ColumnNorm
pivoted_qr!(A; kwargs...) = qr!(A, ColumnNorm(); kwargs...)
end
176 changes: 130 additions & 46 deletions src/linpred.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,45 +50,143 @@ leverage_weights(p::LinPred, wt::AbstractVector) = sqrt.(1 .- leverage(p, wt))
# beta0
#end

"""
DensePredQR

A `LinPred` type with a dense, unpivoted QR decomposition of `X`
##########################################
###### DensePredQR
##########################################

# Members
@static if get_pkg_version(GLM) < v"1.9"
@warn(
"GLM.DensePredQR(X::AbstractMatrix, pivot::Bool=true) is not defined, " *
"fallback to unpivoted RobustModels.DensePredQR definition. " *
"To use pivoted QR, GLM version should be greater than or equal to v1.9."
)

- `X`: Model matrix of size `n` × `p` with `n ≥ p`. Should be full column rank.
- `beta0`: base coefficient vector of length `p`
- `delbeta`: increment to coefficient vector, also of length `p`
- `scratchbeta`: scratch vector of length `p`, used in `linpred!` method
- `qr`: a `QRCompactWY` object created from `X`, with optional row weights.
"""
DensePredQR

PRED_QR_WARNING_ISSUED = false

function qrpred(X::AbstractMatrix, pivot::Bool=false)
try
return DensePredCG(Matrix(X), pivot)
catch e
if e isa MethodError
# GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined
global PRED_QR_WARNING_ISSUED
if !PRED_QR_WARNING_ISSUED
@warn(
"GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined, " *
"fallback to unpivoted QR. GLM version should be >= 1.9."
)
PRED_QR_WARNING_ISSUED = true
using LinearAlgebra: QRCompactWY, QRPivoted, Diagonal, qr!, qr

"""
DensePredQR
A `LinPred` type with a dense QR decomposition of `X`
# Members
- `X`: Model matrix of size `n` × `p` with `n ≥ p`. Should be full column rank.
- `beta0`: base coefficient vector of length `p`
- `delbeta`: increment to coefficient vector, also of length `p`
- `scratchbeta`: scratch vector of length `p`, used in `linpred!` method
- `qr`: a `QRCompactWY` object created from `X`, with optional row weights.
- `scratchm1`: scratch Matrix{T} of the same size as `X`
- `scratchm2`: scratch Matrix{T} of the same size as `X`
- `scratchR`: scratch Matrix{T} of the same size as `qr.R`, a square matrix.
"""
mutable struct DensePredQR{T<:BlasReal,Q<:Union{QRCompactWY,QRPivoted}} <: DensePred
X::Matrix{T} # model matrix
beta0::Vector{T} # base coefficient vector
delbeta::Vector{T} # coefficient increment
scratchbeta::Vector{T}
qr::Q
scratchm1::Matrix{T}
scratchm2::Matrix{T}
scratchR::Matrix{T}

function DensePredQR(X::AbstractMatrix, pivot::Bool=false)
n, p = size(X)
T = typeof(float(zero(eltype(X))))

if false
# if pivot
F = pivoted_qr!(copy(X))
else
if n >= p
F = qr(X)
else
# adjoint of X so R is square
# cannot use in-place qr!
F = qr(X)
end
end
return DensePredCG(Matrix(X))

return new{T,typeof(F)}(
Matrix{T}(X),
zeros(T, p),
zeros(T, p),
zeros(T, p),
F,
similar(X, T),
similar(X, T),
zeros(T, size(F.R)),
)
end
end

# GLM.DensePredQR(X::AbstractMatrix, pivot::Bool) is not defined
function qrpred(X::AbstractMatrix, pivot::Bool=false)
return DensePredQR(Matrix(X))
end

# GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}) is ill-defined
function delbeta!(p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}) where {T<:BlasReal}
n, m = size(p.X)
if n >= m
p.delbeta = p.qr \ r
else
p.delbeta = p.qr' \ r
end
return p
end

# GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}, wt::Vector{T}) is not defined
function delbeta!(
p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}, wt::Vector{T}
) where {T<:BlasReal}
rnk = rank(p.qr.R)
X = p.X
W = Diagonal(wt)
sqrtW = Diagonal(sqrt.(wt))
scratchm1 = p.scratchm1 = similar(X, T)
mul!(scratchm1, sqrtW, X)

n, m = size(X)
if n >= m
# W½ X = Q R , with Q'Q = I
# X'WX β = X'y => R'Q'QR β = X'y
# => β = R⁻¹ R⁻ᵀ X'y
qnr = p.qr = qr(scratchm1)
Rinv = p.scratchR = inv(qnr.R)

scratchm2 = p.scratchm2 = similar(X, T)
mul!(scratchm2, W, X)
mul!(p.delbeta, transpose(scratchm2), r)

p.delbeta = Rinv * Rinv' * p.delbeta
else
rethrow()
# (W½ X)' = Q R , with Q'Q = I
# W½X β = W½y => R'Q' β = y
# => β = Q . [R⁻ᵀ y; 0]
qnrT = p.qr = qr(scratchm1')
RTinv = p.scratchR = inv(qnrT.R)'
@assert 1 <= n <= size(p.delbeta, 1)
mul!(view(p.delbeta, 1:n), RTinv, r)
p.delbeta = zeros(size(p.delbeta))
p.delbeta[1:n] .= RTinv * r
lmul!(qnrT.Q, p.delbeta)
end
return p
end


## Use DensePredQR from GLM
else
using GLM: DensePredQR
import GLM: qrpred
end


##########################################
###### [Dense/Sparse]PredCG
##########################################

"""
DensePredCG
Expand All @@ -109,20 +207,8 @@ mutable struct DensePredCG{T<:BlasReal} <: DensePred
scratchbeta::Vector{T}
scratchm1::Matrix{T}
scratchr1::Vector{T}
function DensePredCG{T}(X::Matrix{T}, beta0::Vector{T}) where {T}
n, p = size(X)
length(beta0) == p || throw(DimensionMismatch("length(β0) ≠ size(X,2)"))
return new{T}(
X,
beta0,
zeros(T, p),
zeros(T, (p, p)),
zeros(T, p),
zeros(T, (n, p)),
zeros(T, n),
)
end
function DensePredCG{T}(X::Matrix{T}) where {T}

function DensePredCG(X::Matrix{T}) where {T<:BlasReal}
n, p = size(X)
return new{T}(
X,
Expand All @@ -135,10 +221,8 @@ mutable struct DensePredCG{T<:BlasReal} <: DensePred
)
end
end
DensePredCG(X::Matrix, beta0::Vector) = DensePredCG{eltype(X)}(X, beta0)
DensePredCG(X::Matrix{T}) where {T} = DensePredCG{T}(X, zeros(T, size(X, 2)))
function Base.convert(::Type{DensePredCG{T}}, X::Matrix{T}) where {T}
return DensePredCG{T}(X, zeros(T, size(X, 2)))
return DensePredCG(X)
end

# Compatibility with cholpred(X, pivot)
Expand Down
3 changes: 3 additions & 0 deletions src/regularizedpred.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ function postupdate_λ!(r::RidgePred)
# Update the extended model matrix with the new value
GG = r.sqrtλ * r.G
@views r.pred.X[(n + 1):(n + m), :] .= GG

# Update other fields
# TODO: update DensePredQR
if isa(r.pred, DensePredChol)
# Recompute the cholesky decomposition
X = r.pred.X
Expand Down
Loading

0 comments on commit 35f270c

Please sign in to comment.