In [1]:
import os
import sys
import torch
from torch import matmul as m
from sdes import TwoDimensionalSynDrift
from torch.autograd.functional import jacobian

from torch.distributions import MultivariateNormal

sys.path.insert(0,".")

from gps import MultivariateNormal,white_noise_kernel

# SDEs

## Diffusion Functions

In [2]:
class ConstantDiffusion:
    def __init__(self, sigma=0.5):
        self.sigma = sigma
    
    def __call__(self, x):
        return self.sigma * torch.ones_like(x)

## Drift Functions

In [3]:
# Define the drift and diffusion functions for the double well potential SDE
class TwoDimensionalSynDrift:
    
    def __init__(self):
        pass
    
    def __call__(self, X):
        x = X[:,0]
        y = X[:,1]
        dxdt = x*(1.-x**2 - y**2)
        dydt = y*(1.-x**2 - y**2)
        return torch.stack([dxdt, dydt], dim=1)
    
    def derivative(self,X):
        x = X[:,0]
        y = X[:,1]
        A = (1.-x**2 - y**2)
        fxx = A -2.*x**2
        fxy = -2.*x*y
        fyx = -2.*x*y
        fyy = A - 2.*y**2

        J = torch.zeros((X.size(0),X.size(1),X.size(1)))
        J[:,0,0] = fxx
        J[:,0,1] = fxy
        J[:,1,0] = fyx
        J[:,1,1] = fyy

        return J

# Numerics

In [4]:
def jacobian_of_drift(drift,inducing_points):
    """
    parameters
    ----------
    drift: R^D - > R^D
    inducing_points: [n,D]

    returns
    -------
    J: shape[n,D,D]
    """
    J = []
    for i in range(inducing_points.size(1)):
        drift_i = lambda x: drift(x)[:,i]
        j_i = jacobian(drift_i, inducing_points)
        j_i = j_i[range(inducing_points.size(0)),range(inducing_points.size(0))]
        J.append(j_i.clone().unsqueeze(-1))
    J = torch.cat(J,dim=-1)
    return J

# Derivatives

In [5]:
drift = TwoDimensionalSynDrift()
diffusion = ConstantDiffusion(sigma=1.)

# Define the range and number of points for the grid
x_range = torch.linspace(-2.0, 2.0, 10)  # 10 points in x-axis from -2 to 2
y_range = torch.linspace(-2.0, 2.0, 10)  # 10 points in y-axis from -2 to 2

# Define the range and number of points for the grid
x_range_i = torch.linspace(-2.0, 2.0, 5)  # 10 points in x-axis from -2 to 2
y_range_i = torch.linspace(-2.0, 2.0, 5)  # 10 points in y-axis from -2 to 2

# Create a 2D grid of points
X, Y = torch.meshgrid(x_range, y_range, indexing='ij')  # 'ij' indexing for matrix indexing
X_i, Y_i = torch.meshgrid(x_range_i, y_range_i, indexing='ij')  # 'ij' indexing for matrix indexing

observation_points = torch.stack([X, Y], dim=-1).view(-1, 2)  # Reshape to (n, 2) format
inducing_points = torch.stack([X_i, Y_i], dim=-1).view(-1, 2)  # Reshape to (n, 2) format

n_inducing = inducing_points.size(0)
f_inducing = MultivariateNormal(torch.zeros((2,)),torch.eye(2)).sample((n_inducing,))


# OU Bridge

\begin{equation}
dX_t = [f(z_k) - \Gamma_k(X_t-z_k)]dt + D^{1/2}dW
\end{equation}


 \begin{equation}
\begin{bmatrix} A_s \\ B_s  \end{bmatrix} = 
\exp\left(\begin{bmatrix} \Gamma_k&D_k\\0&-\Gamma^T_k \end{bmatrix}\right)\begin{bmatrix} 0 \\ I  \end{bmatrix}
\end{equation}


