In [2]:
import copy
import unittest
from unittest import TestCase

from src.multiple_trees.compare_trees import get_distances_by_files
from src.single_tree.development_tree_reader import read_all_trees
from src.single_tree.global_params import GlobalParams

global_params = GlobalParams(max_level=11, param_a=0.6, g_weight=0.1, chain_length_weight=0.1)


class TestDistance(TestCase):
    # expected_matr = top-right matr (diagonal is excluded)
    def compare(self, path, is_reducing, expected_matr, g_weight=0.0):
        [_trees, matr] = get_distances_by_files(f"test/test_input/{path}",
                                                GlobalParams(max_level=11, is_test_nodes=True, g_weight=g_weight),
                                                is_reducing=is_reducing,
                                                is_test_nodes=True)
        for (i, expected_row) in enumerate(expected_matr):
            for (j, expected_value) in enumerate(expected_row):
                actual_value = matr[i][j + i + 1]
                self.assertAlmostEqual(expected_value, actual_value)

    def test_chain(self):
        # if reduce - all trees are the same
        # distance between chain_13 and chain_13_with_division_at_12 = 0
        distance_13_13_with_div_at_12 = 0
        self.compare("chains/test_chain*.xtg", True,  [[0, 0, 0],
                                                           [0, 0],
                                                              [distance_13_13_with_div_at_12]])

        # if NO reduce
        dist_level_10 = pow(0.5, 9)  # d(Leave, Growth)
        dist_level_11 = pow(0.5, 10) # d(Leave, Null)
        dist_chain_10_chain_11 = dist_level_10 + dist_level_11
        dist_chain_10_chain_13 = dist_chain_10_chain_11 # equal because of cut at level 11
        dist_chain_11_chain_13 = 0 # equal because of cut at level 11
        self.compare("chains/test_chain*.xtg", False, [[dist_chain_10_chain_11, dist_chain_10_chain_13, dist_chain_10_chain_13],
                                                                                [dist_chain_11_chain_13, dist_chain_11_chain_13],
                                                                                                         [0]])

    def test_growth_chain(self):
        # if growth > 0, then distance[0][1] > 0
        # then distance[1][2] tests producing growths in the chain
        d_growth_and_no_growth = pow(0.5, 1) * (256 - 1)
        self.compare("chains/test_*chain10.xtg", True,  [[d_growth_and_no_growth, d_growth_and_no_growth],
                                                                                   [0]],
                     g_weight=1.0)

        # if growth = 0, then distance[0][1] = 0
        self.compare("chains/test_*chain10.xtg", True,  [[0, 0],
                                                             [0]],
                     g_weight=0.0)

    def test_m2(self):
        # cannot reduce, because shown null nodes aren't completely null - they're shown for test purposes
        #self.compare("paper_m/M2_*.xtg", True, [[0]])
        self.compare("paper_m/M2_*.xtg", False, [[0]])

    def test_m3(self):
        self.compare("paper_m/M3_*.xtg", True,  [[0, 0],
                                                    [0]])
        self.compare("paper_m/M3_*.xtg", False, [[0, 0],
                                                    [0]])

    def test_m4(self):
        self.compare("paper_m/M4_*.xtg", True,  [[0.5, 2.0],
                                                      [2.0]])
        self.compare("paper_m/M4_*.xtg", False, [[0.5, 2.0],
                                                      [2.0]])

    def test_m5(self):
        self.compare("paper_m/M5_*.xtg", False, [[0.00, 0.50, 2.00],
                                                       [0.50, 2.00],
                                                             [2.00]])

    def test_sofa_reduce(self):
        self.compare("sofa/test_reduce*.xtg", True,  [[0.00]])
        self.compare("sofa/test_reduce*.xtg", False, [[1.00]])

    # def test_m6(self):
    #     self.compare("paper_m/M6_*.xtg", True,  [[1.25]])
    #     self.compare("paper_m/M6_*.xtg", False, [[1.25]])
    #
    # def test_patt(self):
    #     self.compare("patt_*.xtg", True,  [[0.75]])
    #     self.compare("patt_*.xtg", False, [[1.00]])

    def test_to_standard_form_growth(self):
        # read trees from *.xtg files in xtg folder
        src_trees = read_all_trees(pattern="test/test_input/test_standard_form_growth_*.xtg", is_test_nodes=True)

        # create a copy of trees to modify
        trees = [copy.deepcopy(src_tree) for src_tree in src_trees]

        self.assertEqual("XGN X XGN".replace(" ", ""), trees[0].to_string(3))
        self.assertEqual("XGN X XGN".replace(" ", ""), trees[1].to_string(3))

        trees[0].to_standard_form(3)
        trees[1].to_standard_form(3)

        self.assertEqual(trees[1].root.to_array(3), trees[0].root.to_array(3))  # equal after standartization
        self.assertNotEqual(src_trees[1].root.to_array(3), trees[1].root.to_array(3))  # changed

        self.assertEqual(src_trees[0].root.to_array(3), trees[0].root.to_array(3))  # not changed during standartization

    def test_to_standard_form_completion(self):
        # read trees from *.xtg files in xtg folder
        src_trees = read_all_trees(pattern="test/test_input/test_standard_form_compl_*.xtg", is_test_nodes=True)

        # create a copy of trees to modify
        trees = [copy.deepcopy(src_tree) for src_tree in src_trees]

        self.assertEqual("ZDG X NAN".replace(" ", ""), trees[0].to_string(3))
        self.assertEqual("NAN X GDZ".replace(" ", ""), trees[1].to_string(3))

        trees[0].to_standard_form(3)
        trees[1].to_standard_form(3)

        self.assertEqual("ZDG X NAN".replace(" ", ""), trees[0].to_string(3))

        self.assertEqual(trees[1].root.to_array(3), trees[0].root.to_array(3))  # equal after standartization
        self.assertNotEqual(src_trees[1].root.to_array(3), trees[1].root.to_array(3))  # changed

        self.assertEqual(src_trees[0].root.to_array(3), trees[0].root.to_array(3))  # not changed during standartization


unittest.main(argv=[''], verbosity=2, exit=False)

test_chain (__main__.TestDistance) ... ok
test_growth_chain (__main__.TestDistance) ... ERROR
test_m2 (__main__.TestDistance) ... ok
test_m3 (__main__.TestDistance) ... ok
test_m4 (__main__.TestDistance) ... ok
test_m5 (__main__.TestDistance) ... ok
test_sofa_reduce (__main__.TestDistance) ... FAIL
test_to_standard_form_completion (__main__.TestDistance) ... ok
test_to_standard_form_growth (__main__.TestDistance) ... ok

ERROR: test_growth_chain (__main__.TestDistance)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-2-5f9bdb57fab7>", line 48, in test_growth_chain
    g_weight=1.0)
  File "<ipython-input-2-5f9bdb57fab7>", line 21, in compare
    actual_value = matr[i][j + i + 1]
IndexError: list index out of range

FAIL: test_sofa_reduce (__main__.TestDistance)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-2-5f9bdb57fab7>", l

<unittest.main.TestProgram at 0x7f96f8b3fc90>