Skip to content

Commit

Permalink
add test_moments
Browse files Browse the repository at this point in the history
  • Loading branch information
itamarfaran committed Jul 2, 2021
1 parent 7297ef9 commit 2cea469
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions packages/back-end/src/python/test/test_dists.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest import TestCase
from functools import partial
from unittest import TestCase, main as unittest_main

import numpy as np
import pandas as pd
Expand All @@ -8,6 +9,7 @@


DECIMALS = 5
round_ = partial(np.round, decimals=DECIMALS)


def roundsum(x, decimals=DECIMALS):
Expand Down Expand Up @@ -41,7 +43,25 @@ def test_posterior(self):
pd.testing.assert_series_equal(res, out)

def test_moments(self):
self.fail()
pars = 12, 745
result = Beta.moments(*pars)
expected = beta.mean(*pars), beta.var(*pars)
for res, out in zip(result, expected):
self.assertEqual(round_(res), round_(out))

pars = 12, 745
result = Beta.moments(*pars, log=True)
mean = beta.expect(np.log, pars)
var = beta.expect(lambda x: np.log(x) ** 2, pars) - mean ** 2
expected = mean, var
for res, out in zip(result, expected):
self.assertEqual(round_(res), round_(out))

pars = np.array([12, 745]), np.array([745, 12])
result = Beta.moments(*pars)
expected = beta.mean(*pars), beta.var(*pars)
for res, out in zip(result, expected):
np.testing.assert_array_almost_equal(res, out)

def test_gq(self):
test_cases = zip([10, 100, 500, 1000, 10000],
Expand Down Expand Up @@ -80,7 +100,23 @@ def test_posterior(self):
pd.testing.assert_series_equal(res, out)

def test_moments(self):
self.fail()
pars = 10, 100
result = Norm.moments(*pars)
expected = norm.mean(*pars), norm.var(*pars)
for res, out in zip(result, expected):
self.assertEqual(round_(res), round_(out))

pars = 100, 10
result = Norm.moments(*pars, log=True)
expected = np.log(100), (10 / 100) ** 2
for res, out in zip(result, expected):
self.assertEqual(round_(res), round_(out))

pars = np.array([10, 100]), np.array([100, 10])
result = Norm.moments(*pars)
expected = norm.mean(*pars), norm.var(*pars)
for res, out in zip(result, expected):
np.testing.assert_array_almost_equal(res, out)

def test_gq(self):
test_cases = zip([0, -2, 2, 10],
Expand All @@ -89,3 +125,7 @@ def test_gq(self):
x, w = Norm.gq(24, loc, scale)
for p in range(8):
self.assertEqual(roundsum(x ** p * w), roundsum(norm.moment(p, loc, scale)))


if __name__ == '__main__':
unittest_main()

0 comments on commit 2cea469

Please sign in to comment.