# Visualize taxonomy and alpha/beta diversities

## Platform dependent part
- Resolve platform setup
- the difference to local imports should be resolved by setting the Blue Cloud VRE well, Colab will still be an issue.

In [1]:
import sys
import os
import io
import logging
from IPython import get_ipython
logger = logging.getLogger(name="Diversity analysis app")

if 'google.colab' in str(get_ipython()):
    # clone the momics-demos repository to use the utils module from there
    # TODO: eventually utils from momics will be used for that
    try:
        os.system('git clone https://github.com/palec87/momics-demos.git')
        logger.info(f"Repository cloned")
    except OSError as e:
        logger.info(f"An error occurred while cloning the repository: {e}")

    sys.path.insert(0,'/content/momics-demos')
elif "zmqshell" in str(get_ipython()):
    logger.info("Binder")
    print('binder')
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
else:
    logger.info("Local")
    print('local')
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))  # local utils, to be removed in the future

    # downside of this is that all the deps need to be installed in the current (momics-demos) environment
    # sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../../marine-omics')))  # local momics package, to be removed too

from utils import init_setup, get_notebook_environment
# Determine the notebook environment
env = get_notebook_environment()

init_setup()
logger.info(f"Environment: {env}")

# if path exists add sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../../marine-omics')))  # local momics package, to be removed too
local_momics_path = os.path.abspath(os.path.join(os.getcwd(), '../../../marine-omics'))
if os.path.exists(local_momics_path):
    sys.path.append(local_momics_path)
    logger.info(f"Added local momics path: {local_momics_path}")
    print(f"Added local momics path: {local_momics_path}")

binder
Platform: local Linux
Added local momics path: /media/davidp/Data/coding/marine_omics/marine-omics


## Imports

In [2]:
# This needs to be repeated here for the Pannel dashboard to work, WEIRD
# TODO: report as possible bug
import sys
import os
import io
import warnings
import psutil


from functools import partial
warnings.filterwarnings('ignore')

# import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
import panel as pn
from dotenv import load_dotenv
load_dotenv()

from skbio.diversity import beta_diversity
from skbio.stats.ordination import pcoa

# All low level functions are imported from the momics package
from momics.loader import load_parquets, process_collection_date, extract_season
import momics.plotting as pl
from momics.panel_utils import (
    diversity_select_widgets, create_indicators_diversity,
    serve_app, close_server,
)
from momics.utils import memory_load, reconfig_logger
from momics.taxonomy import (
    pivot_taxonomic_data,
    separate_taxonomy)

# Note: This is breaking the panel preview functionality
# %load_ext autoreload
# %autoreload 2

## User settings

In [3]:
DEBUG = True  # enable stdout logging

# Set up logging
reconfig_logger()

INFO | root | Logging.basicConfig completed successfully


## Loading

In [4]:
def fill_na_for_object_columns(df):
    """
    Fill NA values with 'NA' for object columns in the dataframe.

    Args:
        df (pd.DataFrame): The input dataframe.

    Returns:
        pd.DataFrame: The dataframe with NA values filled for object columns.
    """
    # Apply fillna only to object columns
    df[df.select_dtypes(include=['object']).columns] = df.select_dtypes(include=['object']).apply(lambda col: col.fillna('NA'))
    return df

@pn.cache()
def get_data(folder):
    return load_parquets(folder)

# @pn.cache()
def get_metadata(folder):
    # Load metadata
    sample_metadata = pd.read_csv(
        os.path.join(folder, "Batch1and2_combined_logsheets_2024-11-12.csv")
    )

    observatory_metadata = pd.read_csv(
        os.path.join(folder, "Observatory_combined_logsheets_validated.csv")
    )

    # Merge metadata
    full_metadata = pd.merge(
        sample_metadata,
        observatory_metadata,
        on=["obs_id", "env_package"],  # Matching conditions
        how="inner"  # Inner join
    )

    # Sort the merged dataframe by 'ref_code' column in ascending order
    full_metadata = full_metadata.sort_values(by="ref_code", ascending=True)

    # first convert some of the boolean cols
    full_metadata["failure"] = full_metadata["failure"].astype(str)
    # replace the 'nan' values with 'NA'
    full_metadata["failure"] = full_metadata["failure"].replace("nan", "NA")


    # adding replacement for the missing values for object type columns
    full_metadata = fill_na_for_object_columns(full_metadata)
    
    return full_metadata

