Skip to content

Commit

Permalink
Added milstein_grad_free, milstein_strat and milstein_strat_grad_free
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Aug 18, 2020
1 parent 603358d commit c6e9fcf
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 9 deletions.
1 change: 1 addition & 0 deletions diagnostics/ito_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def inspect_sample():

ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein')
ys_milstein_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', options={'grad_free': True})
ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk')
ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

Expand Down
22 changes: 16 additions & 6 deletions diagnostics/stratonovich_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,30 @@ def inspect_sample():
ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
ys_heun = sdeint(sde_strat, y0=y0, ts=ts, dt=dt, bm=bm, method='heun', names={'drift': 'f_corr'})
ys_midpoint = sdeint(sde_strat, y0=y0, ts=ts, dt=dt, bm=bm, method='midpoint', names={'drift': 'f_corr'})
# TODO add milstein strat with grad when fix is merged
ys_milstein_strat = sdeint(sde_strat, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein_strat', names={'drift': 'f_corr'}, options={'grad_free': True})
ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

ys_euler = ys_euler.squeeze().t()
ys_heun = ys_heun.squeeze().t()
ys_midpoint = ys_midpoint.squeeze().t()
ys_milstein_strat = ys_milstein_strat.squeeze().t()
ys_analytical = ys_analytical.squeeze().t()

ts_, ys_euler_, ys_heun_, ys_midpoint_, ys_analytical_ = to_numpy(
ts, ys_euler, ys_heun, ys_midpoint, ys_analytical)
ts_, ys_euler_, ys_heun_, ys_midpoint_, ys_milstein_strat_, ys_analytical_ = to_numpy(
ts, ys_euler, ys_heun, ys_midpoint, ys_milstein_strat, ys_analytical)

# Visualize sample path.
img_dir = os.path.join('.', 'diagnostics', 'plots', 'stratonovich_diagonal')
makedirs_if_not_found(img_dir)

for i, (ys_euler_i, ys_heun_i, ys_midpoint_i, ys_analytical_i) in enumerate(
zip(ys_euler_, ys_heun_, ys_midpoint_, ys_analytical_)):
for i, (ys_euler_i, ys_heun_i, ys_midpoint_i, ys_milstein_strat_i, ys_analytical_i) in enumerate(
zip(ys_euler_, ys_heun_, ys_midpoint_, ys_milstein_strat_, ys_analytical_)):
plt.figure()
plt.plot(ts_, ys_euler_i, label='euler')
plt.plot(ts_, ys_heun_i, label='heun')
plt.plot(ts_, ys_midpoint_i, label='midpoint')
plt.plot(ts_, ys_milstein_strat_i, label='milstein_strat')
plt.plot(ts_, ys_analytical_i, label='analytical')
plt.legend()
plt.savefig(os.path.join(img_dir, f'{i}'))
Expand All @@ -83,6 +87,7 @@ def inspect_strong_order():
euler_mses_ = []
heun_mses_ = []
midpoint_mses_ = []
milstein_strat_mses_ = []

with torch.no_grad():
bm = BrownianInterval(t0=ts[0], t1=ts[-1], shape=y0.shape, dtype=y0.dtype, device=device,
Expand All @@ -93,29 +98,34 @@ def inspect_strong_order():
_, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
_, ys_heun = sdeint(sde_strat, y0=y0, ts=ts, dt=dt, bm=bm, method='heun', names={'drift': 'f_corr'})
_, ys_midpoint = sdeint(sde_strat, y0=y0, ts=ts, dt=dt, bm=bm, method='midpoint', names={'drift': 'f_corr'})
_, ys_milstein_strat = sdeint(sde_strat, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein_strat', names={'drift': 'f_corr'}, options={'grad_free': True})
_, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

euler_mse = compute_mse(ys_euler, ys_analytical)
heun_mse = compute_mse(ys_heun, ys_analytical)
midpoint_mse = compute_mse(ys_midpoint, ys_analytical)
milstein_strat_mse = compute_mse(ys_milstein_strat, ys_analytical)

euler_mse_, heun_mse_, midpoint_mse_ = to_numpy(euler_mse, heun_mse, midpoint_mse)
euler_mse_, heun_mse_, midpoint_mse_, milstein_strat_mse_ = to_numpy(euler_mse, heun_mse, midpoint_mse, milstein_strat_mse)

euler_mses_.append(euler_mse_)
heun_mses_.append(heun_mse_)
midpoint_mses_.append(midpoint_mse_)
del euler_mse_, heun_mse_, midpoint_mse_
milstein_strat_mses_.append(milstein_strat_mse_)
del euler_mse_, heun_mse_, midpoint_mse_, milstein_strat_mse_

# Divide the log-error by 2, since textbook strong orders are represented so.
log = lambda x: np.log(np.array(x))
euler_slope, _, _, _, _ = stats.linregress(log(dts), log(euler_mses_) / 2)
heun_slope, _, _, _, _ = stats.linregress(log(dts), log(heun_mses_) / 2)
midpoint_slope, _, _, _, _ = stats.linregress(log(dts), log(midpoint_mses_) / 2)
milstein_strat_slope, _, _, _, _ = stats.linregress(log(dts), log(milstein_strat_mses_) / 2)

plt.figure()
plt.plot(dts, euler_mses_, label=f'euler(k={euler_slope:.4f})')
plt.plot(dts, heun_mses_, label=f'heun(k={heun_slope:.4f})')
plt.plot(dts, midpoint_mses_, label=f'midpoint(k={midpoint_slope:.4f})')
plt.plot(dts, milstein_strat_mses_, label=f'milstein_strat(k={milstein_strat_slope:.4f})')
plt.xscale('log')
plt.yscale('log')
plt.legend()
Expand Down
3 changes: 3 additions & 0 deletions torchsde/_core/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .milstein import Milstein
from .srk import SRK
from .heun import Heun
from .milstein_strat import MilsteinStrat


def select(method):
Expand All @@ -32,5 +33,7 @@ def select(method):
return Midpoint
elif method == METHODS.heun:
return Heun
elif method == METHODS.milstein_strat:
return MilsteinStrat
else:
raise ValueError(f"Method '{method}' does not match any known method.")
28 changes: 25 additions & 3 deletions torchsde/_core/methods/milstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS
import math

import torch

from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHOD_OPTIONS as opt

from .. import base_solver

Expand All @@ -27,12 +31,30 @@ class Milstein(base_solver.BaseSDESolver):
def step(self, t0, y0, dt):
assert dt > 0, 'Underflow in dt {}'.format(dt)

I_k = self.bm(t0, t0 + dt)
t1 = t0 + dt

I_k = self.bm(t0, t1)
v = [delta_bm_ ** 2. - dt for delta_bm_ in I_k]

f_eval = self.sde.f(t0, y0)
g_prod_eval = self.sde.g_prod(t0, y0, I_k)
gdg_prod_eval = self.sde.gdg_prod(t0, y0, v)

if opt.grad_free in self.options and self.options[opt.grad_free]:
g_eval = self.sde.g(t0, y0)
g_prod_eval_v = self.sde.g_prod(t0, y0, v)
sqrt_dt = torch.sqrt(dt) if isinstance(dt, torch.Tensor) else math.sqrt(dt)
y0_prime = [
y0_ + dt * f_eval_ + g_eval_ * sqrt_dt
for y0_, f_eval_, g_eval_ in zip(y0, f_eval, g_eval)
]
g_prod_eval_prime = self.sde.g_prod(t1, y0_prime, v)
gdg_prod_eval = [
(g_prod_eval_prime_ - g_prod_eval_v_) / sqrt_dt
for g_prod_eval_prime_, g_prod_eval_v_ in zip(g_prod_eval_prime, g_prod_eval_v)
]
else:
gdg_prod_eval = self.sde.gdg_prod(t0, y0, v)

y1 = [
y0_i + f_eval_i * dt + g_prod_eval_i + .5 * gdg_prod_eval_i
for y0_i, f_eval_i, g_prod_eval_i, gdg_prod_eval_i in zip(y0, f_eval, g_prod_eval, gdg_prod_eval)
Expand Down
63 changes: 63 additions & 0 deletions torchsde/_core/methods/milstein_strat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch

from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHOD_OPTIONS as opt

from .. import base_solver


class MilsteinStrat(base_solver.BaseSDESolver):
strong_order = 1.0
weak_order = 1.0
sde_type = SDE_TYPES.stratonovich
noise_types = (NOISE_TYPES.additive, NOISE_TYPES.diagonal, NOISE_TYPES.scalar)
levy_area_approximations = LEVY_AREA_APPROXIMATIONS.all()

def step(self, t0, y0, dt):
assert dt > 0, 'Underflow in dt {}'.format(dt)

t1 = t0 + dt

I_k = self.bm(t0, t1)
v = [delta_bm_ ** 2. for delta_bm_ in I_k]

f_eval = self.sde.f(t0, y0)
g_prod_eval = self.sde.g_prod(t0, y0, I_k)

if opt.grad_free in self.options and self.options[opt.grad_free]:
g_eval = self.sde.g(t0, y0)
g_prod_eval_v = self.sde.g_prod(t0, y0, v)
sqrt_dt = torch.sqrt(dt) if isinstance(dt, torch.Tensor) else math.sqrt(dt)
y0_prime = [
y0_ + dt * f_eval_ + g_eval_ * sqrt_dt
for y0_, f_eval_, g_eval_ in zip(y0, f_eval, g_eval)
]
g_prod_eval_prime = self.sde.g_prod(t1, y0_prime, v)
gdg_prod_eval = [
(g_prod_eval_prime_ - g_prod_eval_v_) / sqrt_dt
for g_prod_eval_prime_, g_prod_eval_v_ in zip(g_prod_eval_prime, g_prod_eval_v)
]
else:
gdg_prod_eval = self.sde.gdg_prod(t0, y0, v)

y1 = [
y0_i + f_eval_i * dt + g_prod_eval_i + .5 * gdg_prod_eval_i
for y0_i, f_eval_i, g_prod_eval_i, gdg_prod_eval_i in zip(y0, f_eval, g_prod_eval, gdg_prod_eval)
]
t1 = t0 + dt
return t1, y1
5 changes: 5 additions & 0 deletions torchsde/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class METHODS(metaclass=ContainerMeta):
srk = 'srk'
midpoint = 'midpoint'
heun = 'heun'
milstein_strat = 'milstein_strat'


class NOISE_TYPES(metaclass=ContainerMeta): # noqa
Expand All @@ -52,3 +53,7 @@ class LEVY_AREA_APPROXIMATIONS(metaclass=ContainerMeta): # noqa
space_time = 'space-time' # Only compute an (exact) space-time Levy area
davie = 'davie' # Compute Davie's approximation to Levy area
foster = 'foster' # Compute Foster's correction to Davie's approximation to Levy area


class METHOD_OPTIONS(metaclass=ContainerMeta):
grad_free = 'grad_free'

0 comments on commit c6e9fcf

Please sign in to comment.