# 201102 Download taxa

In [None]:
import json
from pathlib import Path
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta
from time import sleep
from urllib.error import HTTPError

from tqdm import tqdm
from Bio import Entrez

In [None]:
Entrez.email = 'mjlumpe@gmail.com'

## Paths

In [None]:
tmpdir = Path('tmp')

In [None]:
outdir = Path('../../data/intermediate/201031-database-v1.1-software-version-migration/201102-download-taxa/')
outdir.mkdir(parents=True, exist_ok=True)

## Func defs

In [None]:
def throttle(min_interval, max_attempts=5):
    """Decorator to throttle API request function."""
    
    def decorator(f):
        last_request = datetime.now() - timedelta(seconds=min_interval)
        
        def rate_limited(*args, **kwargs):
            nonlocal last_request
            
            last_exc = None
            
            for i in range(max_attempts):
                
                # Sleep until min interval reached if necessary
                interval = (datetime.now() - last_request).total_seconds()
                if interval < min_interval:
                    sleep(min_interval - interval)
                    
                # Attempt request, catching HTTPError's
                try:
                    return f(*args, **kwargs)
                
                except HTTPError as e:
                    last_exc = e
                    
                finally:
                    # Update last request time whether we succeeded or not
                    last_request = datetime.now()
                    
            # Reached max # of attempts
            raise RuntimeError('Exceeded %d attempts' % max_attempts) from last_exc
                
        return rate_limited
    
    return decorator

In [None]:
def parse_date(datestr):
    return datetime.strptime(datestr, '%Y/%m/%d %H:%M:%S')

def taxon_xml_to_json(txml):
    """Convert parsed taxon XML element to JSON-like format that's much easier to work with."""
    assert txml.tag == 'Taxon'
    
    return dict(
        taxid=int(txml.findtext('TaxId')),
        parent_taxid=int(txml.findtext('ParentTaxId')),
        scientific_name=txml.findtext('ScientificName'),
        rank=txml.findtext('Rank'),
        division=txml.findtext('Division'),
        create_date=parse_date(txml.findtext('CreateDate')),
        update_date=parse_date(txml.findtext('UpdateDate')),
        pub_date=parse_date(txml.findtext('PubDate')),
        aka_taxids=[int(e.text) for e in txml.findall('./AkaTaxIds/TaxId')],
    )

In [None]:
def efetch_taxa(taxids):
    taxids = list(taxids)
    resp = Entrez.efetch(db='taxonomy', id=taxids)
    doc = ET.parse(resp)
    root = doc.getroot()
    assert root.tag == 'TaxaSet'
    
    taxa = dict()
    
    for txml in root.findall('./Taxon'):
        tdata = taxon_xml_to_json(txml)
        
        # Get primary or alternate taxon ID that was passed to the function
        taxid = tdata['taxid']
        if taxid not in taxids:
            for id2 in tdata['aka_taxids']:
                if id2 in taxids:
                    taxid = id2
                    break
            else:
                raise RuntimeError('Could not determine requested taxid')
                
        taxa[taxid] = (txml, tdata)
    
    return taxa

efetch_taxa_throttled = throttle(1/3)(efetch_taxa)

## Download taxonomy data

### Setup

In [None]:
with open(tmpdir / 'genome_taxids.json') as f:
    genome_taxids = json.load(f)

In [None]:
tax_dir = tmpdir / 'taxa'
tax_dir.mkdir(exist_ok=True)

In [None]:
taxa_to_download = set(genome_taxids)
taxon_data = dict()
aka_taxids = dict()  # Map alternate taxon IDs to their canonical values

In [11]:
ignore_parenttaxids = [0]  # Stop climbing lineage at these IDs

In [12]:
def record_taxon(taxid, txml, tdata=None):
    """Add downloaded taxonomy data to our list, add parent to download list if needed."""
    if tdata is None:
        tdata = taxon_xml_to_json(txml)
    true_taxid = tdata['taxid']
    
    # Check we have the expected ID
    assert true_taxid == taxid or taxid in tdata['aka_taxids']
    
    taxon_data[true_taxid] = tdata
    taxa_to_download.remove(taxid)
    
    # Record alternate ids
    for id2 in tdata['aka_taxids']:
        assert id2 not in aka_taxids
        aka_taxids[id2] = true_taxid
    
    # Add parent to download list if we don't have it yet
    parentid = tdata['parent_taxid']
    if parentid not in ignore_parenttaxids and parentid not in taxon_data:
        taxa_to_download.add(parentid)

### Download

In [13]:
chunk_size = 100

with tqdm(total=len(taxa_to_download) + len(taxon_data), initial=len(taxon_data)) as pbar:
    while taxa_to_download:
        
        # Find next chunk of IDs to download
        next_chunk = []
        
        for taxid in list(taxa_to_download):
            dst = tax_dir / ('%d.xml' % taxid)
            
            if dst.is_file():
                with dst.open() as f:
                    txml = ET.parse(f).getroot()
                record_taxon(taxid, txml)
                    
            else:
                next_chunk.append(taxid)
                if len(next_chunk) >= chunk_size:
                    break
                    
        if next_chunk:

            # Fetch
            chunk_taxa = efetch_taxa_throttled(next_chunk)

            # Add results
            for taxid, (txml, tdata) in chunk_taxa.items():
                record_taxon(taxid, txml, tdata)

                # Write to file
                dst = tax_dir / ('%d.xml' % taxid)
                assert not dst.is_file()
                with dst.open('wb') as f:
                    f.write(ET.tostring(txml))
                
        # Update progress bar
        total = len(taxon_data) + len(taxa_to_download)
        if pbar.total != total:
            pbar.total = total
        pbar.n = len(taxon_data)
        pbar.refresh()

100%|██████████| 22474/22474 [00:03<00:00, 7214.92it/s]


### Consistency checking

In [14]:
# Check we have one entry per set of aliases
for (id2, taxid) in aka_taxids.items():
    assert id2 not in taxon_data 
    assert taxid in taxon_data

## Output to JSON format

In [15]:
class TaxonEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, datetime):
            return obj.isoformat()
        return super().default(obj)

In [16]:
with open(outdir / 'taxa.json', 'w') as f:
    json.dump(list(taxon_data.values()), f, cls=TaxonEncoder)

In [17]:
with open(outdir / 'aka_taxids.json', 'w') as f:
    json.dump(aka_taxids, f)

In [18]:
genome_true_taxids = list({aka_taxids.get(tid, tid) for tid in genome_taxids})

with open(tmpdir / 'genome_true_taxids.json', 'w') as f:
    json.dump(genome_true_taxids, f)