# Visualize pivoted taxonomy (LSU and SSU tables)
- PCoA beta diversities
- Permanova calculations

## Platform dependent part
- Resolve platform setup
- the difference to local imports should be resolved by setting the Blue Cloud VRE well, Colab will still be an issue.

In [None]:
import sys
import os
import io
import gc
import logging
import psutil

from IPython import get_ipython
logger = logging.getLogger(name="Diversity analysis app")

if 'google.colab' in str(get_ipython()):
    print('Setting Google colab, you will need a ngrok account to make the dashboard display over the tunnel. \
    https://ngrok.com/')
    # clone the momics-demos repository to use it to load data
    try:
        os.system('git clone https://github.com/palec87/momics-demos.git')
        logger.info(f"Repository cloned")
    except OSError as e:
        logger.info(f"An error occurred while cloning the repository: {e}")

    sys.path.insert(0,'/content/momics-demos')

    # this step takes time beacause of many dependencies
    os.system('pip install momics@git+https://github.com/emo-bon/marine-omics-methods.git@main')

elif psutil.users() == []:
    logger.info("Binder")
    NUMBER_PERMUTATIONS = 29  # permanova extremely slow on binder
else:
    logger.info("Local")
    NUMBER_PERMUTATIONS = 999

from momics.utils import (
    memory_load, reconfig_logger,
    init_setup, get_notebook_environment,
)

# Set up logging
reconfig_logger()

# Determine the notebook environment
env = get_notebook_environment()

init_setup()
logger.info(f"Environment: {env}")

## Imports

In [None]:
# This needs to be repeated here for the Pannel dashboard to work, WEIRD
# TODO: report as possible bug
import sys
import os
import io
import warnings

from functools import partial
warnings.filterwarnings('ignore')

# import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
import panel as pn
from dotenv import load_dotenv
load_dotenv()

# from skbio.diversity import beta_diversity
# from skbio.stats.ordination import pcoa

# All low level functions are imported from the momics package
from momics.diversity import run_permanova, update_subset_indicator, update_taxa_count_indicator
from momics.loader import load_parquets
from momics.metadata import enhance_metadata, filter_metadata_table, filter_data
import momics.plotting as pl
from momics.panel_utils import (
    create_indicators_diversity,
    serve_app,
    close_server,
)
from momics.utils import memory_load, reconfig_logger
from momics.taxonomy import (
    pivot_taxonomic_data,
    separate_taxonomy)

### User settings

In [None]:
DEBUG = True  # enable stdout logging

## Loading

In [None]:
def fill_na_for_object_columns(df):
    """
    Fill NA values with 'NA' for object columns in the dataframe.

    Args:
        df (pd.DataFrame): The input dataframe.

    Returns:
        pd.DataFrame: The dataframe with NA values filled for object columns.
    """
    # Apply fillna only to object columns
    df[df.select_dtypes(include=['object']).columns] = df.select_dtypes(include=['object']).apply(lambda col: col.fillna('NA'))
    return df

@pn.cache()
def get_data(folder):
    return load_parquets(folder)

@pn.cache()
def get_metadata(folder):
    # Load metadata
    sample_metadata = pd.read_csv(
        os.path.join(folder, "Batch1and2_combined_logsheets_2024-11-12.csv")
    )

    observatory_metadata = pd.read_csv(
        os.path.join(folder, "Observatory_combined_logsheets_validated.csv")
    )

    # Merge metadata
    full_metadata = pd.merge(
        sample_metadata,
        observatory_metadata,
        on=["obs_id", "env_package"],  # Matching conditions
        how="inner"  # Inner join
    )

    # Sort the merged dataframe by 'ref_code' column in ascending order
    full_metadata = full_metadata.sort_values(by="ref_code", ascending=True)

    # first convert some of the boolean cols
    full_metadata["failure"] = full_metadata["failure"].astype(str)
    # replace the 'nan' values with 'NA'
    full_metadata["failure"] = full_metadata["failure"].replace("nan", "NA")


    # adding replacement for the missing values for object type columns
    full_metadata = fill_na_for_object_columns(full_metadata)
    
    return full_metadata

### Data

In [None]:
# parquet files
if 'google.colab' in str(get_ipython()):
    root_folder = os.path.abspath(os.path.join('/content/momics-demos'))
else:
    root_folder = os.path.abspath(os.path.join('../'))


data_folder = os.path.join(root_folder, 'data/parquet_files')
assets_folder = os.path.join(root_folder, 'assets')


mgf_parquet_dfs = get_data(data_folder)

### Enhance matadata

