Skip to content

Commit

Permalink
[On dev branch] Tuple rewrite (#37)
Browse files Browse the repository at this point in the history
* Rename plot folders from diagnostics.

* Complete tuple rewrite.

* Remove inaccurate comments.

* Minor fixes.

* Fixes.

* Remove comment.

* Fix docstring.

* Fix noise type for problem.
  • Loading branch information
lxuechen committed Aug 25, 2020
1 parent 495ea3c commit 19062d5
Show file tree
Hide file tree
Showing 29 changed files with 532 additions and 1,057 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -12,3 +12,4 @@ benchmarks/plots/
CMakeLists.txt
restats
*-darwin.so
**.pyc
8 changes: 4 additions & 4 deletions diagnostics/ito_additive.py
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
import numpy as np
from scipy import stats

from tests.basic_sde import AdditiveSDE
Expand Down Expand Up @@ -51,7 +51,7 @@ def inspect_samples():
ts_, ys_em_, ys_srk_, ys_true_ = to_numpy(ts, ys_em, ys_srk, ys_true)

# Visualize sample path.
img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_additive')
makedirs_if_not_found(img_dir)

for i, (ys_em_i, ys_srk_i, ys_true_i) in enumerate(zip(ys_em_, ys_srk_, ys_true_)):
Expand Down Expand Up @@ -105,7 +105,7 @@ def inspect_strong_order():
plt.yscale('log')
plt.legend()

img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_additive')
makedirs_if_not_found(img_dir)
plt.savefig(os.path.join(img_dir, 'rate'))
plt.close()
Expand Down
12 changes: 7 additions & 5 deletions diagnostics/ito_diagonal.py
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
Expand Down Expand Up @@ -56,7 +56,7 @@ def inspect_sample():
ts, ys_euler, ys_milstein, ys_milstein_grad_free, ys_srk, ys_analytical)

# Visualize sample path.
img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_diagonal')
img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_diagonal')
makedirs_if_not_found(img_dir)

for i, (ys_euler_i, ys_milstein_i, ys_milstein_grad_free_i, ys_srk_i, ys_analytical_i) in enumerate(
Expand Down Expand Up @@ -92,7 +92,8 @@ def inspect_strong_order():
# Only take end value.
_, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
_, ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein')
_, ys_milstein_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', options={'grad_free': True})
_, ys_milstein_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein',
options={'grad_free': True})
_, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk')
_, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

Expand All @@ -101,7 +102,8 @@ def inspect_strong_order():
milstein_grad_free_mse = compute_mse(ys_milstein_grad_free, ys_analytical)
srk_mse = compute_mse(ys_srk, ys_analytical)

euler_mse_, milstein_mse_, milstein_grad_free_mse_, srk_mse_ = to_numpy(euler_mse, milstein_mse, milstein_grad_free_mse, srk_mse)
euler_mse_, milstein_mse_, milstein_grad_free_mse_, srk_mse_ = to_numpy(
euler_mse, milstein_mse, milstein_grad_free_mse, srk_mse)

euler_mses_.append(euler_mse_)
milstein_mses_.append(milstein_mse_)
Expand All @@ -125,7 +127,7 @@ def inspect_strong_order():
plt.yscale('log')
plt.legend()

img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_diagonal')
img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_diagonal')
makedirs_if_not_found(img_dir)
plt.savefig(os.path.join(img_dir, 'rate'))
plt.close()
Expand Down
12 changes: 6 additions & 6 deletions diagnostics/ito_scalar.py
Expand Up @@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from scipy import stats

from tests.problems import Ex2
from tests.problems import Ex2Scalar
from torchsde import sdeint, BrownianInterval
from torchsde.settings import LEVY_AREA_APPROXIMATIONS
from .utils import to_numpy, makedirs_if_not_found, compute_mse
Expand All @@ -34,7 +34,7 @@ def inspect_sample():
ts = torch.linspace(0., 5., steps=steps).to(device)
dt = 1e-1
y0 = torch.ones(batch_size, d).to(device)
sde = Ex2(d=d).to(device)
sde = Ex2Scalar(d=d).to(device)
sde.noise_type = "scalar"

with torch.no_grad():
Expand All @@ -54,7 +54,7 @@ def inspect_sample():
ts, ys_euler, ys_milstein, ys_srk, ys_analytical)

