# Compare variations in homology-based prediction

## Homology search

```sh
conda activate qiime2-2023.5

# Prepare directories
mkdir ../../../data/cross_validation_suppl/sortmerna
for i in `seq 0 9`; do mkdir -p ../../../data/cross_validation_suppl/sortmerna/split_$i; done

# VSEARCH: export previous results
for i in `seq 0 9`; do qiime tools export --input-path ../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/search_results.qza --output-path ../../../data/cross_validation_suppl/q2_consensus_vsearch/split_$i/; done

# SortMeRNA: Make reference
for i in `seq 0 9`; do indexdb_rna --ref ../../../data/cross_validation/split_$i/ref_seq_full.fasta,../../../data/cross_validation_suppl/sortmerna/split_$i/ref_seq_full.idx; done

# SortMeRNA: Cross-validation
for i in `seq 0 9`; do sortmerna --ref ../../../data/cross_validation/split_$i/ref_seq_full.fasta,../../../data/cross_validation_suppl/sortmerna/split_$i/ref_seq_full.idx --reads ../../../data/cross_validation/split_$i/test_seq_full.fasta --blast 1 --other ../../../data/cross_validation_suppl/sortmerna/split_$i/sortmerna_out --fastx --best 10 --min_lis 10 --aligned ../../../data/cross_validation_suppl/sortmerna/split_$i/aligned.out --log; done
```

In [None]:
import json
import os
import math

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

from scipy.stats import mannwhitneyu, wilcoxon, friedmanchisquare
from statsmodels.stats.multitest import multipletests

matplotlib.rcParams['font.family']       = 'Arial'
matplotlib.rcParams['font.sans-serif']   = ["Arial","DejaVu Sans","Lucida Grande","Verdana"]
matplotlib.rcParams['figure.figsize']    = [4,3]
matplotlib.rcParams['font.size']         = 10
matplotlib.rcParams["axes.labelcolor"]   = "#000000"
matplotlib.rcParams["axes.linewidth"]    = 1.0
matplotlib.rcParams["xtick.major.width"] = 1.0
matplotlib.rcParams["ytick.major.width"] = 1.0
cmap1 = plt.cm.tab20
cmap2 = plt.cm.Set3

In [None]:
trait = pd.read_csv("../../../data/ref_bac2feature/trait_bac2feature.tsv", sep="\t")
trait["species_tax_id"] = trait["species_tax_id"].astype("str")

## Prediction from homology search results

In [None]:
def preprocess_blast_result(blast_result: pd.DataFrame, perc_identity: float):
    """Preprocess the data according to the alignment length and percent identity."""
    # Convert columns to numeric
    blast_result[['pident', 'length']] = blast_result[['pident', 'length']].astype(float)

    # Filter by percentage identity, if provided
    if perc_identity is not None:
        blast_result = blast_result[blast_result['pident'] >= perc_identity]

    # Exclude very short sequences based on half the average length
    min_length = blast_result['length'].mean() * 0.5
    blast_result = blast_result[blast_result['length'] >= min_length]
    blast_result.reset_index()

    return blast_result

def summarize_traits(res_trait, trait_cols):
    """Summarize traits for each sequence based on non-null entries."""
    bests = [
        res_trait.loc[res_trait[col].notnull()].drop_duplicates(subset='sequence', keep='first')[[col, 'pident']].rename(columns={'pident': f'{col}_pident'})
        for col in trait_cols
    ]
    out_trait = pd.concat(bests, axis=1)
    out_trait = pd.concat([res_trait['sequence'].drop_duplicates(keep='first'), out_trait], axis=1)
    return out_trait

In [None]:
for i in range(10):
    in_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/blast6.tsv"
    blast_result = pd.read_csv(in_path, sep="\t", header=None)
    blast_result.columns = ["sequence", "species_tax_id", "pident", "length", "mismatch", "gapopen", "qstart", "qend", "sstart", "send", "evalue", "bitscore"]
    blast_result["species_tax_id"] = blast_result["species_tax_id"].astype("str")

    # Filter the blast result based on the percent identity
    blast_result = preprocess_blast_result(blast_result, perc_identity=None)

    # Merge with trait data
    res_trait = pd.merge(blast_result, trait, how='left', on='species_tax_id')

    # Summarize the predicted traits for each sequences
    res_trait.set_index('sequence', drop=False, inplace=True)
    summarized_trait = summarize_traits(res_trait, trait.columns[1:])

    out_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/prediction_homology.tsv"
    summarized_trait.to_csv(out_path, sep="\t", index=False)

