In [4]:
import numpy as np
import pandas as pd

import arviz as az
import pymc as pm
import aesara.tensor as at

from lifetimes.datasets import load_cdnow_summary
from lifetimes import BetaGeoFitter

In [9]:
df = load_cdnow_summary(index_col=0)
df.head(10)

Unnamed: 0_level_0,frequency,recency,T
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,2,30.43,38.86
2,1,1.71,38.86
3,0,0.0,38.86
4,0,0.0,38.86
5,0,0.0,38.86
6,7,29.43,38.86
7,1,5.0,38.86
8,0,0.0,38.86
9,2,35.71,38.86
10,0,0.0,38.86


In [12]:
n = df.shape[0]
x = df["frequency"].to_numpy()
t_x = df["recency"].to_numpy()
T = df["T"].to_numpy()

x_zero = (x > 0).astype(int)

In [16]:
bgf = BetaGeoFitter()
bgf.fit(frequency=x, recency=t_x, T=T)

bgf.summary

Unnamed: 0,coef,se(coef),lower 95% bound,upper 95% bound
r,0.242593,0.012557,0.217981,0.267205
alpha,4.413532,0.378221,3.672218,5.154846
a,0.792886,0.185719,0.428877,1.156895
b,2.425752,0.705345,1.043276,3.808229


### Full Bayesian Model

In [26]:
with pm.Model() as model_full:
    
    a = pm.HalfNormal(name="a", sigma=10)
    b = pm.HalfNormal(name="b", sigma=10)
    
    alpha = pm.HalfNormal(name="alpha", sigma=10)
    r = pm.HalfNormal(name="r", sigma=10)
    
    lam = pm.Gamma(name="lam", alpha=r, beta=alpha, shape=n)
    p = pm.Beta(name="p", alpha=a, beta=b, shape=n)
    
    def logp(x, t_x, T, x_zero):
        log_term_a = x*at.log(1 - p) + x*at.log(lam) - t_x*lam
        term_b_1 = -lam*(T - t_x)
        term_b_2 = at.log(p) - at.log(1 - p)
        log_term_b = at.switch(
            x_zero,
            at.logsumexp(term_b_1, term_b_2),
            term_b_1,
        )
        
        return at.sum(log_term_a) + at.sum(log_term_b)
    
    likelihood = pm.DensityDist(
        name="likelihood",
        logp=lambda value: logp(x, t_x, T, x_zero),
    )

In [27]:
with model_full:
    trace = pm.sample(chains=1,)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...


ValueError: Length of Elemwise{sub,no_inplace}.0 cannot be determined

In [30]:
from aesara.tensor.random.op import RandomVariable

isinstance(pm.Normal.rv_op, RandomVariable)

True