Skip to content

Commit

Permalink
Add test for filling unobserved values only
Browse files Browse the repository at this point in the history
  • Loading branch information
victorkristof committed Aug 11, 2020
1 parent 0c3e22e commit f19eb54
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import pytest
import numpy as np
from numpy.testing import assert_almost_equal

from predikon import (Model, WeightedAveraging, MatrixFactorisation,
GaussianSubSVD, LogisticSubSVD, GaussianTensorSubSVD,
LogisticTensorSubSVD)


"""Setup methods"""


def susq(A):
# sum of squares
return np.sum(np.square(A))


def get_M_w_vec():
M = np.random.randn(3, 4)
w = np.abs(np.random.randn(3))
Expand All @@ -26,6 +30,7 @@ def get_M_w_mat():

"""Programmatical Tests"""


def test_observed_indexes_vec():
M, w = get_M_w_vec()
m = np.array([0, np.nan, 0])
Expand Down Expand Up @@ -93,8 +98,36 @@ def test_prediction_not_nan_mat_unreg():
assert not any(np.isnan(pred[-1]))


def test_prediction_fill_nan_only():
# Models = [GaussianSubSVD, LogisticSubSVD]
M, w = get_M_w_vec()
m = np.array([0.0, 0.3, np.nan])
w = np.array([2, 7, 2])

# Gaussian
model = GaussianSubSVD(M, w, n_dim=1, l2_reg=1e-5)
pred = model.fit_predict(m)
# assert not (np.isnan(pred[-1]))
# assert np.allclose(pred[:2], m[:2])
assert_almost_equal(pred[:2], m[:2])

# Bernoulli
model = LogisticSubSVD(M, w, n_dim=1, l2_reg=1e-5)
pred = model.fit_predict(m)
# assert not (np.isnan(pred[-1]))
# assert np.allclose(pred[:2], m[:2])
assert_almost_equal(pred[:2], m[:2])

# for MODEL in Models:
# model = MODEL(M, w, n_dim=1, l2_reg=1e-5)
# pred = model.fit_predict(m)
# assert not (np.isnan(pred[-1]))
# assert np.allclose(pred[:2], m[:2])


"""Methodological Tests"""


def test_averaging():
M, w = get_M_w_vec()
m = np.array([0.0, 0.3, np.nan])
Expand Down Expand Up @@ -150,6 +183,7 @@ def test_mf_converges():

"""Exception Tests"""


def test_unallowed_weighting_length():
with pytest.raises(ValueError, match=r".*dimension.*"):
M, w = get_M_w_vec()
Expand Down

0 comments on commit f19eb54

Please sign in to comment.