<a href="https://colab.research.google.com/github/jcandane/StochasticPhysics/blob/main/gptorch_play.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import math

###
try:
    import gpytorch
except:
    !pip install gpytorch
    import gpytorch

###
try:
    import pyro
    from pyro.infer.mcmc import NUTS, MCMC, HMC
except:
    !pip install pyro-ppl
    import pyro
    from pyro.infer.mcmc import NUTS, MCMC, HMC

from matplotlib import pyplot as plt
import plotly.graph_objects as go

Collecting gpytorch
  Downloading gpytorch-1.11-py3-none-any.whl (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.1/266.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting linear-operator>=0.5.0 (from gpytorch)
  Downloading linear_operator-0.5.2-py3-none-any.whl (175 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.6/175.6 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping>=0.2.9 (from linear-operator>=0.5.0->gpytorch)
  Downloading jaxtyping-0.2.28-py3-none-any.whl (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.7/40.7 kB[0m [31m894.0 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting typeguard~=2.13.3 (from linear-operator>=0.5.0->gpytorch)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.11->linear-operator>=0.5.0->gpytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K    

In [None]:
# Training data is 11 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 4)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)



In [None]:
k = ( gpytorch.kernels.RBFKernel() )

In [None]:
covar_module = gpytorch.kernels.LinearKernel()
x1 = torch.randn(8, 3)
x2 = torch.randn(7, 3)
lazy_covar_matrix = covar_module(x1)              # Returns a RootLinearOperator, ## abstract-sparse
covariance_matrix = lazy_covar_matrix.to_dense()  # Gets the actual tensor for this kernel matrix

covariance_matrix

In [None]:
torch.rand(2, 3)

In [None]:
k = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

type( covar_module )

In [None]:
lazy_covar_matrix

In [None]:
gpytorch.kernels.LinearKernel()(x1, x2).to_dense().shape
k(x1, x2).to_dense().shape

In [None]:
k(x1) #.to_dense()

# RCF torch function

In [None]:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.manual_seed(80)

Domain = torch.tensor([[0,10.],[-3,4.]]) #torch.tensor([[0,10.],[-3,4.],[-8,-2]]) ### numpy.2darray
N      = 3  ### number of defining points
MO     = 1   ### int (dimension of OUT)

kernel = gpytorch.kernels.RBFKernel() ##gpx.kernels.RBF()
μ_i    = torch.zeros(N, dtype=torch.float64) ##jax.numpy.zeros(self.N)
seed   = 137


### set of randomly sampled points
R_ix  = torch.rand(N, Domain.shape[0], dtype=torch.float64) ## domain.shape ??
R_ix *= torch.diff(Domain, axis=1).reshape(-1)
R_ix += Domain[:,0] ## save this!!!!!

L_ij  = torch.linalg.cholesky( kernel(R_ix) ) ## .to_dense() for concrete implementation

D_iX  = torch.normal(0, 1, size=(N, MO))
D_iX *= torch.diag(L_ij.to_dense()).reshape(-1,1)
D_iX += μ_i.reshape(-1,1)
D_iX  = torch.matmul(L_ij, D_iX)
#S_jX  = torch.linalg.solve(L_ij, D_iX)
S_jX  = torch.cholesky_solve(D_iX, L_ij.to_dense()) ## save this!!!!!



#####################
### random points
D_ax  = torch.rand(N, Domain.shape[0]) ## [0,1) domain.shape ??
D_ax *= torch.diff(Domain, axis=1).reshape(-1)
D_ax += Domain[:,0]

D_ay  = torch.matmul(kernel(D_ax, R_ix), S_jX)

In [None]:
kernel.state_dict()

In [None]:
from torch.autograd import grad


d_loss_dx = grad(outputs=D_ay, inputs=R_ix)

In [None]:
#### generate mesh to plot
R_ax = torch.stack(torch.meshgrid(*[ torch.arange(Domain[i,0], Domain[i,1], 0.33) for i in range(len(Domain)) ]), axis=-1)
R_ax = R_ax.reshape((torch.prod( torch.asarray(R_ax.shape[:-1]) ), R_ax.shape[-1]))

R_ay = (torch.matmul(kernel(R_ax, R_ix), S_jX)).detach().numpy()
R_ax = R_ax.detach().numpy()

#### the plot
fig = go.Figure(data=[go.Scatter3d(x=R_ax[:,0], y=R_ax[:,1], z=R_ay[:,0], mode='markers'),
                      go.Scatter3d(x=R_ix.detach().numpy()[:,0], y=R_ix.detach().numpy()[:,1], z=D_iX.detach().numpy()[:,0], mode='markers'),
                      go.Scatter3d(x=D_ax.detach().numpy()[:,0], y=D_ax.detach().numpy()[:,1], z=D_ay.detach().numpy()[:,0], mode='markers')])
fig.show()

In [None]:
import torch
import gpytorch

torch.set_default_tensor_type(torch.DoubleTensor)

#@torch.jit.script
class RCF():
    """ built: 3/5/2024
    samples a Random-Contionus-Function (RCF), with-respect-to a GP kernel
    RCF : IN -> OUT
    we define a prior, and then sample to form a posterior.
    """

    def __init__(self, Domain, N:int, MO:int=1, seed:int=777,
                 IN_noise=None, OUT_noise=None,
                 kernel=gpytorch.kernels.RBFKernel()):
        """
        GIVEN >
            Domain  : 2d-torch.Tensor (domain of input points)
                 N  : int (number of points)
                MO  : int (Multiple-Output Dimension)
        ** IN_noise : 1d-torch.Tensor
        **OUT_noise : 1d-torch.Tensor
        **   seed : int
        ** kernel : gpytorch.kernels
        GET   >
            None
        """

        self.IN = Domain.double() ### 2d-torch.Tensor
        self.N  = N      ### number of defining points
        self.MO = MO     ### int (dimension of OUT)

        try:
            kernel.register_load_state_dict_post_hook
            self.kernel = kernel
        except:
            raise "kernel must be of class gpytorch.kernels"

        ### define random sampling key
        self.seed = seed
        torch.manual_seed(self.seed)

        ### define anisotropic i.i.d white noise
        if IN_noise is None:
            self.IN_noise = torch.zeros( self.IN.shape[0] , dtype=torch.float64)
        else:
            self.IN_noise = IN_noise
        if OUT_noise is None:
            self.OUT_noise = torch.zeros( self.MO , dtype=torch.float64)
        else:
            self.OUT_noise = OUT_noise

        ### find a series of random defining points,
        ### keep looping until we find a stable configuration of initial-points
        self.R_ix  = torch.rand(N, self.IN.shape[0], dtype=torch.float64)
        self.R_ix *= torch.diff(self.IN, axis=1).reshape(-1)
        self.R_ix += self.IN[:,0]

        ### compute cholesky-factorization
        L_ij       = torch.linalg.cholesky( self.kernel(self.R_ix) ).to_dense()

        ### compute OUT-space defining-points
        D_iX       = torch.normal(0, 1, size=(self.N, self.MO))
        D_iX      *= torch.diag(L_ij).reshape(-1,1)
        D_iX       = torch.matmul(L_ij, D_iX)

        ### compute (L \ D) used to interpolate arbtirary points
        self.S_jX  = torch.cholesky_solve(D_iX, L_ij)

    def evalulate(self, D_ax):
        """ evalulate for arbitrary values/points in OUT given points in IN
        GIVEN > self, function-values above {D_ix, S_jX} : 2d-torch.Tensor
                **IN_noise  : {float, 1d-torch.Tensor}
                **OUT_noise : {float, 1d-torch.Tensor}
        GET   > D_aX : 2d-torch.Tensor
        """
        D_ax += self.IN_noise*torch.normal(0, 1, size=D_ax.shape)
        D_aX  = torch.matmul(self.kernel(D_ax, self.R_ix), self.S_jX)
        D_aX += self.OUT_noise*torch.normal(0, 1, size=D_aX.shape)
        return D_aX

f = RCF(Domain, 23, seed=1287)

#### generate mesh to plot
R_ax = torch.stack(torch.meshgrid(*[ torch.arange(Domain[i,0], Domain[i,1], 0.33) for i in range(len(Domain)) ]), axis=-1)
R_ax = R_ax.reshape((torch.prod( torch.asarray(R_ax.shape[:-1]) ), R_ax.shape[-1]))

R_ay = (f.evalulate(R_ax)).detach().numpy()
R_ax = R_ax.detach().numpy()

#### the plot
fig = go.Figure(data=[go.Scatter3d(x=R_ax[:,0], y=R_ax[:,1], z=R_ay[:,0], mode='markers'),
                      go.Scatter3d(x=f.R_ix.detach().numpy()[:,0], y=f.R_ix.detach().numpy()[:,1], z=f.evalulate(f.R_ix).detach().numpy()[:,0], mode='markers')])
fig.show()