Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dg_ga_jvp_column_sum_v2 can break #46

Closed
patrick-kidger opened this issue Sep 4, 2020 · 7 comments
Closed

dg_ga_jvp_column_sum_v2 can break #46

patrick-kidger opened this issue Sep 4, 2020 · 7 comments
Assignees

Comments

@patrick-kidger
Copy link
Collaborator

It assumes that the batch dimension can be modified, but something like e.g.

z = (batch, channel)
def g(self, t, y):
    return torch.stack([z, y], dim=2)

will then throw an error because of the different batch dimensions. (This being a scenario that I've actually run into.)

Incidentally I'm also concerned (but have not tested) that multiplying the batch dimension will put a peak in our memory usage. Having a dramatically non-rectangular memory profile could reduce the maximum possible batch size and thus slow things down overall.

I'm wondering about switching to v1 for now? At least on test_strat.py I don't find it that much slower. (And as per above, may even be faster across a batch.)

@lxuechen lxuechen self-assigned this Sep 4, 2020
@lxuechen
Copy link
Collaborator

lxuechen commented Sep 4, 2020

The example seems interesting. I will make v1 a default for now. If you could provide some more detail on how z is passed in, then I might be able to improve the usage of v2. Seems related, I remember running into something similar where I had to "contextualize" the SDE based on a representation produced by GRUs back doing latent SDEs.

Incidentally I'm also concerned (but have not tested) that multiplying the batch dimension will put a peak in our memory usage.

This could be true when the Brownian motion dimension is large. Though, if we use adjoints, then the issue might not be as prominent if we could fit models without this term with backprop through solver.

lxuechen added a commit that referenced this issue Sep 4, 2020
lxuechen added a commit that referenced this issue Sep 4, 2020
* Add log-ODE scheme and simplify typing.

* Register log-ODE method.

* Refactor diagnostics and examples.

* Refactor plotting.

* Move btree profile to benchmarks.

* Refactor all ito diagnostics.

* Refactor.

* Split imports.

* Refactor the Stratonovich diagnostics.

* Fix documentation.

* Minor typing fix.

* Remove redundant imports.

* Fixes from comment.

* Simplify.

* Simplify.

* Fix typo caused bug.

* Fix directory issue.

* Fix order issue.

* Change back weak order.

* Fix test problem.

* Add weak order inspection.

* Bugfixes for log-ODE (#45)

* fixed rate diagnostics

* tweak

* adjusted test_strat

* fixed logODE default.

* Fix typo.

Co-authored-by: Xuechen Li <12689993+lxuechen@users.noreply.github.com>

* Default to loop-based. Fixes #46.

* Minor tweak of settings.

* Fix directory structure.

* Speed up experiments.

* Cycle through the possible line styles.

Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
@lxuechen lxuechen closed this as completed Sep 4, 2020
@patrick-kidger
Copy link
Collaborator Author

The context is that z is some additional static (not time evolving) information that is passed as additional information to the drift and diffusion.

The way I'm doing this is a bit ugly:

classs SDE(torch.nn.Module):
    sde_type = ...
    noise_type = ...

    def set_data(self, z):
        self._z = z

     def f(self, t, y):
        # use both y and _z

    ...

def somefunction(sde: SDE):
    sde.set_data(z)
    torchsde.sdeint(sde, ...)

I'm aware that z could be included in the state with zero drift/diffusion but that's even uglier IMO. (+inefficient)

Thinking about it, we could perhaps include an additional argument to sdeint, sdeint_adjoint corresponding to such static information? This would neaten the above code a lot. (And allow for v2 if we do want it over v1.)
Additionally, the above code can't reset z after calling sdeint because it still needs to be there for the backward pass; if we instead capture it as an argument then that's another wart removed.

Obviously that is departing a little further from our basic duties of solving an SDE, but I'd be happy to offer a PR on that if you're interested.

@lxuechen
Copy link
Collaborator

lxuechen commented Sep 4, 2020

Thinking about it, we could perhaps include an additional argument to sdeint, sdeint_adjoint corresponding to such static information? This would neaten the above code a lot. (And allow for v2 if we do want it over v1.)

Now that I'm starting to remember the hairy issues with latent SDE contextualization, this really makes sense. Consider especially when using adjoints, the example you presented poses an additional challenge: The grads w.r.t. z won't be recorded at all. Back in the days, I hacked the solver to make this work.

Off the top of my head, a potential modification to fix this would be to allow sdeint and sdeint_adjoint to take in additional_ys and additional_params. More explicitly, something like

sde = ... 
additional_ys = ...
additional_params = ...
ys = sdeint(sde, y0, ts, bm, additional_y=additional_y, additional_params=additional_params)
ys_from_adjoint = sdeint_adjoint(sde, y0, ts, bm, additional_y=additional_y, additional_params=additional_params)

The only thing that I'm feeling not too certain about is the format of additional_ys. Having it be a tuple of tensors of size (batch_size, d') makes sense. Though, it would be more useful if it could take in tensors of size (T, batch_size, d') (or (T - 1, batch_size, d)).

@patrick-kidger
Copy link
Collaborator Author

You're thinking that additional_ys represents this additional static state, and whilst we're at it we could add additional_params to augment SDE.parameters() for the adjoint?

If so I'd note that additional_params would only be needed in the adjoint case. We could follow torchdiffeq for consistency on this - there we called it adjoint_params, and if passed then it is used instead of the parameters of the vector field, rather than as well.

On the format of additional_ys: I'm quite keen to avoid explicitly encoding a single batch dimension.
I'd suggest essentially following what autograd.Function does on this: accept a tuple of Python objects; and if they're gradient-requiring tensors then compute gradients wrt them. Allow tensors to be of any shape.
This does mean that we can't really use v2, as we don't expect to have access to a batch dimension, but I think this kind of batch dimension hacking is quite fragile to the variety of things a user can throw at it anyway.

For speeding up v1, there is this: pytorch/pytorch#42368 which mentions the possibility of a torch.vmap, in particular with a view to batch-vjps. I don't know the state of it but it might be interesting to us.

@patrick-kidger
Copy link
Collaborator Author

Actually thinking about - with the above proposal we wouldn't need an adjoint_params. Whatever extra tensors that we need to compute gradients with can just be included in additional_ys and ignored in the drift/diffusion.

@lxuechen
Copy link
Collaborator

lxuechen commented Sep 4, 2020

Taking a step back, I think having sdeint take in additional_ys is likely going to overcomplicate the solver code. I'm not too inclined to do this at the moment.

I do feel a need to support back-propagating gradients backward towards non-parameters nodes with adjoints. I am fully aware of adjoint_params of torchdiffeq, and I can send in a PR on this.

@lxuechen
Copy link
Collaborator

lxuechen commented Sep 4, 2020

Re: torch.vmap

I'm not entirely sure this will make our lives easier. Given that there's not much documentation on what's going on there, much of this discussion seems rather like speculation in my opinion.

lxuechen added a commit that referenced this issue Oct 22, 2020
* Added BrownianInterval

* Unified base solvers

* Updated solvers to use interval interface

* Unified solvers.

* Required Python 3.6. Bumped version.

* Updated benchmarks. Fixed several bugs.

* Tweaked BrownianInterval to accept queries outside its range

* tweaked benchmark

* Added midpoint. Tweaked and fixed things.

* Tided up adjoint.

* Bye bye Python 2; fixes #14.

* tweaks from feedback

* Fix typing.

* changed version

* Rename.

* refactored settings up a level. Fixed bug in BTree.

* fixed bug with non-srk methods

* Fixed? srk noise

* Fixed SRK properly, hopefully

* fixed mistake in adjoint

* Fix broken tests and refresh documentation.

* Output type annotation.

* Rename to reverse bm.

* Fix typo.

* minor refactors in response to feedback

* Tided solvers a little further.

* Fixed strong order for midpoint

* removed unused code

* Dev kidger2 (#19)

* Many fixes.

Updated diagnostics.
Removed trapezoidal_approx
Fixed error messages for wrong methods etc.
Can now call BrownianInterval with a single value.
Fixed bug in BrownianInterval that it was always returning 0!
There's now a 2-part way of getting levy area: it has to be set as
available during __init__, and then specified that it's wanted during
__call__. This allows one to create general Brownian motions that can be
used in multiple solvers, and have each solver call just the bits it
wants.
Bugfix spacetime -> space-time
Improved checking in check_contract
Various minor tidy-ups
Can use Brownian* with any Levy area sufficient for the solver, rather
than just the minimum the solver needs.
Fixed using bm=None in sdeint and sdeint_adjoint, so that it creates an
appropriate BrownianInterval. This also makes method='srk' easy.

* Fixed ReverseBrownian

* bugfix for midpoint

* Tided base SDE classes slightly.

* spacetime->space-time; small tidy up; fix latent_sde.py example

* Add efficient gdg_jvp term for log-ODE schemes. (#20)

* Add efficient jvp for general noise and refactor surrounding.

* Add test for gdg_jvp.

* Simplify requires grad logic.

* Add more rigorous numerical tests.

* Fix all issues

* Heun's method (#24)

* Implemented Heun method

* Refactor after review

* Added docstring

* Updated heun docstring

* BrownianInterval tests + bugfixes (#28)

* In progress commit on branch dev-kidger3.

* Added tests+bugfixes for BrownianInterval

* fixed typo in docstring

* Corrections from review

* Refactor tests for

* Refactor tests for BrownianInterval.

* Refactor tests for Brownian path and Brownian tree.

* use default CPU

* Remove loop.

Co-authored-by: Xuechen Li <12689993+lxuechen@users.noreply.github.com>

* bumped numpy version (#32)

* Milstein (Strat), Milstein grad-free (Ito + Strat) (#31)

* Added milstein_grad_free, milstein_strat and milstein_strat_grad_free

* Refactor after first review

* Changes after second review

* Formatted imports

* Changed used Ex. Reversed g_prod

* Add support for Stratonovich adjoint (#21)

* Add efficient jvp for general noise and refactor surrounding.

* Add test for gdg_jvp.

* Simplify requires grad logic.

* Add more rigorous numerical tests.

* Minor refactor.

* Simplify adjoints.

* Add general noise version.

* Refactor adjoint code.

* Fix new interface.

* Add adjoint method checking.

* Fix bug in not indexing the dict.

* Fix broken tests for sdeint.

* Fix bug in selection.

* Fix flatten bug in adjoint.

* Fix zero filling bug in jvp.

* Fix bug.

* Refactor.

* Simplify tuple logic in modified Brownian.

* Remove np.searchsorted in BrownianPath.

* Make init more consistent.

* Replace np.searchsorted with bisect for speed; also fixes #29.

* Prepare levy area support for BrownianPath.

* Use torch.Generator to move to torch 1.6.0.

* Prepare space-time Levy area support for BrownianPath.

* Enable all levy area approximations for BrownianPath.

* Fix for test_sdeint.

* Fix all broken tests; all tests pass.

* Add numerical test for gradient using midpoint method for Strat.

* Support float/int time list.

* Fixes from comments.

* Additional fixes from comments.

* Fix documentation.

* Remove to for BrownianPath.

* Fix insert.

* Use none for default levy area.

* Refactor check tensor info to reduce boilerplate.

* Add a todo regarding get noise.

* Remove type checks in adjoint.

* Fixes from comments.

* Added BrownianReturn (#34)

* Added BrownianReturn

* Update utils.py

* Binterval improvements (#35)

* Tweaked to not hang on adaptive solvers.

* Updated adaptive fix

* Several fixes for tests and adjoint.

Removed some broken tests.
Added error-raising `g` to the adjoint SDE.
Fixed Milstein for adjoint.
Fixed running adjoint at all.

* fixed bug in SRK

* tided up BInterval

* variable name tweak

* Improved heuristic for BrownianInterval's dependency tree. (#40)

* [On dev branch] Tuple rewrite (#37)

* Rename plot folders from diagnostics.

* Complete tuple rewrite.

* Remove inaccurate comments.

* Minor fixes.

* Fixes.

* Remove comment.

* Fix docstring.

* Fix noise type for problem.

* Binterval recursion fix (#42)

* Improved heuristic for BrownianInterval's dependency tree.

* Inlined the recursive code to reduce number of stack frames

* Add version number.

Co-authored-by: Xuechen <12689993+lxuechen@users.noreply.github.com>

* Refactor.

* Euler-Heun method (#39)

* Implemented euler-heun

* After refactor

* Applied refactor. Added more diagnostics

* Refactor after review

* Corrected order

* Formatting

* Formatting

* BInterval - U fix (#44)

* Improved heuristic for BrownianInterval's dependency tree.

* fixed H aggregation

* Added consistency test

* test fixes

* put seed back

* from comments

* Add log-ODE scheme and simplify typing. (#43)

* Add log-ODE scheme and simplify typing.

* Register log-ODE method.

* Refactor diagnostics and examples.

* Refactor plotting.

* Move btree profile to benchmarks.

* Refactor all ito diagnostics.

* Refactor.

* Split imports.

* Refactor the Stratonovich diagnostics.

* Fix documentation.

* Minor typing fix.

* Remove redundant imports.

* Fixes from comment.

* Simplify.

* Simplify.

* Fix typo caused bug.

* Fix directory issue.

* Fix order issue.

* Change back weak order.

* Fix test problem.

* Add weak order inspection.

* Bugfixes for log-ODE (#45)

* fixed rate diagnostics

* tweak

* adjusted test_strat

* fixed logODE default.

* Fix typo.

Co-authored-by: Xuechen Li <12689993+lxuechen@users.noreply.github.com>

* Default to loop-based. Fixes #46.

* Minor tweak of settings.

* Fix directory structure.

* Speed up experiments.

* Cycle through the possible line styles.

Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>

* Simplify and fix documentation.

* Minor fixes.

- Simplify strong order assignment for euler.
- Fix bug with "space_time".

* Simplify strong order assignment for euler.

* Fix bug with space-time naming.

* Make tensors for grad for adjoint specifiable. (#52)

* Copy of #55 | Created pyproject.toml (#56)

* Skip tests if the optional C++ implementations don't compile; fixes #51.

* Create pyproject.toml

* Version add 1.6.0 and up

Co-authored-by: Xuechen <12689993+lxuechen@users.noreply.github.com>

* Latent experiment (#48)

* Latent experiment

* Refactor after review

* Fixed y0

* Added stable div

* Minor refactor

* Simplify latent sde even further.

* Added double adjoint (#49)

* Added double adjoint

* tweaks

* Updated adjoint tests

* Dev adjoint double test (#57)

* Add gradgrad check for adjoint.

* Relax tolerance.

* Refactor numerical tests.

* Remove unused import.

* Fix bug.

* Fixes from comments.

* Rename for consistency.

* Refactor comment.

* Minor refactor.

* Add adjoint support for general/scalar noise in the Ito case. (#58)

* adjusted requires_grad

Co-authored-by: Xuechen Li <12689993+lxuechen@users.noreply.github.com>

* Dev minor (#63)

* Add requirements and update latent sde.

* Fix requirements.

* Fix.

* Update documentation.

* Use split to speed things up slightly.

* Remove jit standalone.

* Enable no value arguments.

* Fix bug in args.

* Dev adjoint strat (#67)

* Remove logqp test.

* Tide examples.

* Refactor to class attribute.

* Fix gradcheck.

* Reenable adjoints.

* Typo.

* Simplify tests

* Deprecate this test.

* Add back f ito and strat.

* Simplify.

* Skip more.

* Simplify.

* Disable adaptive.

* Refactor due to change of problems.

* Reduce problem size to prevent general noise test case run for ever.

* Continuous Integration.  (#68)

* Skip tests if the optional C++ implementations don't compile; fixes #51.

* Continuous integration.

* Fix os.

* Install package before test.

* Add torch to dependency list.

* Reduce trials.

* Restrict max number of parallel runs.

* Add scipy.

* Fixes from comment.

* Reduce frequency.

* Fixes.

* Make sure run installed package.

* Add check version on pr towards master.

* Separate with blank lines.

* Loosen tolerance.

* Add badge.

* Brownian unification (#61)

* Added tol. Reduced number of generator creations. Spawn keys now of
finite length. Tidied code.

* Added BrownianPath and BrownianTree as BrownianInterval wrappers

* added trampolining

* Made Path and Tree wrappers on Interval.

* Updated tests. Fixed BrownianTree determinism. Allowed cache_size=0

* done benchmarks. Fixed adjoint bug. Removed C++ from setup.py

* fixes for benchmark

* added base brownian

* BrownianPath/Tree now with the same interface as before

* BInterval(shape->size), changed BPath and BTree to composition-over-inheritance.

* tweaks

* Fixes for CI. (#69)

* Fixes for CI.

* Tweaks to support windows.

* Patch for windows.

* Update patch for windows.

* Fix flaky tests of BInterval.

* Add fail-fast: false (#72)

* Dev methods fixes (#73)

* Fixed adaptivity checks. Improved default method selection.

* Fixes+updated sdeint tests

* adjoint method fixes

* Fixed for Py3.6

* assert->ValueError; tweaks

* Dev logqp (#75)

* Simplify.

* Add stable div utility.

* Deprecate.

* Refactor problems.

* Sync adjoint tests.

* Fix style.

* Fix import style.

* Add h to test problems.

* Add logqp.

* Logqp backwards compatibility.

* Add type annotation.

* Better documentation.

* Fixes.

* Fix notebook. (#74)

* Fix notebook.

* Remove trivial stuff.

* Fixes from comments.

* Fixes.

* More fixes.

* Outputs.

* Clean up.

* Fixes.

* fixed BInterval flakiness+slowness (#76)

* Added documentation (#71)

* Added documentation

* tweaks

* Fix significance level.

* Fix check version.

* Skip confirmation.

* Fix indentation errors.

* Update README.md

Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
Co-authored-by: Mateusz Sokół <8431159+mtsokol@users.noreply.github.com>
Co-authored-by: Sayantan Das <36279638+ucalyptus@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants