# Visualize taxonomy and alpha/beta diversities

## Platform dependent part
- Resolve platform setup
- the difference to local imports should be resolved by setting the Blue Cloud VRE well, Colab will still be an issue.

In [None]:
import sys
import os
import logging
from IPython import get_ipython
logger = logging.getLogger(name="Diversity analysis app")

if 'google.colab' in str(get_ipython()):
    print('Setting Google colab, you will need a ngrok account to make the dashboard display over the tunnel. \
    https://ngrok.com/')
    # clone the momics-demos repository to use it to load data
    try:
        os.system('git clone https://github.com/palec87/momics-demos.git')
        logger.info(f"Repository cloned")
    except OSError as e:
        logger.info(f"An error occurred while cloning the repository: {e}")

    sys.path.insert(0,'/content/momics-demos')

    # this step takes time beacause of many dependencies
    os.system('pip install marine-omics')

from momics.utils import (
    memory_load, reconfig_logger,
    init_setup, get_notebook_environment,
)

# Set up logging
reconfig_logger()

# Determine the notebook environment
env = get_notebook_environment()

init_setup()
logger.info(f"Environment: {env}")

## Imports

In [None]:
import warnings
import holoviews as hv
from skbio.stats.ordination import pcoa

from functools import partial
warnings.filterwarnings('ignore')

import pandas as pd
import panel as pn

from mgo.udal import UDAL

# All low level functions are imported from the momics package
from momics.loader import load_parquets_udal
from momics.metadata import get_metadata_udal, enhance_metadata
import momics.plotting as pl
from momics.panel_utils import diversity_select_widgets

from momics.diversity import (
    beta_diversity_parametrized,
)
from momics.utils import load_and_clean, taxonomy_common_preprocess01

## User settings

In [None]:
DEBUG = True  # enable stdout logging

## Loading

In [None]:
# parquet files
if 'google.colab' in str(get_ipython()):
    root_folder = os.path.abspath(os.path.join('/content/momics-demos'))
else:
    root_folder = os.path.abspath(os.path.join('../'))

assets_folder = os.path.join(root_folder, 'assets')

In [None]:
def get_valid_samples():
    df_valid = pd.read_csv(
        os.path.join(root_folder, 'data/shipment_b1b2_181.csv')
    )
    return df_valid

valid_samples = get_valid_samples()

In [None]:
# High level function from the momics.utils module
full_metadata, mgf_parquet_dfs = load_and_clean(valid_samples=valid_samples)

In [None]:
# select categorical columns from metadata
categorical_columns = sorted(full_metadata.select_dtypes(include=['object', "boolean"]).columns)

# select numerical columns from metadata
numerical_columns = sorted(full_metadata.select_dtypes(include=['int64', 'float64']).columns)

if DEBUG:
    logger.info(f"Data table names are:\n{mgf_parquet_dfs.keys()}")
    logger.info(f"Categorical metadata columns are:\n{categorical_columns}")
    logger.info(f"Numerical metadata columns are:\n{numerical_columns}")

In [None]:
# df = mgf_parquet_dfs['ssu'].copy()
# if DEBUG:
#     logger.info(f'Number of unique ref_codes: {df.ref_code.nunique()}')

In [None]:
(select_table, select_cat_factor, 
 select_table_beta, select_taxon,
 select_beta_factor, beta_norm,
 ) = diversity_select_widgets(categorical_columns, numerical_columns)

In [None]:
tables = {
    "lsu": mgf_parquet_dfs['lsu'].copy(),
    "ssu": mgf_parquet_dfs['ssu'].copy(),
}