# Visualize sample path.
img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar')
img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_scalar')
makedirs_if_not_found(img_dir)

for i, (ys_euler_i, ys_milstein_i, ys_srk_i, ys_analytical_i) in enumerate(
Expand All @@ -74,7 +74,7 @@ def inspect_strong_order():
ts = torch.tensor([0., 5.]).to(device)
dts = tuple(2 ** -i for i in range(1, 9))
y0 = torch.ones(batch_size, d).to(device)
sde = Ex2(d=d).to(device)
sde = Ex2Scalar(d=d).to(device)

euler_mses_ = []
milstein_mses_ = []
Expand Down Expand Up @@ -116,7 +116,7 @@ def inspect_strong_order():
plt.yscale('log')
plt.legend()

img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar')
img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_scalar')
makedirs_if_not_found(img_dir)
plt.savefig(os.path.join(img_dir, 'rate'))
plt.close()
Expand Down
12 changes: 8 additions & 4 deletions diagnostics/stratonovich_diagonal.py
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
Expand Down Expand Up @@ -43,7 +43,8 @@ def inspect_sample():
ys_heun = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='heun', names={'drift': 'f_corr'})
ys_midpoint = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='midpoint', names={'drift': 'f_corr'})
ys_milstein_strat = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'})
ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'}, options={'grad_free': True})
ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'},
options={'grad_free': True})
ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

ys_heun = ys_heun.squeeze().t()
Expand Down Expand Up @@ -93,15 +94,18 @@ def inspect_strong_order():
_, ys_heun = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='heun', names={'drift': 'f_corr'})
_, ys_midpoint = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='midpoint', names={'drift': 'f_corr'})
_, ys_milstein_strat = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'})
_, ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'}, options={'grad_free': True})
_, ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein',
names={'drift': 'f_corr'}, options={'grad_free': True})
_, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

heun_mse = compute_mse(ys_heun, ys_analytical)
midpoint_mse = compute_mse(ys_midpoint, ys_analytical)
milstein_strat_mse = compute_mse(ys_milstein_strat, ys_analytical)
mil_strat_grad_free_mse = compute_mse(ys_mil_strat_grad_free, ys_analytical)

heun_mse_, midpoint_mse_, milstein_strat_mse_, mil_strat_grad_free_mse_ = to_numpy(heun_mse, midpoint_mse, milstein_strat_mse, mil_strat_grad_free_mse)
heun_mse_, midpoint_mse_, milstein_strat_mse_, mil_strat_grad_free_mse_ = to_numpy(heun_mse, midpoint_mse,
milstein_strat_mse,
mil_strat_grad_free_mse)

heun_mses_.append(heun_mse_)
midpoint_mses_.append(midpoint_mse_)
Expand Down
5 changes: 1 addition & 4 deletions tests/basic_sde.py
Expand Up @@ -129,12 +129,9 @@ def h(self, t, y):
class ScalarSDE(AdditiveSDE):
def __init__(self, d=10, m=3):
super(ScalarSDE, self).__init__(d=d, m=m)
self.g_param = nn.Parameter(torch.sigmoid(torch.randn(1, d)), requires_grad=True)
self.g_param = nn.Parameter(torch.sigmoid(torch.randn(1, d, 1)), requires_grad=True)
self.noise_type = "scalar"

def g(self, t, y):
return self.g_param.repeat(y.size(0), 1)


class TupleSDE(SDEIto):
def __init__(self, d=10):
Expand Down
9 changes: 9 additions & 0 deletions tests/problems.py
Expand Up @@ -105,6 +105,15 @@ def nfe(self):
return self._nfe


class Ex2Scalar(Ex2):
def __init__(self, d=10, sde_type='ito'):
super(Ex2Scalar, self).__init__(d=d, sde_type=sde_type)
self.noise_type = "scalar"

def g(self, t, y):
return super(Ex2Scalar, self).g(t, y).unsqueeze(2)


class Ex3(BaseSDE):
def __init__(self, d=10, sde_type='ito'):
super(Ex3, self).__init__(noise_type="diagonal", sde_type=sde_type)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_adjoint_logqp.py
Expand Up @@ -43,18 +43,26 @@
class TestAdjointLogqp(TorchTestCase):

def test_basic_sde1(self):
self.skipTest("Temporarily deprecating logqp.")

sde = BasicSDE1(d).to(device)
_test_forward_and_backward(sde)

