# Identify samples which contain certain taxa

## Platform dependent part
- Resolve platform setup
- the difference to local imports should be resolved by setting the Blue Cloud VRE well, Colab will still be an issue.

In [1]:
import sys
import os
import logging
from IPython import get_ipython
logger = logging.getLogger(name="Taxonomic finder")

if 'google.colab' in str(get_ipython()):
    print('Setting Google colab, you will need a ngrok account to make the dashboard display over the tunnel. \
    https://ngrok.com/')
    # clone the momics-demos repository to use it to load data
    try:
        os.system('git clone https://github.com/palec87/momics-demos.git')
        logger.info(f"Repository cloned")
    except OSError as e:
        logger.info(f"An error occurred while cloning the repository: {e}")

    sys.path.insert(0,'/content/momics-demos')

    # this step takes time beacause of many dependencies
    os.system('pip install marine-omics')

from momics.utils import (
    memory_load, reconfig_logger,
    init_setup, get_notebook_environment,
)

# Set up logging
reconfig_logger()

# Determine the notebook environment
env = get_notebook_environment()

init_setup()
logger.info(f"Environment: {env}")

INFO | root | Logging.basicConfig completed successfully
INFO | Taxonomic finder | Environment: vscode
INFO | Taxonomic finder | Environment: vscode


## Imports

In [6]:
import warnings
import holoviews as hv
from skbio.stats.ordination import pcoa

warnings.filterwarnings('ignore')

import pandas as pd
import panel as pn

# from mgo.udal import UDAL

# All low level functions are imported from the momics package
import momics.plotting as pl
from momics.panel_utils import (
    diversity_select_widgets, create_indicators_diversity,
    serve_app, close_server,
)

from momics.diversity import (
    beta_diversity_parametrized,
)
from momics.utils import load_and_clean

### Future momics methods

In [None]:
from typing import Union, Tuple, Dict, List
from bokeh.models import CategoricalColorMapper, ContinuousColorMapper
from bokeh.palettes import Category20, viridis
PLOT_FACE_COLOR = "#e6e6e6"
MARKER_SIZE = 16


def find_taxa_in_table(
        table: pd.DataFrame,
        tax_level: str,
        search_term: Union[str, int],
        ncbi_tax_id: bool=False,
        exact_match:bool=False,
    ) -> pd.DataFrame:
    """
    Find taxa in the given table at the specified taxonomic level matching the search term.

    args:
        table (pd.DataFrame): DataFrame containing taxonomic data.
        tax_level (str): Taxonomic level to search ('all' for all levels).
        search_term (str|int): Term to search for.
        ncbi_tax_id (bool): If True, search by NCBI taxonomic ID.
        exact_match (bool): If True, perform exact match; otherwise, use substring match.

    returns:
        pd.DataFrame: DataFrame containing matching taxa.
    """
    # ncbi_tax_id search
    index_names = getattr(table.index, "names", [])
    if ncbi_tax_id and ('ncbi_tax_id' not in table.columns and 'ncbi_tax_id' not in index_names):
        raise ValueError("The table does not contain 'ncbi_tax_id' column or index level.")

    # if ncbi_tax_id is an index level, bring it into a column for uniform handling
    if ncbi_tax_id and ('ncbi_tax_id' in index_names):
        table = table.reset_index()

    if ncbi_tax_id:
        # Search by NCBI taxonomic ID
        matching_taxa = table[table['ncbi_tax_id'].astype(str) == str(search_term)]
        return matching_taxa.set_index(index_names) if index_names else matching_taxa

    # search by taxonomic level, all ranks
    if tax_level == 'all':
        found = []
        for tax_level in TAXONOMY_RANKS:
            if exact_match:
                found.append(table[table[tax_level].str.lower().fillna('') == search_term.lower()])
            else:
                found.append(table[table[tax_level].str.contains(search_term, case=False, na=False)])
        matching_taxa = pd.concat(found)
    # specific taxonomic level
    else:
        if exact_match:
            matching_taxa = table[table[tax_level].str.lower().fillna('') == search_term.lower()]
        else:
            matching_taxa = table[table[tax_level].str.contains(search_term, case=False, na=False)]

    return matching_taxa


