Skip to content

Commit

Permalink
n large for win
Browse files Browse the repository at this point in the history
  • Loading branch information
horta committed Apr 23, 2019
1 parent bfbec74 commit b2ef401
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
30 changes: 18 additions & 12 deletions doc/lmm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ B, C₀, and C₁.
>>> from glimix_core.lmm import Kron2Sum
>>>
>>> random = RandomState(0)
>>> n = 5
>>> n = 15
>>> p = 2
>>> c = 3
>>> Y = random.randn(n, p)
Expand All @@ -118,7 +118,7 @@ B, C₀, and C₁.
>>> mlmm = Kron2Sum(Y, A, X, G, restricted=False)
>>> mlmm.fit(verbose=False)
>>> mlmm.lml() # doctest: +FLOAT_CMP
-6.520026228479136
-42.08309530985927

We also provide :class:`.KronFastScanner` for performing an even faster
inference across several (millions, for example) covariates independently.
Expand Down Expand Up @@ -174,21 +174,27 @@ The parameters 𝚩ⱼ, 𝚨ⱼ, and sⱼ are found via maximum likelihood.
.. doctest::

>>> mscanner = mlmm.get_fast_scanner()
>>> A = random.randn(2, 5)
>>> X = random.randn(5, 3)
>>> A = random.randn(2, n)
>>> X = random.randn(n, 3)
>>> r = mscanner.scan(A, X)
>>> r["lml"] # doctest: +FLOAT_CMP
83.08864898305367
-42.74668875515792
>>> r["effsizes0"] # doctest: +FLOAT_CMP
array([[ 0.01482133, 0.45189275],
[ 0.43706748, -0.71162517],
[ 0.52595486, -1.59740035]])
array([[ 0.00277822, -0.01476164],
[-0.0005451 , 0.00290053],
[-0.00990904, 0.05266315]])
>>> r["effsizes1"] # doctest: +FLOAT_CMP
array([[ 0.03868156, -0.77199913, -0.09643554, -0.53973775, 1.03149564],
[ 0.05780863, -0.24744739, -0.11882984, -0.19331759, 0.74964805],
[ 0.01051071, -1.61751886, -0.0654883 , -1.09931899, 1.51034738]])
array([[-0.0127383 , -0.03796125, 0.04740337, -0.04064709, -0.03945676,
0.00239382, -0.01167387, 0.06761218, 0.04603321, -0.00731968,
-0.05142721, 0.03228656, -0.02494051, -0.06615618, 0.03947441],
[-0.03729972, -0.05256696, 0.01844337, 0.00221303, 0.01784714,
0.07925216, 0.03037916, -0.0247654 , 0.04081066, -0.0442502 ,
-0.01950785, 0.00037021, 0.0548026 , -0.03010318, -0.02419531],
[ 0.01722535, 0.0408958 , -0.04265975, 0.03336806, 0.02959425,
-0.01610653, 0.00428473, -0.05174828, -0.04550623, 0.01396271,
0.04619166, -0.02688382, 0.01095342, 0.06031332, -0.02847803]])
>>> r["scale"] # doctest: +FLOAT_CMP
5.238689482212067e-11
1.0452327204999015

API
===
Expand Down
16 changes: 8 additions & 8 deletions glimix_core/lmm/test/test_kron2sum_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def test_lmm_kron_scan():
random = RandomState(0)
n = 5
n = 20
Y = random.randn(n, 3)
A = random.randn(3, 3)
A = A @ A.T
Expand All @@ -37,7 +37,7 @@ def func(scale):
F1 = random.randn(n, 4)

r = scan.scan(A1, F1)
assert_allclose(r["scale"], 0.3000748879939645, rtol=1e-3)
assert_allclose(r["scale"], 0.7365021111700154, rtol=1e-3)

m = kron(A, F) @ vec(r["effsizes0"]) + kron(A1, F1) @ vec(r["effsizes1"])

Expand All @@ -50,13 +50,13 @@ def func(scale):
assert_allclose(r["lml"], st.multivariate_normal(m, s * K).logpdf(vec(Y)))

r = scan.scan(empty((3, 0)), F1)
assert_allclose(r["lml"], -10.96414417860732, rtol=1e-4)
assert_allclose(r["scale"], 0.5999931720566452, rtol=1e-3)
assert_allclose(r["lml"], -85.36667704747371, rtol=1e-4)
assert_allclose(r["scale"], 0.8999995537936586, rtol=1e-3)
assert_allclose(
r["effsizes0"],
[
[1.411082677273241, 0.41436234081257045, -1.5337251391408189],
[-0.6753042112998789, -0.20299590400182352, 0.6723874047807074],
[0.21489119796865844, 0.6412947101778663, -0.7176143380221816],
[0.8866722740598517, -0.18731140321348416, -0.26118052682069],
],
rtol=1e-2,
atol=1e-2,
Expand All @@ -66,7 +66,7 @@ def func(scale):

def test_lmm_kron_scan_with_lmm():
random = RandomState(0)
n = 5
n = 15
Y = random.randn(n, 3)
A = random.randn(3, 3)
A = A @ A.T
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_lmm_kron_scan_with_lmm():

def test_lmm_kron_scan_unrestricted():
random = RandomState(0)
n = 5
n = 15
Y = random.randn(n, 3)
A = random.randn(3, 3)
A = A @ A.T
Expand Down

0 comments on commit b2ef401

Please sign in to comment.