In [1]:
import math
import torch
import gpytorch
import pandas as pd 
import numpy as np
from matplotlib import pyplot as plt

In [2]:
import sklearn.metrics as sk

### Fully supervised GPLVM (we should come up with a better acronym/name)

Observations:
$$Y = [y_1,...,y_N], \ y_n \in R^{p}$$
Model parameters
$$X = [x_1,...,x_N], \ x_n \in R^{p\times q}$$
Latent parameters/features
$$L = [l_1,...,l_N], \ l_n \in R^{d }$$
With $d<q$ and $L\sim N(0,I)$.

We want to find a function $f_\theta$ such that:
$$ Y = f_\theta(X,L)+\epsilon $$
Let's make $f$ a GP, as we like to, with mean $m_\theta(X,L)$ and covariance $K_\theta((X,L),(X',L'))$
We would like to learn $L$ and $\theta$ (the hyper-parameters). 

As a cost function we use the likelihood
$$p(Y,L|X) = p(Y|L,X)p(L) $$
and as $f$ is a GP,
$$p(Y|L,X) \sim N(m_\theta(X,L), K_\theta((X,L),(X,L)))$$

Taking the log we have
$$\log p(Y,L|X) = \log p(Y|L,X)+\log p(L) $$
$$ = -\frac{NP}{2}\log(2\pi)-\frac{1}{2}\log|K_\theta|-\frac{1}{2}(Y-m_\theta(X,L))^\top\ K_\theta^{-1}(Y-m_\theta(X,L))-\frac{N}{2}\log(2\pi)-\frac{1}{2}\log|I|-\frac{1}{2}(L)^\top I^{-1}(L)) $$


to calculate $\log |K|$ we use the fact that $\log |K| = 2\log\Sigma (\mathrm{diag}(L)_i)$, where $L$ is the cholesky decomposition of $K$.

In [63]:
def multivariate_ll(y,mu,K,K_inv=None):
    det = 2*torch.sum(torch.log(torch.diag(torch.linalg.cholesky(K))))
    if K_inv==None:
        ll = -(y.shape[0]/2)*torch.log(2*torch.tensor(torch.pi))-0.5*det-0.5*torch.matmul((y-mu).T,torch.linalg.solve(K,y-mu))
    return ll

In [64]:
def rbf(X,scaling,lengthscale):
    X_=X/lengthscale
    X_norm = torch.sum(X_ ** 2, axis = -1)
    rbf = scaling * torch.exp(-0.5 * (X_norm + X_norm - 2 * torch.matmul(X_,X_.T)))
    return rbf

For the linear kernel we don't really have any separation possibilities that are useful to us from an inverting persepctive, as far as I can tell? 
$$K_{ij}=x_i^\top x_j + l_i^\top l_j$$

So we don't get much in the way of help here.

In [65]:
def linear(X1,L1,X2,L2,X_length,L_length,p1,p2):
    
    L1=L1/torch.sqrt(L_scale)
    X1=X1/torch.sqrt(X_scale)
    X_1 = torch.cat((X1,torch.repeat_interleave(L1,p1)[:,None]),axis=1)

    L2=L2/torch.sqrt(L_scale)
    X2=X2/torch.sqrt(X_scale)
    X_2 = torch.cat((X2,torch.repeat_interleave(L2,p2)[:,None]),axis=1)
    
    rbf =  torch.matmul(X_1, X_2.T) 
    return rbf

sample $L$ and $X$ from $U[-1,1]$

Here we do $x_i \neq x_j$, just for the sake of it

In [66]:
N = 10 #number of systems
p = 40 #observations per system
alpha=10
beta=-70
zeta=500

In [67]:
L_true=torch.FloatTensor(N, 1).uniform_(-1, 1)

In [68]:
X=torch.torch.FloatTensor(N*p, 1).uniform_(-1, 1)

In [69]:
def true_func(X,L,alpha,beta,zeta):
    tf = alpha*X+beta*L+zeta
    return tf

In [70]:
y=[]
for i in range(N):
    y.append(true_func(X[i*p:(i+1)*p],L_true[i],alpha,beta,zeta))

In [71]:
Y=torch.cat(y)

In [72]:
Y.shape

torch.Size([400, 1])

In [73]:
Y_mean = Y.mean()
Y_std = Y.std()

In [74]:
Y=(Y-Y_mean)/Y_std

