In [1]:
import numpy as np
from random import shuffle
import matplotlib.pyplot as plt
%matplotlib inline

Create some sample data:

In [2]:
sample_size = 50
obs_cond = np.random.choice([0,1,2],[sample_size],p=[.7,.1,.2]) # cold = 0, flu = 1, allergies = 2

runny_nose_cold = np.random.choice([0,1],[sample_size],p=[.1,.9])
headache_cold = np.random.choice([0,1],[sample_size],p=[.5,.5])
fever_cold = np.random.choice([0,1],[sample_size],p=[.99,.01])

runny_nose_flu = np.random.choice([0,1],[sample_size],p=[.5,.5])
headache_flu = np.random.choice([0,1],[sample_size],p=[.5,.5])
fever_flu = np.random.choice([0,1],[sample_size],p=[.5,.5])

runny_nose_al = np.random.choice([0,1],[sample_size],p=[.1,.9])
headache_al = np.random.choice([0,1],[sample_size],p=[.99,.01])
fever_al = np.random.choice([0,1],[sample_size],p=[.99,.01])

runny_nose = np.stack([runny_nose_cold,runny_nose_flu,runny_nose_al])
headache = np.stack([headache_cold,headache_flu,headache_al])
fever = np.stack([fever_cold,fever_flu,fever_al])

In [3]:
obs_runny_nose = np.array([runny_nose[j,i] for i,j in enumerate(obs_cond)])
obs_headache = np.array([headache[j,i] for i,j in enumerate(obs_cond)])
obs_fever = np.array([fever[j,i] for i,j in enumerate(obs_cond)])

In [4]:
ddxs = []
for i,cond in enumerate(obs_cond):
    rem_cond = [j for j in range(3) if j != cond]
    shuffle(rem_cond)
    cur_ord = [cond]+rem_cond
    ddxs.append(cur_ord)

#print(obs_cond)
# obs_order = np.zeros([sample_size,3,3],dtype=np.int)
# trunc_obs_order = np.zeros([sample_size,(3*(3-1))/2],dtype=np.int)
# ddxs = []
# for i,cond in enumerate(obs_cond):
#     rem_cond = [j for j in range(3) if j != cond]
#     shuffle(rem_cond)
#     cur_ord = [cond]+rem_cond
#     #print(cur_ord)
#     for j,cond_j in enumerate(cur_ord):
#         for k,cond_k in enumerate(cur_ord):
#             if j<k:
#                 obs_order[i,cond_j,cond_k]=1
                
#     trunc_obs_order[i,0]=obs_order[i,1,2]
#     trunc_obs_order[i,1]=obs_order[i,2,0]
#     trunc_obs_order[i,2]=obs_order[i,0,1]
    
#     ddxs.append(cur_ord)
        

    #print(trunc_obs_order[i,:])


Hide some of the sample data:

In [5]:
msk_cond = np.ma.masked_where(np.random.rand(sample_size) > .9, obs_cond)
#msk_ord = np.ma.masked_where(np.random.rand(sample_size,3) > .5, trunc_obs_order)
msk_runny_nose = np.ma.masked_where(np.random.rand(sample_size) > .9, obs_runny_nose)
msk_headache = np.ma.masked_where(np.random.rand(sample_size) > .9, obs_headache)
msk_fever = np.ma.masked_where(np.random.rand(sample_size) > .9, obs_fever)

In [6]:
#msk_ord.shape

Code the probabilisitic model in pymc3:

In [7]:
msk_findings = np.ma.vstack([msk_runny_nose,msk_headache,msk_fever]).T
obs_findings = np.vstack([obs_runny_nose,obs_headache,obs_fever]).T

In [8]:
ddxs_ar = np.array(ddxs)

diff_tensor = np.tile(np.eye(3)[:, :, np.newaxis] - np.eye(3)[:, np.newaxis, :], [sample_size, 1, 1, 1])

reordered_diff_tensor = diff_tensor[np.arange(sample_size)[:, np.newaxis, np.newaxis, np.newaxis],
                                            np.arange(3)[np.newaxis, :, np.newaxis, np.newaxis],
                                            ddxs_ar[:, np.newaxis, :, np.newaxis],
                                            ddxs_ar[:, np.newaxis, np.newaxis, :]]

indicator_array = np.ma.ones([sample_size,3,3])
for i in range(0,3):
    for j in range(i):
        indicator_array[:,i,j] = 0
    indicator_array[:,i,i] = np.ma.masked

In [9]:
from copy import copy
msk_indicator_array = copy(indicator_array)
for n,k in enumerate(np.random.randint(low=0,high=4,size=sample_size)):
    for i in range(k,3):
        msk_indicator_array[n,i,:] = np.ma.masked

