# 201215 unmatched species taxonomy trees

In [None]:
exptname = '201031-database-v1.1-software-version-migration'
datestr = '201215'
nbname = datestr + '-unmatched-species-taxonomy-trees'

In [None]:
import json
from pathlib import Path
import re
from zipfile import ZipFile
from gzip import GzipFile
from collections import Counter

In [3]:
from tqdm import tqdm
import pandas as pd

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

In [4]:
from midas.db.models import Genome

## File paths

In [5]:
infiles = dict(
    v11_archive='/home/jared/projects/midas/data/v1/archives/refseq_curated_1.1_beta_200525.midas-archive.gz',
    taxa=Path('../../data/intermediate/201031-database-v1.1-software-version-migration/201102-download-taxa/'),
    taxa_additional=Path('../../data/intermediate/201031-database-v1.1-software-version-migration/201205-download-additional-taxa/'),
    taxonomy_original=Path('../../data/intermediate/201031-database-v1.1-software-version-migration/201113-original-genome-taxa/'),
    matches=Path('../../data/intermediate/201031-database-v1.1-software-version-migration/201122-taxon-name-matching/'),
    updated_taxids=Path('../../data/intermediate/201031-database-v1.1-software-version-migration/201201-download-updated-assembly-summaries/updated-assembly-taxids.json')
)

In [6]:
processed_out = Path('../../data/processed/') / exptname / nbname
processed_out.mkdir(exist_ok=True, parents=True)

## Load data

### Archive files

In [7]:
archive_v11 = ZipFile(GzipFile(infiles['v11_archive']))
archive_v11.read('info').decode()

'{"archive_version": "1.0"}'

In [8]:
with archive_v11.open('genome_sets/midas/assembly/curated') as f:
    gset_data = json.load(f)

In [9]:
genomes_by_species = dict()

for key, adata in gset_data['annotations'].items():
    sp = (adata['tax_genus'], adata['tax_species'])
    genomes_by_species.setdefault(sp, set()).add(key)
    
species_names = sorted(genomes_by_species.keys())
genus_names = sorted(set(g for g, s in species_names))
    
len(genus_names), len(species_names)

(419, 1438)

### Taxonomy

#### Current data

In [10]:
with open(infiles['taxa'] / 'taxa.json') as f:
    taxon_data = {tdata['taxid']: tdata for tdata in json.load(f)}

with open(infiles['taxa_additional'] / 'taxa.json') as f:
    taxon_data.update({tdata['taxid']: tdata for tdata in json.load(f)})

In [11]:
with open(infiles['taxa'] / 'aka_taxids.json') as f:
    aka_taxids = {int(id1): id2 for id1, id2 in json.load(f).items()}

In [12]:
# Build dictionary of parent ID relationships
parent_rels = dict()

for taxid, taxon in taxon_data.items():
    ptaxid = taxon['parent_taxid']
    if ptaxid != 0:
        parent_rels[taxid] = ptaxid

#### Original 2016 data

In [13]:
with open(infiles['taxonomy_original'] / 'original-tax-summaries.json') as f:
    orig_tax_summaries = {int(tid): s for tid, s in json.load(f).items()}

### Name matches

In [14]:
species_name_matches = dict()

with open(infiles['matches'] / 'species-name-matches.json') as f:
    for d in json.load(f):
        sp = (d.pop('curated_genus'), d.pop('curated_species'))
        species_name_matches[sp] = d if d['matched_taxid'] is not None else None

### Database connection

In [15]:
engine = create_engine('sqlite:///db.sqlite')
Session = sessionmaker(engine)

In [16]:
session = Session()

### Assembly taxids

In [17]:
assembly_ids = {g.key: g.entrez_id for g in session.query(Genome)}

In [18]:
orig_assembly_taxids = {g.entrez_id: g.extra['ncbi_taxid'] for g in session.query(Genome)}

In [19]:
with open(infiles['updated_taxids']) as f:
    updated_assembly_taxids = {int(k): v for k, v in json.load(f).items()}

In [20]:
assert updated_assembly_taxids.keys() == orig_assembly_taxids.keys()

## Func defs

In [21]:
def only(it):
    """Get the only element of an iterable if it has length one, else raise an error."""
    (item,) = it
    return item

In [22]:
def resolve_alias(tid):
    return aka_taxids.get(tid, tid)

In [23]:
def gettaxon(tid):
    """Get taxon by ID, resolving alias IDs."""
    return taxon_data[resolve_alias(tid)]

In [24]:
def getparent(taxon):
    """Get taxon's parent, handling aliases of parent id."""
    if isinstance(taxon, int):
        taxon = gettaxon(taxon)
        
    try:
        return gettaxon(taxon['parent_taxid'])
    except KeyError:
        return None

