In [1]:
import pymc3 as pm
import numpy as np
from pymc3.step_methods.hmc import quadpotential

  from ._conv import register_converters as _register_converters


In [43]:
n_chains = 4

with pm.Model() as m:
    x = pm.Normal('x', shape=10)
    trace1 = pm.sample(1000, tune=1000, cores=n_chains)

with m:
    cov = np.atleast_1d(pm.trace_cov(trace1))
    start = list(np.random.choice(trace1, n_chains))
    potential = quadpotential.QuadPotentialFull(cov)
    step_size = trace1.get_sampler_stats('step_size_bar')[-1]
    step_scale = step_size * (m.ndim ** 0.25)

with pm.Model() as m2:
    x = pm.Normal('x', shape=10)
    step = pm.NUTS(potential=potential, 
                   adapt_step_size=False, 
                   step_scale=step_scale)
    step.tune = False
    trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains, start=start)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [x]
Sampling 4 chains: 100%|██████████| 8000/8000 [00:01<00:00, 4871.39draws/s]
Only 100 samples in chain.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [x]
Sampling 4 chains: 100%|██████████| 400/400 [00:00<00:00, 1790.08draws/s]


In [44]:
for statname in trace1.stat_names:
    print(statname)
    print(trace1.get_sampler_stats(statname)[-10:])
    print(trace2.get_sampler_stats(statname)[-10:])
    print('\n')

mean_tree_accept
[0.89587444 0.91788297 1.         0.792088   0.81323281 0.95371789
 0.88839392 0.67369024 0.6651117  0.66842588]
[0.90408799 0.72660679 0.87747766 1.         0.71760797 0.63849806
 0.71510783 0.67571846 0.99262431 0.91960052]


tune
[False False False False False False False False False False]
[False False False False False False False False False False]


tree_size
[55. 15.  3. 11.  3.  7.  3.  3.  3.  7.]
[ 3. 15.  3.  3.  3.  3. 15. 35. 55. 11.]


energy
[20.99568768 22.80919304 19.75207315 19.62360533 19.63791659 16.84965487
 16.53761229 17.67888639 19.10123826 23.73517854]
[16.06908503 17.46064728 15.98365309 13.59640393 15.59073766 17.37116411
 21.7217679  20.70956196 19.52095084 20.16836644]


model_logp
[-16.31327237 -17.09536222 -14.11266778 -14.18745579 -13.67020211
 -13.72625956 -11.90427792 -12.60583753 -16.26384699 -16.21436448]
[-11.91270832 -12.49459893 -12.73302614 -11.36274526 -11.36274526
 -14.04997837 -13.08764464 -15.9212822  -15.25789445 -15.754902

In [82]:
n_chains = 4

with pm.Model() as m:
    x = pm.Normal('x', shape=10)
    # init == 'jitter+adapt_diag'
    start = []
    for _ in range(n_chains):
        mean = {var: val.copy() for var, val in m.test_point.items()}
        for val in mean.values():
            val[...] += 2 * np.random.rand(*val.shape) - 1
        start.append(mean)
    mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0)
    var = np.ones_like(mean)
    potential = quadpotential.QuadPotentialDiagAdapt(
        m.ndim, mean, var, 10)
    step = pm.NUTS(potential=potential)
    trace1 = pm.sample(1000, step=step, tune=1000, cores=n_chains)

with m: # need to be the same model
    step_size = trace1.get_sampler_stats('step_size_bar')[-1]
    from pymc3.step_methods import step_sizes
    step.tune = False
    step.step_adapt = step_sizes.DualAverageAdaptation(
            step_size, step.target_accept, 0.05, .75, 10
        )
    trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains)

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [x]
Sampling 4 chains: 100%|██████████| 8000/8000 [00:01<00:00, 4212.24draws/s]
Only 100 samples in chain.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [x]
Sampling 4 chains: 100%|██████████| 400/400 [00:00<00:00, 2352.17draws/s]


In [83]:
for statname in trace1.stat_names:
    print(statname)
    print(trace1.get_sampler_stats(statname)[-10:])
    print(trace2.get_sampler_stats(statname)[-10:])
    print('\n')

mean_tree_accept
[0.95838238 0.90337408 0.94474752 0.9286075  0.9078535  0.94806194
 0.93944191 1.         0.84998695 0.95804629]
[0.83527192 0.85464282 0.91902447 0.43155576 0.61987476 1.
 0.83740787 1.         0.96985489 0.51769917]


tune
[False False False False False False False False False False]
[False False False False False False False False False False]


tree_size
[ 7. 31.  7. 15. 15.  3. 11.  3. 11.  3.]
[ 3.  3. 63.  3.  7.  3. 63.  3. 63.  3.]


energy
[15.29930491 16.45092966 17.21767406 18.12509899 17.81795815 17.28109965
 16.80782001 16.03812175 16.22917819 14.32148612]
[18.85887842 16.42948418 12.82849541 18.38710381 25.07190653 23.13877454
 22.06830794 18.4932948  16.13431252 21.9525738 ]


model_logp
[-12.83171756 -13.65914299 -13.77937626 -13.67703766 -14.43413986
 -13.39781605 -14.29177167 -12.03995525 -12.91810537 -10.99933172]
[-13.6437218  -10.75863195 -10.91238854 -14.70549444 -19.56929197
 -15.21123197 -16.86906708 -12.91152229 -13.07063074 -14.6886793 ]


en