# Compare variations in taxonomy-based prediction

## Taxonomic classifications

```sh
conda activate qiime2-2023.5

# Prepare directories
mkdir ../../../data/cross_validation_suppl/q2_consensus_blast ../../../data/cross_validation_suppl/q2_consensus_vsearch
for i in `seq 0 9`; do mkdir -p ../../../data/cross_validation_suppl/q2_consensus_blast/split_$i; done
for i in `seq 0 9`; do mkdir -p ../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i; done

# Cross-validation: BLAST-consensus-classifier
qiime feature-classifier classify-consensus-blast --i-query ../../../data/cross_validation/split_$i/test_seq_full.qza --i-reference-reads ../../../data/cross_validation/split_$i/ref_seq_full.qza --i-reference-taxonomy ../../../data/cross_validation/split_$i/ref_taxonomy.qza --o-classification ../../../data/cross_validation_suppl/q2_consensus_blast/split_$i/test_taxonomy.qza --o-search-results ../../../data/cross_validation_suppl/q2_consensus_blast/split_$i/search_results.qza

# Cross-validation: VSEARCH-consensus-classifier
qiime feature-classifier classify-consensus-blast --i-query ../../../data/cross_validation/split_$i/test_seq_full.qza --i-reference-reads ../../../data/cross_validation/split_$i/ref_seq_full.qza --i-reference-taxonomy ../../../data/cross_validation/split_$i/ref_taxonomy.qza --o-classification ../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/test_taxonomy.qza --o-search-results ../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/search_results.qza

# Export: BLAST-consensus-classifier
for i in `seq 0 9`; do qiime tools export --input-path .../../../data/cross_validation_suppl/q2_consensus_blast/split_$i/test_taxonomy.qza --output-path ../../../data/cross_validation_suppl/q2_consensus_blast/split_$i/; done
for i in `seq 0 9`; do qiime tools export --input-path .../../../data/cross_validation_suppl/q2_consensus_blast/split_$i/search_results.qza --output-path .../../../data/cross_validation_suppl/q2_consensus_blast/split_$i/; done

# Export: VSEARCH-consensus-classifier
for i in `seq 0 9`; do qiime tools export --input-path .../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/test_taxonomy.qza --output-path ../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/; done
for i in `seq 0 9`; do qiime tools export --input-path .../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/search_results.qza --output-path .../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/; done
```

In [None]:
import decimal
import json
import os
import math

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

from scipy.stats import mannwhitneyu, wilcoxon, friedmanchisquare
from statsmodels.stats.multitest import multipletests

matplotlib.rcParams['font.family']       = 'Arial'
matplotlib.rcParams['font.sans-serif']   = ["Arial","DejaVu Sans","Lucida Grande","Verdana"]
matplotlib.rcParams['figure.figsize']    = [4,3]
matplotlib.rcParams['font.size']         = 10
matplotlib.rcParams["axes.labelcolor"]   = "#000000"
matplotlib.rcParams["axes.linewidth"]    = 1.0
matplotlib.rcParams["xtick.major.width"] = 1.0
matplotlib.rcParams["ytick.major.width"] = 1.0
cmap1 = plt.cm.tab20
cmap2 = plt.cm.Set3

from matplotlib.ticker import MultipleLocator, FormatStrFormatter

In [None]:
def predict_from_emp(tax_dict, emp_trait, t, is_binary=False):
    emp = emp_trait[t]
    res = None
    for k, v in list(reversed(tax_dict.items())):
        if v in emp[k] and emp[k][v] != "NA":
            res = emp[k][v]
            break
    if res is None:
        return res
    if is_binary:
        res = decimal.Decimal(str(res)).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_DOWN)
    return res

In [None]:
clades = ["superkingdom", "phylum", "class",
          "order", "family", "genus", "species"]
prefix = ["k__", "p__", "c__", "o__", "f__", "g__", "s__"]

nt = ['cell_diameter', 'cell_length', 'doubling_h', 'growth_tmp', 'optimum_tmp', 'optimum_ph', 'genome_size', 'gc_content', 'coding_genes', 'rRNA16S_genes', 'tRNA_genes']

