Skip to content

Commit

Permalink
correctly use QR
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Jun 2, 2023
1 parent b3ef64a commit 348603b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 26 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
32 changes: 24 additions & 8 deletions src/RobustModels.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module RobustModels

using Pkg: Pkg

include("compat.jl")

# Use README as the docstring of the module and doctest README
Expand All @@ -10,10 +12,10 @@ end RobustModels

# Import with `using` to use the module names to prefix the methods
# that are extended from these modules
using GLM
using StatsAPI
using StatsBase
using StatsModels
using GLM: GLM
using StatsAPI: StatsAPI
using StatsBase: StatsBase
using StatsModels: StatsModels

## Import to implement new methods
import Base:
Expand All @@ -32,14 +34,27 @@ import StatsModels:
using Distributions: ccdf, pdf, quantile, Normal, Chisq, TDist, FDist
using SparseArrays: SparseMatrixCSC, spdiagm
using LinearAlgebra: dot, tr, I, UniformScaling, rmul!, lmul!, mul!, BlasReal, Hermitian, transpose,
inv, diag, diagm, ldiv!
inv, diag, diagm, Diagonal, rank, qr, ldiv!

using Random: AbstractRNG, GLOBAL_RNG
using Printf: @printf, @sprintf
using GLM: FPVector, lm, SparsePredChol, DensePredChol, DensePredQR
using StatsBase: AbstractWeights, CoefTable, ConvergenceException, median, mad, mad_constant, sample
using StatsModels: @delegate, @formula, RegressionModel, FormulaTerm, ModelFrame, modelcols,
apply_schema, schema, checknamesexist, checkcol, termvars
using StatsBase:
AbstractWeights, CoefTable, ConvergenceException, median, mad, mad_constant, sample
using StatsModels:
@delegate,
@formula,
formula,
RegressionModel,
FormulaTerm,
InterceptTerm,
ModelFrame,
modelcols,
apply_schema,
schema,
checknamesexist,
checkcol,
termvars
using IterativeSolvers: cg!
using Tables
using Roots: find_zero, Order1, ConvergenceFailed
Expand Down Expand Up @@ -186,6 +201,7 @@ abstract type AbstractRegularizedPred{T} end
Base.broadcastable(m::T) where {T<:AbstractEstimator} = Ref(m)
Base.broadcastable(m::T) where {T<:LossFunction} = Ref(m)


include("tools.jl")
include("losses.jl")
include("estimators.jl")
Expand Down
83 changes: 65 additions & 18 deletions src/linpred.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,74 @@ A `LinPred` type with a dense, unpivoted QR decomposition of `X`
"""
DensePredQR

PRED_QR_WARNING_ISSUED = false

function qrpred(X::AbstractMatrix, pivot::Bool=false)
p = try
DensePredCG(Matrix(X), pivot)
catch e
if e isa MethodError
# GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined
global PRED_QR_WARNING_ISSUED
if !PRED_QR_WARNING_ISSUED
@warn(
"GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined, " *
"fallback to unpivoted QR. GLM version should be >= 1.9."
)
PRED_QR_WARNING_ISSUED = true
end
DensePredCG(Matrix(X))
function get_pkg_version(m::Module)
toml = Pkg.TOML.parsefile(joinpath(pkgdir(m), "Project.toml"))
return VersionNumber(toml["version"])
end


@static if get_pkg_version(GLM) < v"1.9"
@warn(
"GLM.DensePredQR(X::AbstractMatrix, pivot::Bool) is not defined, " *
"fallback to unpivoted QR. GLM version should be >= 1.9."
)

# GLM.DensePredQR(X::AbstractMatrix, pivot::Bool) is not defined
function qrpred(X::AbstractMatrix, pivot::Bool=false)
DensePredQR(Matrix(X))
end

# GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}, wt::Vector{T}) is not defined
function delbeta!(p::DensePredQR{T}, r::Vector{T}, wt::Vector{T}) where T<:BlasReal
rnk = rank(p.qr.R)
X = p.X
W = Diagonal(wt)
sqrtW = Diagonal(sqrt.(wt))
scratchm1 = similar(X, T)
mul!(scratchm1, sqrtW, X)

n, m = size(X)
if n >= m
# W½ X = Q R , with Q'Q = I
# X'WX β = X'y => R'Q'QR β = X'y
# => β = R⁻¹ R⁻ᵀ X'y
qnr = qr(scratchm1)
Rinv = inv(qnr.R)

scratchm2 = similar(X, T)
mul!(scratchm2, W, X)
mul!(p.delbeta, transpose(scratchm2), r)

p.delbeta = Rinv * Rinv' * p.delbeta
else
# (W½ X)' = Q R , with Q'Q = I
# W½X β = W½y => R'Q' β = y
# => β = Q . [R⁻ᵀ y; 0]
qnr = qr(scratchm1')
RTinv = inv(qnr.R)'
@assert 1 <= n <= size(p.delbeta, 1)
mul!(view(p.delbeta, 1:n), RTinv, r)
p.delbeta = zeros(size(p.delbeta))
p.delbeta[1:n] .= RTinv * r
lmul!(qnr.Q, p.delbeta)

Check warning on line 117 in src/linpred.jl

View check run for this annotation

Codecov / codecov/patch

src/linpred.jl#L111-L117

Added lines #L111 - L117 were not covered by tests
end
return p
end

# GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}) is ill-defined
function delbeta!(p::DensePredQR{T}, r::Vector{T}) where T<:BlasReal
n, m = size(p.X)
if n >= m
p.delbeta = p.qr \ r

Check warning on line 126 in src/linpred.jl

View check run for this annotation

Codecov / codecov/patch

src/linpred.jl#L123-L126

Added lines #L123 - L126 were not covered by tests
else
rethrow()
qnrT = qr(p.X')
p.delbeta = qnrT' \ r

Check warning on line 129 in src/linpred.jl

View check run for this annotation

Codecov / codecov/patch

src/linpred.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end
return p

Check warning on line 131 in src/linpred.jl

View check run for this annotation

Codecov / codecov/patch

src/linpred.jl#L131

Added line #L131 was not covered by tests
end

else
qrpred(X::AbstractMatrix, pivot::Bool=false) = DensePredQR(Matrix(X), pivot)

Check warning on line 135 in src/linpred.jl

View check run for this annotation

Codecov / codecov/patch

src/linpred.jl#L135

Added line #L135 was not covered by tests
end


Expand Down

0 comments on commit 348603b

Please sign in to comment.