# Crades-out cross validation

```sh
mkdir ../../../data/cross_validation_suppl/clade_out
```

In [None]:
import json
import random
import os
import math

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['font.family']       = 'Arial'
matplotlib.rcParams['font.sans-serif']   = ["Arial","DejaVu Sans","Lucida Grande","Verdana"]
matplotlib.rcParams['figure.figsize']    = [4,3]
matplotlib.rcParams['font.size']         = 10
matplotlib.rcParams["axes.labelcolor"]   = "#000000"
matplotlib.rcParams["axes.linewidth"]    = 1.0
matplotlib.rcParams["xtick.major.width"] = 1.0
matplotlib.rcParams["ytick.major.width"] = 1.0
cmap1 = plt.cm.tab20
cmap2 = plt.cm.Set3

## Prepare clade-out validation datasets

In [None]:
random.seed(0)

# Trait table
trait_path = "../../../data/trait/condensed_species_NCBI.csv"
trait = pd.read_csv(trait_path, sep=",", dtype=str)

# SILVA taxonomy
silva_metadata_path = "../../../data/silva/SILVA_138.1_SSURef_Nr99.full_metadata"
silva_metadata = pd.read_csv(silva_metadata_path, sep="\t", dtype=str)
silva_metadata.fillna("", inplace=True)

# Concatenate trait data to SILVA based on NCBI taxid
silva_metadata = pd.merge(trait, silva_metadata, how='inner', left_on='species_tax_id', right_on='tax_xref_ncbi', suffixes=["", "_silva"])



tax_levels = ["order", "family", "genus"]
num_test_clades = 10
num_iterations = 10

for level in tax_levels:
    dominant_clades = silva_metadata[silva_metadata[level]!=""].groupby(level).size().sort_values(ascending=False).index.tolist()[:num_test_clades * num_iterations]
    
    dominant_clades_shuffled = random.sample(dominant_clades, len(dominant_clades))

    # Split the dataset for 10-fold cross validation
    out_dir = f'../../../data/cross_validation_suppl/clade_out/{level}'

    l_tax_id = silva_metadata["species_tax_id"].tolist()

    ## Split
    trait_table = pd.read_csv("../../../data/ref_bac2feature/trait_bac2feature.tsv", sep="\t", dtype=str)
    for i, test_clades in enumerate(np.array_split(dominant_clades_shuffled, num_iterations)):

        silva_metadata["is_test_clades"] = silva_metadata[level].apply(lambda x: x in set(test_clades))
        test_header = silva_metadata.loc[silva_metadata["is_test_clades"]]["species_tax_id"].tolist()

        # Path
        split_dir = os.path.join(out_dir, f"split_{i}")
        if not os.path.exists(split_dir):
            os.makedirs(split_dir)
        test_header_path = os.path.join(split_dir, "nodeid_test.txt")
        ref_header_path = os.path.join(split_dir, "nodeid_ref.txt")
        ref_trait_path = os.path.join(split_dir, "traits.tsv")
        # Split header
        set_test_header = set(test_header)
        ref_header = [h for h in l_tax_id if h not in set_test_header]
        with open(test_header_path, 'w') as f:
            f.write('\n'.join(test_header))
        with open(ref_header_path, 'w') as f:
            f.write('\n'.join(ref_header))
        # Split trait table
        is_ref = trait_table["species_tax_id"].apply(
                    lambda x: x not in set_test_header)
        trait_table[is_ref].to_csv(ref_trait_path, sep="\t", index=False)

## Cross-vadalition
```sh
tax_levels=("order" "family" "genus")

for level in "${tax_levels[@]}"; do
  for i in $(seq 0 9); do
    # Split full 16S rRNA sequences
    seqkit grep -nf ../../data/2025-05-15/$level/split_$i/nodeid_test.txt \
      ../../data/2024-06-06/intermediate_dir/SILVA_138.1_SSURef_NR99_tax_silva_taxid.fasta \
      > ../../data/2025-05-15/$level/split_$i/test_seq_full.fasta

    seqkit grep -nf ../../data/2025-05-15/$level/split_$i/nodeid_ref.txt \
      ../../data/2024-06-06/intermediate_dir/SILVA_138.1_SSURef_NR99_tax_silva_taxid.fasta \
      > ../../data/2025-05-15/$level/split_$i/ref_seq_full.fasta
  done

  # Make reference dataset for PICRUSt2 using header list
  for i in $(seq 0 9); do ./make_pro_ref.sh -t ../../../data/ref_bac2feature/phylogeny/phylogeny.tre -a ../../../data/ref_bac2feature/phylogeny/phylogeny.fasta -n ../../../data/cross_vadalition_suppl/clade_out/$level/split_$i/nodeid_ref.txt -o ../../../data/cross_vadalition_suppl/clade_out/$level/split_$i; done

  # Conduct CV by Phylogenetic placement-based method
  for i in $(seq 0 9); do bac2feature -s ../../../data/cross_vadalition_suppl/clade_out/$level/split_$i/test_seq_full.fasta -o ../../../data/cross_vadalition_suppl/clade_out/$level/split_$i/estimation_full.tsv -m phylogeny --ref_dir_placement ../../../data/cross_vadalition_suppl/clade_out/$level/split_$i --ref_trait ../../../data/cross_vadalition_suppl/clade_out/$level/split_$i/traits.tsv --threads 1 --calculate_NSTI; done

done
```

