# Annotate taxonomy at tree nodes

This script attempts to annotate internal nodes of a tree using NCBI taxonomy. It is not limited to the seven common ranks. Any taxon can be a candidate.

## Preparation

### Dependencies

In [1]:
import pandas as pd
from skbio import TreeNode

### Input files

Tree of genomes

In [2]:
tree_fp = '../trees/astral.e5p50.nwk'

Genome ID to taxonomy ID map

In [3]:
g2tid_fp = '../../genomes/in_tree/g2tid.txt'

NCBI taxdump directory

In [4]:
taxdump_dir = '../taxdump'

### Parameters

Ranks and order to care:

In [5]:
ranks = ['phylum', 'class', 'order', 'family', 'genus', 'species']

Drop taxon with fraction below this threshold:

In [6]:
min_frac_th = 0.01

Consider a taxon with fraction above this threshold as dominant:

In [7]:
dom_frac_th = 0.95

Exclude these TaxIDs:

In [8]:
tids_to_exclude = set(['1', '131567'])  # "root" and "cellular organism"

## Pre-processing

### Read input files

In [9]:
tree = TreeNode.read(tree_fp)
n, m = tree.count(), tree.count(tips=True)
print('Tree has %d tips and %d internal nodes.' % (m, n - m))

Tree has 10575 tips and 10028 internal nodes.


In [10]:
with open(g2tid_fp, 'r') as f:
    g2tid = dict(x.split('\t') for x in f.read().splitlines())
print('Map has %d genome IDs corresponding to %d taxIDs.' % (len(g2tid), len(set(g2tid.values()))))

Map has 10575 genome IDs corresponding to 9887 taxIDs.


In [11]:
taxdump = {}
with open('%s/nodes.dmp' % taxdump_dir, 'r') as f:
    for line in f:
        x = line.rstrip('\r\n').replace('\t|', '').split('\t')
        taxdump[x[0]] = {'parent': x[1], 'rank': x[2], 'name': ''}
with open('%s/names.dmp' % taxdump_dir, 'r') as f:
    for line in f:
        x = line.rstrip('\r\n').replace('\t|', '').split('\t')
        if x[3] == 'scientific name':
            taxdump[x[0]]['name'] = x[1]
print('NCBI taxdump has %d TaxIDs.' % len(taxdump))

NCBI taxdump has 1573975 TaxIDs.


### Generate maps

Associate genomes to all ancestral taxonomy IDs.

For example, `G000046705` is mapped to TaxID: 38323 (*Bartonella henselae*), then it will be associated with TaxIDs: 38323, 773 (*Bartonella*), 772 (Bartonellaceae), 356 (Rhizobiales), 28211 (Alphaproteobacteria), 1224 (Proteobacteria) and 2 (Bacteria).

In [12]:
tid2gs, g2tids = {}, {}
for g in tree.subset():
    cid = g2tid[g]
    g2tids[g] = []
    while True:
        if cid not in taxdump:
            raise ValueError('Invalid TaxID: %s.' % cid)

        if len(tids_to_exclude) == 0 or not cid in tids_to_exclude:
            g2tids[g].append(cid)

        tid2gs.setdefault(cid, set()).add(g)
        pid = taxdump[cid]['parent']
        if cid == pid:
            break
        cid = pid
print('TaxIDs associated with one or more genomes: %d.' % len(tid2gs))

TaxIDs associated with one or more genomes: 16388.


### Export maps

In [13]:
with open('tid2gs.txt', 'w') as f:
    for tid, gs in sorted(tid2gs.items(), key=lambda x: int(x[0])):
        f.write('%s\t%s\n' % (tid, ','.join(sorted(gs))))

In [14]:
with open('g2tids.txt', 'w') as f:
    for g, tids in sorted(g2tids.items(), key=lambda x: x[0]):
        f.write('%s\t%s\n' % (g, ','.join(sorted(tids, key=int))))

In [15]:
with open('tid_info.tsv', 'w') as f:
    f.write('taxID\trank\ttaxon\n')
    for tid in sorted(tid2gs, key=int):
        f.write('%s\t%s\t%s\n'
                % (tid, taxdump[tid]['rank'], taxdump[tid]['name']))

## Solution I: Strict monophyly

### Background

A node is annotated by a TaxID only when:
 1. All descendant genomes belong to this TaxID.
 2. No other genomes in the tree belong to this TaxID.
 3. If multiple TaxIDs meet criteria 1 and 2, the lowest (most derived) one is kept.

### Workflow

Identify the lowest common ancestor of genomes represented by each TaxID.

In [16]:
tid2nid = {}
for tid, gs in tid2gs.items():
    tid2nid[tid] = tree.lca(list(gs)).name
print('Nodes representing LCAs per TaxID: %d.' % len(tid2nid))