def beta_plot_abund_taxa(
    table: pd.DataFrame,
    metadata: pd.DataFrame,
    found_taxa: pd.DataFrame,
    taxon: str = "ncbi_tax_id",
) -> Tuple[hv.element.Scatter, Tuple[float, float]]:
    """
    Creates a beta diversity PCoA plot.

    Args:
        table (pd.DataFrame): DataFrame containing species abundances.
        metadata (pd.DataFrame): A DataFrame containing metadata.
        factor (str): The column name to color the points by.
        taxon (str, optional): The taxon level for beta diversity calculation. Defaults to "ncbi_tax_id".

    Returns:
        Tuple[hv.element.Scatter, Tuple[float, float]]: A tuple containing the beta diversity PCoA plot and the explained variance for PC1 and PC2.
    """
    beta = beta_diversity_parametrized(
        table, taxon=taxon, metric="braycurtis"
    )
    pcoa_result = pcoa(beta, method="eigh")
    explained_variance = (
        pcoa_result.proportion_explained[0],
        pcoa_result.proportion_explained[1],
    )
    if not set(pcoa_result.samples.index) == set(metadata.index):
        raise ValueError("Metadata index name does not match PCoA result.")

    pcoa_df = pd.merge(
        pcoa_result.samples,
        metadata,
        left_index=True,
        right_index=True,
        how="inner",
    )
    pcoa_df['found_abundance'] = 0
    abundance_sum = found_taxa.groupby('source material ID')['abundance'].sum()
    for tax in abundance_sum.index:
        pcoa_df.loc[tax, 'found_abundance'] = abundance_sum[tax]

    return (
        hvplot_plot_pcoa_black(
            pcoa_df, color_by='found_abundance', explained_variance=explained_variance
        ),
        explained_variance,
    )


def hvplot_plot_pcoa_black(
    pcoa_df: pd.DataFrame,
    color_by: str = None,
    explained_variance: Tuple[float, float] = None,
) -> hv.element.Scatter:
    """
    Plots a PCoA plot with optional coloring using hvplot.

    Args:
        pcoa_df (pd.DataFrame): A DataFrame containing PCoA results.
        color_by (str, optional): The column name to color the points by. Defaults to None.

    Returns:
        hv.element.Scatter: The PCoA plot.
    """
    index_name = pcoa_df.index.name if pcoa_df.index.name else "sample"
    pcoa_df = pcoa_df.reset_index()  # Ensure index is a column for hvplot

    if color_by is None:
        # No coloring specified, use black
        fig = pcoa_df.hvplot.scatter(
            x="PC1",
            y="PC2",
            color="black",
            hover_cols=[index_name, "PC1", "PC2"],
        )
        valid_perc = 100.0
        title = "PCoA (no coloring applied)"
        color_palette = None
    else:
        valid_perc = pcoa_df[color_by].count() / len(pcoa_df[color_by]) * 100

        if 2 < len(pcoa_df[color_by].unique()) <= 20:
            if pcoa_df[color_by].dtype == "object":
                pal = Category20[
                    len(pcoa_df[color_by].unique())
                ]  # Use the correct number of colors
                color_mapper = CategoricalColorMapper(
                    factors=pcoa_df[color_by]
                    .unique()
                    .tolist(),  # Unique categories in the factor column
                    palette=pal,
                )
            else:
                color_mapper = ContinuousColorMapper(
                    palette="Turbo256",
                    low=pcoa_df[color_by].min(),
                    high=pcoa_df[color_by].max(),
                )
        else:
            if pcoa_df[color_by].dtype == "object":
                pal = viridis(len(pcoa_df[color_by].unique()))
                color_mapper = CategoricalColorMapper(
                    factors=pcoa_df[color_by]
                    .unique()
                    .tolist(),  # Unique categories in the factor column
                    palette=pal,
                )
            else:
                color_mapper = ContinuousColorMapper(
                    palette="Turbo256",
                    low=pcoa_df[color_by].min(),
                    high=pcoa_df[color_by].max(),
                )

        if pcoa_df[color_by].count() >= 0:
            # Create the scatter plot using hvplot
            fig = pcoa_df.hvplot.scatter(
                x="PC1",
                y="PC2",
                color=color_by,  # Use the factor column for coloring
                hover_cols=[index_name, "PC1", "PC2"],
            )
        else:
            fig = pcoa_df.hvplot.scatter(
                x="PC1",
                y="PC2",
                color="black",  # Use black for coloring
                hover_cols=[index_name, "PC1", "PC2"],
            )
        
        title = f"PCoA colored by {color_by}, valid values: ({valid_perc:.2f}%)"
        color_palette = color_mapper.palette

    if explained_variance:
        var_perc = explained_variance[0] * 100, explained_variance[1] * 100
        fig = fig.opts(
            xlabel=f"PC1 ({var_perc[0]:.2f}%)",
            ylabel=f"PC2 ({var_perc[1]:.2f}%)",
        )
    else:
        fig = fig.opts(
            xlabel="PC1",
            ylabel="PC2",
        )
    
    assert "PC1" in pcoa_df.columns, f"Missing 'PC1' column in PCoA DataFrame"
    assert "PC2" in pcoa_df.columns, f"Missing 'PC2' column in PCoA DataFrame"

    opts = {
        "title": title,
        "size": MARKER_SIZE,
        "fill_alpha": 0.5,
        "show_legend": False,
        "backend_opts": {"plot.toolbar.autohide": True},
    }
    
    if color_palette is not None:
        opts["cmap"] = color_palette
    
    fig = fig.opts(**opts)
    return fig

