Skip to content

Commit

Permalink
TST: Add (failing) unittests for LBA.
Browse files Browse the repository at this point in the history
  • Loading branch information
twiecki committed Jan 14, 2014
1 parent fa4e949 commit f058ebf
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 30 deletions.
75 changes: 45 additions & 30 deletions hddm/tests/test_likelihoods.py
Expand Up @@ -263,38 +263,53 @@ def test_failure_mode(self):
st = 0.1


def test_lba():
try:
import rpy2
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr

except ImportError:
return
import hddm
import numpy as np

robjects.r("source('lba-math.r')")
for i in range(100):
x = np.random.randn() * 3
A = np.random.rand() * 2
b = A + np.random.rand() * 2
v1 = np.random.rand() * 2
v2 = np.random.rand() * 2
sv = np.random.rand() * 2 + .5
pdf = robjects.r("n1PDF") #t,x0max,chi,drift,sdI
if x > 0:
drifts = rpy2.robjects.vectors.FloatVector([v1, v2])
else:
drifts = rpy2.robjects.vectors.FloatVector([v2, v1])

r_result = pdf(abs(x), A, b, drifts, sv)
hddm_result = np.exp(hddm.lba.lba_like(np.array([x]), A, b, 0, sv, v1, v2))

np.testing.assert_almost_equal(r_result, hddm_result, 6,
"Parameters x=%f, A=%f, b=%f, v1=%f, v2=%f, sv=%f" %(x, A, b, v1, v2, sv))
class TestLBA(unittest.TestCase):
def test_lba(self):
try:
import rpy2
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr

except ImportError:
return
import hddm
import numpy as np

robjects.r("source('lba-math.r')")
for i in range(100):
x = np.random.randn() * 3
A = np.random.rand() * 2
b = A + np.random.rand() * 2
v1 = np.random.rand() * 2
v2 = np.random.rand() * 2
sv = np.random.rand() * 2 + .5
pdf = robjects.r("n1PDF") #t,x0max,chi,drift,sdI
if x > 0:
drifts = rpy2.robjects.vectors.FloatVector([v1, v2])
else:
drifts = rpy2.robjects.vectors.FloatVector([v2, v1])

r_result = pdf(abs(x), A, b, drifts, sv)
hddm_result = np.exp(hddm.lba.lba_like(np.array([x]), A, b, 0, sv, v1, v2))

np.testing.assert_almost_equal(r_result, hddm_result, 6,
"Parameters x=%f, A=%f, b=%f, v1=%f, v2=%f, sv=%f" %(x, A, b, v1, v2, sv))


@SkipTest
def test_pdf_integrate_to_one(self):
np.random.seed(123)
for tests in range(5):
x = np.random.randn() * 3
A = np.random.rand() * 2
b = A + np.random.rand() * 2
v1 = np.random.rand() * 2
v2 = np.random.rand() * 2
sv = np.random.rand() * 2 + .5
func = lambda x: np.exp(hddm.lba.lba_like(np.array([x]), A, b, 0, sv, v1, v2))
integ, error = sp.integrate.quad(func, a=-np.inf, b=np.inf)

np.testing.assert_almost_equal(integ, 1, 2)

if __name__=='__main__':
print "Run nosetest."
54 changes: 54 additions & 0 deletions hddm/tests/test_models.py
Expand Up @@ -348,5 +348,59 @@ def test_posterior_plots_breakdown():
ppc = hddm.utils.post_pred_gen(m, samples=10)
hddm.utils.post_pred_stats(data, ppc)


class TestRecovery(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestRecovery, self).__init__(*args, **kwargs)
self.iter = 2000
self.burn = 20
np.random.seed(1)

def runTest(self):
return

@SkipTest
def test_HLBA_flat(self):
params_true = {'A': .5, 'b': 1, 't': .3, 's': .5, 'v': [.5, .6, .7, .8]}
extended_params, merged_params = extend_params(params_true)

data, params_true = hddm.models.hlba_truncated.gen_rand_data(extended_params, size=100, subjs=1)
model = hddm.models.HLBA(data, depends_on={'v': 'condition'})
model.find_starting_values()
model.sample(self.iter, burn=self.burn)
model.gen_stats()
model.print_stats()
for param, true_val in merged_params.iteritems():
np.testing.assert_almost_equal(true_val, model.nodes_db.ix[param]['mean'])

@SkipTest
def test_HLBA_hierarchical(self):
params = hddm.models.hlba_truncated.gen_rand_params(cond_dict={'v': [.5, .6, .75, .8]})
data, params_true = hddm.models.hlba_truncated.gen_rand_data(params, size=100, subjs=10)
model = hddm.models.HLBA(data, depends_on={'v': 'condition'})
model.find_starting_values()
model.sample(self.iter, burn=self.burn)


def extend_params(params):
# Find list
extend_param = [param for param, val in params.iteritems() if isinstance(val, (list, tuple))]
if len(extend_param) > 1:
raise ValueError('Only one parameter can be extended')
extend_param = extend_param[0]

fixed_params = [param for param, val in params.iteritems() if not isinstance(val, (list, tuple))]

out_extended = {}
out_merged = {k: params[k] for k in fixed_params}
for i_cond, extend_val in enumerate(params[extend_param]):
cond_params = {k: params[k] for k in fixed_params}
cond_params[extend_param] = extend_val
out_extended['cond%i' % i_cond] = cond_params

out_merged['%s(cond%i)' % (extend_param, i_cond)] = extend_val

return out_extended, out_merged

if __name__=='__main__':
print "Run nosetest.py"

0 comments on commit f058ebf

Please sign in to comment.