ct = ['gram_stain',
      'sporulation', 'motility', 'range_salinity', 'facultative_respiration',
      'anaerobic_respiration', 'aerobic_respiration', 'mesophilic_range_tmp',
      'thermophilic_range_tmp', 'psychrophilic_range_tmp',
      'bacillus_cell_shape', 'coccus_cell_shape', 'filament_cell_shape',
      'coccobacillus_cell_shape', 'vibrio_cell_shape', 'spiral_cell_shape']

## Prediction from taxonomic classification results

In [None]:
# q2-consensus-blast

for i in range(10):
    # Empirical trait distribution
    emp_file = open(f'../../../data/cross_validation/split_{i}/empirical_dict.json', 'r')
    emp_trait = json.load(emp_file)
    emp_file.close()

    result_path = f"../../../data/cross_validation_suppl/q2_consensus_blast/split_{i}/taxonomy.tsv"
    naive_bayes_result = pd.read_csv(result_path, sep="\t")

    x = naive_bayes_result["Taxon"].str.split("; ", expand=True)
    assigned_clades = clades[:x.shape[1]]
    x = x.set_axis(assigned_clades, axis="columns")
    x = x.fillna("")
    for c, p in zip(assigned_clades, prefix):
        x[c] = x[c].str.replace(p, "")
    naive_bayes_result = pd.concat([naive_bayes_result, x], axis=1)
    naive_bayes_result["tax_dict"] = naive_bayes_result[assigned_clades].apply(
                                            lambda df: df.to_dict(), axis=1
                                        )

    for t in nt:
        naive_bayes_result[t] = naive_bayes_result["tax_dict"].apply(
                                    lambda x: predict_from_emp(x, emp_trait, t)
                                )

    for t in ct:
        naive_bayes_result[t] = naive_bayes_result["tax_dict"].apply(
                                    lambda x: predict_from_emp(x, emp_trait, t, is_binary=True)
                                )

    # Save
    estimation_result_path = f"../../../data/cross_validation_suppl/q2_consensus_blast/split_{i}/estimation_naive.tsv"
    naive_bayes_result.rename(columns={'Feature ID': 'sequence'}, inplace=True)
    naive_bayes_result[["sequence"]+nt+ct].to_csv(estimation_result_path, sep="\t", index=False)

In [None]:
# q2-consensus-vsearch

for i in range(10):
    # Empirical trait distribution
    emp_file = open(f'../../../data/cross_validation/split_{i}/empirical_dict.json', 'r')
    emp_trait = json.load(emp_file)
    emp_file.close()

    result_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/taxonomy.tsv"
    naive_bayes_result = pd.read_csv(result_path, sep="\t")

    x = naive_bayes_result["Taxon"].str.split("; ", expand=True)
    assigned_clades = clades[:x.shape[1]]
    x = x.set_axis(assigned_clades, axis="columns")
    x = x.fillna("")
    for c, p in zip(assigned_clades, prefix):
        x[c] = x[c].str.replace(p, "")

    naive_bayes_result = pd.concat([naive_bayes_result, x], axis=1)

    naive_bayes_result["tax_dict"] = naive_bayes_result[assigned_clades].apply(
                                            lambda df: df.to_dict(), axis=1
                                        )
    for t in nt:
        naive_bayes_result[t] = naive_bayes_result["tax_dict"].apply(
                                    lambda x: predict_from_emp(x, emp_trait, t)
                                )
    for t in ct:
        naive_bayes_result[t] = naive_bayes_result["tax_dict"].apply(
                                    lambda x: predict_from_emp(x, emp_trait, t, is_binary=True)
                                )

    # Save
    estimation_result_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/estimation_naive.tsv"
    naive_bayes_result.rename(columns={'Feature ID': 'sequence'}, inplace=True)
    naive_bayes_result[["sequence"]+nt+ct].to_csv(estimation_result_path, sep="\t", index=False)

## Visualization

