In [1]:
import shelve
from typing import Union
import marisa_trie


class WrappedTrie:
    def __init__(self, trie: marisa_trie.RecordTrie, stringify: bool):
        self.trie = trie
        self.stringify = stringify

    def _extract(self, raw):
        if self.stringify:
            head = [elem.decode() if isinstance(elem, bytes) else str(elem) for elem in raw[0]]
        else:
            head = list(raw[0])
        if len(head) == 1:
            return head[0]
        return head

    def __getitem__(self, key: str):
        return self._extract(self.trie[str(key)])

    def __contains__(self, key):
        return key in self.trie

    def __enter__(self):
        return self

    def __exit__(self, type, value, tb):
        pass

    def get(self, name, default=None):
        r = self.trie.get(str(name))
        if r:
            return self._extract(r)
        return default


def open_file_db_by_extension(db_path: str, fmt: Union[str, None] = None, stringify=True):
    if db_path.endswith(".db"):
        return shelve.open(db_path.replace('.db', ''), 'r')
    assert fmt, "fmt is required for loading marisa trie key value stores"
    return WrappedTrie(marisa_trie.RecordTrie(fmt).mmap(db_path), stringify)

In [7]:
accession_to_taxid_path = "./accession2taxid.marisa"
with open_file_db_by_extension(accession_to_taxid_path, "L") as accession_to_taxid:
    # don't include version when querying (ex query for AP009048 instead of AP009048.1)
    taxid = accession_to_taxid.get("AP009048", "NA")
    print(taxid)


316407


In [12]:
lineage_path = "./taxid-lineages.marisa"
with open_file_db_by_extension(accession_to_taxid_path, "lll") as lineages:
    # use taxid (go from accession to taxid using above mapping)
    lineage = lineages.get("37124", "NA")
    print(taxid)

316407


In [17]:
loc_path = "./nt_loc.marisa"
with open_file_db_by_extension(loc_path, "QII") as loc_mapping:
    # use accession ID with version
    loc = loc_mapping.get("AP009048.1", "NA")
    print(loc)

['633633307925', '74', '4646333']


In [18]:
info_path = "./nt_info.marisa"
with open_file_db_by_extension(info_path, "256pI") as info_mapping:
    # use accession ID with version
    info = info_mapping.get("AP009048.1", "NA")
    print(info)

['Escherichia coli str. K-12 substr. W3110 DNA, complete genome', '4646332']
