# Inspect taxonomy of marferret and marmicrodb



## Setup

In [1]:
import os 
import gc
import re
import csv
import glob
import math
import umap
import json
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
from time import time
from tqdm import tqdm
from scipy import stats
from collections import * 
from sklearn import cluster
from sklearn import decomposition
from ete4 import NCBITaxa, Tree
import matplotlib.pyplot as plt
import matplotlib.colors as pltc
from scipy.spatial import distance
from scipy.cluster import hierarchy
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches


In [2]:
import sys
sys.path.append('../repo-armbrust-metat-search')

In [3]:
import functions.fn_metat_files as fnf

In [4]:
ncbi = NCBITaxa()

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
os.getcwd()

In [7]:
workdir = '/scratch/bgrodner/iron_ko_contigs'
os.chdir(workdir)


In [8]:
os.getcwd()

In [9]:
os.listdir()

Plotting

In [10]:
def general_plot(
    xlabel="", ylabel="", ft=12, dims=(5, 3), col="k", lw=1, pad=0, tr_spines=True
):
    fig, ax = plt.subplots(figsize=(dims[0], dims[1]), tight_layout={"pad": pad})
    for i in ax.spines:
        ax.spines[i].set_linewidth(lw)
    if not tr_spines:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    else:
        ax.spines["top"].set_color(col)
        ax.spines["right"].set_color(col)
    ax.spines["bottom"].set_color(col)
    ax.spines["left"].set_color(col)
    ax.tick_params(direction="in", labelsize=ft, color=col, labelcolor=col)
    ax.set_xlabel(xlabel, fontsize=ft, color=col)
    ax.set_ylabel(ylabel, fontsize=ft, color=col)
    ax.patch.set_alpha(0)
    return (fig, ax)

def plot_umap(
    embedding,
    figsize=(10, 10),
    markersize=10,
    alpha=0.5,
    colors="k",
    xticks=[],
    yticks=[],
    markerstyle='o',
    cmap_name='tab20',
    cl_lab=False
):
    fig, ax = general_plot(dims=figsize)
    if isinstance(markerstyle, str):
        ax.scatter(
            embedding[:, 0],
            embedding[:, 1],
            s=markersize,
            alpha=alpha,
            c=colors,
            edgecolors="none",
            marker=markerstyle,
            cmap=cmap_name
        )
    else:
        for e0, e1, c, m in zip(
            embedding[:, 0], 
            embedding[:, 1],
            colors,
            markerstyle 
        ):
            ax.scatter(
                e0,
                e1,
                s=markersize,
                alpha=alpha,
                c=c,
                edgecolors="none",
                marker=m
            )
    ax.set_aspect("equal")
    if len(xticks) > 0:
        ax.set_xticks(xticks)
    if len(yticks) > 0:
        ax.set_yticks(yticks)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    return fig, ax


#### Get KO dict

Get dataframe

In [11]:
ko_fn = "ko00001.json"
database = list()
for _, v in pd.read_json(ko_fn).iterrows():
    d = v["children"]
    cat_1 = d["name"]
    for child_1 in d["children"]:
        cat_2 = child_1["name"] # Module?
        for child_2 in child_1["children"]:
            cat_3 = child_2["name"]
            if "children" in child_2:
                for child_3 in child_2["children"]:
                    cat_4 = child_3["name"]
                    fields = [cat_1, cat_2, cat_3, cat_4]
                    database.append(fields)
df_kegg = pd.DataFrame(database, columns=["Level_A", "Level_B", "Level_C", "Level_D"])
df_kegg.shape


In [12]:
ld = df_kegg['Level_D'].values
ld[:5]

In [13]:
dict_ko_name = {}
for name in ld:
    ko = re.search(r"^\w+",name)[0]
    dict_ko_name[ko] = name

## Inspect marmicrodb table

Load file   

In [None]:
fn_marmicro = '/scratch/bgrodner/iron_ko_contigs/MARMICRODB_catalog.tsv'

# marmicro = pd.read_csv(fn_marmicro, on_bad_lines='warn')
# marmicro.shape


i

Checkout skipped lines