Nodes representing LCAs per TaxID: 16388.


Extract TaxIDs that represent monophyletic groups in the tree.

In [17]:
tid2nid_mono = {}
for tid, nid in tid2nid.items():
    if tree.find(nid).count(tips=True) == len(tid2gs[tid]):
        tid2nid_mono[tid] = nid
print('TaxIDs that are monophyletic: %d.' % len(tid2nid_mono))

TaxIDs that are monophyletic: 1216.


Identify TaxID(s) assigned to each node.

Note that one node may receive multiple TaxIDs that are sequential in the taxonomic hierarchy. For example, 38323 (*Bartonella henselae*), 773 (*Bartonella*), 772 (Bartonellaceae) can all be assigned to one node if there are only two Bartonellaceae members: two strains of *B. henselae* represented in the tree.

In [18]:
nid2tids = {}
for tid, nid in tid2nid_mono.items():
    nid2tids.setdefault(nid, set()).add(tid)
print('Nodes with monophyletic TaxIDs assigned: %d.' % len(nid2tids))

Nodes with monophyletic TaxIDs assigned: 970.


Assign a unique TaxID to each node.

In case there are multiple TaxIDs assigned to one node, the lowest-level (most derived) one will be used. In the aforementioned example, 38323 (*Bartonella henselae*) will be selected.

In [19]:
def lowest_tid(tids):
    if len(tids) == 1:
        return max(tids)
    else:
        # for each TaxID, extract the full lineage string
        tid2pids = {}
        for tid in tids:
            tid2pids[tid] = []
            cid = tid
            while True:
                pid = taxdump[cid]['parent']
                if cid == pid:
                    break
                tid2pids[tid].append(pid)
                cid = pid
        # TaxID with the longest lineage string is the most derived
        return sorted(tid2pids, key=lambda k: len(tid2pids[k]),
                      reverse=True)[0]

In [20]:
nid2tid_uniq = {k: lowest_tid(v) for k, v in nid2tids.items()}

Generate result table.

In [21]:
def nid2tid_to_df(nid2tid):
    df = pd.DataFrame.from_dict(nid2tid, orient='index')
    df.columns=['taxID']
    df.index.name = 'node'
    df['rank'] = df['taxID'].apply(lambda x: taxdump[x]['rank'])
    df['taxon'] = df['taxID'].apply(lambda x: taxdump[x]['name'])
    indices = sorted(nid2tid, key=lambda x: int(x.split(':')[-1][1:]))
    return df.reindex(indices)

In [22]:
df = nid2tid_to_df(nid2tid_uniq)
df.head()

Unnamed: 0_level_0,taxID,rank,taxon
node,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
N1,131567,no rank,cellular organisms
N2,2157,superkingdom,Archaea
N3,2,superkingdom,Bacteria
N17,1801634,no rank,unclassified Candidatus Micrarchaeota
N25,1936272,no rank,Candidatus Heimdallarchaeota


In [23]:
df.to_csv('strict_assignments.tsv', sep='\t')

(non-Pandas version)

In [24]:
%%script false
with open('strict_assignments.tsv', 'w') as f:
    f.write('%s\n' % '\t'.join(('#nodeID', 'taxID', 'rank', 'taxon')))
    for nid in sorted(nid2tid_uniq, key=lambda x: int(x.split(':')[-1][1:])):
        tid = nid2tid_uniq[nid]
        f.write('%s\t%s\t%s\t%s\n' % (nid.split(':')[-1], tid,
                taxdump[tid]['rank'], taxdump[tid]['name']))

## Solution II: Fractions

### Background

Uses the fractions of available taxonomic groups to annotate a node. Theis "fraction" is calculated based on the phylogenetic diversity of the clade, rather than a simple weighted sum. For example, in the following tree:

In [25]:
print(TreeNode.read(['((A,A,A,A),B)C;']).ascii_art())

                    /-A
                   |
                   |--A
          /--------|
         |         |--A
-C-------|         |
         |          \-A
         |
          \-B


The taxonomic composition at node "C" is 50% A and 50% B.

All available TaxIDs (including the immediate and ancestral ones) will be calculate at each nodes. For example in the following tree:

In [26]:
print(TreeNode.read(["('R. leguminosarum',('B. bacilliformis',('B. henselae','B. quintana')A)B)C;"]).ascii_art())

          /-R. leguminosarum
