Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparsity detection fallback method #263

Closed
ElOceanografo opened this issue May 16, 2024 · 9 comments · Fixed by #297
Closed

Sparsity detection fallback method #263

ElOceanografo opened this issue May 16, 2024 · 9 comments · Fixed by #297
Labels
core Related to the core utilities of the package

Comments

@ElOceanografo
Copy link

Both SparseConnectivityTracer and Symbolics fail when trying to detect sparsity through functions that don't accept generic number types...external linear algebra libraries being one example. I don't know what the ultimate solution is for those packages, but it would be good to have some kind of fallback method here, even if it's just calculating a dense Jacobian/Hessian and seeing which elements are zero. (I guess that's less than ideal, but maybe better than nothing? Or maybe there's a more clever algorithm I don't know about?)

A minimal (non-)working example:

using DifferentiationInterface, ReverseDiff, SparseConnectivityTracer, LinearAlgebra

const y = randn(10)

function f(u)
    Q = diagm(exp.(u))
    return logdet(Q) - y' * Q * y
end

backend = AutoSparse(AutoReverseDiff(), 
    sparsity_detector=TracerSparsityDetector(), 
    coloring_algorithm=GreedyColoringAlgorithm())

hessian(f, backend, randn(10))    
@gdalle
Copy link
Member

gdalle commented May 16, 2024

Hi @ElOceanografo, thanks for reporting this!

I have opened adrhill/SparseConnectivityTracer.jl#68 on the SparseConnectivityTracer side to track it. Did you encounter the same error message? I think that particular behavior is not a limitation of SCT but just a plain old bug, we'll fix it soon.

As for your more general question, "what to do when I can't trace through my function?", the existing options are:

  • Use a dense backend like AutoReverseDiff()
  • Don't specify the sparsity detector, which makes ADTypes default to NoSparsityDetector() and essentially does the same as a dense backend
julia> AutoSparse(AutoReverseDiff(); coloring_algorithm=GreedyColoringAlgorithm())
AutoSparse{AutoReverseDiff, ADTypes.NoSparsityDetector, GreedyColoringAlgorithm}(AutoReverseDiff(false), ADTypes.NoSparsityDetector(), GreedyColoringAlgorithm())
  • Define a new AbstractSparsityDetector following the ADTypes.jl interface, which uses a given backend to compute one dense Jacobian / Hessian, and then compares its elements to zero with a fixed tolerance. Here's an implementation:
using ADTypes
using DifferentiationInterface
using SparseArrays

struct DenseSparsityDetector{B} <: ADTypes.AbstractSparsityDetector
    backend::B
    atol::Float64
end

function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector)
    J = jacobian(f, detector.backend, x)
    return sparse(abs.(J) .> detector.atol)
end

function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector)
    J = jacobian(f!, y, detector.backend, x)
    return sparse(abs.(J) .> detector.atol)
end

function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector)
    H = hessian(f, detector.backend, x)
    return sparse(abs.(H) .> detector.atol)
end

If you think this belongs in DifferentiationInterface, I can add it, or at least document it.
Once you have this detector, you can plug it inside the AutoSparse backend as usual.

using DifferentiationInterface
using ForwardDiff: ForwardDiff
using Zygote: Zygote

dense_backend = SecondOrder(AutoForwardDiff(), AutoZygote())

sparse_backend = AutoSparse(
    dense_backend;
    sparsity_detector=DenseSparsityDetector(dense_backend, 1e-5),
    coloring_algorithm=GreedyColoringAlgorithm(),
)

extras = prepare_hessian(f, sparse_backend, randn(10))

hessian(f, sparse_backend, randn(10), extras)

A few remarks:

  • Don't forget to prepare the sparse Hessian computation, so that you only run the dense hessian and subsequent coloring once. As long as the control flow in your function doesn't change, you can reuse them for however many sparse Hessians.
  • Right now your best bet for Hessians is ForwardDiff over Zygote. The sparse Hessian computation doesn't follow the same code path as the dense one, so you cannot take advantage of the caching mechanisms used by e.g. ReverseDiff. In addition, pullbacks with ReverseDiff are slow (and not even implemented in the public API).

@gdalle
Copy link
Member

gdalle commented May 16, 2024

Out of curiosity, what are you using sparse Hessians for? We're currently writing a paper on sparse autodiff and looking for cool application examples

@adrhill
Copy link
Collaborator

adrhill commented May 16, 2024