In [5]:
# parquet files
if 'google.colab' in str(get_ipython()):
    root_folder = os.path.abspath(os.path.join('/content/momics-demos'))
else:
    root_folder = os.path.abspath(os.path.join('../'))


data_folder = os.path.join(root_folder, 'data/parquet_files')
assets_folder = os.path.join(root_folder, 'assets')


mgf_parquet_dfs = get_data(data_folder)

### Enhance matadata

In [6]:
## TODO: enhace metadata here
# Load and merge metadata
full_metadata = get_metadata(os.path.join(root_folder, 'data'))
full_metadata = process_collection_date(full_metadata)
full_metadata = extract_season(full_metadata)

# fiter the metadata only for valid 181 samples
df_valid = pd.read_csv(
    os.path.join(root_folder, 'data/shipment_b1b2_181.csv')
)

# Filter the full_metadata on the 'ref_code' only for entries that are in df_valid
full_metadata = full_metadata[full_metadata['ref_code'].isin(df_valid['ref_code'])]

missing = df_valid[~df_valid['ref_code'].isin(full_metadata['ref_code'])]
assert len(missing) == 0, "Missing samples in the metadata"
assert len(full_metadata) == len(df_valid), "Filtered metadata does not match the valid samples"

In [7]:
# select categorical columns from metadata
categorical_columns = sorted(full_metadata.select_dtypes(include=['object', "boolean"]).columns)
cat_to_remove = ["ref_code", "samp_description", "source_mat_id", "source_mat_id_orig"]
categorical_columns = [k for k in categorical_columns if k not in cat_to_remove]

# select numerical columns from metadata
numerical_columns = sorted(full_metadata.select_dtypes(include=['int64', 'float64']).columns)

assert (
    len(full_metadata.columns) == len(numerical_columns) + len(categorical_columns) + len(cat_to_remove), # + for removed cats
"i have wrong number orf columns in the metadata",
)

if DEBUG:
    logger.info(f"Data table names are:\n{mgf_parquet_dfs.keys()}")
    logger.info(f"Categorical metadata columns are:\n{categorical_columns}")
    logger.info(f"Numerical metadata columns are:\n{numerical_columns}")