In [6]:
def E_matrix(points,drift,diffusion):
    """
    returns:[n_points,2*dimensions,2*dimensions] the E matrix to calculate the marginals of the OU bridge

    """
    n_points = points.size(0)
    dimensions = points.size(1)

    diagonal_k = diffusion(points)
    Gamma_k = jacobian_of_drift(drift,points)

    D_k = torch.zeros((n_points,dimensions,dimensions))
    E_k = torch.zeros((n_points,2*dimensions,2*dimensions))
    OI = torch.zeros((2*dimensions,dimensions))
    OI[dimensions:,:] = torch.eye(dimensions)

    D_k[:,range(dimensions),range(dimensions)] = diagonal_k
    E_k[:,:dimensions,:dimensions] = Gamma_k 
    E_k[:,:dimensions:,dimensions:] = D_k
    E_k[:,dimensions:,dimensions:] = Gamma_k.transpose(2,1)
    return OI,E_k

def where_mc_time(mc_times,observation_times):
    """
    mc_times: shape[nmc]
    observation_times: shape[n_observations]
    where_mc_time: shape[nmc] index of observations times where mc_times is located
    """
    n_observations = observation_times.size(0)
    right_ = mc_times[:,None] > observation_times[:n_observations-1][None,:]
    left_ = mc_times[:,None] < observation_times[1:][None,:]
    where_mc_time_index = right_*left_
    where_mc_time_index = torch.argmax(where_mc_time_index.float(),dim=1)
    return where_mc_time_index

In [7]:
t_nmc = 10
x_nmc = 9

mc_times = torch.rand((t_nmc,)) # points in time where to evaluate the bridge for the monte carlo
points = observation_points # these points can be either observations or inducing points
observation_times = torch.linspace(0.,1.,points.size(0)) # times where the points where observed

In [8]:
n_points = points.size(0)
dimensions = points.size(1)

diffusion_diagonal = diffusion(points)
drift_at_points = drift(points)
Gamma = jacobian_of_drift(drift,points)

D = torch.zeros((n_points,dimensions,dimensions))
E = torch.zeros((n_points,2*dimensions,2*dimensions))
OI = torch.zeros((2*dimensions,dimensions))
OI[dimensions:,:] = torch.eye(dimensions)

D[:,range(dimensions),range(dimensions)] = diffusion_diagonal
E[:,:dimensions,:dimensions] = Gamma 
E[:,:dimensions:,dimensions:] = D
E[:,dimensions:,dimensions:] = Gamma.transpose(2,1)

In [9]:
where_mc_time_index = where_mc_time(mc_times,observation_times)

z_k = points[where_mc_time_index]
z_k1 = points[where_mc_time_index+1]

f_k = drift_at_points[where_mc_time_index,:]
E_k = E[where_mc_time_index,:,:]
Gamma_k = Gamma[where_mc_time_index,:,:]
Gamma_k_inv = torch.inverse(Gamma_k)
D_k = D[where_mc_time_index,:,:]

time_difference_k1 = observation_times[where_mc_time_index+1,None,None]-mc_times[:,None,None]
time_difference_k = mc_times[:,None,None] - observation_times[where_mc_time_index,None,None]

E_k1 = torch.matrix_exp(E_k*time_difference_k1)
E_k1 = torch.matmul(E_k,OI)

E_k = torch.matrix_exp(E_k*time_difference_k)
E_k = torch.matmul(E_k,OI)

A_s1 = E_k1[:,:dimensions,:]
B_s1 = E_k1[:,dimensions:,:]

A_s = E_k[:,:dimensions,:]
B_s = E_k[:,dimensions:,:]

S_s1 = torch.matmul(A_s1,torch.inverse(B_s1))
S_s1_inv = torch.inverse(S_s1)

S_s = torch.matmul(A_s,torch.inverse(B_s))
S_s_inv = torch.inverse(S_s)

alpha_k = z_k + m(Gamma_k_inv,f_k.unsqueeze(-1)).squeeze()

ME1 = torch.matrix_exp(-Gamma_k.transpose(2,1)*time_difference_k1)
ME2 = torch.matrix_exp(-Gamma_k*time_difference_k1)

C_t = torch.inverse(m(ME1,m(S_s1_inv,ME2))+S_s_inv)

ME3 = m(C_t,m(ME1,S_s1_inv))
a = z_k1[:,:,None] - alpha_k[:,:,None] + m(ME2,alpha_k[:,:,None])

ME4 = m(C_t,S_s_inv)
b = alpha_k[:,:,None] + m(torch.matrix_exp(-Gamma_k*time_difference_k),(z_k[:,:,None]-alpha_k[:,:,None]))

