Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computational time for Brownian Interval #109

Open
qsh-zh opened this issue Jan 6, 2022 · 2 comments
Open

Computational time for Brownian Interval #109

qsh-zh opened this issue Jan 6, 2022 · 2 comments

Comments

@qsh-zh
Copy link

qsh-zh commented Jan 6, 2022

I observed something strange about computation time for brownain interval

sde.cnt = 0
%timeit -n 2 -r 7 torchsde.sdeint(sde, y0, th.tensor([0.0, 1.0]), dt=0.01, method="euler")
print(sde.cnt)
# 1.87 s ± 60.6 ms per loop (mean ± std. dev. of 7 runs, 2 loops each)
# 1428

sde.cnt = 0
%timeit -n 2 -r 7 torchsde.sdeint(sde, y0, th.tensor([0.0, 5.0]), dt=0.05, method="euler")
print(sde.cnt)
# 57.3 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 2 loops each)
# 1414

sde.cnt = 0
%timeit -n 2 -r 7 torchsde.sdeint(sde, y0, th.tensor([0.0, 10.0]), dt=0.1, method="euler")
print(sde.cnt)
# 57.2 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 2 loops each)
# 1414

where the sde is very similar to the one defined in the Quick example in README. In the above three examples, I change the different ts and dt. I think they should have roughly the same computation time. But it turns out the time used by the line are very different. According to the paper, the worse case should roughly be O(log T/dt) if I understand correctly. Why the first case is so slow?

@qsh-zh
Copy link
Author

qsh-zh commented Jan 6, 2022

When I change from th.tensor([0.0, 1.0]), dt=0.01 to th.tensor([0.0, 5.0]), dt=0.05.
image
image

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Jan 6, 2022

So this is a bit weird! In particular because one tends to imagine a BrownianInterval being scale-invariant. I've been able to figure out why this happens, but I don't have a fix in mind yet.

The reason has to do with the binary tree heuristic built in to the BrownianInterval. Once >100 queries have been made, the BrownianInterval averages the step sizes it has observed over those queries, and uses those as an estimate of the average step size for the rest of the SDE solve. This is used to build up a binary tree, as per Section E.2, "Backward pass" of the paper. (Which I refer to as it sounds like you've read it.)

The dt=0.01 case makes 101 steps, which triggers the calculation of this heuristic. Evaluating that heuristic (building up the binary tree) is what takes up so much time. The other cases make only 100 steps, and the solve actually completes before the heuristic even triggers.

Why different number of steps between these apparently-scale-invariant cases? Floating point inaccuracies:

>>> import torch
>>> x = torch.tensor(0.01)
>>> sum = 0
>>> for _ in range(100):
...     sum = sum + x
...
>>> sum
tensor(1.0000)
>>> sum < 1
tensor(True)

The real bug here is simply that the heuristic takes so much time to compute. I'll need to have a deeper look, later, to figure out what might be done to resolve this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants