# Visualization: compare_hsp_method.ipynb

## Environment settings
```sh
# Working Directory
cd Bac2fFeature/scripts/09_cross_validation_suppl/093_phylogeny
# Output Directory
directories=(
    "../../../data/cross_validation_suppl"
)
for dir in "${directories[@]}"; do
  if [ ! -d "$dir" ]; then
    mkdir -p "$dir"
  fi
done
```

In [None]:
import json
import os

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['font.family']       = 'Arial'
matplotlib.rcParams['font.sans-serif']   = ["Arial","DejaVu Sans","Lucida Grande","Verdana"]
matplotlib.rcParams['figure.figsize']    = [4,3]
matplotlib.rcParams['font.size']         = 10
matplotlib.rcParams["axes.labelcolor"]   = "#000000"
matplotlib.rcParams["axes.linewidth"]    = 1.0
matplotlib.rcParams["xtick.major.width"] = 1.0
matplotlib.rcParams["ytick.major.width"] = 1.0
cmap1 = plt.cm.tab20
cmap2 = plt.cm.Set3

import math

In [None]:
def get_xy_vals(cmp, t):
    known_flag = (~cmp[t+'_t'].isnull()) & (~cmp[t+'_e'].isnull())
    estimated_vals, true_vals = cmp[known_flag][t+'_e'], cmp[known_flag][t+'_t']
    return estimated_vals, true_vals

def calc_cor(df, t):
    estimated_vals, true_vals = get_xy_vals(df, t)
    return estimated_vals.corr(true_vals)

def calc_macro_f1(df, t):
    estimated_vals, true_vals = get_xy_vals(df, t)
    # calc confusion matrix
    category = list(true_vals.unique())
    cnt = pd.DataFrame(data={i: [0] * len(category) for i in category},
                    index=category,
                    columns=category
                    )
    for e, t in zip(estimated_vals, true_vals):
        cnt[e][t] += 1
    cnt = cnt.apply(lambda x: x/x.sum(), axis=1)
    # calc macro-F1
    macro_f1 = np.diag(cnt).mean()
    return macro_f1

In [None]:
titles = {'cell_diameter': 'Cell diameter', 'cell_length': 'Cell length', 'doubling_h': 'Doubling time', 'growth_tmp': 'Growth temp.', 'optimum_tmp': 'Optimum temp.', 'optimum_ph': 'Optimum pH', 'genome_size': 'Genome size', 'gc_content': 'GC content', 'coding_genes': 'Coding genes', 'rRNA16S_genes': 'rRNA16S genes', 'tRNA_genes': 'tRNA genes', 'gram_stain': 'Gram stain', 'sporulation': 'Sporulation', 'motility': 'Motility', 'range_salinity': 'Halophile', 'facultative_respiration': 'Facultative', 'anaerobic_respiration': 'Anaerobe', 'aerobic_respiration':'Aerobe' ,'mesophilic_range_tmp': 'Mesophile', 'thermophilic_range_tmp':'Thermophile', 'psychrophilic_range_tmp': 'Psychrophile', 'bacillus_cell_shape': 'Bacillus', 'coccus_cell_shape': 'Coccus', 'filament_cell_shape': 'Filament', 'coccobacillus_cell_shape': 'Coccobacillus', 'vibrio_cell_shape': 'Vibrio', 'spiral_cell_shape': 'Spiral'}

## Continuous traits

In [None]:
dic_method_result = {"WSCP": "hsp_results/hsp_result_wscp.tsv",
                     "PIC":  "hsp_results/hsp_result_pic.tsv",
                     "SA":   "hsp_results/hsp_result_sa.tsv"}

nt = ['cell_diameter', 'cell_length', 'doubling_h', 'genome_size', 'gc_content', 'coding_genes', 'optimum_tmp', 'optimum_ph', 'growth_tmp', 'rRNA16S_genes', 'tRNA_genes']

fig, axes = plt.subplots(3, 4, figsize=(8, 6))


trait_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(trait_path, sep="\t")

for k, hsp_result_path in dic_method_result.items():
    for i, t in enumerate(nt):

        estimated_vals = pd.read_csv(hsp_result_path, sep="\t")

        cmp = pd.merge(estimated_vals, true_vals, left_on="species_tax_id", right_on="species_tax_id", how="inner", suffixes=["_e", "_t"])

        score = cmp.groupby("threshold").apply(lambda df: calc_cor(df, t))
        threshold = cmp["threshold"].unique()

        ax = axes.flatten()[i]
        ax.plot(threshold, score, marker='+')

        ax.set_ylim(0, 1)
        ax.set_xlabel("Phylogenetic distance")
        ax.set_ylabel("Pearson's r")
        ax.set_title(titles[t])

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

# axes.flatten()[0].legend(["WSCP", "PIC", "SA"])
axes.flatten()[-1].set_visible(False)

plt.tight_layout()

plt.savefig("../../../results/09_cross_validation_suppl/figS4a.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)

## Categorical traits

In [None]:
dic_method_result = {
                     "EMP": "hsp_results//hsp_result_emp.tsv",
                     "Mk-SYM":   "hsp_results/hsp_result_mk_SYM.tsv",
                     "Mk-ARD":  "hsp_results/hsp_result_mk_ARD.tsv",
                     "MP-no-edge":  "hsp_results/hsp_result_mp_no_edge.tsv",
                     "MP-edge":  "hsp_results/hsp_result_mp_inversed_edge.tsv"
                     }

ct = ["gram_stain", "sporulation", "motility", "range_salinity", "facultative_respiration", "anaerobic_respiration", "aerobic_respiration", "mesophilic_range_tmp", "thermophilic_range_tmp", "psychrophilic_range_tmp", "bacillus_cell_shape", "coccus_cell_shape", "filament_cell_shape", "coccobacillus_cell_shape", "vibrio_cell_shape", "spiral_cell_shape"]

ncols = 4
nrows = math.ceil(float(len(ct)) / ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(2*ncols, 2*nrows))

trait_path = "../../../data/ref_bac2feature/trait_bac2feature.tsv"
true_vals = pd.read_csv(trait_path, sep="\t")

for k, hsp_result_path in dic_method_result.items():
    for i, t in enumerate(ct):

        estimated_vals = pd.read_csv(hsp_result_path, sep="\t")

        cmp = pd.merge(estimated_vals, true_vals, left_on="species_tax_id", right_on="species_tax_id", how="inner", suffixes=["_e", "_t"])

        score = cmp.groupby("threshold").apply(lambda df: calc_cor(df, t))
        threshold = cmp["threshold"].unique()

        ax = axes.flatten()[i]
        ax.plot(threshold, score, marker='+')

        ax.set_ylim(-0.3, 1)
        ax.set_xlabel("Phylogenetic distance")
        ax.set_ylabel("MCC")
        ax.set_title(titles[t])

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

# axes.flatten()[0].legend(["EMP", "Mk-ER", "Mk-SYM", "MP-no-edge", "MP-edge"])
axes.flatten()[-1].set_visible(False)

plt.tight_layout()

plt.savefig("../../../results/09_cross_validation_suppl/figS4b.pdf", format="pdf", dpi=300, facecolor="white", bbox_inches="tight", pad_inches=0.1)