Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add efficient gdg_jvp term for log-ODE schemes. #20

Merged
merged 6 commits into from
Aug 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions tests/test_strat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Temporary test for Stratonovich stuff.

This should be eventually refactored and the file should be removed.
"""

import torch
from torch import nn

import time
from torchsde._core.base_sde import ForwardSDE # noqa
from torchsde import settings

torch.set_default_dtype(torch.float64)
cpu, gpu = torch.device('cpu'), torch.device('cuda')
device = gpu if torch.cuda.is_available() else cpu


def _column_wise_func(y, t, i):
# This function is designed so that there are mixed partials.
return (torch.cos(y ** 2 * i + t * 0.1) +
torch.tan(y[..., 0:1] * y[..., -2:-1]) +
torch.sum(y ** 2, dim=-1, keepdim=True))


class SDE(nn.Module):

def __init__(self):
super(SDE, self).__init__()
self.noise_type = settings.NOISE_TYPES.general
self.sde_type = settings.SDE_TYPES.stratonovich

def f(self, t, y):
return [torch.sin(y_) + t for y_ in y]

def g(self, t, y):
return [
torch.stack([_column_wise_func(y_, t, i) for i in range(m)], dim=-1)
for y_ in y
]


batch_size, d, m = 3, 5, 12


def _batch_jacobian(output, input_):
# Create batch of Jacobians for output of size (batch_size, d_o) and input of size (batch_size, d_i).
assert output.dim() == input_.dim() == 2
assert output.size(0) == input_.size(0)
jacs = []
for i in range(output.size(0)): # batch_size.
jac = []
for j in range(output.size(1)): # d_o.
grad, = torch.autograd.grad(output[i, j], input_, retain_graph=True, allow_unused=True)
grad = torch.zeros_like(input_[i]) if grad is None else grad[i].detach()
jac.append(grad)
jac = torch.stack(jac, dim=0)
jacs.append(jac)
return torch.stack(jacs, dim=0)


def _gdg_jvp_brute_force(sde, t, y, a):
# Only returns the value for the first input-output pair.
with torch.enable_grad():
y = [y_.detach().requires_grad_(True) if not y_.requires_grad else y_ for y_ in y]
g_eval = sde.g(t, y)
v = [torch.bmm(g_eval_, a_) for g_eval_, a_ in zip(g_eval, a)]

y0, g_eval0, v0 = y[0], g_eval[0], v[0]
num_brownian = g_eval0.size(-1)
jacobians_by_column = [_batch_jacobian(g_eval0[..., l], y0) for l in range(num_brownian)]
return [
sum(torch.bmm(jacobians_by_column[l], v0[..., l].unsqueeze(-1)).squeeze() for l in range(num_brownian))
]


def _make_inputs():
t = torch.rand(()).to(device)
y = [torch.randn(batch_size, d).to(device)]
a = torch.randn(batch_size, m, m).to(device)
a = [a - a.transpose(1, 2)] # Anti-symmetric.
sde = ForwardSDE(SDE())
return sde, t, y, a


def test_gdg_jvp():
sde, t, y, a = _make_inputs()
outs_brute_force = _gdg_jvp_brute_force(sde, t, y, a) # Reference.
outs = sde.gdg_jvp_column_sum(t, y, a)
outs_v2 = sde.gdg_jvp_column_sum_v2(t, y, a)
for out_brute_force, out, out_v2 in zip(outs_brute_force, outs, outs_v2):
assert torch.allclose(out_brute_force, out)
assert torch.allclose(out_brute_force, out_v2)


def _time_function(func, reps=10):
now = time.perf_counter()
[func() for _ in range(reps)]
return time.perf_counter() - now


def check_efficiency():
sde, t, y, a = _make_inputs()

func1 = lambda: sde.gdg_jvp_column_sum_v1(t, y, a) # Linear in m.
time_elapse = _time_function(func1)
print(f'Time elapse for loop: {time_elapse:.4f}')

func2 = lambda: sde.gdg_jvp_column_sum_v2(t, y, a) # Almost constant in m.
time_elapse = _time_function(func2)
print(f'Time elapse for duplicate: {time_elapse:.4f}')


test_gdg_jvp()
check_efficiency()
13 changes: 2 additions & 11 deletions torchsde/_core/adjoint_sdes/additive.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .. import misc


class AdjointSDEAdditive(base_sde.AdjointSDEIto):
class AdjointSDEAdditive(base_sde.AdjointSDE):

def __init__(self, sde, params):
super(AdjointSDEAdditive, self).__init__(sde, noise_type="general")
Expand All @@ -36,8 +36,6 @@ def f(self, t, y_aug):

f_eval = sde.f(-t, y)
f_eval = [-f_eval_ for f_eval_ in f_eval]
f_eval = misc.make_seq_requires_grad(f_eval)

vjp_y_and_params = misc.grad(
outputs=f_eval,
inputs=y + params,
Expand All @@ -61,8 +59,6 @@ def g_prod(self, t, y_aug, noise):
adj_y = [adj_y_.detach() for adj_y_ in adj_y]

g_eval = [-g_ for g_ in sde.g(-t, y)]
g_eval = misc.make_seq_requires_grad(g_eval)

vjp_y_and_params = misc.grad(
outputs=g_eval, inputs=y + params,
grad_outputs=[
Expand Down Expand Up @@ -90,7 +86,7 @@ def gdg_prod(self, t, y, v):
raise NotImplementedError("This method shouldn't be called.")


class AdjointSDEAdditiveLogqp(base_sde.AdjointSDEIto):
class AdjointSDEAdditiveLogqp(base_sde.AdjointSDE):
def __init__(self, sde, params):
super(AdjointSDEAdditiveLogqp, self).__init__(sde, noise_type="general")
self.params = params
Expand All @@ -106,8 +102,6 @@ def f(self, t, y_aug):

f_eval = sde.f(-t, y)
f_eval = [-f_eval_ for f_eval_ in f_eval]
f_eval = misc.make_seq_requires_grad(f_eval)

vjp_y_and_params = misc.grad(
outputs=f_eval,
inputs=y + params,
Expand All @@ -128,7 +122,6 @@ def f(self, t, y_aug):
u_eval = misc.seq_sub(f_eval, h_eval)
u_eval = [torch.bmm(g_inv_eval_, u_eval_) for g_inv_eval_, u_eval_ in zip(g_inv_eval, u_eval)]
log_ratio_correction = [.5 * torch.sum(u_eval_ ** 2., dim=1) for u_eval_ in u_eval]
log_ratio_correction = misc.make_seq_requires_grad(log_ratio_correction)
corr_vjp_y_and_params = misc.grad(
outputs=log_ratio_correction, inputs=y + params,
grad_outputs=adj_l,
Expand All @@ -154,8 +147,6 @@ def g_prod(self, t, y_aug, noise):
adj_y = [adj_y_.detach() for adj_y_ in adj_y]

g_eval = [-g_ for g_ in sde.g(-t, y)]
g_eval = misc.make_seq_requires_grad(g_eval)

vjp_y_and_params = misc.grad(
outputs=g_eval, inputs=y + params,
grad_outputs=[-noise_.unsqueeze(1) * adj_y_.unsqueeze(2) for noise_, adj_y_ in zip(noise, adj_y)],
Expand Down
24 changes: 2 additions & 22 deletions torchsde/_core/adjoint_sdes/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .. import misc


class AdjointSDEDiagonal(base_sde.AdjointSDEIto):
class AdjointSDEDiagonal(base_sde.AdjointSDE):

def __init__(self, sde, params):
super(AdjointSDEDiagonal, self).__init__(sde, noise_type="diagonal")
Expand All @@ -34,8 +34,6 @@ def f(self, t, y_aug):
adj_y = [adj_y_.detach() for adj_y_ in adj_y]

g_eval = sde.g(-t, y)
g_eval = misc.make_seq_requires_grad(g_eval)

gdg = misc.grad(
outputs=g_eval, inputs=y,
grad_outputs=g_eval,
Expand All @@ -45,10 +43,7 @@ def f(self, t, y_aug):
gdg = misc.convert_none_to_zeros(gdg, y)

f_eval = sde.f(-t, y)

f_eval_corrected = misc.seq_sub(gdg, f_eval) # Stratonovich correction for reverse-time.
f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected)

vjp_y_and_params = misc.grad(
outputs=f_eval_corrected,
inputs=y + params,
Expand Down Expand Up @@ -95,7 +90,6 @@ def g_prod(self, t, y_aug, noise):
adj_y = [adj_y_.detach() for adj_y_ in adj_y]

g_eval = [-g_ for g_ in sde.g(-t, y)]
g_eval = misc.make_seq_requires_grad(g_eval)
vjp_y_and_params = misc.grad(
outputs=g_eval, inputs=y + params,
grad_outputs=[-noise_ * adj_y_ for noise_, adj_y_ in zip(noise, adj_y)],
Expand All @@ -119,7 +113,6 @@ def gdg_prod(self, t, y_aug, noise):
adj_y = [adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y]

g_eval = sde.g(-t, y)
g_eval = misc.make_seq_requires_grad(g_eval)
gdg_times_v = misc.grad(
outputs=g_eval, inputs=y,
grad_outputs=misc.seq_mul(g_eval, noise),
Expand Down Expand Up @@ -154,8 +147,6 @@ def gdg_prod(self, t, y_aug, noise):
allow_unused=True, create_graph=True
)
gdg_v = misc.convert_none_to_zeros(gdg_v, y)
gdg_v = misc.make_seq_requires_grad(gdg_v)

mixed_partials_adj_y_and_params = misc.grad(
outputs=gdg_v, inputs=y + params,
grad_outputs=[torch.ones_like(p) for p in gdg_v],
Expand All @@ -180,7 +171,7 @@ def h(self, t, y):
raise NotImplementedError("This method shouldn't be called.")


class AdjointSDEDiagonalLogqp(base_sde.AdjointSDEIto):
class AdjointSDEDiagonalLogqp(base_sde.AdjointSDE):

def __init__(self, sde, params):
super(AdjointSDEDiagonalLogqp, self).__init__(sde, noise_type="diagonal")
Expand All @@ -196,8 +187,6 @@ def f(self, t, y_aug):
adj_y = [adj_y_.detach() for adj_y_ in adj_y]

g_eval = sde.g(-t, y)
g_eval = misc.make_seq_requires_grad(g_eval)

gdg = misc.grad(
outputs=g_eval, inputs=y,
grad_outputs=g_eval,
Expand All @@ -208,8 +197,6 @@ def f(self, t, y_aug):

f_eval = sde.f(-t, y)
f_eval_corrected = misc.seq_sub(gdg, f_eval)
f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected)

vjp_y_and_params = misc.grad(
outputs=f_eval_corrected, inputs=y + params,
grad_outputs=[-adj_y_ for adj_y_ in adj_y],
Expand Down Expand Up @@ -244,8 +231,6 @@ def f(self, t, y_aug):
h_eval = sde.h(-t, y)
u_eval = misc.seq_sub_div(f_eval, h_eval, g_eval)
log_ratio_correction = [.5 * torch.sum(u_eval_ ** 2., dim=1) for u_eval_ in u_eval]

log_ratio_correction = misc.make_seq_requires_grad(log_ratio_correction)
corr_vjp_y_and_params = misc.grad(
outputs=log_ratio_correction, inputs=y + params,
grad_outputs=adj_l,
Expand All @@ -271,7 +256,6 @@ def g_prod(self, t, y_aug, noise):
adj_y = [adj_y_.detach() for adj_y_ in adj_y]

g_eval = sde.g(-t, y)
g_eval = misc.make_seq_requires_grad(g_eval)
minus_g_eval = [-g_ for g_ in g_eval]
minus_g_prod_eval = misc.seq_mul(minus_g_eval, noise)

Expand All @@ -297,8 +281,6 @@ def gdg_prod(self, t, y_aug, noise):
adj_y = [adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y]

g_eval = sde.g(-t, y)
g_eval = misc.make_seq_requires_grad(g_eval)

gdg_times_v = misc.grad(
outputs=g_eval, inputs=y,
grad_outputs=misc.seq_mul(g_eval, noise),
Expand Down Expand Up @@ -333,8 +315,6 @@ def gdg_prod(self, t, y_aug, noise):
create_graph=True,
)
gdg_v = misc.convert_none_to_zeros(gdg_v, y)
gdg_v = misc.make_seq_requires_grad(gdg_v)

gdg_v = [gdg_v_.sum() for gdg_v_ in gdg_v]
mixed_partials_adj_y_and_params = misc.grad(
outputs=gdg_v, inputs=y + params,
Expand Down
4 changes: 2 additions & 2 deletions torchsde/_core/adjoint_sdes/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .. import base_sde


class AdjointSDEScalar(base_sde.AdjointSDEIto):
class AdjointSDEScalar(base_sde.AdjointSDE):

def __init__(self, sde, params):
super(AdjointSDEScalar, self).__init__(sde, noise_type="scalar")
Expand All @@ -39,7 +39,7 @@ def gdg_prod(self, t, y, v):
raise NotImplementedError("This method shouldn't be called.")


class AdjointSDEScalarLogqp(base_sde.AdjointSDEIto):
class AdjointSDEScalarLogqp(base_sde.AdjointSDE):

def __init__(self, sde, params):
super(AdjointSDEScalarLogqp, self).__init__(sde, noise_type="scalar")
Expand Down
Loading