Thanks for raising the issue, this should be a simple fix.
We're in midst of a large refactor of SCT in adrhill/SparseConnectivityTracer.jl#65, so this might take a couple of days to trickle down to a tagged DI release.

@adrhill
Copy link
Collaborator

adrhill commented May 17, 2024

I just tagged SparseConnectivityTracer v0.4.0. Let me know whether this fixes the issue for you.

@ElOceanografo
Copy link
Author

Thanks for the suggestions!

Maybe this problem is more relevant to SparseConnectivityTracer...@adrhill, is it your goal for SCT to be able to trace through external calls to BLAS/LAPACK/SuiteSparse/etc? If so, that's fantastic. Though based on my experience with the rest of the AD ecosystem, I imagine that's no small task. Even f that capability is expected in SCT, I do think it would be great to have a fallback DenseSparsityDetector available here or in ADTypes to handle the inevitable corner cases.

I am using sparse Hessians to average out random effects in large statistical models using the Laplace approximation in MarginalLogDensities. Basically, trying to create a Julia implementation of Template Model Builder (R package, paper). It's something I've been working on on and off for a few years, but I've never had the time (or, frankly, the AD expertise) to get performance on par with TMB. The likelihood functions I'm interested in differentiating have structure similar to my MWE above, except Q is a much bigger unstructured sparse matrix.

Seeing DifferentiationInterface and ADTypes 1.0 come together over the past couple of months has been really exciting, since they solve a lot of the design problems related to sparse Hessians and the various AD backends that I haven't been able to figure out in an elegant way myself. So thank you! Once these packages stabilize a bit I plan to do a significant refactor of MarginalLogDensities to take advantage of them. If you want some applications for a paper, I'd be happy to chat.

@ElOceanografo
Copy link
Author

I just tagged SparseConnectivityTracer v0.4.0. Let me know whether this fixes the issue for you.

Unfortunately, it's still not working for me. I followed up on adrhill/SparseConnectivityTracer.jl#68

@gdalle
Copy link
Member

gdalle commented May 18, 2024

Thanks for the additional details!

I'll add DenseSparsityDetector to DifferentiationInterface next week since that's where it naturally belongs.

How cool that your application is Laplace approximation, I just added it to the list of motivations for our paper 2 days ago ;) Do you have any insights on the sparse autodiff capabilities of R? Is TMB something I should add to our SoTA overview? Are there other such packages?

@ElOceanografo
Copy link
Author

As far as I know, there isn't any other software package, in R or elsewhere, that is quite comparable to TMB. It lets the user supply an arbitrary likelihood function and specify which parameters should be integrated out via the Laplace approximation, doing sparsity detection for the Hessian automatically. It also implements a derivative rule for the Laplace approximation itself, to enable gradient-based optimization of the parameters that are not marginalized. TMB has enabled a lot of advances in my field (quantitative ecology/fisheries science/stock assessment) over the past 5-10 years, particularly in fitting hierarchical spatio-temporal models.

However, all the AD and sparsity detection is implemented in C++, not R, so to define your model you need to write it in C++, which is a significant barrier for the mostly-R-using target audience. ADMB is an older C++ package for AD-based optimization, minus the Laplace approximations. And Stan is the other big way to leverage autodiff in R, though that also requires learning a variant of C++ to build your models. I also just came across the autodiffr package, but it seems to be a wrapper for Julia's ForwardDiff and ReverseDiff ;)

Basically, the autodiff (let alone sparse autodiff) capabilities of R all rely on other languages, particularly C++. A classic two-language problem that Julia should be able to solve. Fast, robust sparsity detection and forward-over-reverse sparse Hessians are two of the remaining missing pieces to replicating TMB, so I'm hopeful it can be done in the near future...

A few older threads from Discourse:
https://discourse.julialang.org/t/is-there-any-glmmtmb-package-for-julia/19868
https://discourse.julialang.org/t/laplace-approximation-for-mixed-models-in-julia-tmb-like-functionality/61796
https://discourse.julialang.org/t/ann-announcing-marginallogdensities-jl/99958

This has drifted pretty far from the original issue at this point, if you've got more questions on this feel free to message me on Slack!

@gdalle
Copy link
Member

gdalle commented May 18, 2024

Sorry for the detour but thanks a lot, it's very useful information for us to position our paper about it!
Leaving this issue open until I implement the DenseSparsityDetector

@gdalle gdalle added the core Related to the core utilities of the package label May 27, 2024
@gdalle gdalle linked a pull request Jun 4, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Related to the core utilities of the package
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants