Skip to content

Commit

Permalink
nearly ready to release, lookg at doc
Browse files Browse the repository at this point in the history
  • Loading branch information
horta committed Mar 14, 2019
1 parent 0142cb5 commit 81ba312
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 9 deletions.
2 changes: 2 additions & 0 deletions glimix_core/_util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .eigen import economic_qs_zeros
from .format import format_function
from .hsolve import hsolve
from ._assert import assert_interface

log2pi = 1.837877066409345339081937709124758839607238769531250

Expand All @@ -17,4 +18,5 @@
"vec",
"unvec",
"format_function",
"assert_interface",
]
16 changes: 16 additions & 0 deletions glimix_core/_util/_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def assert_interface(cls, req_attrs):
attrs = dir(cls)
private = set(a for a in attrs if a.startswith("_"))
public = set(attrs) - private

missing = public - set(req_attrs)
if missing:
msg = "The following public attributes exist but have not been asserted: "
msg += "{}".format(", ".join(list(missing)))
raise AssertionError(msg)

missing = set(req_attrs) - public
if missing:
msg = "The following attributes have not been found: "
msg += "{}".format(", ".join(list(missing)))
raise AssertionError(msg)
2 changes: 1 addition & 1 deletion glimix_core/cov/kron2sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from optimix import Function

from .._util import format_function, unvec
from .free import FreeFormCov
from .lrfree import LRFreeFormCov
from .._util import format_function, unvec


class Kron2SumCov(Function):
Expand Down
5 changes: 3 additions & 2 deletions glimix_core/lmm/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,14 @@ def _yTQDiQTm(self):

@property
def X(self):
r"""Covariates set by the user.
"""
Covariates set by the user.
It has to be a matrix of number-of-samples by number-of-covariates.
Returns
-------
:class:`numpy.ndarray`
ndarray
Covariates.
"""
from numpy_sugar.linalg import ddot
Expand Down
6 changes: 3 additions & 3 deletions glimix_core/lmm/_rkron2sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def gradient(self):
"""
Gradient of the log of the marginal likelihood.
"""
return self.lml_gradient()
return self._lml_gradient()

@property
@lru_cache(maxsize=None)
Expand Down Expand Up @@ -367,7 +367,7 @@ def lml(self):

return lml / 2

def lml_gradient(self):
def _lml_gradient(self):
"""
Gradient of the log of the marginal likelihood.
Expand Down Expand Up @@ -546,7 +546,7 @@ def fit(self, verbose=True):
``True`` for progress output; ``False`` otherwise.
Defaults to ``True``.
"""
self._maximize(verbose=verbose)
self._maximize(verbose=verbose, pgtol=1e-6)


def _dot(a, b):
Expand Down
33 changes: 33 additions & 0 deletions glimix_core/lmm/test/test_lmm_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numpy.random import RandomState
from numpy.testing import assert_allclose

from glimix_core._util import assert_interface
from glimix_core.cov import EyeCov, LinearCov, SumCov
from glimix_core.lik import DeltaProdLik
from glimix_core.lmm import LMM
Expand Down Expand Up @@ -54,3 +55,35 @@ def test_lmm_lmm_prediction():
K = dot(X, X.T)
pm = lmm.predictive_mean(ones((n, 1)), K, K.diagonal())
assert_allclose(corrcoef(y, pm)[0, 1], 0.8358820971891354)


def test_lmm_lmm_public_attrs():
assert_interface(
LMM,
[
"lml",
"X",
"beta",
"delta",
"scale",
"mean_star",
"variance_star",
"covariance_star",
"covariance",
"predictive_covariance",
"mean",
"isfixed",
"fixed_effects_variance",
"gradient",
"v0",
"v1",
"fit",
"copy",
"value",
"get_fast_scanner",
"predictive_mean",
"name",
"unfix",
"fix",
],
)
56 changes: 53 additions & 3 deletions glimix_core/lmm/test/test_lmm_rkron2sum.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from numpy.random import RandomState
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal

from glimix_core._util import assert_interface
from glimix_core.lmm import RKron2Sum


def test_lmm_reml_rkron2sum():
def test_lmm_rkron2sum():
random = RandomState(0)
n = 5
Y = random.randn(n, 3)
Expand All @@ -15,8 +16,10 @@ def test_lmm_reml_rkron2sum():
lmm = RKron2Sum(Y, A, F, G)

assert_allclose(lmm.lml(), -16.580821931417656)
# assert_allclose(lmm._check_grad(step=1e-7), 0, atol=1e-3)
assert_allclose(lmm._check_grad(step=1e-7), 0, atol=1e-4)
assert_equal(lmm.nsamples, n)
assert_equal(lmm.ntraits, 3)
assert_equal(lmm.ncovariates, 2)

n = 5
Y = random.randn(n, 1)
Expand All @@ -25,6 +28,53 @@ def test_lmm_reml_rkron2sum():
F = random.randn(n, 2)
G = random.randn(n, 4)
lmm = RKron2Sum(Y, A, F, G)
lmm.name = "KronSum"

assert_allclose(lmm.lml(), -4.582089407009583)
assert_allclose(lmm._check_grad(step=1e-7), 0, atol=1e-4)
assert_allclose(
[lmm.mean.value()[0], lmm.mean.value()[1]],
[0.0497438970225256, 0.5890598193072355],
)

assert_allclose(
[
lmm.cov.value()[0, 0],
lmm.cov.value()[0, 1],
lmm.cov.value()[1, 0],
lmm.cov.value()[1, 1],
],
[
4.3712532668348185,
-0.07239366121399138,
-0.07239366121399138,
2.7242131674614862,
],
)

assert_equal(lmm.nsamples, n)
assert_equal(lmm.ntraits, 1)
assert_equal(lmm.name, "KronSum")
lmm.fit(verbose=False)
grad = lmm.gradient()
assert_allclose(grad["C0.Lu"], [0], atol=1e-5)
assert_allclose(grad["C1.Lu"], [0], atol=1e-5)
assert_allclose(lmm.lml(), -0.6930197958236421)


def test_lmm_rkron2sum_public_attrs():
assert_interface(
RKron2Sum,
[
"fit",
"ntraits",
"lml",
"mean",
"nsamples",
"value",
"cov",
"ncovariates",
"name",
"gradient",
],
)
5 changes: 5 additions & 0 deletions glimix_core/lmm/test/test_lmm_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numpy.random import RandomState
from numpy.testing import assert_allclose

from glimix_core._util import assert_interface
from glimix_core.cov import EyeCov, LinearCov, SumCov
from glimix_core.lik import DeltaProdLik
from glimix_core.lmm import LMM, FastScanner
Expand Down Expand Up @@ -429,3 +430,7 @@ def test_lmm_scan_interface():

with pytest.raises(ValueError):
FastScanner(y, X, QS, nan)


def test_lmm_scan_public_attrs():
assert_interface(FastScanner, ["unset_scale", "null_lml", "set_scale", "fast_scan"])

0 comments on commit 81ba312

Please sign in to comment.