In [25]:
text_skipped = '''
    Skipping line 15696: expected 1 fields, saw 3
    Skipping line 15712: expected 1 fields, saw 5
    Skipping line 15713: expected 1 fields, saw 5
    Skipping line 15714: expected 1 fields, saw 5
    Skipping line 15715: expected 1 fields, saw 5
    Skipping line 15716: expected 1 fields, saw 5
    Skipping line 15717: expected 1 fields, saw 5
    Skipping line 15718: expected 1 fields, saw 5
    Skipping line 18585: expected 1 fields, saw 2
    Skipping line 18587: expected 1 fields, saw 2
    Skipping line 18605: expected 1 fields, saw 2
    Skipping line 18642: expected 1 fields, saw 2
    Skipping line 18646: expected 1 fields, saw 2
    Skipping line 18657: expected 1 fields, saw 2
    Skipping line 18658: expected 1 fields, saw 2
    Skipping line 18660: expected 1 fields, saw 2
    Skipping line 18661: expected 1 fields, saw 2
    Skipping line 18664: expected 1 fields, saw 2
    Skipping line 18677: expected 1 fields, saw 2
    Skipping line 18701: expected 1 fields, saw 2
    Skipping line 18712: expected 1 fields, saw 2
    Skipping line 18713: expected 1 fields, saw 2
    Skipping line 18732: expected 1 fields, saw 2
    Skipping line 18741: expected 1 fields, saw 2
    Skipping line 18742: expected 1 fields, saw 2
    Skipping line 18743: expected 1 fields, saw 2
    Skipping line 18744: expected 1 fields, saw 2
    Skipping line 18760: expected 1 fields, saw 2
'''
set_skipped = re.findall(r'(?<=line\s)\d+',text_skipped)
set_skipped = set([int(l) for l in set_skipped])
set_skipped.add(15694)

i = 0
with open(fn_marmicro, 'r') as f:
    _ = next(f)
    for row in f:
        if i in set_skipped:
            print(i, len(row.split('\t')))
        i += 1


Load with csv reader

In [176]:
lens = []
with open(fn_marmicro, 'r') as f:
    reader = csv.DictReader(f, delimiter='\t')
    for row in reader:
        lens.append(len(row))

set(lens)


In [30]:
dict_marmicro_col_row = {}
with open(fn_marmicro, 'r') as f:
    reader = csv.reader(f, delimiter='\t')
    header = next(reader)
    for h in header:
        dict_marmicro_col_row[h] = []
    for row in reader:
        for c, v in zip(header, row):
            dict_marmicro_col_row[c].append(v)

marmicro = pd.DataFrame(dict_marmicro_col_row)
marmicro.shape, marmicro[:10]

Get a list of taxids not in NCBI

In [44]:
not ncbi.get_taxid_translator([12345647278763])

In [170]:
tax_missing =[]
taxids = marmicro['taxid'].values
for idx, t in enumerate(taxids):
    try:
        int(t)           
    except:
        tax_missing.append(idx)

marmicro.iloc[tax_missing, :]

In [47]:
tax_missing =[]
taxids = marmicro['taxid'].values
for idx, t in enumerate(taxids):
    try:
        d = ncbi.get_taxid_translator([t])
        if not d:
            tax_missing.append(idx)
    except:
        tax_missing.append(idx)

marmicro.iloc[tax_missing, :]


In [169]:
tax_missing =[]
taxids = marmicro['MARMICRODBtaxid'].values
for idx, t in enumerate(taxids):
    try:
        int(t)           
    except:
        tax_missing.append(idx)

marmicro.iloc[tax_missing, :]

Get taxids for the NA taxids

In [128]:
names_na = marmicro.loc[marmicro['taxid'] == 'NA', 'full_name'].values
names_na_fix = []
for n in names_na:
    if "MIT" in n:
        n = re.sub('MIT', 'MIT ', n)
    elif '150SLHB' in n:
        n = 'Prochlorococcus sp. P1344'
    elif '150SLHA' in n:
        n = 'Prochlorococcus sp.P1363'    
    elif '150NLHA' in n:
        n = 'Prochlorococcus sp. P1361'
    if '1418' in n:
        n = 'Prochlorococcus sp.'
    matches = [s in n for s in ['1013','1214','0918','0919',]]
    if any(matches):
        n = re.sub('coccus sp.', 'coccus marinus str.',n)

    names_na_fix.append(n)
