In [7]:
import torch
import torch.nn.functional as F

import torchsde

In [10]:
class Net(torch.nn.Module):

    def __init__(self, input_dim=1):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        # an affine operation: y = Wx + b
        self.input_dim = input_dim
        self.fc1 = torch.nn.Linear(input_dim, 120)  # 5*5 from image dimension
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, input_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

In [None]:
import torch

batch_size, state_size, brownian_size = 32, 3, 2
t_size = 20

class SDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'

    def __init__(self, state_size=1, brownian_size=1, batch_size=10, γ=1.0):
        super().__init__()
        
        self.state_size = state_size
        self.brownian_size = brownian_size
        self.batch_size = batch_size

        self.γ = torch.tensor(γ)
        self.μ = NN(input_dim=state_size)
        self.σ = lambda x: self.γ * torch.eye((state_size, brownian_size))

    # Drift
    def f(self, t, y):
        return self.μ(y)  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        return self.σ(y).view(self.batch_size, 
                              self.state_size, 
                              self.brownian_size)

sde = SDE(batch_size, state_size, brownian_size, 1)
y0 = torch.full((batch_size, state_size), 0.1)
ts = torch.linspace(0, 1, t_size)
# Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]].
# ys will have shape (t_size, batch_size, state_size)
ys = torchsde.sdeint(sde, y0, ts)

In [None]:
def log_g(Θ, ln_prior, ln_like, γ=1.0):
    """
    G function in control objective
    """
    normal_term = -0.5 * (Θ**2).sum(axis=1) / γ
    return ln_prior(Θ) + ln_like(Θ) - normal_term


def relative_entropy_control_cost(sde, Θ_0, ln_prior, ln_like, Δt=0.05, γ=1.0):
    n = int(1.0 / Δt)
    ts = torch.linspace(0, 1, n)
    
    Θs =  torchsde.sdeint(sde, Θ_0, ts)
    μs = sde.f(ts, Θs)
    ΘT = Θs[:,-1,:]
    lng = log_g(ΘT, ln_prior, ln_like, γ).mean()
    girsanov_factor = 0.5 * (μs**2.sum(axis=-1)).mean()
    
    return girsanov_factor - lng