In [None]:
%matplotlib inline

In [None]:
import numpy as np
import edward as ed
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
seed = 42
ed.set_seed(seed)
plt.style.use("ggplot")
fname = "data/ratings.txt.gz"

### Data

In [None]:
data = np.loadtxt(fname, dtype=np.float32)
idx = data.nonzero()
tidx = tf.constant(np.column_stack(idx))
y = data[idx]
n,m = data.shape
%xdel data
y = np.ceil(y)

### Model

In [None]:
from edward.models import Poisson, Gamma
from edward.models import PointMass, Empirical

In [None]:
k = 20
n_iter = 500
t = 500

#### Priors ####

act = Gamma(1.0, 1.0, sample_shape=n) # Users activity
pref = Gamma(1.0, act, sample_shape=k) # Users preference

pop = Gamma(0.1, 0.1, sample_shape=m) # Items popularity
attr = Gamma(1.0, pop, sample_shape=k) # Items attribute

like = Poisson(tf.gather_nd(tf.matmul(pref, attr, transpose_a=True), tidx))


#### Posteriors ####

qact = Empirical(
    tf.nn.softplus(tf.Variable(tf.random_normal([t,n]))),
)
qpref = PointMass(
    tf.nn.softplus(tf.Variable(tf.random_normal([k,n]))),
)
qpop = Empirical(
    tf.nn.softplus(tf.Variable(tf.random_normal([t,m]))),
)
qattr = PointMass(
    tf.nn.softplus(tf.Variable(tf.random_normal([k,m]))),
)

### Inference

In [None]:
inference_e = ed.Gibbs(
    {act:qact, pop:qpop}, 
    data={like:y, pref:qpref, attr:qattr},
)

inference_m = ed.MAP(
    {pref:qpref, attr:qattr},
    data={like:y, act:qact, pop:qpop},
)

inference_e.initialize()
inference_m.initialize(n_iter=n_iter, optimizer="rmsprop")

tf.global_variables_initializer().run()

In [None]:
loss = np.empty(n_iter, dtype=np.float32)

for i in range(n_iter):
    info_dict_e = inference_e.update()
    info_dict_m = inference_m.update()
    
    loss[i] = info_dict_m["loss"]
    
    inference_m.print_progress(info_dict_m)

In [None]:
fig = plt.figure(figsize=(15,6))
ax = fig.add_subplot(111)
ax.plot(loss / loss.max())
ax.set_title("Loss")
ax.set_xlabel("Iteration")
fig.savefig("images/loss.png", transparent=True)

### Save

In [None]:
sess = ed.get_session()

In [None]:
np.save("data/loss", loss)

In [None]:
np.savez("data/act-pop", act=sess.run(qact), pop=sess.run(qpop))

In [None]:
np.savez("data/pref-attr", pref=sess.run(qpref), attr=sess.run(qattr))