diff --git a/src/piqtree/_app/__init__.py b/src/piqtree/_app/__init__.py index ec73cae..c1d9ada 100644 --- a/src/piqtree/_app/__init__.py +++ b/src/piqtree/_app/__init__.py @@ -2,9 +2,10 @@ from collections.abc import Iterable -import cogent3 -import cogent3.app.typing as c3_types from cogent3.app import composable +from cogent3.core.alignment import Alignment +from cogent3.core.tree import PhyloNode +from cogent3.evolve.fast_distance import DistanceMatrix from cogent3.util.misc import extend_docstring_from from piqtree import ( @@ -39,8 +40,8 @@ def __init__( def main( self, - aln: c3_types.AlignedSeqsType, - ) -> cogent3.PhyloNode | cogent3.app.typing.SerialisableType: + aln: Alignment, + ) -> PhyloNode: tree = build_tree( aln, self._model, @@ -57,7 +58,7 @@ class piq_fit_tree: @extend_docstring_from(fit_tree) def __init__( self, - tree: cogent3.PhyloNode, + tree: PhyloNode, model: Model | str, *, num_threads: int | None = None, @@ -70,8 +71,8 @@ def __init__( def main( self, - aln: c3_types.AlignedSeqsType, - ) -> cogent3.PhyloNode | cogent3.app.typing.SerialisableType: + aln: Alignment, + ) -> PhyloNode: tree = fit_tree( aln, self._tree, @@ -89,7 +90,7 @@ def piq_random_tree( num_taxa: int, tree_mode: TreeGenMode, rand_seed: int | None = None, -) -> cogent3.PhyloNode: +) -> PhyloNode: return random_tree(num_taxa, tree_mode, rand_seed) @@ -104,8 +105,8 @@ def __init__( def main( self, - aln: c3_types.AlignedSeqsType, - ) -> c3_types.PairwiseDistanceType | cogent3.app.typing.SerialisableType: + aln: Alignment, + ) -> DistanceMatrix: dists = jc_distances( aln, num_threads=self._num_threads, @@ -117,10 +118,10 @@ def main( @composable.define_app @extend_docstring_from(nj_tree) def piq_nj_tree( - dists: c3_types.PairwiseDistanceType, + dists: DistanceMatrix, *, allow_negative: bool = False, -) -> cogent3.PhyloNode: +) -> PhyloNode: tree = nj_tree(dists, allow_negative=allow_negative) tree.params |= {"provenance": "piqtree"} tree.source = getattr(dists, "source", None) @@ -130,18 +131,18 @@ def piq_nj_tree( @composable.define_app @extend_docstring_from(model_finder) def piq_model_finder( - aln: c3_types.AlignedSeqsType, -) -> ModelFinderResult | c3_types.SerialisableType: + aln: Alignment, +) -> ModelFinderResult: return model_finder(aln) @composable.define_app @extend_docstring_from(consensus_tree) def piq_consensus_tree( - trees: Iterable[cogent3.PhyloNode], + trees: Iterable[PhyloNode], *, min_support: float = 0.5, -) -> cogent3.PhyloNode: +) -> PhyloNode: return consensus_tree(trees, min_support=min_support) diff --git a/src/piqtree/iqtree/_jc_distance.py b/src/piqtree/iqtree/_jc_distance.py index 265a20b..99f5301 100644 --- a/src/piqtree/iqtree/_jc_distance.py +++ b/src/piqtree/iqtree/_jc_distance.py @@ -1,8 +1,8 @@ from collections.abc import Sequence -import cogent3.app.typing as c3_types import numpy as np from _piqtree import iq_jc_distances +from cogent3.core.alignment import Alignment from cogent3.evolve.fast_distance import DistanceMatrix from piqtree.iqtree._decorator import iqtree_func @@ -13,7 +13,7 @@ def _dists_to_distmatrix( distances: np.ndarray, names: Sequence[str], -) -> c3_types.PairwiseDistanceType: +) -> DistanceMatrix: """Convert numpy representation of distance matrix into cogent3 pairwise distance matrix. Parameters @@ -25,7 +25,7 @@ def _dists_to_distmatrix( Returns ------- - c3_types.PairwiseDistanceType + DistanceMatrix Pairwise distance matrix. """ @@ -37,21 +37,21 @@ def _dists_to_distmatrix( def jc_distances( - aln: c3_types.AlignedSeqsType, + aln: Alignment, num_threads: int | None = None, -) -> c3_types.PairwiseDistanceType: +) -> DistanceMatrix: """Compute pairwise JC distances for a given alignment. Parameters ---------- - aln : c3_types.AlignedSeqsType + aln : Alignment alignment to compute pairwise JC distances for. num_threads: int | None, optional Number of threads for IQ-TREE to use, by default None (all available threads). Returns ------- - c3_types.PairwiseDistanceType + DistanceMatrix Pairwise JC distance matrix. """ diff --git a/src/piqtree/iqtree/_model_finder.py b/src/piqtree/iqtree/_model_finder.py index 516f9cd..e46c9c0 100644 --- a/src/piqtree/iqtree/_model_finder.py +++ b/src/piqtree/iqtree/_model_finder.py @@ -2,11 +2,11 @@ import dataclasses from collections.abc import Iterable -from typing import Any +from typing import Any, cast import yaml from _piqtree import iq_model_finder -from cogent3.app import typing as c3_types +from cogent3.core.alignment import Alignment from cogent3.util.misc import get_object_provenance from piqtree.iqtree._decorator import iqtree_func @@ -117,18 +117,18 @@ def from_rich_dict(cls, data: dict[str, Any]) -> "ModelFinderResult": def model_finder( - aln: c3_types.AlignedSeqsType, + aln: Alignment, model_set: Iterable[str] | None = None, freq_set: Iterable[str] | None = None, rate_set: Iterable[str] | None = None, rand_seed: int | None = None, num_threads: int | None = None, -) -> ModelFinderResult | c3_types.SerialisableType: +) -> ModelFinderResult: """Find the models of best fit for an alignment using ModelFinder. Parameters ---------- - aln : c3_types.AlignedSeqsType + aln : Alignment The alignment to find the model of best fit for. model_set : Iterable[str] | None, optional Search space for models. @@ -147,11 +147,11 @@ def model_finder( Returns ------- - ModelFinderResult | c3_types.SerialisableType + ModelFinderResult Collection of data returned from IQ-TREE's ModelFinder. """ - source = aln.info.source + source = cast("str", aln.info.source) if rand_seed is None: rand_seed = 0 # The default rand_seed in IQ-TREE diff --git a/src/piqtree/iqtree/_parse_tree_parameters.py b/src/piqtree/iqtree/_parse_tree_parameters.py new file mode 100644 index 0000000..5b78bd2 --- /dev/null +++ b/src/piqtree/iqtree/_parse_tree_parameters.py @@ -0,0 +1,195 @@ +from typing import Any + +from cogent3.core.tree import PhyloNode + +from piqtree.exceptions import ParseIqTreeError +from piqtree.model import Model, StandardDnaModel + +# the order defined in IQ-TREE +# assume UNREST model has 12 rates, GTR and simpler models always have 6 rates present +RATE_PARS = "A/C", "A/G", "A/T", "C/G", "C/T", "G/T" +RATE_PARS_UNREST = ( + "A/C", + "A/G", + "A/T", + "C/A", + "C/G", + "C/T", + "G/A", + "G/C", + "G/T", + "T/A", + "T/C", + "T/G", +) +MOTIF_PARS = "A", "C", "G", "T" + + +def _insert_edge_pars(tree: PhyloNode, **kwargs: dict) -> None: + # inserts the edge parameters into each edge to match the structure of + # PhyloNode + for node in tree.get_edge_vector(): + # skip the rate parameters when node is the root + if node.is_root(): + kwargs = {k: v for k, v in kwargs.items() if k == "mprobs"} + del node.params["edge_pars"] + node.params.update(kwargs) + + +def _edge_pars_for_cogent3(tree: PhyloNode, model: Model) -> None: + base_model = model.submod_type.base_model + + rate_pars = tree.params["edge_pars"]["rates"] + motif_pars = {"mprobs": tree.params["edge_pars"]["mprobs"]} + # renames parameters to conform to cogent3's naming conventions + if base_model in {StandardDnaModel.JC, StandardDnaModel.F81}: + # skip rate_pars since rate parameters are constant in JC and F81 + _insert_edge_pars( + tree, + **motif_pars, + ) + return + if base_model in {StandardDnaModel.K80, StandardDnaModel.HKY}: + rate_pars = {"kappa": rate_pars["A/G"]} + + elif base_model is StandardDnaModel.TN: + rate_pars = {"kappa_r": rate_pars["A/G"], "kappa_y": rate_pars["C/T"]} + + elif base_model is StandardDnaModel.GTR: + del rate_pars["G/T"] + + # applies global rate parameters to each edge + _insert_edge_pars( + tree, + **rate_pars, + **motif_pars, + ) + + +def _parse_nonlie_model(tree: PhyloNode, tree_yaml: dict) -> None: + # parse motif and rate parameters in the tree_yaml for non-Lie DnaModel + model_fits = tree_yaml.get("ModelDNA", {}) + + state_freq_str = model_fits.get("state_freq", "") + rate_str = model_fits.get("rates", "") + + # parse motif parameters, assign each to a name, and raise an error if not found + if state_freq_str: + # converts the strings of motif parameters into dictionary + state_freq_list = [ + float(value) for value in state_freq_str.replace(" ", "").split(",") + ] + tree.params["edge_pars"] = { + "mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)), + } + else: + msg = "IQ-TREE output malformated, motif parameters not found." + raise ParseIqTreeError(msg) + + # parse rate parameters, assign each to a name, and raise an error if not found + if rate_str: + rate_list = [float(value) for value in rate_str.replace(" ", "").split(",")] + tree.params["edge_pars"]["rates"] = dict( + zip(RATE_PARS, rate_list, strict=True), + ) + else: + msg = "IQ-TREE output malformated, rate parameters not found." + raise ParseIqTreeError(msg) + + +def _parse_lie_model( + tree: PhyloNode, + tree_yaml: dict, + lie_model_name: str, +) -> None: + # parse motif and rate parameters in the tree_yaml for Lie DnaModel + model_fits = tree_yaml.get(lie_model_name, {}) + + # parse motif parameters, assign each to a name, and raise an error if not found + state_freq_str = model_fits.get("state_freq", "") + if state_freq_str: + state_freq_list = [ + float(value) for value in state_freq_str.replace(" ", "").split(",") + ] + tree.params[lie_model_name] = { + "mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)), + } + else: + msg = "IQ-TREE output malformated, motif parameters not found." + raise ParseIqTreeError(msg) + + # parse rate parameters, skipping LIE_1_1 (aka JC69) since its rate parameter is constant thus absent + if "model_parameters" in model_fits: + model_parameters = model_fits["model_parameters"] + + # convert model parameters to a list of floats if they are a string + if isinstance(model_parameters, str): + tree.params[lie_model_name]["model_parameters"] = [ + float(value) for value in model_parameters.replace(" ", "").split(",") + ] + else: + # directly use the float + tree.params[lie_model_name]["model_parameters"] = model_parameters + + +def _parse_unrest_model(tree: PhyloNode, tree_yaml: dict) -> None: + model_fits = tree_yaml.get("ModelUnrest", {}) + + state_freq_str = model_fits.get("state_freq", "") + rate_str = model_fits.get("rates", "") + + # parse state frequencies + if state_freq_str: + state_freq_list = [ + float(value) for value in state_freq_str.replace(" ", "").split(",") + ] + tree.params["edge_pars"] = { + "mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)), + } + else: + msg = "IQ-TREE output malformated, motif parameters not found." + raise ParseIqTreeError(msg) + + # parse rates + if rate_str: + rate_list = [float(value) for value in rate_str.replace(" ", "").split(",")] + tree.params["edge_pars"]["rates"] = dict( + zip(RATE_PARS_UNREST, rate_list, strict=True), + ) + else: + msg = "IQ-TREE output malformated, rate parameters not found." + raise ParseIqTreeError(msg) + + +def parse_model_parameters( + tree: PhyloNode, + tree_yaml: dict[str, Any], + model: Model, +) -> None: + """Parse model parameters from the returned yaml format. + + Parameters + ---------- + tree : PhyloNode + The tree to attach model parameters to. + tree_yaml : dict[str, Any] + The yaml result returned from IQ-TREE. + """ + # parse non-Lie DnaModel parameters + if "ModelDNA" in tree_yaml: + _parse_nonlie_model(tree, tree_yaml) + + elif "ModelUnrest" in tree_yaml: + _parse_unrest_model(tree, tree_yaml) + + # parse Lie DnaModel parameters, handling various Lie model names + elif key := next( + (key for key in tree_yaml if key.startswith("ModelLieMarkov")), + None, + ): + _parse_lie_model(tree, tree_yaml, key) + + # for non-Lie models, populate parameters to each branch and + # rename them to mimic PhyloNode + if "edge_pars" in tree.params: + _edge_pars_for_cogent3(tree, model) diff --git a/src/piqtree/iqtree/_random_tree.py b/src/piqtree/iqtree/_random_tree.py index 886556a..6c57f67 100644 --- a/src/piqtree/iqtree/_random_tree.py +++ b/src/piqtree/iqtree/_random_tree.py @@ -2,8 +2,9 @@ from enum import Enum, auto -import cogent3 from _piqtree import iq_random_tree +from cogent3 import make_tree +from cogent3.core.tree import PhyloNode from piqtree.iqtree._decorator import iqtree_func @@ -25,7 +26,7 @@ def random_tree( num_taxa: int, tree_mode: TreeGenMode, rand_seed: int | None = None, -) -> cogent3.PhyloNode: +) -> PhyloNode: """Generate a random phylogenetic tree. Generates a random tree through IQ-TREE. @@ -41,7 +42,7 @@ def random_tree( Returns ------- - cogent3.PhyloNode + PhyloNode A random phylogenetic tree. """ @@ -49,4 +50,4 @@ def random_tree( rand_seed = 0 # The default rand_seed in IQ-TREE newick = iq_random_tree(num_taxa, tree_mode.name, 1, rand_seed).strip() - return cogent3.make_tree(newick) + return make_tree(newick) diff --git a/src/piqtree/iqtree/_robinson_foulds.py b/src/piqtree/iqtree/_robinson_foulds.py index a93fa16..ddab4f9 100644 --- a/src/piqtree/iqtree/_robinson_foulds.py +++ b/src/piqtree/iqtree/_robinson_foulds.py @@ -2,9 +2,9 @@ from collections.abc import Sequence -import cogent3 import numpy as np from _piqtree import iq_robinson_fould +from cogent3.core.tree import PhyloNode from piqtree.iqtree._decorator import iqtree_func from piqtree.util import get_newick @@ -12,7 +12,7 @@ iq_robinson_fould = iqtree_func(iq_robinson_fould) -def robinson_foulds(trees: Sequence[cogent3.PhyloNode]) -> np.ndarray: +def robinson_foulds(trees: Sequence[PhyloNode]) -> np.ndarray: """Pairwise Robinson-Foulds distance between a sequence of trees. For the given collection of trees, returns a numpy array containing @@ -20,7 +20,7 @@ def robinson_foulds(trees: Sequence[cogent3.PhyloNode]) -> np.ndarray: Parameters ---------- - trees : Sequence[cogent3.PhyloNode] + trees : Sequence[PhyloNode] The sequence of trees to calculate the pairwise Robinson-Foulds distances of. diff --git a/src/piqtree/iqtree/_tree.py b/src/piqtree/iqtree/_tree.py index 2b5256a..eaa0899 100644 --- a/src/piqtree/iqtree/_tree.py +++ b/src/piqtree/iqtree/_tree.py @@ -1,17 +1,20 @@ """Python wrappers to tree searching functions in the IQ-TREE library.""" from collections.abc import Iterable, Sequence +from typing import Any, cast -import cogent3 -import cogent3.app.typing as c3_types import numpy as np import yaml from _piqtree import iq_build_tree, iq_consensus_tree, iq_fit_tree, iq_nj_tree -from cogent3 import PhyloNode, make_tree +from cogent3 import make_tree +from cogent3.core.alignment import Alignment +from cogent3.core.tree import PhyloNode +from cogent3.evolve.fast_distance import DistanceMatrix from piqtree.exceptions import ParseIqTreeError from piqtree.iqtree._decorator import iqtree_func -from piqtree.model import Model, StandardDnaModel, make_model +from piqtree.iqtree._parse_tree_parameters import parse_model_parameters +from piqtree.model import Model, make_model from piqtree.util import get_newick iq_build_tree = iqtree_func(iq_build_tree, hide_files=True) @@ -19,167 +22,12 @@ iq_nj_tree = iqtree_func(iq_nj_tree, hide_files=True) iq_consensus_tree = iqtree_func(iq_consensus_tree, hide_files=True) -# the order defined in IQ-TREE -# assume UNREST model has 12 rates, GTR and simpler models always have 6 rates present -RATE_PARS = "A/C", "A/G", "A/T", "C/G", "C/T", "G/T" -RATE_PARS_UNREST = ( - "A/C", - "A/G", - "A/T", - "C/A", - "C/G", - "C/T", - "G/A", - "G/C", - "G/T", - "T/A", - "T/C", - "T/G", -) -MOTIF_PARS = "A", "C", "G", "T" - - -def _rename_iq_tree(tree: cogent3.PhyloNode, names: Sequence[str]) -> None: + +def _rename_iq_tree(tree: PhyloNode, names: Sequence[str]) -> None: for tip in tree.tips(): tip.name = names[int(tip.name)] -def _insert_edge_pars(tree: cogent3.PhyloNode, **kwargs: dict) -> None: - # inserts the edge parameters into each edge to match the structure of - # cogent3.PhyloNode - for node in tree.get_edge_vector(): - # skip the rate parameters when node is the root - if node.is_root(): - kwargs = {k: v for k, v in kwargs.items() if k == "mprobs"} - del node.params["edge_pars"] - node.params.update(kwargs) - - -def _edge_pars_for_cogent3(tree: cogent3.PhyloNode, model: Model) -> None: - base_model = model.submod_type.base_model - - rate_pars = tree.params["edge_pars"]["rates"] - motif_pars = {"mprobs": tree.params["edge_pars"]["mprobs"]} - # renames parameters to conform to cogent3's naming conventions - if base_model in {StandardDnaModel.JC, StandardDnaModel.F81}: - # skip rate_pars since rate parameters are constant in JC and F81 - _insert_edge_pars( - tree, - **motif_pars, - ) - return - if base_model in {StandardDnaModel.K80, StandardDnaModel.HKY}: - rate_pars = {"kappa": rate_pars["A/G"]} - - elif base_model is StandardDnaModel.TN: - rate_pars = {"kappa_r": rate_pars["A/G"], "kappa_y": rate_pars["C/T"]} - - elif base_model is StandardDnaModel.GTR: - del rate_pars["G/T"] - - # applies global rate parameters to each edge - _insert_edge_pars( - tree, - **rate_pars, - **motif_pars, - ) - - -def _parse_nonlie_model(tree: cogent3.PhyloNode, tree_yaml: dict) -> None: - # parse motif and rate parameters in the tree_yaml for non-Lie DnaModel - model_fits = tree_yaml.get("ModelDNA", {}) - - state_freq_str = model_fits.get("state_freq", "") - rate_str = model_fits.get("rates", "") - - # parse motif parameters, assign each to a name, and raise an error if not found - if state_freq_str: - # converts the strings of motif parameters into dictionary - state_freq_list = [ - float(value) for value in state_freq_str.replace(" ", "").split(",") - ] - tree.params["edge_pars"] = { - "mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)), - } - else: - msg = "IQ-TREE output malformated, motif parameters not found." - raise ParseIqTreeError(msg) - - # parse rate parameters, assign each to a name, and raise an error if not found - if rate_str: - rate_list = [float(value) for value in rate_str.replace(" ", "").split(",")] - tree.params["edge_pars"]["rates"] = dict( - zip(RATE_PARS, rate_list, strict=True), - ) - else: - msg = "IQ-TREE output malformated, rate parameters not found." - raise ParseIqTreeError(msg) - - -def _parse_lie_model( - tree: cogent3.PhyloNode, - tree_yaml: dict, - lie_model_name: str, -) -> None: - # parse motif and rate parameters in the tree_yaml for Lie DnaModel - model_fits = tree_yaml.get(lie_model_name, {}) - - # parse motif parameters, assign each to a name, and raise an error if not found - state_freq_str = model_fits.get("state_freq", "") - if state_freq_str: - state_freq_list = [ - float(value) for value in state_freq_str.replace(" ", "").split(",") - ] - tree.params[lie_model_name] = { - "mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)), - } - else: - msg = "IQ-TREE output malformated, motif parameters not found." - raise ParseIqTreeError(msg) - - # parse rate parameters, skipping LIE_1_1 (aka JC69) since its rate parameter is constant thus absent - if "model_parameters" in model_fits: - model_parameters = model_fits["model_parameters"] - - # convert model parameters to a list of floats if they are a string - if isinstance(model_parameters, str): - tree.params[lie_model_name]["model_parameters"] = [ - float(value) for value in model_parameters.replace(" ", "").split(",") - ] - else: - # directly use the float - tree.params[lie_model_name]["model_parameters"] = model_parameters - - -def _parse_unrest_model(tree: cogent3.PhyloNode, tree_yaml: dict) -> None: - model_fits = tree_yaml.get("ModelUnrest", {}) - - state_freq_str = model_fits.get("state_freq", "") - rate_str = model_fits.get("rates", "") - - # parse state frequencies - if state_freq_str: - state_freq_list = [ - float(value) for value in state_freq_str.replace(" ", "").split(",") - ] - tree.params["edge_pars"] = { - "mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)), - } - else: - msg = "IQ-TREE output malformated, motif parameters not found." - raise ParseIqTreeError(msg) - - # parse rates - if rate_str: - rate_list = [float(value) for value in rate_str.replace(" ", "").split(",")] - tree.params["edge_pars"]["rates"] = dict( - zip(RATE_PARS_UNREST, rate_list, strict=True), - ) - else: - msg = "IQ-TREE output malformated, rate parameters not found." - raise ParseIqTreeError(msg) - - def _tree_equal(node1: PhyloNode, node2: PhyloNode) -> bool: children_group1 = node1.children children_group2 = node2.children @@ -199,17 +47,18 @@ def _tree_equal(node1: PhyloNode, node2: PhyloNode) -> bool: def _process_tree_yaml( - tree_yaml: dict, + tree_yaml: dict[str, Any], names: Sequence[str], -) -> cogent3.PhyloNode: + model: Model, +) -> PhyloNode: newick = tree_yaml["PhyloTree"]["newick"] - tree = cogent3.make_tree(newick) + tree = make_tree(newick) candidates = tree_yaml["CandidateSet"] likelihood = None for candidate in candidates.values(): candidate_likelihood, candidate_newick = candidate.split(" ") - candidate_tree = cogent3.make_tree(candidate_newick) + candidate_tree = make_tree(candidate_newick) if _tree_equal(candidate_tree, tree): likelihood = float(candidate_likelihood) break @@ -219,19 +68,7 @@ def _process_tree_yaml( tree.params["lnL"] = likelihood - # parse non-Lie DnaModel parameters - if "ModelDNA" in tree_yaml: - _parse_nonlie_model(tree, tree_yaml) - - elif "ModelUnrest" in tree_yaml: - _parse_unrest_model(tree, tree_yaml) - - # parse Lie DnaModel parameters, handling various Lie model names - elif key := next( - (key for key in tree_yaml if key.startswith("ModelLieMarkov")), - None, - ): - _parse_lie_model(tree, tree_yaml, key) + parse_model_parameters(tree, tree_yaml, model) # parse rate model, handling various rate model names if key := next((key for key in tree_yaml if key.startswith("Rate")), None): @@ -245,19 +82,19 @@ def _process_tree_yaml( def build_tree( - aln: c3_types.AlignedSeqsType, + aln: Alignment, model: Model | str, rand_seed: int | None = None, bootstrap_replicates: int | None = None, num_threads: int | None = None, -) -> cogent3.PhyloNode: +) -> PhyloNode: """Reconstruct a phylogenetic tree. Given a sequence alignment, uses IQ-TREE to reconstruct a phylogenetic tree. Parameters ---------- - aln : c3_types.AlignedSeqsType + aln : Alignment The sequence alignment. model : Model | str The substitution model with base frequencies and rate heterogeneity. @@ -273,7 +110,7 @@ def build_tree( Returns ------- - cogent3.PhyloNode + PhyloNode The IQ-TREE maximum likelihood tree from the given alignment. """ @@ -302,23 +139,17 @@ def build_tree( num_threads, ), ) - tree = _process_tree_yaml(yaml_result, names) - - # for non-Lie models, populate parameters to each branch and - # rename them to mimic cogent3.PhyloNode - if "edge_pars" in tree.params: - _edge_pars_for_cogent3(tree, model) - return tree + return _process_tree_yaml(yaml_result, names, model) def fit_tree( - aln: c3_types.AlignedSeqsType, - tree: cogent3.PhyloNode, + aln: Alignment, + tree: PhyloNode, model: Model | str, num_threads: int | None = None, *, bl_fixed: bool = False, -) -> cogent3.PhyloNode: +) -> PhyloNode: """Fit branch lengths and likelihood for a tree. Given a sequence alignment and a fixed topology, @@ -326,9 +157,9 @@ def fit_tree( Parameters ---------- - aln : c3_types.AlignedSeqsType + aln : Alignment The sequence alignment. - tree : cogent3.PhyloNode + tree : PhyloNode The topology to fit branch lengths to. model : Model | str The substitution model with base frequencies and rate heterogeneity. @@ -343,7 +174,7 @@ def fit_tree( Returns ------- - cogent3.PhyloNode + PhyloNode A phylogenetic tree with same given topology fitted with branch lengths. """ @@ -368,25 +199,19 @@ def fit_tree( num_threads, ), ) - tree = _process_tree_yaml(yaml_result, names) - - # for non-Lie models, populate parameters to each branch and - # rename them to mimic cogent3.PhyloNode - if "edge_pars" in tree.params: - _edge_pars_for_cogent3(tree, model) - return tree + return _process_tree_yaml(yaml_result, names, model) def nj_tree( - pairwise_distances: c3_types.PairwiseDistanceType, + pairwise_distances: DistanceMatrix, *, allow_negative: bool = False, -) -> cogent3.PhyloNode: +) -> PhyloNode: """Construct a neighbour joining tree from a pairwise distance matrix. Parameters ---------- - pairwise_distances : c3_types.PairwiseDistanceType + pairwise_distances : DistanceMatrix Pairwise distances to construct neighbour joining tree from. allow_negative : bool, optional Whether to allow negative branch lengths in the output. @@ -394,7 +219,7 @@ def nj_tree( Returns ------- - cogent3.PhyloNode + PhyloNode The neigbour joining tree. See Also @@ -415,12 +240,12 @@ def nj_tree( if not allow_negative: for node in tree.preorder(include_self=False): - node.length = max(node.length, 0) + node.length = max(cast("float", node.length), 0) return tree -def _all_same_taxa_set(trees: Iterable[cogent3.PhyloNode]) -> bool: +def _all_same_taxa_set(trees: Iterable[PhyloNode]) -> bool: tree_it = iter(trees) try: taxa_set = set(next(tree_it).get_tip_names()) @@ -431,10 +256,10 @@ def _all_same_taxa_set(trees: Iterable[cogent3.PhyloNode]) -> bool: def consensus_tree( - trees: Iterable[cogent3.PhyloNode], + trees: Iterable[PhyloNode], *, min_support: float = 0.5, -) -> cogent3.PhyloNode: +) -> PhyloNode: """Build a consensus tree, defaults to majority-rule consensus tree. The min_support parameter represents the proportion of trees a clade @@ -445,7 +270,7 @@ def consensus_tree( Parameters ---------- - trees : Iterable[cogent3.PhyloNode] + trees : Iterable[PhyloNode] The trees to form a consensus tree from. min_support : float, optional The minimum support for a clade to appear @@ -453,7 +278,7 @@ def consensus_tree( Returns ------- - cogent3.PhyloNode + PhyloNode The constructed consensus tree. """ diff --git a/tests/test_iqtree/test_tree_yaml.py b/tests/test_iqtree/test_tree_yaml.py index a541aec..47989d7 100644 --- a/tests/test_iqtree/test_tree_yaml.py +++ b/tests/test_iqtree/test_tree_yaml.py @@ -4,6 +4,7 @@ import pytest from cogent3 import make_tree +from piqtree import make_model from piqtree.exceptions import ParseIqTreeError from piqtree.iqtree._tree import _process_tree_yaml, _tree_equal @@ -168,10 +169,7 @@ def test_newick_not_in_candidates( ParseIqTreeError, match=re.escape("IQ-TREE output malformated, likelihood not found."), ): - _ = _process_tree_yaml( - yaml, - ["a", "b", "c", "d"], - ) + _ = _process_tree_yaml(yaml, ["a", "b", "c", "d"], make_model("JC")) def test_non_lie_dna_with_rate_model( @@ -185,7 +183,6 @@ def test_non_lie_dna_with_rate_model( "A/T": 1.0, "C/G": 1.0, "C/T": 3.82025079, - "G/T": 1, }, "mprobs": { "A": 0.3628523161, @@ -195,8 +192,13 @@ def test_non_lie_dna_with_rate_model( }, } rate_params = {"gamma_shape": 1.698497993, "p_invar": 1.002841144e-06} - tree = _process_tree_yaml(non_lie_dna_with_rate_model, ["a", "b", "c", "d"]) - assert tree.params["edge_pars"] == edge_params + tree = _process_tree_yaml( + non_lie_dna_with_rate_model, + ["a", "b", "c", "d"], + make_model("GTR+I+G"), + ) + for rate, value in edge_params["rates"].items(): + assert tree[0].params[rate] == value assert tree.params["RateGammaInvar"] == rate_params @@ -208,7 +210,11 @@ def test_non_lie_dna_model_motif_absent( ParseIqTreeError, match=re.escape("IQ-TREE output malformated, motif parameters not found."), ): - _ = _process_tree_yaml(non_lie_dna_with_rate_model, ["a", "b", "c", "d"]) + _ = _process_tree_yaml( + non_lie_dna_with_rate_model, + ["a", "b", "c", "d"], + make_model("GTR+I+G"), + ) def test_non_lie_dna_model_rate_absent( @@ -219,7 +225,11 @@ def test_non_lie_dna_model_rate_absent( ParseIqTreeError, match=re.escape("IQ-TREE output malformated, rate parameters not found."), ): - _ = _process_tree_yaml(non_lie_dna_with_rate_model, ["a", "b", "c", "d"]) + _ = _process_tree_yaml( + non_lie_dna_with_rate_model, + ["a", "b", "c", "d"], + make_model("GTR+I+G"), + ) def test_lie_dna_model( @@ -230,7 +240,7 @@ def test_lie_dna_model( "model_parameters": 0.4841804549, "mprobs": {"A": 0.25, "C": 0.25, "G": 0.25, "T": 0.25}, } - tree = _process_tree_yaml(lie_dna_model, ["a", "b", "c", "d"]) + tree = _process_tree_yaml(lie_dna_model, ["a", "b", "c", "d"], make_model("RY2.2b")) assert tree.params["ModelLieMarkovRY2.2b"] == model_parameters @@ -242,15 +252,19 @@ def test_lie_dna_model_motif_absent( ParseIqTreeError, match=re.escape("IQ-TREE output malformated, motif parameters not found."), ): - _ = _process_tree_yaml(lie_dna_model, ["a", "b", "c", "d"]) + _ = _process_tree_yaml( + lie_dna_model, + ["a", "b", "c", "d"], + make_model("RY2.2b"), + ) def test_unrest_model( unrest_model: dict[str, Any], ) -> None: - tree = _process_tree_yaml(unrest_model, ["a", "b", "c", "d"]) - assert tree.params["edge_pars"]["rates"] - assert tree.params["edge_pars"]["mprobs"] + tree = _process_tree_yaml(unrest_model, ["a", "b", "c", "d"], make_model("UNREST")) + assert "A/C" in tree[0].params + assert tree.params["mprobs"] @pytest.mark.parametrize(