In [None]:
## TODO: enhace metadata here
# Load and merge metadata
full_metadata = get_metadata(os.path.join(root_folder, 'data'))
# fiter the metadata only for valid 181 samples
df_valid = pd.read_csv(
    os.path.join(root_folder, 'data/shipment_b1b2_181.csv')
)

full_metadata = enhance_metadata(full_metadata, df_valid)

In [None]:
# select categorical columns from metadata
categorical_columns = sorted(full_metadata.select_dtypes(include=['object', "boolean"]).columns)
cat_to_remove = ["ref_code", "samp_description", "source_mat_id", "source_mat_id_orig"]
categorical_columns = [k for k in categorical_columns if k not in cat_to_remove]

# select numerical columns from metadata
numerical_columns = sorted(full_metadata.select_dtypes(include=['int64', 'float64']).columns)

assert (
    len(full_metadata.columns) == len(numerical_columns) + len(categorical_columns) + len(cat_to_remove), # + for removed cats
"i have wrong number orf columns in the metadata",
)

if DEBUG:
    logger.info(f"Data table names are:\n{mgf_parquet_dfs.keys()}")
    logger.info(f"Categorical metadata columns are:\n{categorical_columns}")
    logger.info(f"Numerical metadata columns are:\n{numerical_columns}")

# This is for PCoA from Andrzej more or less
factors_to_remove = ['ENA_accession_number_project', "ENA_accession_number_umbrella", "arr_date_hq",
                     "arr_date_seq", "contact_email", "contact_name", "contact_orcid",
                     "investigation_type", "long_store", "organism_count_method", "organization_edmoid",
                     'other_person', 'other_person_orcid',"organization_country", "project_name",
                     "samp_store_date", 'samp_mat_process', 'samp_mat_process_dev',
                     'samp_store_loc', 'sampl_person', 'sampl_person_orcid', 'store_person',
                     'store_person_orcid', 'time_fi', "wa_id",
                     'env_broad_biome', 'env_local', "extra_site_info", 'failure_comment',
                     'obs_id', 'size_frac','ship_date', 'ship_date_seq', 'sampling_event', 'organism_count',
                     'samp_collect_device',
                     'ammonium_method', 'chlorophyll_method', 'conduc_method', 'density_method', 'diss_oxygen_method',
                     'nitrate_method', 'nitrite_method', 'ph_method', 'phaeopigments_method', 'phosphate_method', 'pigments_method', 'pressure_method',
                     'sea_subsurf_salinity_method', 'sea_subsurf_temp_method', 'sea_surf_salinity_method', 'sea_surf_temp_method',
                     'silicate_method', 'turbidity_method']

factor_cols = [col for col in categorical_columns if col not in factors_to_remove]

## Pivot the tables

In [None]:
# LSU and SSU
lsu = mgf_parquet_dfs['LSU']
ssu = mgf_parquet_dfs['SSU']

lsu_standard = pivot_taxonomic_data(lsu)
ssu_standard = pivot_taxonomic_data(ssu)

In [None]:
# Free memory
del mgf_parquet_dfs
del lsu
del ssu

gc.collect()

In [None]:
lsu_standard.set_index('taxonomic_concat', inplace=True)
ssu_standard.set_index('taxonomic_concat', inplace=True)

split_taxo_tables_lsu = separate_taxonomy(lsu_standard)
split_taxo_tables_ssu = separate_taxonomy(ssu_standard)

In [None]:
del lsu_standard
del ssu_standard
gc.collect()

## Granular PCoA page for the app


### Dropdowns for the pCOA
- credits for inspiration to Andrzej Tkacz's NB

In [None]:
pn.extension("tabulator")


granular_tables = {
    "LSU": split_taxo_tables_lsu,
    "SSU": split_taxo_tables_ssu
}

select_granular_table = pn.widgets.Select(
    name="Granular analysis",
    options= list(granular_tables.keys()),
    description="Select a table for granular analysis",
)

select_granular_level = pn.widgets.Select(
    name="Subset taxonomic level",
    options=list(granular_tables[select_granular_table.value].keys()),
    description="Select a table for analysis",
)

pcoa_factor_dropdowns = {
    categorical_col: pn.widgets.MultiSelect(
        name=categorical_col,
        value=['All'],
        options=['All'] + list(full_metadata[categorical_col].unique()),
        size=6, width=180,)
        for categorical_col in factor_cols
}

box_granular = pn.GridBox(
    *pcoa_factor_dropdowns.values(),
    ncols=5,
    )

color_factor_granular = pn.widgets.Select(
    name="Color by",
    value=factor_cols[0],
    options=factor_cols,
)

