## Bayesian Probabilistic Matrix Factorisation

In this notebook, we'll use Coin SVGD for Bayesian probablistic matrix factorisation (PMF).

In [5]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, random

import os
import itertools
from tqdm import tqdm
from copy import deepcopy
from functools import partial

import matplotlib.pyplot as plt

In [6]:
plot_dir = "plots/SVGD/BayesPMF"
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

results_dir = "results/SVGD/BayesPMF"
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

### Model
Now we're ready to define our model.

In [12]:
"""
Probabilistic Matrix Factorisation
"""

class BayesianPMF:

    def __init__(self, data, batchsize=100, alpha=3, mu0=0, a0=4, b0=5, D=20, N=943, M=1682):
        """
        :param data: input data
        :param batchsize: batchsize for stochastic gradients
        :param alpha: precision of observation noise (hyper-parameter)
        :param mu0: mean of mu_W (hyper-parameter)
        :param a0: location parameter of lambda_W (hyper-parameter)
        :param b0: scale parameter of lambda_W (hyper-parameter)
        :param D: latent dimension
        :param N: number of users
        :param M: number of movies
        """
        self.data = data
        self.totalsize, _ = data.shape
        self.batchsize = batchsize

        self.alpha = alpha
        self.mu0 = mu0
        self.a0 = a0
        self.b0 = b0

        self.D = D
        self.N = N
        self.M = M

        self.iter = 0
        self.permutation = np.random.permutation(self.totalsize)

    @partial(jit, static_argnums=(0,))
    def log_lik(self, params, data):
        """
        log likelihood for a single data point
        """
        U, V, _, _, _, _ =  params
        i, j, r_ij = data
        return (-self.alpha/2)*(r_ij - jnp.dot(U[:,(i-1).astype(int)].T, V[:, (j-1).astype(int)]))**2

    def batch_log_lik(self, params, data):
        """
        log likelihood for batch of data
        """
        return vmap(self.log_lik, in_axes=(None, 0))(params, data)

    @partial(jit, static_argnums=(0,))
    def log_pri_m(self, M, mu_M, lambda_M):
        """
        log prior for single vectors U_i and V_j
        """
        return -0.5*jnp.linalg.multi_dot([M-mu_M, jnp.diag(jnp.exp(lambda_M)), M-mu_M]) + 0.5*jnp.sum(lambda_M)

    def batch_log_pri_m(self, M, mu_M, lambda_M):
        """
        log prior for batch of vectors U_i and V_j
        """
        return jit(vmap(self.log_pri_m, in_axes=(1, None, None)))(M, mu_M, lambda_M)

    def log_pri_mu(self, mu_M, lambda_M):
        """
        log prior of mu
        """
        return -0.5*jnp.linalg.multi_dot([mu_M, jnp.diag(jnp.exp(lambda_M)), mu_M]) + 0.5*jnp.sum(lambda_M)

    def log_pri_lambda(self, log_lambda_i):
        """
        log prior of lambda (precision)
        """
        return (self.a0-1)*log_lambda_i - self.b0*jnp.exp(log_lambda_i) + log_lambda_i

    def batch_log_pri_lambda(self, log_lambda_i):
        """
        batch log prior of lambda
        """
        return vmap(self.log_pri_lambda, in_axes=(0))(log_lambda_i)

    @partial(jit, static_argnums=(0,))
    def log_prior(self, params):
        """
        Prior for U, V, mu_U, lambda_U, mu_V, lambda_V
        """
        U, V, mu_U, lambda_U, mu_V, lambda_V = params

        term_U = jnp.sum(self.batch_log_pri_m(U, mu_U, lambda_U))
        term_V = jnp.sum(self.batch_log_pri_m(V, mu_V, lambda_V))

        term_muU = self.log_pri_mu(mu_U, lambda_U)
        term_muV = self.log_pri_mu(mu_V, lambda_V)

        term_lambdaU = jnp.sum(self.batch_log_pri_lambda(lambda_U))
        term_lambdaV = jnp.sum(self.batch_log_pri_lambda(lambda_V))

        return term_U + term_V + term_muU + term_muV + term_lambdaU + term_lambdaV

    @partial(jit, static_argnums=(0,))
    def log_post(self, params, data):
        Ndata = 80000 # dataset size
        return self.log_prior(params) + Ndata*jnp.mean(self.batch_log_lik(params, data))

    def log_post_flat(self, params, data):
        """
        Log posterior density
        """
        params = self.unpack_params(params)
        return self.log_prior(params) + self.totalsize * jnp.mean(self.batch_log_lik(params, data))

    def pack_params(self, params):
        """
        Pack all parameters
        """
        return jnp.concatenate([x.ravel() for x in params])

    def unpack_params(self, params):
        """
        Unpack all parameters
        """
        dims = jnp.array([self.D * self.N, self.D * self.M, self.D, self.D, self.D, self.D])
        breaks = jnp.cumsum(dims)
        U = jnp.reshape(params[0:breaks[0]], (self.D, self.N))
        V = jnp.reshape(params[breaks[0]:breaks[1]], (self.D, self.M))
        mu_U = params[breaks[1]:breaks[2]]
        lambda_U = params[breaks[2]:breaks[3]]
        mu_V = params[breaks[3]:breaks[4]]
        lambda_V = params[breaks[4]:breaks[5]]
        return [U, V, mu_U, lambda_U, mu_V, lambda_V]

    def predict_loss(self, U, V, data, mean_rating):
        """
        Predictive RMSE for a single data point and a single sample
        """
        i, j, r_ij = data
        pred = jnp.dot(U[:, (i - 1).astype(int)].T, V[:, (j - 1).astype(int)]) + mean_rating
        pred = jnp.where(pred > 5, 5, pred)
        pred = jnp.where(pred < 1, 1, pred)
        return (pred - r_ij) ** 2

    def sgld(self, params0, dt=1e-5, n_iter=1000):
        """
        Stochastic Gradient Langevin Dynamics (SGLD)
        """

        # initial theta
        params0 = self.pack_params(params0)
        params = deepcopy(params0)

        # grad log pdf
        grad_log_pdf = grad(self.log_post_flat)

        # noise
        key = random.PRNGKey(1)
        w = jnp.sqrt(2 * dt) * random.normal(key=key, shape=(n_iter,) + params.shape)

        for t in tqdm(range(n_iter)):

            if self.batchsize > 0:
                batch = [i % self.totalsize for i in
                         range(self.iter * self.batchsize, (self.iter + 1) * self.batchsize)]
                ridx = self.permutation[batch]
                self.iter += 1
            else:
                ridx = np.random.permutation(self.data.shape[0])

            batch_data = self.data[ridx, :]

            grad_theta = grad_log_pdf(params, batch_data)

            params = params + dt * grad_theta + w[t, :]

        return self.unpack_params(params)

    def gram(self, kernel, xs):
        return vmap(lambda x: vmap(lambda y: kernel(x, y))(xs))(xs)

    def rbf(self, x, y):
        return jnp.exp(-jnp.sum((x - y) ** 2))

    def svgd_kernel(self, theta):
        Kxy = self.gram(self.rbf, theta)
        k_grad = grad(self.rbf)
        return Kxy, k_grad(theta, theta)

    def svgd(self, params0, dt=1e-5, n_iter=1000, adagrad=True, alpha=0.9):
        """
        Stein Variational Gradient Descent
        """
        # initial theta
        params = deepcopy(params0)

        # for adagrad with momentum
        fudge_factor = 1e-6
        historical_grad = 0

        # grad log pdf
        grad_log_pdf = grad(self.log_post_flat)

        for t in tqdm(range(n_iter)):

            if self.batchsize > 0:
                batch = [i % self.totalsize for i in range(self.iter * self.batchsize, (self.iter + 1) * self.batchsize)]
                ridx = self.permutation[batch]
                self.iter += 1
            else:
                ridx = np.random.permutation(self.data.shape[0])

            batch_data = self.data[ridx, :]

            ln_p_grad = np.zeros_like(params)
            for k in range(params.shape[0]):
                ln_p_grad[k, :] = grad_log_pdf(params[k, :], batch_data)
            kxy, dxkxy = self.svgd_kernel(params)
            grad_params = (jnp.matmul(kxy, ln_p_grad) + dxkxy) / params.shape[0]

            if adagrad:
                if t == 0:
                    historical_grad = historical_grad + grad_params ** 2
                else:
                    historical_grad = alpha * historical_grad + (1 - alpha) * (grad_params ** 2)

                adj_grad = np.divide(grad_params, fudge_factor + np.sqrt(historical_grad))

            else:
                adj_grad = grad_params

            params = params + dt * adj_grad

        return params

    def svgd_param_free(self, theta0, n_iter=1000):
        """
        Coin Stein Variational Gradient Descent (Coin SVGD)
        """

        # initial theta
        theta = deepcopy(theta0)

        # grad log pdf
        grad_log_pdf = grad(self.log_post_flat)

        # initialise other vars
        L = 0
        grad_theta_sum = 0
        reward = 0
        abs_grad_theta_sum = 0

        for t in tqdm(range(n_iter)):

            # data batch
            if self.batchsize > 0:
                batch = [i % self.totalsize for i in
                         range(self.iter * self.batchsize, (self.iter + 1) * self.batchsize)]
                ridx = self.permutation[batch]
                self.iter += 1
            else:
                ridx = np.random.permutation(self.data.shape[0])

            batch_data = self.data[ridx, :]

            # calculate grad log density
            ln_p_grad = np.zeros_like(theta)
            for k in range(theta.shape[0]):
                ln_p_grad[k, :] = grad_log_pdf(theta[k, :], batch_data)

            # calculate kernel matrix
            kxy, dx_kxy = self.svgd_kernel(theta)
            grad_theta = (jnp.matmul(kxy, ln_p_grad) + dx_kxy) / theta0.shape[0]

            # |gradient|
            abs_grad_theta = abs(grad_theta)

            # constant
            L = jnp.maximum(abs_grad_theta, L)

            # sum of gradients
            grad_theta_sum += grad_theta
            abs_grad_theta_sum += abs_grad_theta

            # 'reward'
            reward = np.maximum(reward + jnp.multiply(theta - theta0, grad_theta), 0)

            # theta update
            theta = theta0 + grad_theta_sum / (L * (abs_grad_theta_sum + L)) * (L + reward)

        return theta

    def _predict_loss(self, U, V, data, mean_rating):
        """
        Get the predictive RMSE for a single data point and a single sample
        """
        i, j, r_ij = data
        pred = jnp.dot(U[:,(i-1).astype(int)].T, V[:, (j-1).astype(int)]) + mean_rating
        pred = jnp.where(pred>5, 5, pred)
        pred = jnp.where(pred<1, 1, pred)
        return (pred-r_ij)**2

    def rmse_1sample(self, params, test_data, mean_rating):
        """
        RMSE for 1 sample for 1 sample Take average over all data
        """
        batch_predict_loss = vmap(self._predict_loss, in_axes=(None, None, 0, None))
        return jnp.mean(batch_predict_loss(params[0], params[1], test_data, mean_rating))

    def rmse_batch(self, params, test_data, mean_rating):
        """
        RMSE for all particles
        """
        return sum([self.rmse_1sample(self.unpack_params(params[i,:]), test_data, mean_rating) for i in range(params.shape[0])]) / params.shape[0]

    def init_params(self, key):
        subkey1, subkey2, subkey3, subkey4, subkey5, subkey6 = random.split(key, 6)
        U = random.normal(subkey1, shape=(self.D, self.N))
        V = random.normal(subkey2, shape=(self.D, self.M))
        mu_U = random.normal(subkey3, shape=(self.D,))
        lambda_U = 0.1 * random.normal(subkey4, shape=(self.D,))
        mu_V = random.normal(subkey5, shape=(self.D,))
        lambda_V = 0.1 * random.normal(subkey6, shape=(self.D,))
        params = [U, V, mu_U, lambda_U, mu_V, lambda_V]
        return [0.1*p for p in params]

    def batch_init_params(self, key, n_particles):
        total_dim = self.D * (self.N + self.M + 1 + 1 + 1 + 1)
        params = np.zeros((n_particles, total_dim))
        for i in range(n_particles):
            key, subkey = random.split(key)
            params[i, :] = self.pack_params(self.init_params(subkey))
        return jnp.array(params)

    def init_PMF_zeros(self):
        U = jnp.zeros(shape=(self.D, self.N))
        V = jnp.zeros(shape=(self.D, self.M))
        mu_U = jnp.zeros(shape=(self.D,))
        lambda_U = jnp.zeros(shape=(self.D,))
        mu_V = jnp.zeros(shape=(self.D,))
        lambda_V = jnp.zeros(shape=(self.D,))
        params = [U, V, mu_U, lambda_U, mu_V, lambda_V]
        return params