In [None]:
def get_pred_and_true_df(pred_vals_path, true_vals_path):
    # Prediction
    pred_vals = pd.read_csv(pred_vals_path, sep="\t")
    pred_vals["sequence"] = pred_vals["sequence"].astype(str)
    # Reference
    true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)
    cmp = pd.merge(pred_vals, true_vals,
                   left_on='sequence', right_on="species_tax_id", how='inner', suffixes=['_e', '_t'])
    return cmp

def remove_null_values(cmp, t, dtype):
    known_flag = (~cmp[t+'_t'].isnull()) & (~cmp[t+'_e'].isnull())
    pred_vals, true_vals = cmp[known_flag][t+'_e'], cmp[known_flag][t+'_t']
    if dtype == 'float':
        pred_vals = pred_vals.astype(float)
        true_vals = true_vals.astype(float)
    elif dtype == 'int':
        pred_vals = pred_vals.astype(int)
        true_vals = true_vals.astype(int)
    return pred_vals, true_vals

In [None]:
def plot_significance_bar(x1, x2, y, h, pval, ax):
    if pval < 0.05:
        ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1, c='black')
        ax.text(float((x1+x2))/2, y+2*h, '*', ha='center', va='top')
    return

def calc_wilcoxon_signed(res_all_melt, t, permutation=False):
    res_trait = res_all_melt[res_all_melt['trait'] == t]
    y1 = res_trait[res_trait['method']=='Naive-Bayes']['accuracy']
    y2 = res_trait[res_trait['method']=='BLAST+']['accuracy']
    y3 = res_trait[res_trait['method']=='VSEARCH']['accuracy']
    p1 = wilcoxon(y1, y2, zero_method='wilcox', alternative='two-sided', method='exact').pvalue
    p2 = wilcoxon(y2, y3, zero_method='wilcox', alternative='two-sided', method='exact').pvalue
    p3 = wilcoxon(y3, y1, zero_method='wilcox', alternative='two-sided', method='exact').pvalue
    return p1, p2, p3

def prep_xy_for_sig_bar(x_num, y, h, s):
    x1s = x_num
    x2s = x_num[1:] + x_num[:1]
    ys = [y, y + s * h, y + 2 * s * h]
    return x1s, x2s, ys

In [None]:
true_vals_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)

for i in range(0, 10):

    # q2_feature_classifier
    nb_result_path = f"../../../data/cross_validation/split_{i}/estimation_naive.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=nb_result_path, true_vals_path=true_vals_path)
    res_list_q2_feature_classifier = []
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_q2_feature_classifier += [pred_vals.corr(true_vals, method="pearson")]

    # q2_consensus_blast
    nb_result_path = f"../../../data/cross_validation_suppl/q2_consensus_blast/split_{i}/estimation_naive.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=nb_result_path, true_vals_path=true_vals_path)
    res_list_q2_consensus_blast = []
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_q2_consensus_blast += [pred_vals.corr(true_vals, method="pearson")]

    # q2_consensus_vsearch
    nb_result_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/estimation_naive.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=nb_result_path, true_vals_path=true_vals_path)
    res_list_q2_consensus_vsearch = []
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_q2_consensus_vsearch += [pred_vals.corr(true_vals, method="pearson")]


    # Concat
    tmp = pd.DataFrame({"trait": nt,
                        "Naive-Bayes": res_list_q2_feature_classifier,
                        "BLAST+": res_list_q2_consensus_blast,
                        "VSEARCH": res_list_q2_consensus_vsearch})
    tmp["split"] = i
    if i == 0:
        res_all_nt = tmp
    else:
        res_all_nt = pd.concat([res_all_nt, tmp], axis=0)

res_all_melt_nt = res_all_nt.melt(id_vars=['trait', 'split'], var_name='method', value_name='accuracy')