In [5]:
## other part of selectors
def tax_finder_selector1():
    select_table_tax = pn.widgets.Select(
        name="Taxonomic table",
        value="ssu",
        options=["ssu", "lsu"],
        description="Select a table for taxonomic search",
    )

    tax_level = pn.widgets.Select(
        name="Taxonomic level",
        value="ssu",
        options=[
            "all",
            "ncbi_tax_id",
            "superkingdom",
            "kingdom",
            "phylum",
            "class",
            "order",
            "family",
            "genus",
            "species",
        ],
        description="Select a taxonomic search level",
    )

    search_term = pn.widgets.TextInput(
        name="Search term",
        value="",
        description="Enter a search term (string or NCBI tax ID)",
    )

    checkbox_exact_match = pn.widgets.Checkbox(
        name="Exact match of the search term",
        value=False,
    )

    # checkbox_ncbi_tax_id = pn.widgets.Checkbox(
    #     name="Search by NCBI taxonomic ID",
    #     value=False,
    # )
    return (
        select_table_tax,
        tax_level,
        search_term,
        checkbox_exact_match,
        # checkbox_ncbi_tax_id,
    )

## User settings

In [3]:
DEBUG = True  # enable stdout logging

## Loading

In [4]:
# parquet files
if 'google.colab' in str(get_ipython()):
    root_folder = os.path.abspath(os.path.join('/content/momics-demos'))
else:
    root_folder = os.path.abspath(os.path.join('../'))

assets_folder = os.path.join(root_folder, 'assets')

In [5]:
def get_valid_samples():
    df_valid = pd.read_csv(
        os.path.join(root_folder, 'data/shipment_b1b2_181.csv')
    )
    return df_valid

valid_samples = get_valid_samples()

In [6]:
# High level function from the momics.utils module
full_metadata, mgf_parquet_dfs = load_and_clean(valid_samples=valid_samples)

In [7]:
# select categorical columns from metadata
categorical_columns = sorted(full_metadata.select_dtypes(include=['object', "boolean"]).columns)

# select numerical columns from metadata
numerical_columns = sorted(full_metadata.select_dtypes(include=['int64', 'float64']).columns)

if DEBUG:
    logger.info(f"Data table names are:\n{mgf_parquet_dfs.keys()}")
    logger.info(f"Categorical metadata columns are:\n{categorical_columns}")
    logger.info(f"Numerical metadata columns are:\n{numerical_columns}")