### Experiments
We can now run all of the methods. 
* We'll run all of the algorithms for T=1000 iterations, using N=50 particles. 
* For the learning-rate dependent methods (SVGD and SGLD), we'll consider a grid of candidate learning rates, and run them for each of theses. Note: running this grid search over learning rates can take some time!

In [None]:
# load the dataset
R_train = np.genfromtxt("data/MovieLens/train.dat")
R_test = np.genfromtxt("data/MovieLens/test.dat")

mean_rating = jnp.mean(R_train[:,2])
R_train[:, 2] = R_train[:, 2] - mean_rating
R_train = jnp.array(R_train)

# model parameters
D = 20
N = 943
M = 1682

alpha = 3
mu0 = 0
a0 = 1
b0 = 5

# algorithm parameters
n_iter = 1000
batchsize = 1000
n_particles = 50

all_key_int = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
all_dt = list(list(np.logspace(-7, 0, 30)))

vars = [all_key_int, all_dt]
all_vars = list(itertools.product(*vars))

sgld_rmse = np.zeros((len(all_key_int), len(all_dt)))
svgd_rmse = np.zeros((len(all_key_int), len(all_dt)))
coin_svgd_rmse = np.zeros((len(all_key_int)))

for ii, key_int in enumerate(all_key_int):

    for jj, dt in enumerate(all_dt):

        print("Key: " + str(ii + 1) + "/" + str(len(all_key_int)))
        print("LR:" + str(jj + 1) + "/" + str(len(all_dt)))

        key = random.PRNGKey(key_int)
        PMF = BayesianPMF(data=R_train, batchsize=batchsize, alpha=alpha, mu0=mu0, a0=a0, b0=b0, D=D, N=N, M=M)

        theta0_sgld = [p for p in PMF.init_params(key)]
        theta_sgld = PMF.sgld(theta0_sgld, n_iter=n_iter, dt=dt)
        init_rmse_sgld = PMF.rmse_1sample(theta0_sgld, R_test, mean_rating)
        final_rmse_sgld = PMF.rmse_1sample(theta_sgld, R_test, mean_rating)
        sgld_rmse[ii, jj] = final_rmse_sgld

        theta0_svgd = PMF.batch_init_params(key, n_particles)
        theta_svgd = PMF.svgd(theta0_svgd, n_iter=n_iter, dt=dt)
        init_rmse_svgd = PMF.rmse_batch(theta0_svgd, R_test, mean_rating)
        final_rmse_svgd = PMF.rmse_batch(theta_svgd, R_test, mean_rating)
        svgd_rmse[ii,jj] = final_rmse_svgd
        
        if jj == 0:
            theta0_coin_svgd = PMF.batch_init_params(key, n_particles)
            theta_coin_svgd = PMF.svgd_param_free(theta0_coin_svgd, n_iter=n_iter)
            init_rmse_coin_svgd = PMF.rmse_batch(theta0_coin_svgd, R_test, mean_rating)
            final_rmse_coin_svgd = PMF.rmse_batch(theta_coin_svgd, R_test, mean_rating)
            coin_svgd_rmse[ii] = final_rmse_coin_svgd