m_t = m(ME3,a) + m(ME4,b)

q_t = MultivariateNormal(m_t.squeeze(),C_t)
g_t = f_k

monte_carlo_points = q_t.sample((x_nmc,))
A_x = q_t.log_prob(monte_carlo_points).reshape(-1)

In [10]:
x_nmc,time_nmc,_ = monte_carlo_points.shape
monte_carlo_points_ = monte_carlo_points.reshape(x_nmc*time_nmc,-1) # first index 
full_mc = x_nmc*time_nmc

In [11]:
Gamma_k_ = Gamma_k.repeat((x_nmc,1,1))
S_s1_inv_ = S_s1_inv.repeat((x_nmc,1,1))
z_k_ = z_k.repeat((x_nmc,1))
z_k1_ = z_k1.repeat((x_nmc,1))
alpha_k_ = alpha_k.repeat((x_nmc,1))
f_k_ = f_k.repeat((x_nmc,1)) 
time_difference_k1_ = time_difference_k1.repeat((x_nmc,1,1)) 
D_k_ = D_k.repeat((x_nmc,1,1))

In [12]:
A = f_k_[:,:,None] - m(Gamma_k_,(monte_carlo_points_ - z_k_)[:,:,None])
MEa = torch.matrix_exp(-Gamma_k_.transpose(2,1)*time_difference_k1_)
MEb = torch.matrix_exp(-Gamma_k_*time_difference_k1_)
B = m(D_k_,m(MEa,S_s1_inv_))
C = z_k1_[:,:,None] - alpha_k_[:,:,None] - m(MEb,(monte_carlo_points_-alpha_k_)[:,:,None])

g_x = A + m(B,C)

# EM

\begin{equation}
\mathcal{L}_s(f,q) = \frac{1}{2}\int||E_0[f(x)|f_s]||^2 A(x) dx - \frac{1}{2}\int(E_0[f(x)|f_s],b(x))dx
\end{equation}

\begin{equation}
\Lambda_s = K^{-1}_s\left\{\int k_s(x)D^{-1}(x)A(x) k^T_s(x)dx\right\} K^{-1}_s
\end{equation}

\begin{equation}
y_s = K^{-1}_s\int D^{-1}(x)k_s(x)b(x)dx
\end{equation}

## GPS

In [13]:
from gpytorch.kernels import RBFKernel, ScaleKernel
import numpy as np

#==============================================
windows_size = dimensions
kernel_sigma = 1.
kernel_l = 1.

# ========================================================================
# DEFINE AND INITIALIZE KERNEL
kernel = ScaleKernel(RBFKernel(ard_num_dims=windows_size, requires_grad=True),requires_grad=True) + white_noise_kernel()
hypers = {"raw_outputscale": torch.tensor(kernel_sigma),"base_kernel.raw_lengthscale": torch.tensor(np.repeat(kernel_l, windows_size))}
kernel = kernel.kernels[0].initialize(**hypers)

In [24]:
dimension_index = 0
D_mc_diagonal = diffusion(monte_carlo_points_)
D_ind = torch.diag(diffusion(inducing_points)[:,dimension_index])

In [25]:
K_t_ind = kernel.forward(monte_carlo_points_, inducing_points)
K_ind_ind = kernel.forward(inducing_points, inducing_points)
K_ind_ind_inv = torch.inverse(K_ind_ind)

In [47]:
Lambda_s = m(K_t_ind,K_t_ind.T)[range(full_mc),range(full_mc)]
Lambda_s = Lambda_s*D_mc_diagonal[:,dimension_index]*A_x
Lambda_s = m(K_ind_ind_inv,K_ind_ind_inv)*Lambda_s.mean()

torch.Size([25, 25])

In [34]:
b_x = A_x*g_x[:,dimension_index,0]
b_integral = (K_t_ind*b_x[:,None]).mean(axis=0)
y_s = m(K_ind_ind_inv,b_integral[:,None])

nx = 42
X = torch.rand((nx,dimensions))
kx = kernel.forward(X, inducing_points)

A = torch.inverse(torch.eye(n_inducing) + m(Lambda_s,K_ind_ind_inv))
f_x = m(m(kx,A),y_s)