## Visualization

In [None]:
def get_pred_and_true_df(pred_vals_path, true_vals_path):
    # Prediction
    pred_vals = pd.read_csv(pred_vals_path, sep="\t")
    pred_vals["sequence"] = pred_vals["sequence"].astype(str)
    # Reference
    true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)
    cmp = pd.merge(pred_vals, true_vals,
                   left_on='sequence', right_on="species_tax_id", how='inner', suffixes=['_e', '_t'])
    return cmp

def remove_null_values(cmp, t, dtype):
    known_flag = (~cmp[t+'_t'].isnull()) & (~cmp[t+'_e'].isnull())
    pred_vals, true_vals = cmp[known_flag][t+'_e'], cmp[known_flag][t+'_t']
    if dtype == 'float':
        pred_vals = pred_vals.astype(float)
        true_vals = true_vals.astype(float)
    elif dtype == 'int':
        pred_vals = pred_vals.astype(int)
        true_vals = true_vals.astype(int)
    return pred_vals, true_vals

In [None]:
titles = {'cell_diameter': 'Cell diameter', 'cell_length': 'Cell length', 'doubling_h': 'Doubling time', 'growth_tmp': 'Growth temp.', 'optimum_tmp': 'Optimum temp.', 'optimum_ph': 'Optimum pH', 'genome_size': 'Genome size', 'gc_content': 'GC content', 'coding_genes': 'Coding genes', 'rRNA16S_genes': 'rRNA16S genes', 'tRNA_genes': 'tRNA genes', 'gram_stain': 'Gram stain', 'sporulation': 'Sporulation', 'motility': 'Motility', 'range_salinity': 'Halophile', 'facultative_respiration': 'Facultative', 'anaerobic_respiration': 'Anaerobe', 'aerobic_respiration':'Aerobe' ,'mesophilic_range_tmp': 'Mesophile', 'thermophilic_range_tmp':'Thermophile', 'psychrophilic_range_tmp': 'Psychrophile', 'bacillus_cell_shape': 'Bacillus', 'coccus_cell_shape': 'Coccus', 'filament_cell_shape': 'Filament', 'coccobacillus_cell_shape': 'Coccobacillus', 'vibrio_cell_shape': 'Vibrio', 'spiral_cell_shape': 'Spiral'}

nt = ['cell_diameter', 'cell_length', 'doubling_h', 'growth_tmp', 'optimum_tmp', 'optimum_ph', 'genome_size', 'gc_content', 'coding_genes', 'rRNA16S_genes', 'tRNA_genes']

ct = ['gram_stain',
      'sporulation', 'motility', 'range_salinity', 'facultative_respiration',
      'anaerobic_respiration', 'aerobic_respiration', 'mesophilic_range_tmp',
      'thermophilic_range_tmp', 'psychrophilic_range_tmp',
      'bacillus_cell_shape', 'coccus_cell_shape', 'filament_cell_shape',
      'coccobacillus_cell_shape', 'vibrio_cell_shape', 'spiral_cell_shape']

In [None]:
true_vals_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)

for i in range(0, 10):
    # Genus
    level = "genus"
    genus_result_path = f"../../../data/cross_validation_suppl/clade_out/{level}/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=genus_result_path, true_vals_path=true_vals_path)
    res_list_genus = []
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_genus += [pred_vals.corr(true_vals, method="pearson")]

    # Family
    level = "family"
    family_result_path = f"../../../data/cross_validation_suppl/clade_out/{level}/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=family_result_path, true_vals_path=true_vals_path)
    res_list_family = [] 
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_family += [pred_vals.corr(true_vals, method="pearson")]

    # Order
    level = "order"
    order_result_path = f"../../../data/cross_validation_suppl/clade_out/{level}/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=order_result_path, true_vals_path=true_vals_path)
    res_list_order = []
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_order += [pred_vals.corr(true_vals, method="pearson")]

    # 10-fold cross validation
    pic_result_path = f"../../../data/cross_validation/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=pic_result_path, true_vals_path=true_vals_path)
    res_list_pic = []
    for t in nt:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_pic += [pred_vals.corr(true_vals, method="pearson")]

    # Concat
    tmp = pd.DataFrame({"trait": nt,
                        "Genus": res_list_genus,
                        "Family": res_list_family,
                        "Order": res_list_order,
                        "Random": res_list_pic
                        })
    tmp["split"] = i
    if i == 0:
        res_all_nt = tmp
    else:
        res_all_nt = pd.concat([res_all_nt, tmp], axis=0)

res_all_melt_nt = res_all_nt.melt(id_vars=['trait', 'split'], var_name='method', value_name='accuracy')

