Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/piqtree/iqtree/_model_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,18 @@ def __post_init__(self, raw_data: dict[str, Any]) -> None:
def to_rich_dict(self) -> dict[str, Any]:
import piqtree

result = {"version": piqtree.__version__, "type": get_object_provenance(self)}
result: dict[str, Any] = {
"version": piqtree.__version__,
"type": get_object_provenance(self),
}

raw_data = {
str(model_): f"{stats.lnL} {stats.nfp} {stats.tree_length}"
for model_, stats in self.model_stats.items()
}
for attr in ("best_model_AIC", "best_model_AICc", "best_model_BIC"):
raw_data[attr] = str(getattr(self, attr.replace("_model", "").lower()))

result["init_kwargs"] = {"raw_data": raw_data, "source": self.source}
return result

Expand Down
28 changes: 15 additions & 13 deletions tests/test_iqtree/test_fit_tree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

import numpy as np
import pytest
from cogent3 import get_app, make_tree
Expand All @@ -14,18 +16,18 @@ def check_likelihood(got: PhyloNode, expected: model_result) -> None:


def check_motif_probs(got: PhyloNode, expected: PhyloNode) -> None:
expected = expected.params["mprobs"]
got = got.params["mprobs"]
expected_mprobs = expected.params["mprobs"]
got_mprobs = got.params["mprobs"]

expected_keys = set(expected.keys())
got_keys = set(got.keys())
expected_keys = set(expected_mprobs.keys())
got_keys = set(got_mprobs.keys())

# Check that the base characters are the same
assert expected_keys == got_keys

# Check that the probs are the same
expected_values = [expected[key] for key in expected_keys]
got_values = [got[key] for key in expected_keys]
expected_values = [expected_mprobs[key] for key in expected_keys]
got_values = [got_mprobs[key] for key in expected_keys]
assert all(
got == pytest.approx(exp)
for got, exp in zip(got_values, expected_values, strict=True)
Expand Down Expand Up @@ -54,16 +56,16 @@ def check_rate_parameters(got: PhyloNode, expected: PhyloNode) -> None:


def check_branch_lengths(got: PhyloNode, expected: PhyloNode) -> None:
got = got.tip_to_tip_distances()
expected = expected.tip_to_tip_distances()
got_dists = got.tip_to_tip_distances()
expected_dists = expected.tip_to_tip_distances()
# make sure the distance matrices have the same name order
# so we can just compare entire numpy arrays
expected = expected.take_dists(got.names)
expected_dists = expected_dists.take_dists(got_dists.names)
# Check that the keys of branch lengths are the same
assert set(got.names) == set(expected.names)
assert set(got_dists.names) == set(expected_dists.names)

# Check that the branch lengths are the same
np.testing.assert_allclose(got.array, expected.array, atol=1e-4)
np.testing.assert_allclose(got_dists.array, expected_dists.array, atol=1e-4)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -178,7 +180,7 @@ def test_fit_tree_paramaterisation(three_otu: Alignment, model_str: str) -> None

assert isinstance(tree.params["lnL"], float)
for node in tree.preorder(include_self=False):
assert node.length > 0
assert cast("float", node.length) > 0


def test_special_characters(three_otu: Alignment) -> None:
Expand All @@ -194,6 +196,6 @@ def _renamer(before: str) -> str:

assert isinstance(tree.params["lnL"], float)
for node in tree.preorder(include_self=False):
assert node.length > 0
assert cast("float", node.length) > 0

assert set(three_otu.names) == set(tree.get_tip_names())
9 changes: 7 additions & 2 deletions tests/test_iqtree/test_nj_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from typing import cast

import numpy as np
import pytest
Expand All @@ -23,11 +24,15 @@ def test_nj_tree_allow_negative(all_otu: Alignment) -> None:

# check that all branch lengths are non-negative, by default
tree1 = nj_tree(dists)
assert all(node.length >= 0 for node in tree1.preorder(include_self=False))
assert all(
cast("float", node.length) >= 0 for node in tree1.preorder(include_self=False)
)

# check that some branch lengths are negative when allow_negative=True
tree2 = nj_tree(dists, allow_negative=True)
assert any(node.length < 0 for node in tree2.preorder(include_self=False))
assert any(
cast("float", node.length) < 0 for node in tree2.preorder(include_self=False)
)


def test_nj_tree_nan(four_otu: Alignment) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_iqtree/test_tree_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,5 @@ def test_unrest_model(
)
def test_tree_equal(candidate: str, expected: bool) -> None:
tree = make_tree("((a:1.0,b:0.9),c:0.8);")
candidate = make_tree(candidate)
assert _tree_equal(tree, candidate) == expected
candidate_tree = make_tree(candidate)
assert _tree_equal(tree, candidate_tree) == expected
Loading