In [22]:
from __future__ import annotations
import json
import pandas as pd
import json
import numpy as np
from typing import Iterable
from dataclasses import dataclass, field

# http://etetoolkit.org/download/
# http://etetoolkit.org/docs/latest/tutorial/tutorial_ncbitaxonomy.html#setting-up-a-local-copy-of-the-ncbi-taxonomy-database
# warning: ete tk inserts itself into ~/.local/share/ete/
from ete4 import NCBITaxa
ncbi = NCBITaxa()

from local.utils import regex

In [23]:
_df = pd.read_csv("../../data/scadc_nr.tsv", header=None, sep="\t")
print(_df.shape)
hits = {}
meta = {}
for _, row in _df.iterrows():
    query, subject, annotation, pident, bitscore = row
    hits[query] = hits.get(query, []) + [(subject, pident, bitscore)]

    if subject in meta: continue
    meta[subject] = annotation

for q, d in hits.items():
    hits[q] = sorted(d, key=lambda x: x[2], reverse=True)

len(hits), len(meta)

(98034, 5)


(4127, 93071)

In [24]:
_rows = []
for i, (q, d) in enumerate(hits.items()):
    # print(f"{i}/{len(hits)}", end="\r")
    best = d[0]
    subject, pident, bitscore = best

    ann = meta[subject]
    tax = next(regex(r"\[.+\]", ann))[1:-1]

    _rows.append((q, tax, pident, bitscore))

_df = pd.DataFrame(_rows, columns=["query", "tax", "pident", "bitscore"])

In [25]:
def tax_from_sci_name(sci_name):
    _get = lambda x: next(iter(x.values()))
    tax_data = ncbi.get_name_translator([sci_name])
    if len(tax_data)==0: # retry with genus
        tax_data = ncbi.get_name_translator([sci_name.split(" ")[0]])
    if len(tax_data)==0:
        taxid, match_type = None, None
    else:
        taxid = _get(tax_data)[0]
        match_type = _get(ncbi.get_rank([taxid]))
    return taxid, match_type

In [26]:
print(_df.shape)
_rows = []
for _, row in _df.iterrows():
    q, tax, pident, bitscore = row
    taxid, match_type = tax_from_sci_name(tax)
    if taxid is None: continue
    id_lin = ncbi.get_lineage(taxid)
    assert id_lin is not None
    names = ncbi.get_taxid_translator(id_lin)
    ranks = ncbi.get_rank(id_lin)
    lineage = [(ranks[id], names[id]) for id in id_lin if ranks[id] != "no rank"]

    _rows.append((q, tax, taxid, pident, bitscore, json.dumps(lineage, separators=(',', ':'))))
_df = pd.DataFrame(_rows, columns=['query', 'sci_name', 'taxid', 'pident', 'bitscore', 'lineage_json'])
print(_df.shape)

(4127, 4)
(4096, 6)


In [None]:
@dataclass
class Node:
    name: str
    level: str
    entropy: float =    field(default_factory=lambda: 0.0)
    count: int =        field(default_factory=lambda: 0)
    sum_entropy: float =field(default_factory=lambda: 0.0)
    sum_count: int =    field(default_factory=lambda: 0)
    parent: Node|None = field(default_factory=lambda: None)

    def __str__(self) -> str:
        return f"<{self.name}, entropy={self.sum_entropy}, count={self.sum_count}>"
    
    def __repr__(self) -> str:
        return self.__str__()

class Lineage:
    def __init__(self, lineage: Iterable[tuple[str, str]]|None = None) -> None:
        """@lineage is iterable of (name, level) tuples"""
        self._i = 0
        self._lineage: list[tuple[str, str]] = [] if lineage is None else list(lineage)

    def Add(self, name: str, level: str):
        self._lineage.append((name, level))

    def __iter__(self):
        self._i = 0
        return self
    
    def __next__(self):
        if self._i >= len(self._lineage): raise StopIteration
        self._i += 1
        return self._lineage[self._i-1]

class Tree:
    def __init__(self) -> None:
        self.root = Node("root", "root")
        self.nodes: dict[str, Node] = {self.root.name:self.root}
        self.sum = 0

    def NewObservation(self, lineage: Lineage):
        self.sum += 1
        parent = self.root
        for k, level in lineage:
            if k not in self.nodes:
                self.nodes[k] = Node(k, level, parent=parent)
            else:
                assert self.nodes[k].parent == parent, f"invalid lineage: the parent of [{k}] was previously said to be [{self.nodes[k].parent}], now [{parent.name}]"
            parent = self.nodes[k]
        leaf = parent # last node assigned in the loop
        leaf.count += 1

    def BestLineage(self) -> list[Node]:
        # update entropy
        for node in self.nodes.values():
            node.entropy = 0.0

        for node in self.nodes.values():
            if node.count == 0: continue

            r_of_x = node.count/self.sum
            c = node.count
            entropy = r_of_x*np.log10(r_of_x)
            while node is not None:
                node.sum_entropy += entropy
                node.sum_count += c
                node = node.parent
        
        # find best lineage
        all_children: dict[str, list[Node]] = {}
        for node in self.nodes.values():
            if node.parent is None: continue
            if node.parent.name not in all_children:
                all_children[node.parent.name] = []
            all_children[node.parent.name].append(node)

        lineage: list[Node] = []
        node = self.root
        children = all_children[node.name]
        while True:
            # best = sorted(children, key=lambda x: x.sum_count, reverse=True)[0]
            best = sorted(children, key=lambda x: x.sum_entropy)[0]
            lineage.append(best)
            if best.name not in all_children: break
            children = all_children[best.name]

        return lineage
        
trees: dict[str, Tree] = {}
ranks = {}
for _, row in _df.iterrows():
    query, sci_name, taxid, pident, bitscore, lineage_json = row
    entry = query.split("_")[0]
    lineage = Lineage()
    for rank, name in json.loads(lineage_json):
        ranks[name] = rank
        lineage.Add(name, rank)
    if entry not in trees:
        trees[entry] = Tree()
    trees[entry].NewObservation(lineage)

for i, (k, tree) in enumerate(trees.items()):
    # if i != 3: continue
    lin = tree.BestLineage()
    print(k)
    print([(n.level, n.name) for n in lin if n.level not in {"clade", }])
    print([(n.sum_count) for n in lin])
    print([(round(n.sum_entropy*100)/100) for n in lin])
    print()

    # break