In [1]:
import numpy as np
import pymc3 as pm
import linmix
import matplotlib.pyplot as plt
%matplotlib notebook

In [2]:
%matplotlib notebook

In [3]:
np.random.seed(4)
x = np.linspace(-10,10,50) + np.random.normal(scale=1,size=50)
x_er = abs(np.random.normal(scale=1,size=50))
y = np.linspace(-10,10,50) + np.random.normal(scale=1,size=50)
y_er = abs(np.random.normal(scale=1,size=50))

In [4]:
linear_model = pm.Model()
with linear_model:
    beta = pm.Uniform('beta',lower=-100,upper=100)
    alpha = pm.Uniform('alpha',lower=-100,upper=100)
    X = pm.Normal('X', mu=x, sd=x_er, shape=len(x))
    Y = pm.Normal('Y', mu=beta*X+alpha, sd=y_er, observed=y)

    step_method = pm.NUTS()
    trace1 = pm.sample(5000, chains=12, cores=12, tune=5000, step=step_method)

pm.traceplot(trace1,figsize=(9,3))
pm.summary(trace1,alpha=0.1)

Multiprocess sampling (12 chains in 12 jobs)
NUTS: [X, alpha, beta]
Sampling 12 chains: 100%|██████████| 120000/120000 [00:36<00:00, 3274.16draws/s]


<IPython.core.display.Javascript object>

Unnamed: 0,mean,sd,mc_error,hpd_5,hpd_95,n_eff,Rhat
X__0,-9.927895,0.203094,0.000938,-10.261552,-9.593755,38649.013995,1.000034
X__1,-9.578878,0.274598,0.001092,-10.031027,-9.127614,55621.887989,1.000061
X__2,-10.174831,0.037263,9.7e-05,-10.234395,-10.111863,148971.279844,0.999936
X__3,-8.083614,0.036083,8.4e-05,-8.142842,-8.024884,145772.17689,0.999941
X__4,-8.508681,0.702587,0.002026,-9.666124,-7.346071,145735.483214,0.999919
X__5,-9.484452,0.172274,0.000528,-9.770877,-9.203435,141509.270228,0.999931
X__6,-8.109806,0.718521,0.001851,-9.317903,-6.956263,131443.515748,0.999988
X__7,-7.612873,0.312267,0.000796,-8.126387,-7.098825,104548.796495,0.999967
X__8,-7.072494,0.203209,0.000731,-7.40903,-6.742541,60187.763887,0.999973
X__9,-7.1876,0.187063,0.000735,-7.496958,-6.883068,68655.281821,0.999967


In [5]:
lm = linmix.LinMix(x, y, x_er, y_er, nchains=12)
lm.run_mcmc(silent=True)

In [7]:
plt.figure()
plt.hist(lm.chain['beta'],label='linmix',edgecolor='w',alpha=0.5,density=True)
plt.hist(trace1['beta'],label='pymc',edgecolor='w',alpha=0.5,density=True)
plt.legend()

plt.figure()
plt.hist(lm.chain['alpha'],label='linmix',edgecolor='w',alpha=0.5,density=True)
plt.hist(trace1['alpha'],label='pymc',edgecolor='w',alpha=0.5,density=True)
plt.legend()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f36a08e5860>

In [16]:
plt.figure()
plt.errorbar(x, y, xerr=x_er, yerr=y_er, fmt='.k', alpha=0.5)
for i in range(0, len(lm.chain), 100):
    xs = np.arange(-10,11)
    ys = lm.chain[i]['alpha'] + xs * lm.chain[i]['beta']
    plt.plot(xs, ys, color='r', alpha=0.02)


<IPython.core.display.Javascript object>

In [9]:
plt.figure(figsize=(6,4))
plt.errorbar(x,y,y_er,fmt='.k',alpha=0.5,capsize=2,elinewidth=0.7)
#test_x = np.linspace(-20,20,100)
#pymc_y_best = test_x

#plt.plot(test_x,test_x*pm.summary(trace1,alpha=0.1)['mean']['slope'],'-r',lw=0.5)
#plt.plot(test_x,test_x*pm.summary(trace1,alpha=0.1)['hpd_5']['slope'],'--r',lw=0.5)
#plt.plot(test_x,test_x*pm.summary(trace1,alpha=0.1)['hpd_95']['slope'],'--r',lw=0.5)

#plt.plot(x,x,'-b',lw=1.0,alpha=0.5)

<IPython.core.display.Javascript object>

<ErrorbarContainer object of 3 artists>

In [8]:
lm.chain

array([(-0.62336626, 0.90603296, 7.50804922, [0.4094876 , 0.10875802, 0.48175438], [-0.44345617, -7.26016681,  5.39854707], [16.05191115, 21.34284047, 21.99138277],  3.39632082, 12.28417581, 25.55397589,  1.62958248, 6.05695745, 0.89467613),
       (-0.22862605, 0.96478468, 6.90281021, [0.39346247, 0.21326095, 0.39327659], [ 0.29246619, -5.79532005,  4.38995444], [18.66025753, 14.12612588, 16.53112233], -0.0389868 , 10.1624209 ,  8.71127569,  0.60562532, 5.59146426, 0.89904226),
       ( 0.51393624, 1.06683201, 4.48828282, [0.44448957, 0.23747931, 0.31803113], [-1.45960403, -7.37996304,  4.78225976], [19.98722782, 12.02507006, 11.00143034], -4.02222852,  8.69179348, 20.34272003, -0.88045983, 5.96806338, 0.94885083),
       ...,
       ( 0.89482907, 0.94290756, 3.34252135, [0.3033141 , 0.16742704, 0.52925886], [ 1.32154691,  3.5094838 , -2.72832851], [13.44133289, 22.34824517, 17.71197477],  4.47855925, 21.5618318 , 34.03457533, -0.45556573, 4.84941805, 0.92853023),
       ( 1.01224006,