diff --git a/src/piqtree/iqtree/_model_finder.py b/src/piqtree/iqtree/_model_finder.py index a187983..afd14ab 100644 --- a/src/piqtree/iqtree/_model_finder.py +++ b/src/piqtree/iqtree/_model_finder.py @@ -101,7 +101,10 @@ def __post_init__(self, raw_data: dict[str, Any]) -> None: def to_rich_dict(self) -> dict[str, Any]: import piqtree - result = {"version": piqtree.__version__, "type": get_object_provenance(self)} + result: dict[str, Any] = { + "version": piqtree.__version__, + "type": get_object_provenance(self), + } raw_data = { str(model_): f"{stats.lnL} {stats.nfp} {stats.tree_length}" @@ -109,6 +112,7 @@ def to_rich_dict(self) -> dict[str, Any]: } for attr in ("best_model_AIC", "best_model_AICc", "best_model_BIC"): raw_data[attr] = str(getattr(self, attr.replace("_model", "").lower())) + result["init_kwargs"] = {"raw_data": raw_data, "source": self.source} return result diff --git a/tests/test_iqtree/test_fit_tree.py b/tests/test_iqtree/test_fit_tree.py index 1ef00cf..14fe653 100644 --- a/tests/test_iqtree/test_fit_tree.py +++ b/tests/test_iqtree/test_fit_tree.py @@ -1,3 +1,5 @@ +from typing import cast + import numpy as np import pytest from cogent3 import get_app, make_tree @@ -14,18 +16,18 @@ def check_likelihood(got: PhyloNode, expected: model_result) -> None: def check_motif_probs(got: PhyloNode, expected: PhyloNode) -> None: - expected = expected.params["mprobs"] - got = got.params["mprobs"] + expected_mprobs = expected.params["mprobs"] + got_mprobs = got.params["mprobs"] - expected_keys = set(expected.keys()) - got_keys = set(got.keys()) + expected_keys = set(expected_mprobs.keys()) + got_keys = set(got_mprobs.keys()) # Check that the base characters are the same assert expected_keys == got_keys # Check that the probs are the same - expected_values = [expected[key] for key in expected_keys] - got_values = [got[key] for key in expected_keys] + expected_values = [expected_mprobs[key] for key in expected_keys] + got_values = [got_mprobs[key] for key in expected_keys] assert all( got == pytest.approx(exp) for got, exp in zip(got_values, expected_values, strict=True) @@ -54,16 +56,16 @@ def check_rate_parameters(got: PhyloNode, expected: PhyloNode) -> None: def check_branch_lengths(got: PhyloNode, expected: PhyloNode) -> None: - got = got.tip_to_tip_distances() - expected = expected.tip_to_tip_distances() + got_dists = got.tip_to_tip_distances() + expected_dists = expected.tip_to_tip_distances() # make sure the distance matrices have the same name order # so we can just compare entire numpy arrays - expected = expected.take_dists(got.names) + expected_dists = expected_dists.take_dists(got_dists.names) # Check that the keys of branch lengths are the same - assert set(got.names) == set(expected.names) + assert set(got_dists.names) == set(expected_dists.names) # Check that the branch lengths are the same - np.testing.assert_allclose(got.array, expected.array, atol=1e-4) + np.testing.assert_allclose(got_dists.array, expected_dists.array, atol=1e-4) @pytest.mark.parametrize( @@ -178,7 +180,7 @@ def test_fit_tree_paramaterisation(three_otu: Alignment, model_str: str) -> None assert isinstance(tree.params["lnL"], float) for node in tree.preorder(include_self=False): - assert node.length > 0 + assert cast("float", node.length) > 0 def test_special_characters(three_otu: Alignment) -> None: @@ -194,6 +196,6 @@ def _renamer(before: str) -> str: assert isinstance(tree.params["lnL"], float) for node in tree.preorder(include_self=False): - assert node.length > 0 + assert cast("float", node.length) > 0 assert set(three_otu.names) == set(tree.get_tip_names()) diff --git a/tests/test_iqtree/test_nj_tree.py b/tests/test_iqtree/test_nj_tree.py index 33c1416..2685d7a 100644 --- a/tests/test_iqtree/test_nj_tree.py +++ b/tests/test_iqtree/test_nj_tree.py @@ -1,4 +1,5 @@ import re +from typing import cast import numpy as np import pytest @@ -23,11 +24,15 @@ def test_nj_tree_allow_negative(all_otu: Alignment) -> None: # check that all branch lengths are non-negative, by default tree1 = nj_tree(dists) - assert all(node.length >= 0 for node in tree1.preorder(include_self=False)) + assert all( + cast("float", node.length) >= 0 for node in tree1.preorder(include_self=False) + ) # check that some branch lengths are negative when allow_negative=True tree2 = nj_tree(dists, allow_negative=True) - assert any(node.length < 0 for node in tree2.preorder(include_self=False)) + assert any( + cast("float", node.length) < 0 for node in tree2.preorder(include_self=False) + ) def test_nj_tree_nan(four_otu: Alignment) -> None: diff --git a/tests/test_iqtree/test_tree_yaml.py b/tests/test_iqtree/test_tree_yaml.py index 47989d7..e10d542 100644 --- a/tests/test_iqtree/test_tree_yaml.py +++ b/tests/test_iqtree/test_tree_yaml.py @@ -277,5 +277,5 @@ def test_unrest_model( ) def test_tree_equal(candidate: str, expected: bool) -> None: tree = make_tree("((a:1.0,b:0.9),c:0.8);") - candidate = make_tree(candidate) - assert _tree_equal(tree, candidate) == expected + candidate_tree = make_tree(candidate) + assert _tree_equal(tree, candidate_tree) == expected