# show indicator of the explained variance
explained_var_indicator = pn.indicators.Number(
    name='Explained variance by PC1 + PC2', value=0, format='{value:.1f}%',
    font_size='20pt',
    title_size='12pt',
    colors=[(33, 'red'), (50, 'gold'), (66, 'green')]
)

beta_pc_plot_granular = pn.pane.Matplotlib(
    height=600,
    name="Beta PCoA",
    )

button_filter_table = pn.widgets.Button(
    name="Filter table",
    button_type="primary",
    width=200,
)

if DEBUG:
    logger.info(f"Granular levels are:\n{list(granular_tables[select_granular_table.value].keys())}")

### Methods
- filter data and metadata
- update widgets

In [None]:
def get_filtered_metadata():
    # Retrieve the selected factors from the dropdowns
    selected_factors = {col: pcoa_factor_dropdowns[col].value for col in factor_cols}
    # Filter the metadata table
    filtered_metadata = filter_metadata_table(full_metadata, selected_factors)
    return filtered_metadata


def filter_all_box_selection(df):
    # Retrieve the filtered metadata
    filtered_metadata = get_filtered_metadata()
    # Filter the data
    filtered_data = filter_data(df, filtered_metadata)
    return filtered_metadata, filtered_data


def update_beta_pc_plot_granular(filtered_data, metadata, factor):
    beta_pc_plot_granular.object, explained_var_indicator.value = pl.beta_plot_pc_granular(
        filtered_data=filtered_data,
        metadata=metadata,
        factor=factor)


def update_filtered_data(button):
    logger.info(f"Button clicked: {button.name}")
    # Retrieve the filtered metadata
    filtered_metadata, filtered_data = filter_all_box_selection(
        granular_tables[select_granular_table.value][select_granular_level.value])
    logger.info(f"matadata shape {filtered_metadata.shape}")
    logger.info(f"data shape {filtered_data.shape}")
    # Update the beta plot
    update_beta_pc_plot_granular(filtered_data, filtered_metadata, color_factor_granular.value)
    update_subset_indicator(subset_selected, filtered_metadata)
    update_taxa_count_indicator(taxa_selected, filtered_data)

In [None]:
filtered_metadata, filtered_data = filter_all_box_selection(granular_tables['LSU']['Bacteria_phylum'])

### Bindings

In [None]:
button_filter_table.on_click(update_filtered_data)

pn.bind(update_beta_pc_plot_granular,
    filtered_data=filtered_data,
    metadata=filtered_metadata,
    factor=color_factor_granular,
    watch=True,
    )


pcoa_instructions = pn.pane.Markdown(
    """
    ### Instructions
    1. Side panel filters LSU/SSU tables by taxonomy levels.
    2. Color_by is used to color the beta diversity plot.
    3. Main panel filter further the table by the metadata values.
        - `Ctrl`-click to select multiple values in the dropdowns.
    4. Filtering and update of the plot happens only after clicking the `Filter table` button to save CPU.
    """
)

pcoa_tab_granular = pn.Column(
    pcoa_instructions,
    box_granular,
    button_filter_table,
    explained_var_indicator,
    beta_pc_plot_granular,
    scroll=True,
)

## Permanova page for the app
- Credits to Andrzej Tkacz

### Widgets

In [None]:
# PERMANOVA Dropdowns
permanova_factor = pn.widgets.Select(
    name="Main Permanova factor",
    options=['All'] + factor_cols,
    description='Limit by group(s) in factor:',
)

permanova_group = pn.widgets.MultiSelect(
    name="Groups of unique values of the factor",
    options=[],
    description='Groups:',
)

permanova_additional_factors = pn.widgets.MultiSelect(
    name="Factors to test vs ALL the rest",
    options=factor_cols,
    description='PERMANOVA Factors:',
)

permanova_button = pn.widgets.Button(
    name="PERMANOVA",
    button_type="primary",
    width=200,
)

permanova_result_indicator = pn.widgets.Tabulator(pd.DataFrame(), name='Permanova Result')

permanova_instructions = pn.pane.Markdown(
    """
    ### Instructions
    1. Select a factor to limit the analysis.
    2. Select groups in the factor (`Ctrl`-click to select multiple).
    3. Select additional factors for against which PERMANOVA will be run (`Ctrl`-click to select multiple).
    4. Click the `PERMANOVA` button to run the analysis.
    5. **NOTE**, locally permanova with 999 permutations is instant, however takes extremely long on binder.
        - the number of permutations is set to 29 (for binder) and does not lead to correct p-value.
        - If you run locally it is 999.
    """
)

