In [None]:
import os
import numpy
import numpy.random

import tensorflow as tf
import edward as ed
from edward.models import Bernoulli, Exponential, Normal, Gamma, Empirical

import matplotlib.pyplot as plt
%matplotlib inline

Simulated data:

In [None]:
PAGE_COUNT = 10
PAGES_PER_SESSION_PRIOR = 5
NUMBER_OF_SESSIONS = 100

DECAY = 2.

# Simulate some trend
TREND = PAGES_PER_SESSION_PRIOR * numpy.exp(- numpy.arange(NUMBER_OF_SESSIONS) * DECAY 
                                            / NUMBER_OF_SESSIONS) 
         

# Sample data around the trend
DATA = numpy.minimum(PAGE_COUNT, 
                     numpy.maximum(1, numpy.round(numpy.random.exponential(TREND))))

print("Trend from {:.2f} to {:.2f}".format(TREND[0], TREND[-1]))
print("Data:", DATA)

# Convert to tensor
DATA = tf.convert_to_tensor(DATA, numpy.int32)
with tf.Session() as s:
    print(s.run(DATA))

Let's define the model of clicking through a campaign:

In [None]:
def update_beliefs(beliefs, i, j, bandwidth):
    
    # updates the beliefs with new evidence
    update = tf.scatter_nd(tf.stack([tf.stack([i , j])]),
                           tf.constant([1.]),
                           beliefs.shape)
    beliefs = beliefs + update
    
    # compute new evidence in the updated row
    evidence = tf.reduce_sum(beliefs[i, :])
    
    # if the evidence is greater than the bandwidth,
    # scale down
    scale = bandwidth / evidence
    beliefs = tf.cond(scale < 1.,
                      lambda: beliefs * 
                              tf.exp(tf.scatter_nd(tf.stack([tf.stack([i, tf.constant(0)]),
                                                             tf.stack([i, tf.constant(1)])]),
                                                   tf.log(tf.stack([scale, scale])),
                                                   beliefs.shape)),
                      lambda: beliefs)
    return beliefs

In [None]:
foo = update_beliefs(tf.constant([[3., 4.], [3., 3.]]), tf.constant(1), tf.constant(0), tf.constant(1.))
with tf.Session() as s:
    print(s.run(foo))

In [None]:
def model(bandwidth, page_count, number_of_sessions, data):
    churn_probability = 1 / PAGES_PER_SESSION_PRIOR
    beliefs = tf.stack([2 * churn_probability * tf.ones(page_count),
                        2 * (1 - churn_probability) * tf.ones(page_count)],
                       axis=1)
 
    def over_sessions(state, isession):
       
        def over_pages(beliefs, ipage, last_page):
            last_page = tf.logical_or(tf.equal(ipage, data[isession] - 1),
                                      tf.equal(ipage, page_count - 1))
            beliefs = update_beliefs(beliefs, ipage, 
                                     tf.cond(last_page, lambda: 0, lambda: 1),
                                     bandwidth) 
            return (beliefs, ipage + 1, last_page)

        def continues(lefts, ipage, last_page):
            return tf.logical_not(last_page)

 
        beliefs, _ = state
        beliefs, _, _ = tf.while_loop(continues, over_pages, (beliefs, 0, tf.constant(False)))
        
        return beliefs, beliefs
    
    _, beliefs = tf.scan(over_sessions, tf.range(number_of_sessions), (beliefs, beliefs))
    
    # Prepend to each session's lefts 0 and append 1 to get the correct number: never
    # left after 0 pages, always left by reaching the end
    scattered_lefts = Bernoulli(probs=beliefs[:, :, 0] / (beliefs[:, :, 0] + beliefs[:, :, 1]))
    scattered_lefts = tf.concat([tf.zeros((scattered_lefts.shape[0], 1), dtype=tf.int32),
                                 scattered_lefts[:, :-1],
                                 tf.ones((scattered_lefts.shape[0], 1), dtype=tf.int32)],
                                axis=1)
    lefts = tf.argmax(scattered_lefts, axis=1)
    
    return Normal(tf.cast(lefts, dtype=tf.float32), scale=0.5)

In [None]:
foo = model(20., PAGE_COUNT, NUMBER_OF_SESSIONS, DATA)
with tf.Session() as s:
    a = 0.
    b = 0.
    K = 100
    for i in range(100):
        res = s.run(foo)
        a += res[:len(res)//4].mean()
        b += res[-len(res)//4:].mean()
    a /= K
    b /= K
    print(a, b)

## Full model

In [None]:
bandwidth = Exponential(0.05)
pps = model(bandwidth, PAGE_COUNT, NUMBER_OF_SESSIONS, DATA)

Let's just see we still can run it:

In [None]:
with tf.Session() as s:
    resum = numpy.zeros(DATA.shape[0])
    K = 10
    for i in range(K):
        res = s.run(pps)
        resum += res
    resum /= K
    print(resum)
    d = 10
    plt.plot([resum[i:i+d].mean() for i in range(len(resum))])

## MH Monte Carlo

Slow but should work.

In [None]:
N = 1000
with tf.variable_scope("importance", reuse=tf.AUTO_REUSE):
    sampled_bandwidth = Empirical(params=tf.get_variable(
        "sampled_bandwidth", 
        [N], 
        initializer=tf.constant_initializer(10.)))
    proposed_bandwidth = Normal(loc=sampled_bandwidth, scale=1.)

    mh_inference = ed.MetropolisHastings({bandwidth: sampled_bandwidth}, 
                                         {bandwidth: proposed_bandwidth},
                                         data={pps: tf.cast(DATA, tf.float32)})

In [None]:
mh_inference.run()

sess = ed.get_session()
mean, stddev = sess.run([sampled_bandwidth.mean(), sampled_bandwidth.stddev()])
print("posterior: mean={:.4f}, stddev={:.4f}".format(mean, stddev))

In [None]:
_ = plt.hist(sess.run(sampled_bandwidth.sample(1000)), density=True)
_ = plt.xlabel("bandwidth")


## Variational Inference


In [None]:
with tf.variable_scope("variational", reuse=tf.AUTO_REUSE):
    shape = tf.get_variable("shape", (), initializer=tf.constant_initializer(1.))
    scale = tf.get_variable("scale", (), initializer=tf.constant_initializer(20.))
    qbandwidth = Gamma(tf.nn.softplus(shape), 1. / tf.nn.softplus(scale))

variational_inference = ed.KLqp({bandwidth: qbandwidth}, 
                                data={pps: tf.cast(DATA, tf.float32)})

In [None]:
variational_inference.run(n_iter=1000)

In [None]:
alpha = qbandwidth.concentration.eval()
beta = qbandwidth.rate.eval()
mean = alpha / beta
stddev = numpy.sqrt(mean / beta)
print("posterior: mean={:.4f} stddev={:.4f}".format(mean, stddev))

In [None]:
_ = plt.hist(qbandwidth.sample(1000).eval(), density=True)
_ = plt.xlabel("bandwidth")