INFO | Taxonomic finder | Data table names are:
dict_keys(['go', 'go_slim', 'ips', 'ko', 'pfam', 'lsu', 'ssu'])
INFO | Taxonomic finder | Categorical metadata columns are:
['ammonium method', 'chlorophyll method', 'conductivity method', 'country', 'density method', 'dissolved oxygen method', 'environment (biome)', 'environment (feature)', 'environment (material)', 'environmental package', 'investigation type', 'month name', 'nitrate method', 'nitrite method', 'observatory ID', 'observatory local location', 'observatory location ocean or sea', 'observatory regional location', 'organism count', 'organism count method', 'organization', 'organization country', 'pH method', 'phaeopigments method', 'phosphate method', 'pigments (ug/l)', 'pigments method', 'pressure method', 'project name', 'replicate info', 'replicate number', 'sample collection device or method', 'sea subsurface salinity method', 'sea subsurface temperature method', 'sea surface salinity method', 'sea surface temperature me

In [8]:
# filter out only the taxonomy tables
tables = {
    "lsu": mgf_parquet_dfs['lsu'].copy(),
    "ssu": mgf_parquet_dfs['ssu'].copy(),
}

TAXONOMY = pd.DataFrame()
TAXONOMY_RANKS = ['superkingdom', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species']

## Taxonomy finder

In [10]:
table = tables['ssu']
tax_level = 'all'
search_term = 286
# search_term = 'Pseudomonas'
exact_match = True
ncbi_tax_id = True

In [11]:
out = find_taxa_in_table(table, tax_level, search_term, ncbi_tax_id, exact_match)
out

Unnamed: 0_level_0,Unnamed: 1_level_0,abundance,superkingdom,kingdom,phylum,class,order,family,genus,species
source material ID,ncbi_tax_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
EMOBON_BPNS_So_6,286,1.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
EMOBON_BPNS_So_13,286,3.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
EMOBON_NRMCB_So_1,286,2.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
EMOBON_NRMCB_So_2,286,1.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
EMOBON_ROSKOGO_So_2,286,2.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
...,...,...,...,...,...,...,...,...,...,...
EMOBON_ROSKOGO_Wa_34,286,2.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
EMOBON_VB_Wa_96,286,9.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
EMOBON_VB_Wa_97,286,7.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,
EMOBON_VB_Wa_140,286,8.0,Bacteria,,Proteobacteria,Gammaproteobacteria,Pseudomonadales,Pseudomonadaceae,Pseudomonas,


### Visualize

In [13]:
foo, explained_var = beta_plot_abund_taxa(
        table=table,
        metadata=full_metadata,
        found_taxa=out,
        taxon='phylum',
    )

foo.opts(
    width=1000,
    height=600,
)

## APP setup

In [7]:
pn.extension("tabulator")
hv.extension("bokeh", "plotly")
if 'google.colab' in str(get_ipython()):
    pn.extension(comms='colab')
ACCENT = "teal"

styles = {
    "box-shadow": "rgba(50, 50, 93, 0.25) 0px 6px 12px -2px, rgba(0, 0, 0, 0.3) 0px 3px 7px -3px",
    "border-radius": "4px",
    "padding": "10px",
}

(select_table_tax,
 tax_level,
 search_term,
 checkbox_exact_match,
 # checkbox_ncbi_tax_id,
) = tax_finder_selector1()

backend = pn.widgets.RadioBoxGroup(
    name='Backend',
    options=['matplotlib', 'hvplot'],
    inline=True,
)
backend.value = 'hvplot'

progress_bar, indicator_usage = create_indicators_diversity()

def update_used_gb(event):
    if not event:
        return

    used_gb, total_gb = memory_load()
    progress_bar.value = int(used_gb / total_gb * 100)
    indicator_usage.value = used_gb

In [None]:
tax_plot_beta = pn.bind(
    beta_plot_abund_taxa,
    table=tables[select_table_tax.value],
    tables_dict=mgf_parquet_dfs,
    table_name=select_table_tax,
    metadata=full_metadata,
    backend=backend,
)