Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/piqtree/_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from collections.abc import Iterable

import cogent3
import cogent3.app.typing as c3_types
from cogent3.app import composable
from cogent3.core.alignment import Alignment
from cogent3.core.tree import PhyloNode
from cogent3.evolve.fast_distance import DistanceMatrix
from cogent3.util.misc import extend_docstring_from

from piqtree import (
Expand Down Expand Up @@ -39,8 +40,8 @@ def __init__(

def main(
self,
aln: c3_types.AlignedSeqsType,
) -> cogent3.PhyloNode | cogent3.app.typing.SerialisableType:
aln: Alignment,
) -> PhyloNode:
tree = build_tree(
aln,
self._model,
Expand All @@ -57,7 +58,7 @@ class piq_fit_tree:
@extend_docstring_from(fit_tree)
def __init__(
self,
tree: cogent3.PhyloNode,
tree: PhyloNode,
model: Model | str,
*,
num_threads: int | None = None,
Expand All @@ -70,8 +71,8 @@ def __init__(

def main(
self,
aln: c3_types.AlignedSeqsType,
) -> cogent3.PhyloNode | cogent3.app.typing.SerialisableType:
aln: Alignment,
) -> PhyloNode:
tree = fit_tree(
aln,
self._tree,
Expand All @@ -89,7 +90,7 @@ def piq_random_tree(
num_taxa: int,
tree_mode: TreeGenMode,
rand_seed: int | None = None,
) -> cogent3.PhyloNode:
) -> PhyloNode:
return random_tree(num_taxa, tree_mode, rand_seed)


Expand All @@ -104,8 +105,8 @@ def __init__(

def main(
self,
aln: c3_types.AlignedSeqsType,
) -> c3_types.PairwiseDistanceType | cogent3.app.typing.SerialisableType:
aln: Alignment,
) -> DistanceMatrix:
dists = jc_distances(
aln,
num_threads=self._num_threads,
Expand All @@ -117,10 +118,10 @@ def main(
@composable.define_app
@extend_docstring_from(nj_tree)
def piq_nj_tree(
dists: c3_types.PairwiseDistanceType,
dists: DistanceMatrix,
*,
allow_negative: bool = False,
) -> cogent3.PhyloNode:
) -> PhyloNode:
tree = nj_tree(dists, allow_negative=allow_negative)
tree.params |= {"provenance": "piqtree"}
tree.source = getattr(dists, "source", None)
Expand All @@ -130,18 +131,18 @@ def piq_nj_tree(
@composable.define_app
@extend_docstring_from(model_finder)
def piq_model_finder(
aln: c3_types.AlignedSeqsType,
) -> ModelFinderResult | c3_types.SerialisableType:
aln: Alignment,
) -> ModelFinderResult:
return model_finder(aln)


@composable.define_app
@extend_docstring_from(consensus_tree)
def piq_consensus_tree(
trees: Iterable[cogent3.PhyloNode],
trees: Iterable[PhyloNode],
*,
min_support: float = 0.5,
) -> cogent3.PhyloNode:
) -> PhyloNode:
return consensus_tree(trees, min_support=min_support)


Expand Down
14 changes: 7 additions & 7 deletions src/piqtree/iqtree/_jc_distance.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections.abc import Sequence

import cogent3.app.typing as c3_types
import numpy as np
from _piqtree import iq_jc_distances
from cogent3.core.alignment import Alignment
from cogent3.evolve.fast_distance import DistanceMatrix

from piqtree.iqtree._decorator import iqtree_func
Expand All @@ -13,7 +13,7 @@
def _dists_to_distmatrix(
distances: np.ndarray,
names: Sequence[str],
) -> c3_types.PairwiseDistanceType:
) -> DistanceMatrix:
"""Convert numpy representation of distance matrix into cogent3 pairwise distance matrix.

Parameters
Expand All @@ -25,7 +25,7 @@ def _dists_to_distmatrix(

Returns
-------
c3_types.PairwiseDistanceType
DistanceMatrix
Pairwise distance matrix.

"""
Expand All @@ -37,21 +37,21 @@ def _dists_to_distmatrix(


def jc_distances(
aln: c3_types.AlignedSeqsType,
aln: Alignment,
num_threads: int | None = None,
) -> c3_types.PairwiseDistanceType:
) -> DistanceMatrix:
"""Compute pairwise JC distances for a given alignment.

Parameters
----------
aln : c3_types.AlignedSeqsType
aln : Alignment
alignment to compute pairwise JC distances for.
num_threads: int | None, optional
Number of threads for IQ-TREE to use, by default None (all available threads).

Returns
-------
c3_types.PairwiseDistanceType
DistanceMatrix
Pairwise JC distance matrix.

"""
Expand Down
14 changes: 7 additions & 7 deletions src/piqtree/iqtree/_model_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import dataclasses
from collections.abc import Iterable
from typing import Any
from typing import Any, cast

import yaml
from _piqtree import iq_model_finder
from cogent3.app import typing as c3_types
from cogent3.core.alignment import Alignment
from cogent3.util.misc import get_object_provenance

from piqtree.iqtree._decorator import iqtree_func
Expand Down Expand Up @@ -117,18 +117,18 @@ def from_rich_dict(cls, data: dict[str, Any]) -> "ModelFinderResult":


def model_finder(
aln: c3_types.AlignedSeqsType,
aln: Alignment,
model_set: Iterable[str] | None = None,
freq_set: Iterable[str] | None = None,
rate_set: Iterable[str] | None = None,
rand_seed: int | None = None,
num_threads: int | None = None,
) -> ModelFinderResult | c3_types.SerialisableType:
) -> ModelFinderResult:
"""Find the models of best fit for an alignment using ModelFinder.

Parameters
----------
aln : c3_types.AlignedSeqsType
aln : Alignment
The alignment to find the model of best fit for.
model_set : Iterable[str] | None, optional
Search space for models.
Expand All @@ -147,11 +147,11 @@ def model_finder(

Returns
-------
ModelFinderResult | c3_types.SerialisableType
ModelFinderResult
Collection of data returned from IQ-TREE's ModelFinder.

"""
source = aln.info.source
source = cast("str", aln.info.source)
if rand_seed is None:
rand_seed = 0 # The default rand_seed in IQ-TREE

Expand Down
195 changes: 195 additions & 0 deletions src/piqtree/iqtree/_parse_tree_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from typing import Any

from cogent3.core.tree import PhyloNode

from piqtree.exceptions import ParseIqTreeError
from piqtree.model import Model, StandardDnaModel

# the order defined in IQ-TREE
# assume UNREST model has 12 rates, GTR and simpler models always have 6 rates present
RATE_PARS = "A/C", "A/G", "A/T", "C/G", "C/T", "G/T"
RATE_PARS_UNREST = (
"A/C",
"A/G",
"A/T",
"C/A",
"C/G",
"C/T",
"G/A",
"G/C",
"G/T",
"T/A",
"T/C",
"T/G",
)
MOTIF_PARS = "A", "C", "G", "T"


def _insert_edge_pars(tree: PhyloNode, **kwargs: dict) -> None:
# inserts the edge parameters into each edge to match the structure of
# PhyloNode
for node in tree.get_edge_vector():
# skip the rate parameters when node is the root
if node.is_root():
kwargs = {k: v for k, v in kwargs.items() if k == "mprobs"}
del node.params["edge_pars"]
node.params.update(kwargs)


def _edge_pars_for_cogent3(tree: PhyloNode, model: Model) -> None:
base_model = model.submod_type.base_model

rate_pars = tree.params["edge_pars"]["rates"]
motif_pars = {"mprobs": tree.params["edge_pars"]["mprobs"]}
# renames parameters to conform to cogent3's naming conventions
if base_model in {StandardDnaModel.JC, StandardDnaModel.F81}:
# skip rate_pars since rate parameters are constant in JC and F81
_insert_edge_pars(
tree,
**motif_pars,
)
return
if base_model in {StandardDnaModel.K80, StandardDnaModel.HKY}:
rate_pars = {"kappa": rate_pars["A/G"]}

elif base_model is StandardDnaModel.TN:
rate_pars = {"kappa_r": rate_pars["A/G"], "kappa_y": rate_pars["C/T"]}

elif base_model is StandardDnaModel.GTR:
del rate_pars["G/T"]

# applies global rate parameters to each edge
_insert_edge_pars(
tree,
**rate_pars,
**motif_pars,
)


def _parse_nonlie_model(tree: PhyloNode, tree_yaml: dict) -> None:
# parse motif and rate parameters in the tree_yaml for non-Lie DnaModel
model_fits = tree_yaml.get("ModelDNA", {})

state_freq_str = model_fits.get("state_freq", "")
rate_str = model_fits.get("rates", "")

# parse motif parameters, assign each to a name, and raise an error if not found
if state_freq_str:
# converts the strings of motif parameters into dictionary
state_freq_list = [
float(value) for value in state_freq_str.replace(" ", "").split(",")
]
tree.params["edge_pars"] = {
"mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)),
}
else:
msg = "IQ-TREE output malformated, motif parameters not found."
raise ParseIqTreeError(msg)

# parse rate parameters, assign each to a name, and raise an error if not found
if rate_str:
rate_list = [float(value) for value in rate_str.replace(" ", "").split(",")]
tree.params["edge_pars"]["rates"] = dict(
zip(RATE_PARS, rate_list, strict=True),
)
else:
msg = "IQ-TREE output malformated, rate parameters not found."
raise ParseIqTreeError(msg)


def _parse_lie_model(
tree: PhyloNode,
tree_yaml: dict,
lie_model_name: str,
) -> None:
# parse motif and rate parameters in the tree_yaml for Lie DnaModel
model_fits = tree_yaml.get(lie_model_name, {})

# parse motif parameters, assign each to a name, and raise an error if not found
state_freq_str = model_fits.get("state_freq", "")
if state_freq_str:
state_freq_list = [
float(value) for value in state_freq_str.replace(" ", "").split(",")
]
tree.params[lie_model_name] = {
"mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)),
}
else:
msg = "IQ-TREE output malformated, motif parameters not found."
raise ParseIqTreeError(msg)

# parse rate parameters, skipping LIE_1_1 (aka JC69) since its rate parameter is constant thus absent
if "model_parameters" in model_fits:
model_parameters = model_fits["model_parameters"]

# convert model parameters to a list of floats if they are a string
if isinstance(model_parameters, str):
tree.params[lie_model_name]["model_parameters"] = [
float(value) for value in model_parameters.replace(" ", "").split(",")
]
else:
# directly use the float
tree.params[lie_model_name]["model_parameters"] = model_parameters


def _parse_unrest_model(tree: PhyloNode, tree_yaml: dict) -> None:
model_fits = tree_yaml.get("ModelUnrest", {})

state_freq_str = model_fits.get("state_freq", "")
rate_str = model_fits.get("rates", "")

# parse state frequencies
if state_freq_str:
state_freq_list = [
float(value) for value in state_freq_str.replace(" ", "").split(",")
]
tree.params["edge_pars"] = {
"mprobs": dict(zip(MOTIF_PARS, state_freq_list, strict=True)),
}
else:
msg = "IQ-TREE output malformated, motif parameters not found."
raise ParseIqTreeError(msg)

# parse rates
if rate_str:
rate_list = [float(value) for value in rate_str.replace(" ", "").split(",")]
tree.params["edge_pars"]["rates"] = dict(
zip(RATE_PARS_UNREST, rate_list, strict=True),
)
else:
msg = "IQ-TREE output malformated, rate parameters not found."
raise ParseIqTreeError(msg)


def parse_model_parameters(
tree: PhyloNode,
tree_yaml: dict[str, Any],
model: Model,
) -> None:
"""Parse model parameters from the returned yaml format.

Parameters
----------
tree : PhyloNode
The tree to attach model parameters to.
tree_yaml : dict[str, Any]
The yaml result returned from IQ-TREE.
"""
# parse non-Lie DnaModel parameters
if "ModelDNA" in tree_yaml:
_parse_nonlie_model(tree, tree_yaml)

elif "ModelUnrest" in tree_yaml:
_parse_unrest_model(tree, tree_yaml)

# parse Lie DnaModel parameters, handling various Lie model names
elif key := next(
(key for key in tree_yaml if key.startswith("ModelLieMarkov")),
None,
):
_parse_lie_model(tree, tree_yaml, key)

# for non-Lie models, populate parameters to each branch and
# rename them to mimic PhyloNode
if "edge_pars" in tree.params:
_edge_pars_for_cogent3(tree, model)
Loading