In [25]:
def iter_ancestors(taxon, incself=False):
    if isinstance(taxon, int):
        taxon = gettaxon(taxon)
    if not incself:
        taxon = getparent(taxon)
        
    while taxon is not None:
        yield taxon
        taxon = getparent(taxon)

In [26]:
def taxon_url(taxid):
    return 'https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?mode=Info&id=%d' % taxid

In [27]:
def make_tree_child_map(leaf_taxids, root_taxid=None):
    """Get mapping from tax IDs to child IDs."""
    children = dict()
    roots = set()
    
    heads = set(leaf_taxids)
    
    while heads:
        taxid = heads.pop()
            
        # Is root
        if taxid == root_taxid or taxid not in parent_rels:
            roots.add(taxid)
            continue
        
        ptaxid = parent_rels[taxid]
        
        if ptaxid in children:
            children[ptaxid].add(taxid)
        
        else:
            children[ptaxid] = {taxid}
            heads.add(ptaxid)

    if len(roots) > 1:
        raise ValueError('More than one root found')
    
    root = only(roots)
    
    if root_taxid is not None:
        assert root_taxid == root
        
    else:
        # Find LCA
        while root not in leaf_taxids and len(children[root]) == 1:
            root = only(children.pop(root))
    
    return children, root

In [28]:
def find_lca(taxids):
    children, lca = make_tree_child_map(taxids)
    return lca

In [29]:
def sum_counts_recursive(child_map, root, counts):
    rcounts = dict()
    
    def _count_subtree(taxid):
        cnt = counts.get(taxid, 0)
        for child in child_map.get(taxid, []):
            cnt += _count_subtree(child)
        rcounts[taxid] = cnt
        return cnt
    
    _count_subtree(root)
    
    return rcounts

In [30]:
def count_taxon_reassignments(aids):
    """Count newtaxid -> oldtaxid for set of genomes given by assembly ID."""
    counts = dict()
    counts_skipped = dict()

    for aid in aids:
        newtaxid = resolve_alias(updated_assembly_taxids[aid])
        oldtaxid = resolve_alias(orig_assembly_taxids[aid])

        summary = orig_tax_summaries_by_canonical_taxid[oldtaxid]
        if (summary['genus'], summary['species']) != sp:
            c = counts_skipped
        else:
            c = counts

        c.setdefault(newtaxid, Counter())[oldtaxid] += 1
        
    return counts, counts_skipped

## Tree report generation

In [31]:
orig_tax_summaries_by_canonical_taxid = {resolve_alias(taxid): summary for taxid, summary in orig_tax_summaries.items()}

### Config

In [32]:
report_species = [sp for sp in species_names if species_name_matches[sp] is None]
report_species_ids = {sp: re.sub(r'[^a-z-]', '', re.sub(r'\s+', '-', ' '.join(sp).lower())) for sp in report_species}

In [33]:
report_attrs = dict(
    title=datestr + ' unmatched species taxonomy trees',
    tree_indent_px=15,
)

### Page template

In [34]:
REPORT_CSS = '''
body {
    margin: 24px;
    font-size: 1.4em;
}

h1 {
    font-size: 4rem !important;
}

h2 {
    font-size: 2.5rem !important;
    margin-top: 6rem;
}

h3 {
    font-size: 2rem !important;
    margin-top: 3rem;
}

table {
    border-collapse: collapse;
}

table > tbody > tr.row-even {
    background: #eeeeee;
}

td, th {
    padding: 6px 15px !important;
}

.text-gray {
    color: #999;
}

td.checkmark, td.xmark {
    padding: 0 !important;
    text-align: center;
    vertical-align: middle;
    font-size: 150%;
}

td.checkmark::before {
    color: green;
    content: "\\2714";
    // font-weight: bold;
}

td.xmark::before {
    color: red;
    content: "\\2718";
    // font-weight: bold;
}
'''

In [35]:
REPORT_PRE = '''
<!DOCTYPE HTML>
<html lang="en">
    <head>
        <meta name="author" content="Jared Lumpe">
        <title>{title}</title>
        <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.min.css"/>
        <style>{css}</style>
    </head>
    <body>
        <h1>{title}</h1>
'''.format(**report_attrs, css=REPORT_CSS)

REPORT_POST = '''
    </body>
</html>
'''

In [36]:
TABLE_HEAD = '''
<thead>
    <tr>
        <th colspan=5>2020 Taxon</th>
        <th colspan=5>2016 Taxon</th>
    </tr>
    <tr>
        <th>Name</th>
        <th>ID</th>
        <th>Rank</th>
        <th colspan="2">Genome Count</th>
        <th>Same?</th>
        <th>Name</th>
        <th>ID</th>
        <th>Rank</th>
        <th>2016 Name</th>
    </tr>
</thead>
'''

### Contents

