In [1]:
import numpy as np
import pymc3 as pm

from aesara import tensor as at
from scipy import stats as st

You are running the v4 development version of PyMC3 which currently still lacks key features. You probably want to use the stable v3 instead which you can either install via conda or find on the v3 GitHub branch: https://github.com/pymc-devs/pymc3/tree/v3


As recommended by Michael, all shapes consist of prime numbers

In [2]:
µ, σ = -0.5, 3 # true data-generating parameters

Xs = np.random.normal(loc=µ, scale=σ, size=[5,])
Xnew = np.array([-3, -1, 0.5, 3.2, 4]) # N' = 5

Xs = Xs[..., np.newaxis]
Xnew = Xnew[..., np.newaxis]
    
K = 19

In [3]:
st.norm.cdf(Xnew.reshape(-1,), loc=µ, scale=σ)

array([0.20232838, 0.43381617, 0.63055866, 0.89127429, 0.9331928 ])

In [4]:
with pm.Model() as model:
    N = Xs.shape[0]
    
    dirac = at.sum(at.ge(Xnew, Xs.T), axis=1) # shape = (N',)
    dirac = at.as_tensor_variable(dirac) # shape = (N',)
    
    base_dist = pm.Normal("G0", 0, 3, shape=(K, 1)) # K draws
    weights = pm.Dirichlet(
        name="sticks",
        a=np.ones(shape=(K,)),
    )
    
    empirical_base_cdf = at.le(base_dist, Xnew.T)
    empirical_base_cdf = at.sum(at.mul(empirical_base_cdf.T, weights), axis=1)
    
    posterior_dp = pm.Deterministic(
        name="posterior-dp",
        var=empirical_base_cdf/(1 + N) + dirac/(1 + N),
    )
    
    trace = pm.sample(
        draws=1000,
        chains=1,
    )

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [G0, sticks]


Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 11 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


In [5]:
trace.to_dict()["posterior"]["posterior-dp"][0]



array([[0.17990969, 0.38615559, 0.76698384, 0.80085845, 0.97027592],
       [0.17852485, 0.3495069 , 0.71265889, 0.80426872, 0.97093539],
       [0.19908908, 0.38885517, 0.7573241 , 0.78406105, 0.95072771],
       ...,
       [0.2090297 , 0.3906765 , 0.73080114, 0.82236342, 0.99212991],
       [0.20347505, 0.38909596, 0.75364982, 0.79642596, 0.97903922],
       [0.24127624, 0.41206544, 0.75196055, 0.78147412, 0.97720056]])

In [6]:
%load_ext watermark
%watermark -n -u -v -iv -w

Last updated: Thu Jul 08 2021

Python implementation: CPython
Python version       : 3.8.10
IPython version      : 7.25.0

pymc3 : 4.0
scipy : 1.7.0
numpy : 1.21.0
aesara: 2.0.12

Watermark: 2.2.0

