In [1]:
from src.models.pfa import PFAHMCSampler
from src.data import generate_data

import numpy as np

In [3]:
num_topic = 5
vocab_size = 100
n = 50
depend_prob = 1.
document = generate_data(num_topic, vocab_size, n, partial_depend=True)

In [5]:
hparams = { 
    'e0': 1.,
    'f0': 0.001,
    'c0': np.full((num_topic,), 0.5),
    'pn': np.full((n, 1), 0.5),
    'word_dist': np.full((num_topic, vocab_size), 10.),
    }

model = PFAHMCSampler(vocab_size, num_topic, hparams)
model.set_data(document)
model.set_model(n)

event = model.model.event_shape
sample = model.model.sample()

for k, v in model.model.batch_shape.items():
    print(f'{k}: batch: {v} | event: {event[k]}')
    print(f'---{sample[k].shape}\n')

gamma0: batch: () | event: ()
---()

gamma: batch: () | event: (5,)
---(5,)

theta: batch: (50,) | event: (5,)
---(50, 5)

phi: batch: () | event: (5, 100)
---(5, 100)

document: batch: (50,) | event: (100,)
---(50, 100)



In [6]:
n_states = 100
n_burnin = 3000
model.sample_states(document,
                  n_states=n_states,
                  n_burnin=n_burnin,
                  step_size=0.05)

Finish sampling...


In [7]:
draw_samples = model.predict()
pred_low, pred_high = np.quantile(draw_samples, [0.025, 0.975], axis=0)
is_cover = np.all((pred_low < document, pred_high > document), axis=0)

In [8]:
pred_low

array([[  6.   ,   3.   ,   7.   , ..., 476.425, 947.325, 169.   ],
       [  6.475,   3.475,   7.   , ..., 465.   , 963.475, 169.425],
       [  6.475,   2.   ,   6.475, ..., 474.95 , 964.95 , 162.95 ],
       ...,
       [  6.   ,   4.   ,   8.   , ..., 469.475, 952.9  , 165.9  ],
       [  7.   ,   3.475,   7.475, ..., 477.   , 965.   , 174.475],
       [  4.95 ,   4.   ,   7.   , ..., 471.9  , 972.95 , 172.425]])

In [9]:
pred_high

array([[  19.   ,   14.   ,   19.   , ...,  561.575, 1070.525,  229.1  ],
       [  19.525,   13.   ,   21.05 , ...,  566.2  , 1063.575,  217.525],
       [  18.525,   14.   ,   21.   , ...,  546.   , 1079.525,  220.   ],
       ...,
       [  18.525,   15.   ,   20.   , ...,  553.725, 1083.575,  216.05 ],
       [  17.525,   15.   ,   19.   , ...,  560.575, 1078.05 ,  224.   ],
       [  18.525,   15.   ,   20.   , ...,  558.   , 1086.575,  229.   ]])

In [11]:
document

<tf.Tensor: shape=(50, 100), dtype=float32, numpy=
array([[  0.,  11.,  26., ..., 532., 446., 458.],
       [ 20.,  10.,   1., ..., 609., 646., 685.],
       [  6.,  10.,   2., ..., 473., 397., 488.],
       ...,
       [ 25.,  21.,  23., ..., 799., 742., 773.],
       [  1.,  13.,  28., ..., 321., 300., 330.],
       [  6.,   7.,  18., ..., 288., 289., 312.]], dtype=float32)>

In [13]:
is_cover.mean()

0.193