In [None]:
titles = {'cell_diameter': 'Cell diameter', 'cell_length': 'Cell length', 'doubling_h': 'Doubling time', 'growth_tmp': 'Growth temp.', 'optimum_tmp': 'Optimum temp.', 'optimum_ph': 'Optimum pH', 'genome_size': 'Genome size', 'gc_content': 'GC content', 'coding_genes': 'Coding genes', 'rRNA16S_genes': 'rRNA16S genes', 'tRNA_genes': 'tRNA genes', 'gram_stain': 'Gram stain', 'sporulation': 'Sporulation', 'motility': 'Motility', 'range_salinity': 'Halophile', 'facultative_respiration': 'Facultative', 'anaerobic_respiration': 'Anaerobe', 'aerobic_respiration':'Aerobe' ,'mesophilic_range_tmp': 'Mesophile', 'thermophilic_range_tmp':'Thermophile', 'psychrophilic_range_tmp': 'Psychrophile', 'bacillus_cell_shape': 'Bacillus', 'coccus_cell_shape': 'Coccus', 'filament_cell_shape': 'Filament', 'coccobacillus_cell_shape': 'Coccobacillus', 'vibrio_cell_shape': 'Vibrio', 'spiral_cell_shape': 'Spiral'}

ntraits = len(nt)
ncols = 4
nrows = math.ceil(float(ntraits) / ncols)
pvals_list = []

fig, axes = plt.subplots(nrows, ncols, figsize=(1.5*ncols, 1.5*nrows))

for i, t in enumerate(nt):
    ax = axes.flatten()[i]
    x = res_all_melt_nt[res_all_melt_nt["trait"]==t]
    sns.boxplot(data=x, x="method", y="accuracy", hue="method", ax=ax,
                palette="Set2", fill=False, linewidth=1.3, showfliers=False)
    sns.stripplot(data=x, x="method", y="accuracy",
                  ax=ax, color="black", size=3, jitter=0.2)
    ax.set_title(titles[t])
    ax.set_xlabel("")
    ax.set_xticks(range(len(x["method"].unique())))
    if i // ncols == nrows - 1:
        ax.set_xticklabels(["Naïve-Bayes", "BLAST+", "VSEARCH"], rotation=90)
    else:
        ax.set_xticklabels([])
    if i % ncols == 0:
        ax.set_ylabel("Accuracy")
    else:
        ax.set_ylabel("")
    ax.set_ylim(-.05, 1.05)
    ax.yaxis.set_major_locator(MultipleLocator(0.50))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    p1, p2, p3 = calc_wilcoxon_signed(res_all_melt=res_all_melt_nt, t=t)
    pvals_list += [[p1, p2, p3]]

pvals_1d = [p for pvals in pvals_list for p in pvals if not np.isnan(p)]
pvals_idx = [np.isnan(pvals[0]) or np.isnan(pvals[1]) or np.isnan(pvals[2]) for pvals in pvals_list]
fdr_result = multipletests(pvals=pvals_1d, alpha=0.05, method="fdr_bh")
pvals_modified_1d = fdr_result[1]
pvals_modified_2d = pvals_modified_1d.reshape(-1, 3)

for pvals, idx, ax in zip(pvals_modified_2d, pvals_idx, axes.flatten()):
    if idx:
        continue
    x1s, x2s, ys = prep_xy_for_sig_bar(x_num=list(np.arange(3)), y=0.25, h=-0.01, s=8)
    for x1, x2, y, pval in zip(x1s, x2s, ys, pvals):
        plot_significance_bar(x1=x1, x2=x2, y=y, h=-0.02, pval=pval, ax=ax)

axes.flatten()[-1].set_axis_off()
fig.supxlabel("Variation of taxonomic classification methods")
fig.supylabel("Pearson Correlation Coefficients of predicted and actual trait values")

plt.tight_layout()
plt.savefig("../../../results/09_cross_validation_suppl/figS3a.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)

In [None]:
true_vals_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)

