In [1]:
import pyro
from pyro import distributions as dist
from pyro.nn import PyroSample, PyroParam, PyroModule, pyro_method
import pyro.distributions.constraints as constraints
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
import pyro.contrib.gp as gp
from sklearn.datasets import fetch_olivetti_faces
from tqdm.autonotebook import tqdm
from sklearn.decomposition import NMF
from pyro.distributions.util import eye_like
from pyro.contrib.gp import Parameterized

In [2]:
pyro.clear_param_store()

In [3]:
x_train, y_train = fetch_olivetti_faces(return_X_y=True, shuffle=True, random_state=97)


In [4]:
Y = (x_train)*255
Y = Y.astype(int)
Y = torch.tensor(Y)

In [5]:
idx = torch.linspace(-5, 5, 64, dtype=torch.float)

X_coordinates = torch.cartesian_prod(idx, idx)

U_len = 20

idz = torch.linspace(-5, 5, U_len, dtype=torch.float)
Xu = torch.cartesian_prod(idz, idz)



In [75]:
def conditional(
    Xnew,
    X,
    kernel,
    f_loc,
    f_scale_tril=None,
    Lff=None,
    jitter=1e-6,
):

    N = X.size(0)
    M = Xnew.size(0)
    latent_shape = f_loc.shape[:-1]

    if Lff is None:
        Kff = kernel(X).contiguous()
        for i in range(latent_shape[0]):
            Kff[i].view(-1)[:: N + 1] += jitter

        Lff = torch.linalg.cholesky(Kff)
    Kfs = kernel(X, Xnew)

    # convert f_loc_shape from latent_shape x N to N x latent_shape
    f_loc = f_loc.permute(-1, *range(len(latent_shape)))
    # convert f_loc to 2D tensor for packing
    f_loc_2D = f_loc.reshape(N, -1)

    Kfs = Kfs.permute(-2, -1, *range(len(latent_shape)))
    Kfs_2D = Kfs.reshape(N, -1)

    if f_scale_tril is not None:
        # convert f_scale_tril_shape from latent_shape x N x N to N x N x latent_shape
        f_scale_tril = f_scale_tril.permute(-2, -1, *range(len(latent_shape)))
        # convert f_scale_tril to 2D tensor for packing
        f_scale_tril_2D = f_scale_tril.reshape(N, -1)


    

    pack = torch.cat((f_loc_2D, Kfs_2D), dim=1)
    if f_scale_tril is not None:
        pack = torch.cat((pack, f_scale_tril_2D), dim=1)

    Lffinv_pack = torch.linalg.solve_triangular(Lff, pack, upper=False)
    print(Lffinv_pack.shape)
    # unpack
    v_2D = Lffinv_pack[:, :f_loc_2D.size(1)]
    print(v_2D.shape)
    W = Lffinv_pack[:, f_loc_2D.size(1) : f_loc_2D.size(1) + M].t()
    if f_scale_tril is not None:
        S_2D = Lffinv_pack[:, -f_scale_tril_2D.size(1) :]

    loc_shape = latent_shape + (M,)
    loc = W.matmul(v_2D).t().reshape(loc_shape)


    '''see this later'''

    Kssdiag = kernel(Xnew, diag=True)
    Qssdiag = W.pow(2).sum(dim=-1)
    # Theoretically, Kss - Qss is non-negative; but due to numerical
    # computation, that might not be the case in practice.
    var = (Kssdiag - Qssdiag).clamp(min=0)

    if f_scale_tril is not None:
        W_S_shape = (Xnew.size(0),) + f_scale_tril.shape[1:]
        W_S = W.matmul(S_2D).reshape(W_S_shape)
        # convert W_S_shape from M x N x latent_shape to latent_shape x M x N
        W_S = W_S.permute(list(range(2, W_S.dim())) + [0, 1])
        Kdiag = W_S.pow(2).sum(dim=-1)
        var = var + Kdiag
    else:
        
        var = var.expand(latent_shape + (M,))

    return loc, var

In [76]:
class NSF(PyroModule):
    def __init__(self, X, y, kernel, Xu, components=10, jitter=1e-4):
        super(NSF, self).__init__()
        self.X = X
        self.y = y
        self.jitter = jitter
        self.kernel = kernel
        self.components=torch.Size([components])

        self.Xu = PyroParam(Xu)


        M = self.Xu.size(0)

        u_loc = self.Xu.new_zeros(self.components + (M,))
        self.u_loc = PyroParam(u_loc)


        identity = eye_like(self.Xu, M)
        u_scale_tril = identity.repeat(self.components + (1, 1))
        self.u_scale_tril = PyroParam(u_scale_tril, constraints.lower_cholesky)

    @pyro_method
    def model(self):
        # self.set_mode("model")
        M = self.Xu.size(0)
        Kuu = self.kernel(self.Xu).contiguous()
        for i in range(self.components[0]):
            Kuu[i].view(-1)[:: M + 1] += self.jitter

        Luu = torch.linalg.cholesky(Kuu)
        zero_loc = self.Xu.new_zeros(self.u_loc.shape)

        pyro.sample(
                self._pyro_get_fullname("u"),
                dist.MultivariateNormal(zero_loc, scale_tril=Luu).to_event(
                    zero_loc.dim() - 1
                ),
        )

        f_loc, f_var = conditional(
            self.X,
            self.Xu,
            self.kernel,
            self.u_loc,
            self.u_scale_tril,
            Luu,
            jitter=self.jitter
        )

    


In [77]:
kernel = gp.kernels.IndependentRBF(input_dim=2, components=10)
model = NSF(X_coordinates, Y, kernel, Xu)

In [78]:
model.model()

torch.Size([10, 400, 44970])
torch.Size([10, 400, 10])


RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D

In [39]:
model.u_loc.shape

torch.Size([10, 400])

In [37]:
cov = kernel.forward(Xu, Xu)

for i in range(10):
    cov[i].view(-1)[:: Xu.size(0) + 1] += 1e-3
    
cov_cholesky = torch.linalg.cholesky(cov)

In [33]:
dist.MultivariateNormal(torch.zeros(10, Xu.size(0)), scale_tril=cov_cholesky)().shape

torch.Size([10, 400])

In [60]:
model.u_loc.shape

torch.Size([10, 400])

In [56]:
model.u_scale_tril.shape

torch.Size([10, 400, 400])