In [10]:
from pymc3 import Model, Categorical, Bernoulli, Normal, Laplace, Dirichlet, Uniform, find_MAP
import theano.tensor as tt
#from pymc3.math import sigmoid
import theano
theano.config.floatX = 'float32'

In [11]:
def invlogit(x):
    return 1. / (1 + tt.exp(-x))

In [12]:
def _build_graphical_model(dx_order_indicator_array_data, diff_tensor_data, findings_data):
    """
    Builds the graphical model and ties the observable variables to the data. You can pass a batch of 
    :class:`int` `size` data; with :class:`int` `num_dxs` possible diagnoses, :class:`int` `num_findings`
    possible findings, and :class:`int` `ddx_max_length` maximum diagnoses in a given differential.

    Parameters:
        dx_order_indicator_array_data (:class:`np.array`): should be of shape 
            [`size`,`ddx_max_length`,`num_dxs`].
        diff_tensor_data (:class:`np.array`): should be of shape 
            [`size`,`num_dxs`,`ddx_max_length`,`num_dxs`].
        findings_data (:class:`np.array`): should be of shape 
            [`size`,`num_findings`]. 

    Returns:
        4-:tuple: containing (x, dx_order, W, findings), where
            x (:class:`Normal`): are hidden variables
            dx_order (:class:`Bernoulli`): indicate the order of diagnoses in the differential
            W (:class:`Normal`): is a matrix of parameters relating diagnoses to findings
            findings (:class:`Bernoulli`): indicate the presence of findings

    """
    size = dx_order_indicator_array_data.shape[0]
    x = Normal("x", mu=0, sd=10, shape=[size, 3])
    dx_order_p = invlogit(tt.batched_tensordot(x, diff_tensor_data,axes=[1,1]))
    dx_order = Bernoulli("dx_order", p=dx_order_p, observed=dx_order_indicator_array_data)

    W = Normal('W', mu=0., sd=10., shape=[3, 3])

    findings = Bernoulli("findings", invlogit(tt.tensordot(x, W, axes=[1,0])), observed=findings_data)

    return (x, dx_order, W, findings)

In [13]:
with Model() as med_model:
    x, dx_order, W, findings = _build_graphical_model(msk_indicator_array,
                                                      reordered_diff_tensor,
                                                      msk_findings)
#     x = Normal('x',mu=0,sd=10,shape=[sample_size,3])
#     dx_order = Bernoulli('dx_order',p=invlogit(tt.dot(x,diff_ar)),observed=msk_ord)
    
    
#     W = Normal('W',mu=0.,sd=10.,shape=[3,3])
    
#     runny_nose = Bernoulli('runny_nose',invlogit(tt.dot(x,W[:,0])), observed=msk_runny_nose)
#     headache = Bernoulli('headache',invlogit(tt.dot(x,W[:,1])),observed=msk_headache)
#     fever = Bernoulli('fever',invlogit(tt.dot(x,W[:,2])),observed=msk_fever)
    
    



Fit the model (using a MAP estimate - really this is EM, but ok...):

In [14]:
map_estimate = find_MAP(model=med_model)

         Current function value: 730.304458
         Iterations: 34
         Function evaluations: 122
         Gradient evaluations: 112