names_na_fix

In [129]:
mmdbtaxid_na = marmicro.loc[marmicro['taxid'] == 'NA', 'MARMICRODBtaxid'].values
mmdbtaxid_na

In [140]:
trans = {
    'MMDB taxid':[],
    'NCBI taxid':[],
    'MMDB name':[],
    'NCBI name':[],
}
for t, n, nf in zip(mmdbtaxid_na, names_na, names_na_fix):
    d = ncbi.get_name_translator([nf])
    if d:
        tf = d[nf][0]
    else:
        tf = ''
        nf = ''
    trans['MMDB taxid'].append(t)
    trans['NCBI taxid'].append(tf)
    trans['MMDB name'].append(n)
    trans['NCBI name'].append(nf)
trans_df = pd.DataFrame(trans)
trans_df.to_csv('/scratch/bgrodner/Resources/MARMICRODB_prochlorococcus_taxids.csv')
trans_df

In [57]:
marmicro_mismatch = marmicro[marmicro['MARMICRODBtaxid'] != marmicro['taxid']]
marmicro_mismatch

In [58]:
marmicro_mismatch[marmicro_mismatch['sequence_type'] == 'isolate']

In [59]:
bool_isolate = marmicro_mismatch['sequence_type'] == 'isolate'
bool_pro = marmicro_mismatch['taxgroup'] == 'prochlorococcus'
marmicro_mismatch[bool_isolate & bool_pro]

In [174]:
bool_test = marmicro['genome'] == 'MMP03755233'
marmicro[bool_test]

In [182]:
28164865-27841030

In [179]:
bool_test = marmicro['MARMICRODBtaxid'] == '1577725'
marmicro[bool_test]

Compare mmmdb ncbi mapping to actual taxid ncbi mapping

In [165]:
for i, row in marmicro_mismatch.iterrows():
    t = row.taxid  
    n = ncbi.get_taxid_translator([t])[int(t)] if t != 'NA' else 'NA'
    mt = row.MARMICRODBtaxid
    mn = ncbi.get_taxid_translator([mt])
    mn = mn[int(mt)] if mn else 0
    print(n, '\t', mn)

Compare written lineage to ncbi lineage

In [69]:
bool_test = marmicro['MARMICRODBtaxid'] == '2182669'
marmicro.loc[bool_test, 'lineage_assignment'].values[0], ncbi.get_taxid_translator(ncbi.get_lineage(2162565))

Check lineage assignment

In [87]:
tax_mismatch = []
taxids = marmicro['taxid'].values
linassgns = marmicro['lineage_assignment'].values
for t, l in zip(taxids, linassgns):
    if t != 'NA':
        t_ncbi = ncbi.get_taxid_translator([t])[int(t)]
        lin_marmicro = l.split(';')
        match = 0
        for tm in lin_marmicro:
            if tm in t_ncbi:
                match = 1
        if not match:
            tax_mismatch.append([t, t_ncbi, l])

len(tax_mismatch)

In [86]:
'Bathyarchaeota' in 'Candidatus Bathyarchaeota archaeon UBA185'

In [88]:
tax_mismatch

In [77]:
lin_ncbi, lin_marmicro[-1], ncbi.get_taxid_translator([t])[int(t)]

List of failed to fetch

In [158]:
f2f = '''2182663
2182826
2182663
2182663
2182826
2182826
2182863
2182863
2182663
2183026
2183027
2183026
2183026
2183026
2182863
2183014'''
f2f = set(f2f.split('\n'))

for t in f2f:
    print(t)
    bool_test = marmicro['MARMICRODBtaxid'] == t
    print(marmicro[bool_test].values)
    bool_test = marmicro['taxid'] == t
    print('\n',marmicro[bool_test].values)
    print('\n')

How many mmdbtaxids don't have ncbi taxids?