In [None]:
for i in range(10):
    in_path = f"../../../data/cross_validation_suppl/sortmerna/split_{i}/aligned.out.blast"
    blast_result = pd.read_csv(in_path, sep="\t", header=None)
    blast_result.columns = ["sequence", "species_tax_id", "pident", "length", "mismatch", "gapopen", "qstart", "qend", "sstart", "send", "evalue", "bitscore"]
    blast_result["species_tax_id"] = blast_result["species_tax_id"].astype("str")

    # Filter the blast result based on the percent identity
    blast_result = preprocess_blast_result(blast_result, perc_identity=None)

    # Merge with trait data
    res_trait = pd.merge(blast_result, trait, how='left', on='species_tax_id')

    # Summarize the predicted traits for each sequences
    res_trait.set_index('sequence', drop=False, inplace=True)
    summarized_trait = summarize_traits(res_trait, trait.columns[1:])

    out_path = f"../../../data/cross_validation_suppl/sortmerna/split_{i}/prediction_homology.tsv"
    summarized_trait.to_csv(out_path, sep="\t", index=False)

## Visualization

In [None]:
def get_pred_and_true_df(pred_vals_path, true_vals_path):
    # Prediction
    pred_vals = pd.read_csv(pred_vals_path, sep="\t")
    pred_vals["sequence"] = pred_vals["sequence"].astype(str)
    # Reference
    true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)
    cmp = pd.merge(pred_vals, true_vals,
                   left_on='sequence', right_on="species_tax_id", how='inner', suffixes=['_e', '_t'])
    return cmp

def remove_null_values(cmp, t, dtype):
    known_flag = (~cmp[t+'_t'].isnull()) & (~cmp[t+'_e'].isnull())
    pred_vals, true_vals = cmp[known_flag][t+'_e'], cmp[known_flag][t+'_t']
    if dtype == 'float':
        pred_vals = pred_vals.astype(float)
        true_vals = true_vals.astype(float)
    elif dtype == 'int':
        pred_vals = pred_vals.astype(int)
        true_vals = true_vals.astype(int)
    return pred_vals, true_vals

nt = ['cell_diameter', 'cell_length', 'doubling_h', 'growth_tmp', 'optimum_tmp', 'optimum_ph', 'genome_size', 'gc_content', 'coding_genes', 'rRNA16S_genes', 'tRNA_genes']

ct = ['gram_stain',
      'sporulation', 'motility', 'range_salinity', 'facultative_respiration',
      'anaerobic_respiration', 'aerobic_respiration', 'mesophilic_range_tmp',
      'thermophilic_range_tmp', 'psychrophilic_range_tmp',
      'bacillus_cell_shape', 'coccus_cell_shape', 'filament_cell_shape',
      'coccobacillus_cell_shape', 'vibrio_cell_shape', 'spiral_cell_shape']

In [None]:
true_vals_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)

for i in range(0, 10):

    # BLAST
    blast_result_path = f"../../../data/cross_validation/split_{i}/estimation_blast.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=blast_result_path, true_vals_path=true_vals_path)
    res_list_blast = [] # 予測値と真の値の相関
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_blast += [pred_vals.corr(true_vals, method="pearson")]

    # VSEARCH
    vsearch_result_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/prediction_homology.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=vsearch_result_path, true_vals_path=true_vals_path)
    res_list_vsearch = [] # 予測値と真の値の相関
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_vsearch += [pred_vals.corr(true_vals, method="pearson")]

    # SortMeRNA
    sortmerna_result_path = f"../../../data/cross_validation_suppl/sortmerna/split_{i}/prediction_homology.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=sortmerna_result_path, true_vals_path=true_vals_path)
    res_list_sortmerna = [] # 予測値と真の値の相関
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_sortmerna += [pred_vals.corr(true_vals, method="pearson")]


    # Concat
    tmp = pd.DataFrame({"trait": nt,
                        "BLAST": res_list_blast,
                        "VSEARCH": res_list_vsearch,
                        "SortMeRNA": res_list_sortmerna})
    tmp["split"] = i
    if i == 0:
        res_all_nt = tmp
    else:
        res_all_nt = pd.concat([res_all_nt, tmp], axis=0)

res_all_melt_nt = res_all_nt.melt(id_vars=['trait', 'split'], var_name='method', value_name='accuracy')

In [None]:
def plot_significance_bar(x1, x2, y, h, pval, ax):
    if pval < 0.05:
        ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1, c='black')
        ax.text(float((x1+x2))/2, y+2*h, '*', ha='center', va='top')
    return

