-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a LogDensityProblemAD
extension so we can support Turing
#95
Comments
Fixed by #123 |
This actually isn't quite done -- we still need to specify how interface |
This should be fixed in Turing |
@yebai I think this is now sorted on Tapir v0.2.3. @yebai @torfjelde is there any reason to add an integration test that runs more of the Turing.jl pipeline than I'm currently doing in my integration tests? e.g. run sampling on a model using the interface that we expect users to play with? |
I don't think we need to run sampling here, but it would be good to add an additional test for |
Here is an example forcing Tuirng to use using Turing, AbstractMCMC, Tapir, ADTypes, LogDensityProblems, LogDensityProblemsAD
@model function demo(x)
m ~ Normal()
x ~ Normal(m, 1)
end
function AbstractMCMC.LogDensityModel(m::Turing.DynamicPPL.Model, adtype::ADTypes.AbstractADType)
f = LogDensityFunction(m, DynamicPPL.SimpleVarInfo(m))
adf = AbstractMCMC.LogDensityModel(LogDensityProblemsAD.ADgradient(adtype, f))
return adf
end
f = AbstractMCMC.LogDensityModel(demo(1.), AutoTapir());
# compute log density
initial_params = rand(LogDensityProblems.dimension(f.logdensity))
LogDensityProblems.logdensity_and_gradient(f.logdensity, initial_params)
# sampling using NUTS
using AdvancedHMC
n_samples, n_adapts, δ = 1_000, 2_000, 0.8
samples = AbstractMCMC.sample(
f,
AdvancedHMC.NUTS(0.8),
n_adapts + n_samples;
nadapts = n_adapts,
initial_params = initial_params
);
|
I'll look at incorporating this in the near future. |
I've been meaning to do a LogDensityProblemsAD PR that switches everything to ADTypes + DifferentiationInterface. Maybe this is a sign |
I, for one, would be in favour of having fewer things to maintain. |
Well here you go: tpapp/LogDensityProblemsAD.jl#29 |
If either of you can spare the time to review, it might help Tamas who is not familiar with DI (and me who is not familiar with Turing and its inner workings ^^) |
Does Tapir.jl not work with |
I've only really tested it with |
The PR to LogDensityProblemsAD is nearly complete if you wanna help |
When you're talking about
Even if we allow switching, It's also just a thing where @model function demo()
x = Vector{Float64}(undef, 2)
x[1:2][:][1] ~ Normal()
x[2] ~ Normal()
end reconstructing this in a way that is compatible with |
Other than the Julia compat issue mentioned by @devmotion, Turing.jl doesn't really need much here I think:) As in, it doesn't "really matter" for Turing.jl what LogDensityProblemsAD.jl uses under the hood, as long as ADTypes.jl is still the way to specify which backend to use 👍 |
I know, and I fixed the Julia compat issue (with a lot of blood, sweat and tears). My main challenge now is convincing Turing people than the switch to DI is a good idea, so that it can get merged ^^ And there's a bug on Tracker that I still haven't figured out |
Lovely ❤️
You mean the LogDensityProblems "people"? I think from our side, there's no other concerns? |
As @torfjelde said, for Turing it does not matter what LogDensityProblemsAD is doing under the hood (if it does not cause any regressions) as long as the API and compatibilities are not broken. |
I did my best to keep them working in that PR, so hopefully this won't be a problem |
But come to think of it, it would be nice to have downstream tests. Maybe I can open a Turing PR that uses this new branch of LogDensityProblemsAD, just to see what might break? |
@torfjelde @willtebbutt It would be good to focus efforts on making |
This is probably best discussed in a separate issue. But I don't think Tapir.jl will work with the What's the actual reason for why Tapir.jl doesn't work with |
The current reason (I think) is Tapir's lack of suppport for
I suspect so -- I'm guessing we don't get type stability with the |
Ah gotcha 👍 Lovely:)
Exactly. But once I've reworked some of the internals of |
I'm closing this in favour of #132 as I believe we've covered all of the ground that is Tapir.jl-specific. Please do re-open if you think I'm missing something! |
Let's create a
LogDensityProblemAD
extension withinTaped.jl
. That should enable us to play with more Turing models before this package gets officially registered.See, e.g.: https://github.com/tpapp/LogDensityProblemsAD.jl/blob/master/ext/LogDensityProblemsADZygoteExt.jl
The text was updated successfully, but these errors were encountered: