Skip to content

Commit

Permalink
add moar tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Autoplectic committed Nov 8, 2018
1 parent 151fc3e commit a611599
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -450,16 +450,14 @@ class InnerTwoPartIntrinsicMutualInformation(BaseAuxVarOptimizer):

name = ""

def __init__(self, dist, measure, rvs=None, crvs=None, j=None, bound_u=None, bound_v=None, rv_mode=None):
def __init__(self, dist, rvs=None, crvs=None, j=None, bound_u=None, bound_v=None, rv_mode=None):
"""
Initialize the optimizer.
Parameters
----------
dist : Distribution
The distribution to compute the intrinsic mutual information of.
measure : func
The appropriate multivariate mutual information.
rvs : list, None
A list of lists. Each inner list specifies the indexes of the random
variables used to calculate the intrinsic mutual information. If
Expand Down Expand Up @@ -491,8 +489,6 @@ def __init__(self, dist, measure, rvs=None, crvs=None, j=None, bound_u=None, bou

super(InnerTwoPartIntrinsicMutualInformation, self).__init__(dist, rvs + [j], crvs, rv_mode=rv_mode)

self.measure = measure

theoretical_bound_u = prod(self._shape[rv] for rv in self._rvs)
bound_u = min([bound_u, theoretical_bound_u]) if bound_u else theoretical_bound_u

Expand Down Expand Up @@ -641,7 +637,6 @@ def objective(self, x):
dist = Distribution(outcomes, pmf)

inner = InnerTwoPartIntrinsicMutualInformation(dist=dist,
measure=self.measure,
rvs=[[rv] for rv in self._rvs],
crvs=self._crvs,
j=self._j,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, dist, rv_x=None, rv_y=None, rv_z=None, rounds=2, bound_func=N
self._z = self._crvs

@staticmethod
def bound(self, i, x, y):
def bound(i, x, y):
"""
"""
i += 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Tests for dit.multivariate.secret_key_agreement.interactive_intrinsic_mutual_information
"""

import pytest

from dit.example_dists import n_mod_m
from dit.multivariate.secret_key_agreement import interactive_intrinsic_mutual_information


def test_iimi1():
"""
Test against known value.
"""
iimi = interactive_intrinsic_mutual_information(n_mod_m(3, 2), rvs=[[0], [1]], crvs=[2], rounds=1)
assert iimi == pytest.approx(0.0)
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import pytest

from dit.example_dists import giant_bit, n_mod_m
from dit.exceptions import ditException
from dit.multivariate.secret_key_agreement import (
two_part_intrinsic_total_correlation,
two_part_intrinsic_dual_total_correlation,
two_part_intrinsic_CAEKL_mutual_information
)
from dit.multivariate.secret_key_agreement.base_skar_optimizers import InnerTwoPartIntrinsicMutualInformation


@pytest.mark.flaky(reruns=5)
Expand Down Expand Up @@ -46,3 +48,19 @@ def test_tpicaekl1(dist, value):
"""
tpicaekl = two_part_intrinsic_CAEKL_mutual_information(dist, [[0], [1]], [2], bound_j=2, bound_u=2, bound_v=2)
assert tpicaekl == pytest.approx(value)


def test_tpimi_fail1():
"""
Ensure an exception is raised if no conditional variable is supplied.
"""
with pytest.raises(ditException):
two_part_intrinsic_total_correlation(n_mod_m(3, 2), [[0], [1]])


def test_tpimi_fail2():
"""
Ensure an exception is raised if no conditional variable is supplied.
"""
with pytest.raises(ditException):
InnerTwoPartIntrinsicMutualInformation(n_mod_m(3, 2), [[0], [1]])
2 changes: 1 addition & 1 deletion dit/other/tests/test_disequilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_LMPR_complexity4(n):


@pytest.mark.parametrize('n', range(2, 11))
def test_LMPR_complexity4(n):
def test_LMPR_complexity5(n):
"""
Test that peaked Distributions have zero complexity.
"""
Expand Down
2 changes: 1 addition & 1 deletion dit/profiles/tests/test_entropy_triangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_et2_1(d, val):


@pytest.mark.parametrize('val', [(1, 0, 0), (0, 2/3, 1/3), (1/3, 1/3, 1/3), (0, 1/3, 2/3)])
def test_et_2(val):
def test_et2_2(val):
"""
Test EntropyTriangle against known values.
"""
Expand Down

0 comments on commit a611599

Please sign in to comment.