def calc_wilcoxon_signed(res_all_melt, t, permutation=False):
    res_trait = res_all_melt[res_all_melt['trait'] == t]
    y1 = res_trait[res_trait['method']=='BLAST']['accuracy']
    y2 = res_trait[res_trait['method']=='VSEARCH']['accuracy']
    y3 = res_trait[res_trait['method']=='SortMeRNA']['accuracy']
    p1 = wilcoxon(y1, y2, zero_method='wilcox', alternative='two-sided', method='exact').pvalue
    p2 = wilcoxon(y2, y3, zero_method='wilcox', alternative='two-sided', method='exact').pvalue
    p3 = wilcoxon(y3, y1, zero_method='wilcox', alternative='two-sided', method='exact').pvalue
    return p1, p2, p3

def prep_xy_for_sig_bar(x_num, y, h, s):
    x1s = x_num
    x2s = x_num[1:] + x_num[:1]
    ys = [y, y + s * h, y + 2 * s * h]
    return x1s, x2s, ys

In [None]:
titles = {'cell_diameter': 'Cell diameter', 'cell_length': 'Cell length', 'doubling_h': 'Doubling time', 'growth_tmp': 'Growth temp.', 'optimum_tmp': 'Optimum temp.', 'optimum_ph': 'Optimum pH', 'genome_size': 'Genome size', 'gc_content': 'GC content', 'coding_genes': 'Coding genes', 'rRNA16S_genes': 'rRNA16S genes', 'tRNA_genes': 'tRNA genes', 'gram_stain': 'Gram stain', 'sporulation': 'Sporulation', 'motility': 'Motility', 'range_salinity': 'Halophile', 'facultative_respiration': 'Facultative', 'anaerobic_respiration': 'Anaerobe', 'aerobic_respiration':'Aerobe' ,'mesophilic_range_tmp': 'Mesophile', 'thermophilic_range_tmp':'Thermophile', 'psychrophilic_range_tmp': 'Psychrophile', 'bacillus_cell_shape': 'Bacillus', 'coccus_cell_shape': 'Coccus', 'filament_cell_shape': 'Filament', 'coccobacillus_cell_shape': 'Coccobacillus', 'vibrio_cell_shape': 'Vibrio', 'spiral_cell_shape': 'Spiral'}

ntraits = len(nt)
ncols = 4
nrows = math.ceil(float(ntraits) / ncols)
pvals_list = []

fig, axes = plt.subplots(nrows, ncols, figsize=(1.5*ncols, 1.5*nrows))

for i, t in enumerate(nt):
    ax = axes.flatten()[i]
    x = res_all_melt_nt[res_all_melt_nt["trait"]==t]
    sns.boxplot(data=x, x="method", y="accuracy", hue="method", ax=ax,
                palette="Set2", fill=False, linewidth=1.3, showfliers=False)
    sns.stripplot(data=x, x="method", y="accuracy",
                  ax=ax, color="black", size=3, jitter=0.2)
    ax.set_title(titles[t])
    ax.set_xlabel("")
    ax.set_xticks(range(len(x["method"].unique())))
    if i // ncols == nrows - 1:
        ax.set_xticklabels(["BLAST+", "VSEARCH", "SortMeRNA"], rotation=90)
    else:
        ax.set_xticklabels([])
    if i % ncols == 0:
        ax.set_ylabel("Accuracy")
    else:
        ax.set_ylabel("")
    ax.set_ylim(-.05, 1.05)
    ax.yaxis.set_major_locator(MultipleLocator(0.50))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    p1, p2, p3 = calc_wilcoxon_signed(res_all_melt=res_all_melt_nt, t=t)
    pvals_list += [[p1, p2, p3]]

pvals_1d = [p for pvals in pvals_list for p in pvals if not np.isnan(p)]
pvals_idx = [np.isnan(pvals[0]) or np.isnan(pvals[1]) or np.isnan(pvals[2]) for pvals in pvals_list]
fdr_result = multipletests(pvals=pvals_1d, alpha=0.05, method="fdr_bh")
pvals_modified_1d = fdr_result[1]
pvals_modified_2d = pvals_modified_1d.reshape(-1, 3)

for pvals, idx, ax in zip(pvals_modified_2d, pvals_idx, axes.flatten()):
    if idx:
        continue
    x1s, x2s, ys = prep_xy_for_sig_bar(x_num=list(np.arange(3)), y=0.25, h=-0.01, s=8)
    for x1, x2, y, pval in zip(x1s, x2s, ys, pvals):
        plot_significance_bar(x1=x1, x2=x2, y=y, h=-0.02, pval=pval, ax=ax)