for i in range(0, 10):

    # q2_feature_classifier
    nb_result_path = f"../../../data/cross_validation/split_{i}/estimation_naive.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=nb_result_path, true_vals_path=true_vals_path)
    res_list_q2_feature_classifier = [] # 予測値と真の値の相関
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_q2_feature_classifier += [pred_vals.corr(true_vals, method="pearson")]

    # q2_consensus_blast
    nb_result_path = f"../../../data/cross_validation_suppl/q2_consensus_blast/split_{i}/estimation_naive.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=nb_result_path, true_vals_path=true_vals_path)
    res_list_q2_consensus_blast = [] # 予測値と真の値の相関
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_q2_consensus_blast += [pred_vals.corr(true_vals, method="pearson")]

    # q2_consensus_vsearch
    nb_result_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/estimation_naive.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=nb_result_path, true_vals_path=true_vals_path)
    res_list_q2_consensus_vsearch = [] # 予測値と真の値の相関
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_q2_consensus_vsearch += [pred_vals.corr(true_vals, method="pearson")]


    # Concat
    tmp = pd.DataFrame({"trait": ct,
                        "Naive-Bayes": res_list_q2_feature_classifier,
                        "BLAST+": res_list_q2_consensus_blast,
                        "VSEARCH": res_list_q2_consensus_vsearch})
    tmp["split"] = i
    if i == 0:
        res_all_ct = tmp
    else:
        res_all_ct = pd.concat([res_all_ct, tmp], axis=0)

res_all_melt_ct = res_all_ct.melt(id_vars=['trait', 'split'], var_name='method', value_name='accuracy')

In [None]:
ntraits = len(ct)
ncols = 4
nrows = math.ceil(float(ntraits) / ncols)
pvals_list = []

fig, axes = plt.subplots(nrows, ncols, figsize=(1.5*ncols, 1.5*nrows))

for i, t in enumerate(ct):
    ax = axes.flatten()[i]
    x = res_all_melt_ct[res_all_melt_ct["trait"]==t]
    sns.boxplot(data=x, x="method", y="accuracy", hue="method", ax=ax, order=["Naive-Bayes", "BLAST+", "VSEARCH"],
                palette="Set2", fill=False, linewidth=1.3, showfliers=False)
    sns.stripplot(data=x, x="method", y="accuracy", order=["Naive-Bayes", "BLAST+", "VSEARCH"],
                  ax=ax, color="black", size=3, jitter=0.2)
    ax.set_title(titles[t])
    ax.set_xlabel("")
    ax.set_xticks(range(len(x["method"].unique())))
    if i // ncols == nrows - 1:
        ax.set_xticklabels(["Naïve-Bayes", "BLAST+", "VSEARCH"], rotation=90)
    else:
        ax.set_xticklabels([])
    if i % ncols == 0:
        ax.set_ylabel("Accuracy")
    else:
        ax.set_ylabel("")
    ax.set_ylim(-.05, 1.05)
    ax.yaxis.set_major_locator(MultipleLocator(0.50))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    p1, p2, p3 = calc_wilcoxon_signed(res_all_melt=res_all_melt_ct, t=t)
    pvals_list += [[p1, p2, p3]]

pvals_1d = [p for pvals in pvals_list for p in pvals if not np.isnan(p)]
pvals_idx = [np.isnan(pvals[0]) or np.isnan(pvals[1]) or np.isnan(pvals[2]) for pvals in pvals_list]
fdr_result = multipletests(pvals=pvals_1d, alpha=0.05, method="fdr_bh")
pvals_modified_1d = fdr_result[1]
pvals_modified_2d = pvals_modified_1d.reshape(-1, 3)

for pvals, idx, ax in zip(pvals_modified_2d, pvals_idx, axes.flatten()):
    if idx:
        continue
    x1s, x2s, ys = prep_xy_for_sig_bar(x_num=list(np.arange(3)), y=0.25, h=-0.01, s=8)
    for x1, x2, y, pval in zip(x1s, x2s, ys, pvals):
        plot_significance_bar(x1=x1, x2=x2, y=y, h=-0.02, pval=pval, ax=ax)

fig.supxlabel("Variation of taxonomic classification methods")
fig.supylabel("Matthews Correlation Coefficients of predicted and actual trait values")

plt.tight_layout()
plt.savefig("../../../results/09_cross_validation_suppl/figS3b.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)