In [None]:
ntraits = len(nt)
ncols = 4
nrows = math.ceil(float(ntraits) / ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(1.5*ncols, 1.5*nrows))


for i, t in enumerate(nt):
    ax = axes.flatten()[i]
    sns.boxplot(data=res_all_melt_nt[res_all_melt_nt["trait"] == t],
                x="method", y="accuracy", hue="method", ax=ax, order=["Order", "Family", "Genus", "Random"],
                palette="Set2", fill=False, linewidth=1.3, showfliers=False)
    sns.stripplot(data=res_all_melt_nt[res_all_melt_nt["trait"] == t],
                  x="method", y="accuracy",
                  ax=ax, order=["Order", "Family", "Genus", "Random"], color="black", size=3, jitter=0.2)
    ax.axhline(0.5, color='gray', linestyle='--', linewidth=0.5)
    ax.set_ylim(-0.1, 1)
    ax.set_title(titles[t], fontsize=10)
    ax.set_xlabel("")
    ax.set_xticks(list(range(len(["Order", "Family", "Genus", "Random"]))))
    if i // ncols == nrows - 1:
        ax.set_xticklabels(["Order", "Family", "Genus", "10-fold"], rotation=90)
    else:
        ax.set_xticklabels([])
    if i % ncols == 0:
        ax.set_ylabel("Accuracy")
    else:
        ax.set_ylabel("")

axes.flatten()[-1].axis('off')

plt.tight_layout()
plt.savefig("../../../results/09_cross_validation_suppl/figS7a.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)

In [None]:
true_vals_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(true_vals_path, sep="\t", dtype=str)

for i in range(0, 10):
    # Genus
    level = "genus"
    genus_result_path = f"../../../data/cross_validation_suppl/clade_out/{level}/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=genus_result_path, true_vals_path=true_vals_path)
    res_list_genus = []
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_genus += [pred_vals.corr(true_vals, method="pearson")]

    # Family
    level = "family"
    family_result_path = f"../../../data/cross_validation_suppl/clade_out/{level}/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=family_result_path, true_vals_path=true_vals_path)
    res_list_family = [] 
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_family += [pred_vals.corr(true_vals, method="pearson")]

    # Order
    level = "order"
    order_result_path = f"../../../data/cross_validation_suppl/clade_out/{level}/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=order_result_path, true_vals_path=true_vals_path)
    res_list_order = []
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_order += [pred_vals.corr(true_vals, method="pearson")]

    # 10-fold cross validation
    pic_result_path = f"../../../data/cross_validation/split_{i}/estimation_full.tsv"
    cmp = get_pred_and_true_df(pred_vals_path=pic_result_path, true_vals_path=true_vals_path)
    res_list_pic = []
    for t in ct:
        pred_vals, true_vals = remove_null_values(cmp=cmp, t=t, dtype='float')
        res_list_pic += [pred_vals.corr(true_vals, method="pearson")]

    # Concat
    tmp = pd.DataFrame({"trait": ct,
                        "Genus": res_list_genus,
                        "Family": res_list_family,
                        "Order": res_list_order,
                        "Random": res_list_pic
                        })
    tmp["split"] = i
    if i == 0:
        res_all_ct = tmp
    else:
        res_all_ct = pd.concat([res_all_ct, tmp], axis=0)

res_all_melt_ct = res_all_ct.melt(id_vars=['trait', 'split'], var_name='method', value_name='accuracy')

In [None]:
ntraits = len(ct)
ncols = 4
nrows = math.ceil(float(ntraits) / ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(1.5*ncols, 1.5*nrows))


for i, t in enumerate(ct):
    ax = axes.flatten()[i]
    # sns.boxplot(x="method", y="accuracy", data=res_all_melt_nt[res_all_melt_nt["trait"] == t],
    #             ax=ax,linewidth=0.5)
    sns.boxplot(data=res_all_melt_ct[res_all_melt_ct["trait"] == t],
                x="method", y="accuracy", hue="method", ax=ax, order=["Order", "Family", "Genus", "Random"],
                palette="Set2", fill=False, linewidth=1.3, showfliers=False)
    sns.stripplot(data=res_all_melt_ct[res_all_melt_ct["trait"] == t],
                  x="method", y="accuracy", order=["Order", "Family", "Genus", "Random"],
                  ax=ax, color="black", size=3, jitter=0.2)
    ax.axhline(0.5, color='gray', linestyle='--', linewidth=0.5)
    ax.set_ylim(-0.1, 1)
    ax.set_title(titles[t], fontsize=10)
    ax.set_xlabel("")
    ax.set_xticks(list(range(len(["Order", "Family", "Genus", "Random"]))))
    if i // ncols == nrows - 1:
        ax.set_xticklabels(["Order", "Family", "Genus", "10-fold"], rotation=90)
    else:
        ax.set_xticklabels([])
    if i % ncols == 0:
        ax.set_ylabel("Accuracy")
    else:
        ax.set_ylabel("")

plt.tight_layout()
plt.savefig("../../../results/09_cross_validation_suppl/figS7b.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)