TAXONOMY = pd.DataFrame()
TAXONOMY_RANKS = ['superkingdom', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species']

# redefine select_table_beta
select_table_beta = pn.widgets.Select(
    name='Select table for beta diversity',
    options=list(tables.keys()),
    value='ssu',
)

## Alpha diversity

In [None]:
pn.extension("tabulator")
hv.extension("bokeh", "plotly")
if 'google.colab' in str(get_ipython()):
    pn.extension(comms='colab')

In [None]:
sort_alpha = pn.widgets.RadioBoxGroup(
    name='Sort by',
    options=['factor', 'values'],
    inline=True,
)
sort_alpha.value = 'factor'

backend = pn.widgets.RadioBoxGroup(
    name='Backend',
    options=['matplotlib', 'hvplot'],
    inline=True,
)

pn.Column(
    pn.Row(select_table,select_cat_factor),
    sort_alpha,
    backend
)

In [None]:
pl.alpha_plot(
    tables_dict=mgf_parquet_dfs,
    table_name=select_table.value,
    factor=select_cat_factor.value,
    metadata=full_metadata,
    order=sort_alpha.value,
    backend=backend.value,
)

In [None]:
avplot = pl.av_alpha_plot(
    tables_dict=mgf_parquet_dfs,
    table_name=select_table.value,
    factor=select_cat_factor.value,
    metadata=full_metadata,
    order=sort_alpha.value,
    backend=backend.value,
)
avplot.object.opts(
    title=f'Alpha diversity plot for {select_table.value} table',
    width=1000,
    height=600,
)

## Beta diversity

In [None]:
mapping = pn.widgets.Checkbox(
    name="strict mapping to selected taxonomic level (takes time)",
    value=True,
)

low_prevalence_cutoff = pn.widgets.FloatInput(
    name='Low prevalence cutoff [%]',
    value=10, step=1, start=0, end=100,
    description="Percentage of samples in which the taxon must be present not to be removed.",
)

button_process_taxonomy = pn.widgets.Button(
    name="Process taxonomy",
    button_type="primary",
    description="This will process the taxonomy and update the plots.",
    width=200,
)
progress1 = pn.indicators.Progress(name='Pre-processing progress', value=-1,
                                   active=True, width=200)

## Pre-process taxonomy

In [None]:
def process_taxonomy(table, high_taxon, mapping, prevalence_cutoff_value):
    """
    Preprocess the taxonomy data.
    """
    global TAXONOMY
    TAXONOMY = pd.DataFrame()
    df_filt = tables[table]

    TAXONOMY = taxonomy_common_preprocess01(df_filt, high_taxon, mapping, prevalence_cutoff_value, TAXONOMY_RANKS)

    progress1.value = 100

button_process_taxonomy.on_click(
    lambda event: process_taxonomy(
        select_table_beta.value,
        select_taxon.value,
        mapping.value,
        low_prevalence_cutoff.value
    )
)

In [None]:
pn.Column(
    pn.Row(
        select_table_beta,
        select_taxon,
        select_beta_factor,
        
    ),
    beta_norm,
    pn.layout.Divider(),
    mapping,
    low_prevalence_cutoff,
    progress1
    # button_process_taxonomy,
)

In [None]:
process_taxonomy(
    select_table_beta.value,
    select_taxon.value,
    mapping.value,
    low_prevalence_cutoff.value
)

In [None]:
TAXONOMY.shape

In [None]:
if TAXONOMY.empty:
    pl.beta_pc_plot.object, explained_var = pl.beta_plot_pc(
        tables_dict=tables,
        metadata=full_metadata,
        table_name=select_table_beta.value,
        factor=select_beta_factor.value,
        taxon=select_taxon.value,
    )
    
else:
    beta = beta_diversity_parametrized(
        TAXONOMY, taxon=select_taxon.value, metric="braycurtis"
    )
    pcoa_result = pcoa(beta, method="eigh")  # , number_of_dimensions=3)
    explained_variance = (
        pcoa_result.proportion_explained[0],
        pcoa_result.proportion_explained[1]
    )

    if not set(pcoa_result.samples.index) == set(full_metadata.index):
        raise ValueError("Metadata index name does not match PCoA result.")
    pcoa_df = pd.merge(
        pcoa_result.samples,
        full_metadata,
        left_index=True,
        right_index=True,
        how="inner",
    )
    beta_pc_plot, explained_var = pl.hvplot_plot_pcoa_black(pcoa_df, color_by=select_beta_factor.value, explained_variance=explained_variance), explained_variance

explained_var_indicator = sum(explained_var) * 100  # convert to percentage
print('Explained variance:', explained_var_indicator)
beta_pc_plot.opts(
    title=f'Beta diversity PCA plot for {select_table_beta.value} table',
    width=1200,
    height=800,
)

In [None]:
beta = beta_diversity_parametrized(
            TAXONOMY, taxon=select_taxon.value, metric="braycurtis"
        )

In [None]:
plot = pl.beta_plot(
    tables_dict=mgf_parquet_dfs,
    table_name=select_table_beta.value,
    norm=beta_norm.value,
    taxon=select_taxon.value,
    backend=backend.value,
)

plot.object.opts(
    title=f'Beta diversity PCA plot for {select_table_beta.value} table',
    width=1200,
    height=800,
)

In [None]:
beta_pc_plot, explained_var_indicator = pl.beta_plot_pc(
        tables_dict=mgf_parquet_dfs,
        metadata=full_metadata,
        table_name=select_table_beta.value,
        factor=select_beta_factor.value,
        taxon=select_taxon.value,
    )
print('Explained variance:', explained_var_indicator)
beta_pc_plot.opts(
    title=f'Beta diversity PCA plot for {select_table_beta.value} table',
    width=1200,
    height=800,
)