np.save(results_dir + "svgd", svgd_rmse)
np.save(results_dir + "coin_svgd", coin_svgd_rmse)
np.save(results_dir + "sgld", sgld_rmse)

rmses = [svgd_rmse, coin_svgd_rmse, sgld_rmse]
fnames = ["svgd", "coin_svgd", "sgld"]
names = ["SVGD", "Coin SVGD", "SGLD"]

# average over random seeds
average_rmses = [np.mean(rmse, 0) for rmse in rmses]

# plot results
for ii, (rmse, average_rmse, name, fname) in enumerate(zip(rmses, average_rmses, names, fnames)):
    if name == "SGLD" or name == "SVGD":
        plt.fill_between(all_dt, np.amax(rmse, 0), np.amin(rmse, 0), color="C" + str(ii), alpha=0.4)
        plt.plot(all_dt, average_rmse, ".-", label=name, color="C" + str(ii), alpha=1)
    if name == "Coin SVGD":
        plt.axhline(average_rmse, label=name, color="C" + str(ii))
        xlim = plt.gca().get_xlim()
        for jj in range(len(all_key_int)):
            plt.axhline(rmse[jj], color="C" + str(ii), alpha=0.3)
        plt.gca().get_xlim()
plt.xscale("log")
plt.grid(visible=True, color="whitesmoke", ls='-')
plt.gca().set_axisbelow(True)
plt.ylim(0.8, 1.4)
plt.legend(prop={'size':18})
plt.xlabel("Learning Rate", fontsize=18)
plt.ylabel("Test RMSE", fontsize=18)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.gcf().subplots_adjust(bottom=0.15)
fname = plot_dir + "/" + "lr_vs_rmse_" + ".pdf"
plt.savefig(fname, format="pdf")
plt.show()