In [159]:
marmicro.columns

In [160]:
idx_noncbi =[]
taxids = marmicro['MARMICRODBtaxid'].values
for idx, t in enumerate(taxids):
    try:
        d = ncbi.get_taxid_translator([t])
        if not d:
            idx_noncbi.append(idx)
    except:
        idx_noncbi.append(idx)

marmicro.iloc[idx_noncbi, :].shape

## Compare old diamond run to fixed run

TEst files

In [183]:
fn_tax = '/mnt/nfs/projects/armbrust-metat/gradients2/g2_station_ns_metat/assemblies/MarMicro_MarFerr_Diamond_2024_04_14/G2NS.S02C1.15m.0_2um.MarFer_MMDB.tab'
fn_ec = '/mnt/nfs/projects/armbrust-metat/gradients2/g2_station_ns_metat/assemblies/ReadCounts/G2NS.S02C1.15m.0_2um.A/G2NS.S02C1.15m.0_2um.A.tsv'

Get taxa dict

In [184]:
dict_tax_contigs = defaultdict(list)
with open(fn_tax, 'r') as f:
    for row in f:
        contig, taxid, _ = row.split('\t')
        dict_tax_contigs[taxid].append(contig)

In [187]:
len(dict_tax_contigs['131567']), len(dict_tax_contigs['35679'])

Get estcounts dict

In [188]:
dict_contig_ec = {}
with open(fn_ec, 'r') as f:
    _ = next(f)
    for row in f:
        contig, _, _, ec, _ = row.split('\t')
        dict_contig_ec[contig] = ec

Get estcounts for each taxon

In [202]:
dict_tax_ec = {}
for tax, contigs in dict_tax_contigs.items():
    ec_sum = 0
    for c in contigs:
        ec = dict_contig_ec.get(c)
        if ec:
            ec_sum += float(ec)
    dict_tax_ec[tax] = ec_sum



Build tree

In [206]:
taxids = [t for t in dict_tax_ec.keys() if int(t) > 0]
tree = ncbi.get_topology(taxids)
for n in tree.traverse():
    ec = dict_tax_ec.get(n.name)
    n.add_props(estcounts=ec)

print(tree.to_str(props=['sci_name','estcounts'], compact=True))

Get fraction of reads at each node

In [223]:
total_ec = sum([ec for ec in dict_tax_ec.values()])
for n in tree.traverse():
    ec_sum = 0
    for n_d in n.descendants():
        ec = n_d.props['estcounts']
        if ec:
            ec_sum += ec
    n.add_props(ec_descendants=round(ec_sum), pct_ec=str(round(ec_sum/total_ec*100,4)) + '%')

print(tree.to_str(props=['sci_name','ec_descendants', 'pct_ec'], compact=True))

In [238]:
dict_tax_ecstuff = {}
taxids_trim = []
for n in tree.traverse():
    pct_ec = float(n.props['ec_descendants'])
    if pct_ec > 0:
        taxids_trim.append(n.name)
        dict_tax_ecstuff[n.name] = [n.props['ec_descendants'], n.props['pct_ec']]

tree_trim = ncbi.get_topology(taxids_trim)
for n in tree_trim.traverse():
    ed, pe = dict_tax_ecstuff[n.name]
    n.add_props(ec_descendants=ed, pct_ec=pe)

print(tree_trim.to_str(props=['sci_name','ec_descendants', 'pct_ec'], compact=True))

In [239]:
dict_tax_ecstuff = {}
taxids_trim = []
for n in tree.traverse():
    pct_ec = float(n.props['pct_ec'].strip('%'))
    if pct_ec > 0.1:
        taxids_trim.append(n.name)
        dict_tax_ecstuff[n.name] = [n.props['ec_descendants'], n.props['pct_ec']]

tree_trim = ncbi.get_topology(taxids_trim)
for n in tree_trim.traverse():
    ed, pe = dict_tax_ecstuff[n.name]
    n.add_props(ec_descendants=ed, pct_ec=pe)

print(tree_trim.to_str(props=['sci_name','ec_descendants', 'pct_ec'], compact=True))