axes.flatten()[-1].set_axis_off()
fig.supxlabel("Variation of homology search methods")
fig.supylabel("Pearson Correlation Coefficients of predicted and actual trait values")

plt.tight_layout()
plt.savefig("../../../results/09_cross_validation_suppl/figS2a.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)

In [None]:
true_vals_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)

for i in range(0, 10):

    # BLAST
    blast_result_path = f"../../../data/cross_validation/split_{i}/estimation_blast.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=blast_result_path, true_vals_path=true_vals_path)
    res_list_blast = []
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_blast += [pred_vals.corr(true_vals, method="pearson")]

    # VSEARCH
    vsearch_result_path = f"../../../data/cross_validation_suppl/q2_consensus_vsearch/split_{i}/prediction_homology.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=vsearch_result_path, true_vals_path=true_vals_path)
    res_list_vsearch = []
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_vsearch += [pred_vals.corr(true_vals, method="pearson")]

    # SortMeRNA
    sortmerna_result_path = f"../../../data/cross_validation_suppl/sortmerna/split_{i}/prediction_homology.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=sortmerna_result_path, true_vals_path=true_vals_path)
    res_list_sortmerna = []
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_sortmerna += [pred_vals.corr(true_vals, method="pearson")]


    # Concat
    tmp = pd.DataFrame({"trait": ct,
                        "BLAST": res_list_blast,
                        "VSEARCH": res_list_vsearch,
                        "SortMeRNA": res_list_sortmerna})
    tmp["split"] = i
    if i == 0:
        res_all_ct = tmp
    else:
        res_all_ct = pd.concat([res_all_ct, tmp], axis=0)

res_all_melt_ct = res_all_ct.melt(id_vars=['trait', 'split'], var_name='method', value_name='accuracy')

In [None]:
ntraits = len(ct)
ncols = 4
nrows = math.ceil(float(ntraits) / ncols)
pvals_list = []

fig, axes = plt.subplots(nrows, ncols, figsize=(1.5*ncols, 1.5*nrows))

for i, t in enumerate(ct):
    ax = axes.flatten()[i]
    x = res_all_melt_ct[res_all_melt_ct["trait"]==t]
    sns.boxplot(data=x, x="method", y="accuracy", hue="method", ax=ax,
                palette="Set2", fill=False, linewidth=1.3, showfliers=False)
    sns.stripplot(data=x, x="method", y="accuracy",
                  ax=ax, color="black", size=3, jitter=0.2)
    ax.set_title(titles[t])
    ax.set_xlabel("")
    ax.set_xticks(range(len(x["method"].unique())))
    if i // ncols == nrows - 1:
        ax.set_xticklabels(["BLAST+", "VSEARCH", "SortMeRNA"], rotation=90)
    else:
        ax.set_xticklabels([])
    if i % ncols == 0:
        ax.set_ylabel("Accuracy")
    else:
        ax.set_ylabel("")
    ax.set_ylim(-.05, 1.05)
    ax.yaxis.set_major_locator(MultipleLocator(0.50))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    p1, p2, p3 = calc_wilcoxon_signed(res_all_melt=res_all_melt_ct, t=t)
    pvals_list += [[p1, p2, p3]]

pvals_1d = [p for pvals in pvals_list for p in pvals if not np.isnan(p)]
pvals_idx = [np.isnan(pvals[0]) or np.isnan(pvals[1]) or np.isnan(pvals[2]) for pvals in pvals_list]
fdr_result = multipletests(pvals=pvals_1d, alpha=0.05, method="fdr_bh")
pvals_modified_1d = fdr_result[1]
pvals_modified_2d = pvals_modified_1d.reshape(-1, 3)

for pvals, idx, ax in zip(pvals_modified_2d, pvals_idx, axes.flatten()):
    if idx:
        continue
    x1s, x2s, ys = prep_xy_for_sig_bar(x_num=list(np.arange(3)), y=0.25, h=-0.01, s=8)
    for x1, x2, y, pval in zip(x1s, x2s, ys, pvals):
        plot_significance_bar(x1=x1, x2=x2, y=y, h=-0.02, pval=pval, ax=ax)

# axes.flatten()[-1].set_axis_off()
fig.supxlabel("Variation of homology search methods")
fig.supylabel("Matthews Correlation Coefficients of predicted and actual trait values")

plt.tight_layout()
plt.savefig("../../../results/09_cross_validation_suppl/figS2b.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)