In [37]:
def write_contents_table(f):
    f.write('''
    <h2>Contents</h2>
    
    <ol>
    ''')
    
    for sp in report_species:
        f.write('<li><a href="#%s">%s %s</a></li>\n' % (report_species_ids[sp], *sp))
    
    f.write('</ol>\n')

### Sections

In [38]:
def write_old_taxon_cells(f, ntid, otid, count):
    taxon = gettaxon(otid)

    # Old taxon same as new
    if otid == ntid:
        f.write('''
            <td class="checkmark"></td>
            <td></td>
            <td></td>
            <td></td>
        ''')

    # Old taxon different
    else:
        f.write('''
            <td class="xmark"></td>
            <td><a href="{url}" target="_blank">
                {scientific_name}
            </a></td>
            <td>{taxid}</td>
            <td>{rank}</td>
        '''.format(**taxon, url=taxon_url(otid)))

    summary = orig_tax_summaries_by_canonical_taxid.get(otid)
    orig_name = summary['scientificname']
    
    if orig_name != taxon['scientific_name']:
        f.write('<td>%s</td>\n' % orig_name)
    else:
        f.write('<td class="text-gray">(same)</td>\n')

In [39]:
def write_tree_row_group(f, ntid, depth, count, rcount, old_counts, iseven):
    data = dict(gettaxon(ntid))
    data['url'] = taxon_url(ntid)
    data['indent'] = depth * report_attrs['tree_indent_px']
    data['count'] = count or ''
    data['rcount'] = rcount
    data['td'] = 'td rowspan=%d' % max(len(old_counts), 1)
    tr = '<tr class="%s">\n' % ('row-even' if iseven else 'row-odd')

    f.write(tr)
    f.write('''
        <{td}><a href="{url}" target="_blank" style="margin-left: {indent}px; display: inline-block;">
            {scientific_name}
        </a></td>
        <{td}>{taxid}</td>
        <{td}>{rank}</td>
        <{td} style="text-align: center">{count}</td>
        <{td} style="text-align: center">({rcount})</td>
    '''.format(**data))
    
    # Old taxa
    for i, (otid, ocount) in enumerate(old_counts.items()):
        if i > 0:
            f.write(tr)

        write_old_taxon_cells(f, ntid, otid, ocount)
        f.write('</tr>\n')
        
    # No old taxa, write blank cells
    if not old_counts:
        f.write('''
            <td></td>
            <td></td>
            <td></td>
            <td></td>
            <td></td>
        </tr>
        ''')

In [40]:
def write_tree_table(f, child_map, root_taxid, n2o_counts):
    counts = {ntid: sum(ocounts.values()) for ntid, ocounts in n2o_counts.items()}
    rcounts = sum_counts_recursive(child_map, root_taxid, counts)
    
    iseven = False
    
    def _write_subtree(taxid, depth):
        nonlocal iseven

        cnt = counts.get(taxid, 0)
        rcnt = rcounts[taxid]
        old_counts = n2o_counts.get(taxid, dict())
        write_tree_row_group(f, taxid, depth, cnt, rcnt, old_counts, iseven)

        iseven = not iseven

        for child_taxid in child_map.get(taxid, []):
            _write_subtree(child_taxid, depth + 1)

    f.write('<table class="taxonomy-tree">\n')
    f.write(TABLE_HEAD)
    f.write('<tbody>\n')
    _write_subtree(root_taxid, 0)
    f.write('</tbody></table>')

In [41]:
def write_report_section(f, sp):
    aids = [assembly_ids[key] for key in genomes_by_species[sp]]
    n2o_counts, n2o_counts_skipped = count_taxon_reassignments(aids)
    
    all_nids = set(n2o_counts).union(n2o_counts_skipped)
    lca = find_lca(all_nids)
    
    child_map, _ = make_tree_child_map(n2o_counts.keys(), lca)
    total = sum(c for ocounts in n2o_counts.values() for c in ocounts.values())
    
    f.write('<h2 id="%s" class="section-header">%s %s (%d genomes)</h2>\n\n' % (report_species_ids[sp], *sp, total))
    write_tree_table(f, child_map, lca, n2o_counts)
    f.write('\n\n')
    
    if n2o_counts_skipped:
        child_map_skipped, _ = make_tree_child_map(n2o_counts_skipped.keys(), lca)
        total_skipped = sum(c for ocounts in n2o_counts_skipped.values() for c in ocounts.values())
        
        f.write('<h3>Skipped genomes (%d)</h3>\n\n' % total_skipped)
        write_tree_table(f, child_map_skipped, lca, n2o_counts_skipped)
        f.write('\n\n')

### Generate report

In [42]:
with open(processed_out / (nbname + '-report.html'), 'w') as f:
    f.write(REPORT_PRE)
    
    write_contents_table(f)
    f.write('\n\n')
    
    for sp in report_species:
        write_report_section(f, sp)
            
    f.write(REPORT_POST)