Skip to content

Commit

Permalink
fixed BInterval flakiness+slowness (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 19, 2020
1 parent 34f65e3 commit b591c6a
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/test_brownian_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
REPS = 2
MEDIUM_REPS = 25
LARGE_REPS = 500
ALPHA = 0.0001
POOL_SIZE = 48
ALPHA = 0.00001

devices = [cpu, gpu] = [torch.device('cpu'), torch.device('cuda')]

Expand All @@ -53,7 +52,7 @@ def _setup(device, levy_area_approximation, shape):
tb = torch.rand([], device=device)
ta, tb = min(ta, tb), max(ta, tb)
bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=shape, device=device,
levy_area_approximation=levy_area_approximation, pool_size=POOL_SIZE)
levy_area_approximation=levy_area_approximation)
return ta, tb, bm


Expand Down Expand Up @@ -205,7 +204,7 @@ def test_normality_conditional(device, levy_area_approximation):
t0, t1 = 0.0, 1.0
for _ in range(REPS):
bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=(LARGE_BATCH_SIZE,), device=device,
levy_area_approximation=levy_area_approximation, pool_size=POOL_SIZE)
levy_area_approximation=levy_area_approximation)

for _ in range(MEDIUM_REPS):
ta, t_, tb = sorted(npr.uniform(low=t0, high=t1, size=(3,)))
Expand Down Expand Up @@ -268,7 +267,7 @@ def test_consistency(device, levy_area_approximation):
t0, t1 = 0.0, 1.0
for _ in range(REPS):
bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=(LARGE_BATCH_SIZE,), device=device,
levy_area_approximation=levy_area_approximation, pool_size=POOL_SIZE)
levy_area_approximation=levy_area_approximation)

for _ in range(MEDIUM_REPS):
ta, t_, tb = sorted(npr.uniform(low=t0, high=t1, size=(3,)))
Expand Down

0 comments on commit b591c6a

Please sign in to comment.