ValueError: Optimization error: max, logp or dlogp at max have non-finite values. Some values may be outside of distribution support. max: {'x': array([[ 1.53048122, -8.59916592,  7.06272936],
       [ 7.82212019, -2.12785316, -5.67922688],
       [ 5.87954187, -3.02247024, -2.97464681],
       [ 6.77617025, -6.66719913,  0.04102489],
       [ 6.77617025, -6.66719913,  0.04102489],
       [ 6.77617025, -6.66719913,  0.04102489],
       [-0.41313726, -1.07215679,  0.24827646],
       [ 0.39875245,  0.16497053,  0.42914799],
       [ 5.87954187, -3.02247024, -2.97464681],
       [-3.78090692, -2.97595668,  4.39409399],
       [ 7.82212019, -2.12785316, -5.67922688],
       [-0.41313726, -1.07215679,  0.24827646],
       [-3.78048158, -7.04222298,  9.65276527],
       [-1.85337031,  2.59345913, -1.96480584],
       [ 3.11474895, -0.49887502, -0.78341556],
       [ 2.25297284,  1.64694715, -2.26725388],
       [ 0.65088832,  0.35075122,  0.03039607],
       [ 5.87954187, -3.02247024, -2.97464681],
       [-4.27728796, -3.77311182,  7.88271093],
       [ 0.39875245,  0.16497053,  0.42914799],
       [-0.03843694, -0.28859112, -0.68311083],
       [ 5.87954187, -3.02247024, -2.97464681],
       [ 7.82212019, -2.12785316, -5.67922688],
       [-4.27728796, -3.77311182,  7.88271093],
       [-2.8498106 ,  6.62837362, -2.87352037],
       [ 1.9705776 ,  1.07861543, -1.50921237],
       [ 7.56644249, -0.62464482, -6.93630075],
       [ 0.39875245,  0.16497053,  0.42914799],
       [ 7.56644249, -0.62464482, -6.93630075],
       [ 0.25534698,  0.12502252,  0.59280962],
       [-0.03843694, -0.28859112, -0.68311083],
       [ 5.87954187, -3.02247024, -2.97464681],
       [ 0.39875245,  0.16497053,  0.42914799],
       [ 7.56644249, -0.62464482, -6.93630075],
       [ 6.77617025, -6.66719913,  0.04102489],
       [ 7.82212019, -2.12785316, -5.67922688],
       [ 3.11474895, -0.49887502, -0.78341556],
       [ 5.87954187, -3.02247024, -2.97464681],
       [-0.41313726, -1.07215679,  0.24827646],
       [ 1.90335381, -0.67221081, -1.88308764],
       [-0.03843694, -0.28859112, -0.68311083],
       [-0.03843694, -0.28859112, -0.68311083],
       [-3.78090692, -2.97595668,  4.39409399],
       [ 6.77617025, -6.66719913,  0.04102489],
       [-0.03843694, -0.28859112, -0.68311083],
       [-3.78090692, -2.97595668,  4.39409399],
       [ 7.82212019, -2.12785316, -5.67922688],
       [ 5.87954187, -3.02247024, -2.97464681],
       [ 7.82212019, -2.12785316, -5.67922688],
       [ 5.87954187, -3.02247024, -2.97464681]], dtype=float32), 'dx_order_missing': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'W': array([[ -1.64382541,  -3.64445305, -11.9736948 ],
       [ -2.44264197,  13.44161797,  13.9010458 ],
       [ -2.56581664,  -9.85146523,  -1.32095528]], dtype=float32), 'findings_missing': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])} logp: array(-730.304457618121) dlogp: array([        nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan, -0.10174997,  0.74729097,
       -1.06768513,  0.71793592, -0.53917181,  0.47361004,         nan,
               nan,         nan,  0.81211948, -1.4007318 , -0.12420107,
               nan,         nan,         nan, -0.10174997,  0.74729097,
       -1.06768513,         nan,         nan,         nan, -0.30301723,
       -0.34271729, -0.32430875,  0.27669069,  0.13092834,  0.44604996,
        1.93630052, -1.87309182,  0.62629837, -0.08810818,  1.13282454,
       -0.30150583,         nan,         nan,         nan,         nan,
               nan,         nan,  0.71793592, -0.53917181,  0.47361004,
        0.05839998, -0.52402848, -0.06191105,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,  0.04242726, -0.02744102,  0.0809239 ,
        0.83330697, -1.18032491,  1.12265921,         nan,         nan,
               nan,  0.71793592, -0.53917181,  0.47361004,         nan,
               nan,         nan, -0.41759706,  0.93191588,  0.04676917,
        0.05839998, -0.52402848, -0.06191105,         nan,         nan,
               nan,  0.71793592, -0.53917181,  0.47361004,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,  0.27669069,  0.13092834,
        0.44604996,         nan,         nan,         nan, -0.10174997,
        0.74729097, -1.06768513, -0.14619614,  0.18891448, -0.26663321,
        0.05839998, -0.52402848, -0.06191105,  0.05839998, -0.52402848,
       -0.06191105,  0.81211948, -1.4007318 , -0.12420107,         nan,
               nan,         nan,  0.05839998, -0.52402848, -0.06191105,
        0.81211948, -1.4007318 , -0.12420107,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
       -0.90189821,         nan,         nan, -3.52708197,         nan,
               nan,  1.70310712,         nan,         nan], dtype=float32)Check that 1) you don't have hierarchical parameters, these will lead to points with infinite density. 2) your distribution logp's are properly specified. Specific issues: 


Compare the estimate to the generating parameters...

In [None]:
plt.hist(1./(1+np.exp(-map_estimate['x'].dot(map_estimate['W'])))
         -np.vstack([obs_runny_nose,obs_headache,obs_fever]).T)

In [None]:
comparison = 1./(1+np.exp(-np.einsum('ij,ijkl->ikl',map_estimate['x'],reordered_diff_tensor)))-indicator_array
plt.hist(comparison.reshape([-1]))
plt.show()

In [None]:
comparison