INFO | Diversity analysis app | Data table names are:
dict_keys(['go', 'go_slim', 'ips', 'ko', 'LSU', 'pfam', 'SSU'])
INFO | Diversity analysis app | Categorical metadata columns are:
['ENA_accession_number_project', 'ENA_accession_number_umbrella', 'ammonium_method', 'arr_date_hq', 'arr_date_seq', 'chlorophyll_method', 'conduc_method', 'contact_email', 'contact_name', 'contact_orcid', 'density_method', 'diss_oxygen_method', 'env_broad_biome', 'env_local', 'env_material', 'env_package', 'extra_site_info', 'failure', 'failure_comment', 'geo_loc_name', 'investigation_type', 'loc_broad_ocean', 'loc_loc', 'loc_regional', 'month_name', 'nitrate_method', 'nitrite_method', 'obs_id', 'organism_count', 'organism_count_method', 'organization', 'organization_country', 'organization_edmoid', 'other_person', 'other_person_orcid', 'ph_method', 'phaeopigments_method', 'phosphate_method', 'pigments', 'pigments_method', 'pressure_method', 'project_name', 'replicate', 'samp_collect_device', 'samp_mat_pr

### Pivot the tables here

In [8]:
# LSU and SSU
lsu = mgf_parquet_dfs['LSU']
ssu = mgf_parquet_dfs['SSU']

lsu_standard = pivot_taxonomic_data(lsu)
ssu_standard = pivot_taxonomic_data(ssu)

In [9]:
# this is used fot the tabular view only
# df = mgf_parquet_dfs['SSU'].copy()
# if DEBUG:
#     logger.info(f'Number of unique ref_codes: {df.ref_code.nunique()}')

## Development of the beta diversity part

In [10]:
# TODO: link these functions to the indicator
# TODO: put them in the momics package
def get_missing_taxa(df):
    for taxon in ["superkingdom", "kingdom", "phylum", "class", "order", "family", "genus", "species"]:
        logger.info(f'Not classified on {taxon}: {get_missing_taxa_single(df, taxon)}')
    return

def get_missing_taxa_single(df, taxon):
    return len(df[df[taxon].isnull()])

## Permanova calculation

In [11]:
## TODO ##

### Extract from pivot tables

In [12]:
# Function to aggregate data by a specific taxonomic level
def aggregate_by_taxonomic_level(df, level):
    # Drop rows where the level is missing
    df_level = df.dropna(subset=[level])
    # Group by the specified level and sum abundances across samples (columns)
    df_grouped = df_level.groupby(level).sum(numeric_only=True)
    return df_grouped

def separate_taxonomy(df):
    # eukaryota_keywords = ['Discoba', 'Stramenopiles', 'Rhizaria', 'Alveolata', 'Amorphea', 'Archaeoplastida', 'Excavata']

    # Separate rows based on "Bacteria", "Archaea", and "Eukaryota" entries
    prokaryotes_all = df[df.index.str.contains("Bacteria|Archaea", regex=True)]
    eukaryota_all = df[df.index.str.contains("Eukaryota", regex=True)]

    # Further divide "Prokaryotes all" into "Bacteria" and "Archaea"
    bacteria = prokaryotes_all[prokaryotes_all.index.str.contains("Bacteria")]
    archaea = prokaryotes_all[prokaryotes_all.index.str.contains("Archaea")]

    # Further divide "Eukaryota all" by specific keywords
    # eukaryota_dict = {}
    # for keyword in eukaryota_keywords:
    #     subset = eukaryota_all[eukaryota_all.index.str.contains(keyword)]
    #     eukaryota_dict[keyword] = subset
    #     # Standardize each column to sum to 100 before saving the CSV
    #     subset_normalized = subset.div(subset.sum(axis=0), axis=1) * 100

    # Apply taxonomy splitting to the index
    taxonomy_levels = bacteria.index.to_series().apply(split_taxonomy)
    taxonomy_df = pd.DataFrame(taxonomy_levels.tolist(), columns=['phylum', 'class', 'order', 'family', 'genus', 'species'],
                               index=bacteria.index)

    # Combine taxonomy with the abundance data
    bacteria_data = pd.concat([taxonomy_df, bacteria], axis=1)
    

    # Aggregate at each taxonomic level and save to CSV
    taxonomic_levels = ['phylum', 'class', 'order', 'family', 'genus']
    bacteria_levels_dict = {}
    for level in taxonomic_levels:
        aggregated_df = aggregate_by_taxonomic_level(bacteria_data, level)
        # Standardize the values so each column sums to 100
        aggregated_df_normalized = aggregated_df.div(aggregated_df.sum(axis=0), axis=1) * 100
        bacteria_levels_dict[f"Bacteria_{level}"] = aggregated_df_normalized

    all_data = {
        "Prokaryotes All": prokaryotes_all,
        "Eukaryota All": eukaryota_all,
        "Bacteria": bacteria,
        "Archaea": archaea
    }
    # all_data.update(eukaryota_dict)
    all_data.update(bacteria_levels_dict)

    return all_data

def split_taxonomy(index_name):
    # Remove anything before "Bacteria" or "Archaea"
    if "Bacteria" in index_name:
        taxonomy = index_name.split("Bacteria;", 1)[1].split(";")
    elif "Archaea" in index_name:
        taxonomy = index_name.split("Archaea;", 1)[1].split(";")
    else:
        taxonomy = []
    # Return a list with taxonomic levels up to species
    return taxonomy[1:7]  # ['phylum', 'class', 'order', 'family', 'genus', 'species']

### Dropdowns for the pCOA as Andrzej

In [13]:
# lsu_standard.head()

In [14]:
lsu_standard.set_index('taxonomic_concat', inplace=True)
ssu_standard.set_index('taxonomic_concat', inplace=True)

split_taxo_tables_lsu = separate_taxonomy(lsu_standard)
split_taxo_tables_ssu = separate_taxonomy(ssu_standard)

In [None]:
pn.extension("tabulator")

granular_tables = {
    "LSU": split_taxo_tables_lsu,
    "SSU": split_taxo_tables_ssu
}

select_granular_table = pn.widgets.Select(
    name="Granular analysis",
    # value="LSU",
    options= list(granular_tables.keys()),
    description="Select a table for granular analysis",
)

select_granular_level = pn.widgets.Select(
    name="Subset taxonomic level",
    # value=options[0],
    options=list(granular_tables[select_granular_table.value].keys()),
    description="Select a table for analysis",
)

logger.info(f"Granular levels are:\n{list(granular_tables[select_granular_table.value].keys())}")


factors_to_remove = ['ENA_accession_number_project', "ENA_accession_number_umbrella", "arr_date_hq",
                     "arr_date_seq", "contact_email", "contact_name", "contact_orcid",
                     "investigation_type", "long_store", "organism_count_method", "organization_edmoid",
                     'other_person', 'other_person_orcid',"organization_country", "project_name",
                     "samp_store_date", 'samp_mat_process', 'samp_mat_process_dev',
                     'samp_store_loc', 'sampl_person', 'sampl_person_orcid', 'store_person',
                     'store_person_orcid', 'time_fi', "wa_id",
                     'env_broad_biome', 'env_local', "extra_site_info", 'failure_comment',
                     'obs_id', 'size_frac','ship_date', 'ship_date_seq', 'sampling_event', 'organism_count',
                     'samp_collect_device',
                     'ammonium_method', 'chlorophyll_method', 'conduc_method', 'density_method', 'diss_oxygen_method',
                     'nitrate_method', 'nitrite_method', 'ph_method', 'phaeopigments_method', 'phosphate_method', 'pigments_method', 'pressure_method',
                     'sea_subsurf_salinity_method', 'sea_subsurf_temp_method', 'sea_surf_salinity_method', 'sea_surf_temp_method',
                     'silicate_method', 'turbidity_method']
factor_cols = [col for col in categorical_columns if col not in factors_to_remove]
pcoa_factor_dropdowns = {
    categorical_col: pn.widgets.MultiSelect(
        name=categorical_col,
        value=['All'],
        options=['All'] + list(full_metadata[categorical_col].unique()),
        size=6, max_width=1,)
        for categorical_col in factor_cols
}
box_granular = pn.GridBox(*pcoa_factor_dropdowns.values(), ncols=5, sizing_mode="stretch_width")

color_factor_granular = pn.widgets.Select(
    name="Color by",
    value=factor_cols[0],
    options=factor_cols,
    # description="Select a table for analysis",
)

# Filter the metadata table based on the selections in box_granular
def filter_metadata_table(metadata_df, selected_factors):
    # Create a copy of the metadata DataFrame
    filtered_metadata = metadata_df.copy()
    # Apply filters for each selected factor
    for factor, selected_values in selected_factors.items():
        if 'All' not in selected_values:
            filtered_metadata = filtered_metadata[filtered_metadata[factor].isin(selected_values)]
    return filtered_metadata

def get_filtered_metadata():
    # Retrieve the selected factors from the dropdowns
    selected_factors = {col: pcoa_factor_dropdowns[col].value for col in factor_cols}
    # Filter the metadata table
    filtered_metadata = filter_metadata_table(full_metadata, selected_factors)
    return filtered_metadata

INFO | Diversity analysis app | Granular levels are:
['Prokaryotes All', 'Eukaryota All', 'Bacteria', 'Archaea', 'Bacteria_phylum', 'Bacteria_class', 'Bacteria_order', 'Bacteria_family', 'Bacteria_genus']


In [16]:
# box_granular

In [17]:
## filter data according to the metadata
def filter_data(df, filtered_metadata):
    # Filter the DataFrame column names based on the 'ref_code' values in the filtered metadata

    # filter columns names of df which are in the filtered metadata
    cols_to_keep = list([col for col in df.columns.str.strip() if col in filtered_metadata['ref_code'].to_list()])

    filtered_df = df[cols_to_keep]
    return filtered_df

def filter_all_box_selection(df):
    # Retrieve the filtered metadata
    filtered_metadata = get_filtered_metadata()
    # Filter the data
    filtered_data = filter_data(df, filtered_metadata)
    return filtered_metadata, filtered_data


## Additional page to the app


In [None]:
filtered_metadata, filtered_data = filter_all_box_selection(granular_tables['LSU']['Bacteria_phylum'])

# show indicator of the explained variance
explained_var_indicator = pn.indicators.Number(
    name='Explained variance by PC1 + PC2', value=0, format='{value:.1f}%',
    font_size='20pt',
    title_size='12pt',
    colors=[(33, 'red'), (50, 'gold'), (66, 'green')]
)

beta_pc_plot_granular = pn.pane.Matplotlib(
    sizing_mode="stretch_both",
    height=600,
    name="Beta PCoA",
    )

def update_beta_pc_plot_granular(filtered_data, metadata, factor):
    beta_pc_plot_granular.object, explained_var_indicator.value = pl.beta_plot_pc_granular(
        filtered_data=filtered_data,
        metadata=metadata,
        factor=factor)

pn.bind(update_beta_pc_plot_granular,
    filtered_data=filtered_data,
    metadata=filtered_metadata,
    factor=color_factor_granular,
    watch=True,
    )


def update_filtered_data(table, subtable):
    logger.info(f"Selections of BIg table and subtable: {table}, {subtable}")
    logger.info(f"Shape of the table {granular_tables[table][subtable].shape}")
    
    # logger.info(f"Selections of BIg table and subtable: {table.value}, {subtable.value}")
    filtered_metadata, filtered_data = filter_all_box_selection(
        granular_tables[table][subtable])
    logger.info(f"matadata shape {filtered_metadata.shape}")
    logger.info(f"data shape {filtered_data.shape}")

pn.bind(update_filtered_data,
    table=select_granular_table,
    subtable=select_granular_level,
    watch=True,
    )


# watch any of the box_granular multiselects
for factor in factor_cols:
    pn.bind(update_filtered_data,
        table=select_granular_table,
        subtable=select_granular_level,
        watch=True,
        )

button_filter_table = pn.widgets.Button(
    name="Filter table",
    button_type="primary",
    sizing_mode="stretch_width",
    width=200,
)
def update_filtered_data(button):
    logger.info(f"Button clicked: {button.name}")
    # Retrieve the filtered metadata
    filtered_metadata, filtered_data = filter_all_box_selection(
        granular_tables[select_granular_table.value][select_granular_level.value])
    logger.info(f"matadata shape {filtered_metadata.shape}")
    logger.info(f"data shape {filtered_data.shape}")
    # Update the beta plot
    update_beta_pc_plot_granular(filtered_data, filtered_metadata, color_factor_granular.value)

button_filter_table.on_click(update_filtered_data)

pcoa_tab_granular = pn.Column(
        box_granular,
        button_filter_table,
        explained_var_indicator,
        beta_pc_plot_granular,
        scroll=True,
        sizing_mode="stretch_both",
    )



TypeError: 'module' object is not callable

## App setup

In [None]:
pn.extension("tabulator")
if 'google.colab' in str(get_ipython()):
    pn.extension(comms='colab')
ACCENT = "teal"

styles = {
    "box-shadow": "rgba(50, 50, 93, 0.25) 0px 6px 12px -2px, rgba(0, 0, 0, 0.3) 0px 3px 7px -3px",
    "border-radius": "4px",
    "padding": "10px",
}

# TODO: there is a bug in the panel library that does not allow to open png files, renoming does not help 
image = pn.pane.JPG(os.path.join(assets_folder, "figs/metaGOflow_logo_italics.jpg"),
                    width=200,
                    height=100,
                    )


progress_bar, indicator_usage = create_indicators_diversity()

def update_used_gb(event):
    if not event:
        return

    used_gb, total_gb = memory_load()
    progress_bar.value = int(used_gb / total_gb * 100)
    indicator_usage.value = used_gb


# logger.info(f"just before the app definition {granular_tables[select_granular_table][select_granular_level].shape}")
filtered_metadata, filtered_data = filter_all_box_selection(
        granular_tables[select_granular_table.value][select_granular_level.value],
        )

def app():
    cb = pn.state.add_periodic_callback(
        partial(update_used_gb, indicator_usage),
        period=1000,
        timeout=None,
        )

    toggle = pn.widgets.Toggle(name='Toggle callback', value=True)
    toggle.link(cb, bidirectional=True, value='running')

    indicators = pn.FlexBox(
        indicator_usage,
        toggle)

    template = pn.template.FastListTemplate(
        title="Diversity Analysis",
        sidebar=[image,
                "# Beta granular", select_granular_table, select_granular_level,
                color_factor_granular,
                ],
        main=[pn.Column(
                indicators,
                pcoa_tab_granular,
            )],
        # main=[pcoa_tab_granular],
        main_layout=None,
        accent=ACCENT,
    )
    return template

template = app()

# stupid trick to trigger updata()
color_factor_granular.value = color_factor_granular.options[1]
color_factor_granular.value = color_factor_granular.options[0]

if 'google.colab' in str(get_ipython()):  
    s = serve_app(template, env=env, name="diversity_analysis")
else:
    template.servable()

### Uncomment this if running if running ngrok tunnel which you want to quit

In [None]:
# only use for the ngrok tunnel in GColab
# close_server(s, env=env)