In [61]:
import torch
import torch.linalg as linalg
import torch.nn.functional as F
import numpy as np
class PILCO:
    def __init__(self, X, Y, lengthscales, variance, noise, batch_size):
        self.X = torch.tensor(X)
        self.Y = torch.tensor(Y)
        self.lengthscales = torch.tensor(lengthscales).float()
        self.variance = torch.tensor(variance)
        self.noise = torch.tensor(noise)
        self.num_outputs = Y.shape[1]
        self.num_dims = X.shape[1]
        self.batch_size=batch_size
        
    def predict_on_noisy_inputs(self, m, s):
        iK, beta = self.calculate_factorizations()
        return self.predict_given_factorizations(m, s, iK, beta)

    def calculate_factorizations(self):
        K = self.K(self.X)
        batched_eye = torch.eye(self.num_dims, dtype=torch.float64).repeat(self.num_outputs, 1, 1)
        batched_eye = batched_eye.unsqueeze(0).repeat(self.batch_size, 1, 1, 1)
        L = linalg.cholesky(K + self.noise[:, None, None] * batched_eye)
        iK = torch.cholesky_solve(batched_eye, L)
        Y_ = self.Y[:, :, None]
        beta = torch.cholesky_solve(Y_, L)[:, :, 0]
        return iK, beta

    def predict_given_factorizations(self, m, s, iK, beta):
        s = s.repeat(self.num_outputs, self.num_outputs, 1, 1).float()
        inp = self.centralized_input(m).repeat(self.num_outputs, 1, 1).float()
        print(inp.shape)
        iL = (torch.diag(1 / self.lengthscales)).float()
        iN = inp @ iL
        B = iL @ s[0, ...] @ iL + torch.eye(self.num_dims, dtype=torch.float)

        t = (linalg.solve(B, iN.transpose(-2, -1).transpose(-1, -3)).transpose(-1, -3)).transpose(-1, -2)
        lb = (torch.exp(-torch.sum(iN * t, -1) / 2) * beta).float()
        tiL = t.float() @ iL.float()
        c = (self.variance / torch.sqrt(torch.det(B))).float()

        M = torch.sum(lb, -1) * c[:, None]
        V = tiL @ lb[..., None] * c[..., None]

        R = s @ torch.diag(1 / self.lengthscales ** 2) + torch.eye(self.num_dims, dtype=torch.float)
        X = inp[:, :, :] / self.lengthscales[:, None, None] ** 2
#         print(X.shape)
        X2 = -inp[:, :, :] / self.lengthscales[:, None, None] ** 2
        Q = linalg.solve(R, s) / 2
        Xs = torch.sum(X @ Q * X, -1)
        X2s = torch.sum(X2 @ Q * X2, -1)
        maha = -2 * (X @ Q) @ X2.transpose(-2, -1) + Xs[..., :, None] + X2s[..., None, :]

        k = torch.log(self.variance)[:, None] - torch.sum(iN ** 2, -1) / 2
        L = torch.exp(k[:, None, :, None] + k[None, :, None, :] + maha)
        S = (beta[:, None, None, :] @ L @ beta[None, :, :, None]).squeeze(dim=-1).squeeze(dim=-1)

        diagL = torch.diagonal(L.transpose(-1, -2), dim1=-2, dim2=-1).transpose(-1, -2)
        S = S - torch.diagonal(iK @ diagL @ iK.transpose(-1, -2), dim1=-2, dim2=-1).transpose(-1, -2)
        S = S / torch.sqrt(torch.det(R))
        S = S + torch.diag(self.variance)
        S = S - M @ M.transpose(-1, -2)

        return M.transpose(-1, -2), S, V.transpose(-1, -2)

    def K(self, x1, x2=None):
        # Kernel function
        if x2 is None:
            x2 = x1
        dist_sq = torch.cdist(x1, x2, p=2, compute_mode="donot_use_mm_for_euclid_dist")
        K = self.variance * torch.exp(-0.5 * dist_sq / self.lengthscales ** 2)
        return K

    def centralized_input(self, x):
        # Centralize the input by subtracting the training set mean
        return x - torch.mean(self.X, dim=0)

In [62]:
# Example usage
X = np.array([[0.0, 0.0], [1.0, 1.0]])
Y = np.array([[1.0, 2.0], [3.0, 4.0]])
lengthscales = np.array([0.5, 0.5])
variance = np.array([1.0])
noise = np.array([0.1])

model = PILCO(X, Y, lengthscales, variance, noise, 1)
m = torch.tensor([0.5, 0.5])
s = torch.tensor([[0.1, 0.2], [0.2, 0.3]])

prediction_mean, prediction_cov, prediction_var = model.predict_on_noisy_inputs(m, s)
print("Prediction Mean:", prediction_mean)
print("Prediction Covariance:", prediction_cov)
print("Prediction Variance:", prediction_var)

torch.Size([2, 1, 2])


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 2] but got: [2, 1].

In [54]:
model.lengthscales[:, None, None].shape

torch.Size([2, 1, 1])