def test_basic_sde2(self):
self.skipTest("Temporarily deprecating logqp.")

sde = BasicSDE2(d).to(device)
_test_forward_and_backward(sde)

def test_basic_sde3(self):
self.skipTest("Temporarily deprecating logqp.")

sde = BasicSDE3(d).to(device)
_test_forward_and_backward(sde)

def test_basic_sde4(self):
self.skipTest("Temporarily deprecating logqp.")

sde = BasicSDE4(d).to(device)
_test_forward_and_backward(sde)

Expand Down
36 changes: 26 additions & 10 deletions tests/test_sdeint.py
Expand Up @@ -61,11 +61,14 @@
class TestSdeint(TorchTestCase):

def test_rename_methods(self):
# Test renaming works with a subset of names when `logqp=False`.
# Test renaming works with a subset of names.
sde = basic_sde.CustomNamesSDE().to(device)
ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'})
self.assertEqual(ans.shape, (T, batch_size, d))

def test_rename_methods_logqp(self):
self.skipTest("Temporarily deprecating logqp.")

# Test renaming works with a subset of names when `logqp=True`.
sde = basic_sde.CustomNamesSDELogqp().to(device)
ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True)
Expand All @@ -76,18 +79,36 @@ def test_sdeint_general(self):
sde = basic_sde.GeneralSDE(d=d, m=m).to(device)
for method in ('euler',):
self._test_sdeint(sde, bm=bm_general, adaptive=False, method=method, dt=dt)

def test_sdeint_general_logqp(self):
self.skipTest("Temporarily deprecating logqp.")

sde = basic_sde.GeneralSDE(d=d, m=m).to(device)
for method in ('euler',):
self._test_sdeint_logqp(sde, bm=bm_general, adaptive=False, method=method, dt=dt)

def test_sdeint_additive(self):
sde = basic_sde.AdditiveSDE(d=d, m=m).to(device)
for method in ('euler', 'milstein', 'srk'):
self._test_sdeint(sde, bm=bm_general, adaptive=False, method=method, dt=dt)

def test_sdeint_additive_logqp(self):
self.skipTest("Temporarily deprecating logqp.")

sde = basic_sde.AdditiveSDE(d=d, m=m).to(device)
for method in ('euler', 'milstein', 'srk'):
self._test_sdeint_logqp(sde, bm=bm_general, adaptive=False, method=method, dt=dt)

def test_sde_scalar(self):
sde = basic_sde.ScalarSDE(d=d, m=m).to(device)
for method in ('euler', 'milstein', 'srk'):
self._test_sdeint(sde, bm=bm_scalar, adaptive=False, method=method, dt=dt)

def test_sde_scalar_logqp(self):
self.skipTest("Temporarily deprecating logqp.")

sde = basic_sde.ScalarSDE(d=d, m=m).to(device)
for method in ('euler', 'milstein', 'srk'):
self._test_sdeint_logqp(sde, bm=bm_scalar, adaptive=False, method=method, dt=dt)

def test_srk_determinism(self):
Expand Down Expand Up @@ -116,24 +137,19 @@ def test_sdeint_adaptive(self):
self._test_sdeint(sde, bm_diagonal, adaptive=True, method=method, dt=dt)

def test_sdeint_logqp_fixed(self):
self.skipTest("Temporarily deprecating logqp.")

for sde in basic_sdes:
for method in ('euler', 'milstein', 'srk'):
self._test_sdeint_logqp(sde, bm_diagonal, adaptive=False, method=method, dt=dt)

def test_sdeint_logqp_adaptive(self):
self.skipTest("Temporarily deprecating logqp.")

for sde in basic_sdes:
for method in ('milstein', 'srk'):
self._test_sdeint_logqp(sde, bm_diagonal, adaptive=True, method=method, dt=dt)

def test_sdeint_tuple_sde(self):
y0_ = (y0,) # Make tuple input.
sde = basic_sde.TupleSDE(d=d).to(device)

for method in ('euler', 'milstein', 'srk'):
ans = sdeint(sde, y0_, ts, method=method, dt=dt)
self.assertTrue(isinstance(ans, tuple))
self.assertEqual(ans[0].size(), (T, batch_size, d))

def _test_sdeint(self, sde, bm, adaptive, method, dt):
# Using `f` as drift.
with torch.no_grad():
Expand Down

0 comments on commit 19062d5

Please sign in to comment.