-C-------|
         |          /-B. bacilliformis
          \B-------|
                   |          /-B. henselae
                    \A-------|
                              \-B. quintana


 - Node A will be 50% species *B. henselae*, 50% species *B. quintana*, 100% genus *Bartonella*, 100% family Bartonellaceae,...
 - Node B will be 50% *B. bacilliformis*, 25% *B. henselae*, 25% *B. quintana*, 100% *Bartonella*, 100% Bartonellaceae,...
 - Node C will be:
  - species: 50% *R. leguminosarum*, 25% *B. bacilliformis*, 12.5% *B. henselae*, 12.5% *B. quintana*;
  - genus: 50% *Bartonella*, 50% *Rhizobium*;
  - family: 50% Bartonellaceae, 50% Rhizobiaceae;
  - order / class / phylum: 100% Rhizobiales / Alphaproteobacteria / Proteobacteria.

Note that taxonomy is NOT limited to the seven most common ranks. In the aforementioned example, node C will also be assigned as 50% "Rhizobium/Agrobacterium group" (TaxID: 227290).

### Workflow

Identify TaxIDs and calculate their fractions at each node.

In a post-order traversal, the compositions of child nodes are summarized, normalized and added to the current node.

In [27]:
fractions = {}
for node in tree.postorder(include_self=True):
    if node.is_tip():
        fractions[node.name] = {x: 1.0 for x in g2tids[node.name]}
    else:
        fractions[node.name] = {}
        for child in node.children:
            for tid, frac in fractions[child.name].items():
                if tid in fractions[node.name]:
                    fractions[node.name][tid] += frac
                else:
                    fractions[node.name][tid] = frac

        n = len(node.children)
        for tid in fractions[node.name]:
            fractions[node.name][tid] /= n

For each node, drop TaxIDs with fraction below cutoff.

In [28]:
if (min_frac_th or 0) > 0:
    for node in tree.non_tips(include_self=True):
        for tid in [k for k, v in fractions[node.name].items() if v < min_frac_th]:
            fractions[node.name].pop(tid)

In output files, record the taxonomic composition of a node as:

- `genome <tab\> taxID1,taxID2,taxID3,taxID4,...`
- `node <tab\> taxID1,taxID2:0.45,taxID3:0.005,...`

Note: taxID without franction means 1.0

In [29]:
def frac_str(f):
    """Convert a dict of fractions into a string."""
    return ','.join(['%s:%.4g' % (k, v) if v < 1.0 else k for k, v in
                     sorted(f.items(), key=lambda x: x[1], reverse=True)])

In [30]:
def node_id_list(d):
    return sorted([x for x in d if x.startswith('N')],
                  key=lambda x: (x[0], int(x[1:])))

In [31]:
with open('fractions.txt', 'w') as f:
    for nid in node_id_list(fractions):
        f.write('%s\t%s\n' % (nid, frac_str(fractions[nid])))

Calculate fractions at each rank per node.

Note: the sum of fractions will be brought up to one. For example, a node has two descendants, one is "*Bartonella henselae*" and the other is "environmental sample A1", then at the genus level this node will be assigned as 100% *Bartonella* instead of 50%.

This is a defense against unclassified organisms.

In [32]:
frac_at_ranks = {}
for rank in ranks:
    fracs = {}
    for nid in fractions:
        fracs[nid] = {}
        for tid, frac in fractions[nid].items():
            if taxdump[tid]['rank'] == rank:
                fracs[nid][tid] = frac

        if len(fracs[nid]) > 0:
            s = sum(fracs[nid].values())
            for tid in fracs[nid]:
                fracs[nid][tid] /= s

    frac_at_ranks[rank] = fracs

In [33]:
with open('frac_at_ranks.tsv', 'w') as f:
    f.write('node\t%s\n' % '\t'.join(ranks))
    for nid in node_id_list(fractions):
        row = [nid]
        for rank in ranks:
            row.append(frac_str(frac_at_ranks[rank][nid]))
        f.write('%s\n' % '\t'.join(row))

Identify "dominant" TaxIDs, whose fractions are above threshold at each node.

In [34]:
dominants = {}
for node in tree.non_tips(include_self=True):
    for tid, frac in fractions[node.name].items():
        if frac >= dom_frac_th:
            dominants.setdefault(node.name, set()).add(tid)
print('Nodes with at least one dominant taxon assigned: %d.' % len(dominants))

Nodes with at least one dominant taxon assigned: 10027.


Find the lowest-level (most derived) taxon for each node.

In [35]:
lowests = {k: lowest_tid(v) for k, v in dominants.items()}
print('Unique TaxIDs being the lowest: %d.' % len(set(lowests.values())))

Unique TaxIDs being the lowest: 1641.


In [36]:
df = nid2tid_to_df(lowests)
df.head()

Unnamed: 0_level_0,taxID,rank,taxon
node,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
N2,2157,superkingdom,Archaea
N3,2,superkingdom,Bacteria
N4,2157,superkingdom,Archaea
N5,2157,superkingdom,Archaea
N6,1783234,no rank,Bacteria candidate phyla


In [37]:
df.to_csv('p%s_assignments.tsv' % int(dom_frac_th * 100), sep='\t')