In [75]:
def cost(y,X,L,X_scale,L_scale,sigma2,p,mean=0):
    

    K=linear(X,L,X,L,X_scale,L_scale,p,p)+sigma2*torch.eye(X.shape[0])
    
    #K_inv=torch.kron(torch.linalg.inv(K_x),torch.linalg.inv(K_L))
    #K_inv=torch.kron(torch.linalg.inv(K_L),torch.linalg.inv(K_x))
    #print(torch.linalg.det(K))
    c = multivariate_ll(y,mean,K) + multivariate_ll(L,0,torch.eye(L.shape[0])) # multivariate_ll(torch.repeat_interleave(L,x.shape[0])[:,None],0,torch.eye(inp.shape[0]))
    return -c

In [76]:
L = torch.rand(L_true.shape[0],1).clone().detach().requires_grad_(True)

In [77]:
L =torch.normal(0,torch.ones(L_true.shape[0],1)).clone().detach().requires_grad_(True)
X_scale = torch.rand(1).clone().detach().requires_grad_(True)
L_scale = torch.rand(1).clone().detach().requires_grad_(True)
sigma2 = torch.tensor([0.00001]).clone().detach().requires_grad_(True)
mean = torch.rand(1).clone().detach().requires_grad_(True)

In [78]:
cost(Y,X,L,X_scale,L_scale,sigma2,p)

tensor([[1901562.8750]], grad_fn=<NegBackward0>)

In [97]:
gd = torch.optim.Adam([L,X_scale,L_scale,mean], lr=1e-3)
history_gd = []

for i in range(100000):
    gd.zero_grad()
    objective = cost(Y,X,L,X_scale,L_scale,sigma2,p,mean)
    objective.backward()
    gd.step()
    history_gd.append(objective.item())
    if i%1000 ==0:
        print(objective)
    if (i>1) and (np.abs(history_gd[-1] - history_gd[-2]) < .0000001):
        print("Convergence achieved in ", i+1, " iterations")
        print("-LogL Value: ", objective.item())
        break


tensor([[-1442.7107]], grad_fn=<NegBackward0>)
Convergence achieved in  766  iterations
-LogL Value:  -1444.8563232421875


In [98]:
L_true

tensor([[ 0.9738],
        [ 0.1013],
        [ 0.5664],
        [-0.1560],
        [-0.3757],
        [-0.9250],
        [-0.7929],
        [-0.5063],
        [ 0.2551],
        [-0.3836]])

In [99]:
L

tensor([[-1.5983],
        [-0.2337],
        [-0.9611],
        [ 0.1687],
        [ 0.5123],
        [ 1.3713],
        [ 1.1646],
        [ 0.7165],
        [-0.4742],
        [ 0.5247]], requires_grad=True)

In [100]:
cost(Y,X,L,X_scale,L_scale,sigma2,p)

tensor([[33516.0586]], grad_fn=<NegBackward0>)

In [101]:
cost(Y,X,L_true,X_scale,L_scale,sigma2,p)

tensor([[88186.8672]], grad_fn=<NegBackward0>)

In [102]:
X_scale

tensor([1.2776], requires_grad=True)

In [103]:
L_scale

tensor([0.8284], requires_grad=True)

In [104]:
mean

tensor([-0.1336], requires_grad=True)

Posterior mean: $$m(X^*,X)+K((X^*,L^*),(X,L))^\top(K((X,L),(X,L))+\sigma^2I)^{-1}(Y-m(X,X))$$

In [105]:
ref=6
l_ref = 0
dif = 0.1

In [106]:
K=linear(X,L,X,L,X_scale,L_scale,p,p)+sigma2*torch.eye(X.shape[0])

In [107]:
K_s = linear(X,L,X[0:10]+dif,L[[l_ref]],X_scale,L_scale,p,10)

In [108]:
mean*Y_std +torch.matmul(K_s.T,torch.linalg.solve(K,Y-mean))*Y_std + Y_mean

tensor([[422.9970],
        [430.5042],
        [438.8089],
        [437.0302],
        [438.5528],
        [439.8081],
        [426.0893],
        [442.2871],
        [435.7614],
        [425.7455]], grad_fn=<AddBackward0>)

In [109]:
true_func(X[0:10]+dif,L_true[l_ref],alpha,beta,zeta)

tensor([[422.9973],
        [430.5046],
        [438.8094],
        [437.0305],
        [438.5530],
        [439.8083],
        [426.0897],
        [442.2878],
        [435.7616],
        [425.7458]])