### Updates and bindings

In [None]:
def update_permanova_result():
    # Run the permanova function and update the result indicator
    permanova_results = run_permanova(
        granular_tables[select_granular_table.value][select_granular_level.value],
        full_metadata,
        permanova_factor.value,
        permanova_group.value,
        permanova_additional_factors.value,
        permutations=NUMBER_PERMUTATIONS,  # 29 for binder, 999 for local
        verbose=True,
    )
    permanova_result_indicator.value = pd.DataFrame.from_dict(permanova_results)

# Update groups based on selected factor
def update_groups(permanova_factor):
    logger.info(f"Permanova factor value: {permanova_factor}")
    if permanova_factor in factor_cols:
        unique_groups = sorted(full_metadata[permanova_factor].dropna().unique())
        permanova_group.options = unique_groups
    elif permanova_factor == 'All':
        permanova_group.options = sorted(full_metadata['ref_code'].dropna().unique())
    else:
        raise ValueError(f"Unknown factor: {permanova_factor}")
    
pn.bind(update_groups,
    permanova_factor,
    watch=True,
)

permanova_button.on_click(
    lambda event: update_permanova_result()
)

In [None]:
permanova_tab = pn.Column(
    permanova_instructions,
    pn.Row(
        permanova_factor,
        permanova_group,
        permanova_additional_factors,
    ),
    permanova_button,
    permanova_result_indicator,
    scroll=True,
)

### Add to the side panel

In [None]:
total_samplings = full_metadata['ref_code'].nunique()
subset = filtered_metadata['ref_code'].nunique()
taxa_count = len(filtered_data)


subset_selected = pn.indicators.Number(
    name="Subset of samples you filtered",
    value=subset,
    format="{value}" + f"/{total_samplings}",
    width=150,
    font_size="34px",
    title_size="14px",
)

taxa_selected = pn.indicators.Number(
    name="Taxa in the selection.",
    value=taxa_count,
    format="{value}",
    width=150,
    font_size="34px",
    title_size="14px",
)

## APP setup

In [None]:
pn.extension("tabulator")
if 'google.colab' in str(get_ipython()):
    pn.extension(comms='colab')
ACCENT = "teal"

styles = {
    "box-shadow": "rgba(50, 50, 93, 0.25) 0px 6px 12px -2px, rgba(0, 0, 0, 0.3) 0px 3px 7px -3px",
    "border-radius": "4px",
    "padding": "10px",
}

# TODO: there is a bug in the panel library that does not allow to open png files, renoming does not help 
image = pn.pane.JPG(os.path.join(assets_folder, "figs/metaGOflow_logo_italics.jpg"),
                    width=200,
                    height=100,
                    )

tabs = pn.Tabs(
    ('PCoA', pcoa_tab_granular),
    ('Permanova', permanova_tab),
    # atable,
    styles=styles,
    margin=10
)
_, indicator_usage = create_indicators_diversity()

def update_used_gb(event):
    if not event:
        return

    used_gb, total_gb = memory_load()
    indicator_usage.value = used_gb


# logger.info(f"just before the app definition {granular_tables[select_granular_table][select_granular_level].shape}")
filtered_metadata, filtered_data = filter_all_box_selection(
        granular_tables[select_granular_table.value][select_granular_level.value],
        )

def app():
    cb = pn.state.add_periodic_callback(
        partial(update_used_gb, indicator_usage),
        period=1000,
        timeout=None,
        )

    toggle = pn.widgets.Toggle(
        name='Toggle callback',
        value=True,
        button_type='success',)
    toggle.link(cb, bidirectional=True, value='running')

    template = pn.template.FastListTemplate(
        title="Diversity Analysis",
        sidebar=[image,
                "# Beta granular", select_granular_table, select_granular_level,
                color_factor_granular,
                pn.layout.Divider(),
                subset_selected,
                taxa_selected,
                pn.layout.Divider(),
                indicator_usage,
                toggle,
                ],
        main=[pn.Column(
                tabs,
            )],
        main_layout=None,
        accent=ACCENT,
    )
    return template

template = app()

# stupid trick to trigger updata()
color_factor_granular.value = color_factor_granular.options[1]
color_factor_granular.value = color_factor_granular.options[0]

if 'google.colab' in str(get_ipython()):  
    s = serve_app(template, env=env, name="diversity_analysis")
else:
    template.servable()

### Uncomment this if running if running ngrok tunnel which you want to quit

In [None]:
# only use for the ngrok tunnel in GColab
# close_server(s, env=env)