# Visualize taxonomy and alpha/beta diversities

In [1]:
## server-side imports
#- the difference to local imports should be resolved by setting the VRE packages well

In [2]:
# #Install
# !pip install scikit-bio
# !pip install duckdb
# !pip install pingouin

## Local imports

In [34]:
import sys
import os
import io

sys.path.append(os.path.abspath(os.path.join('..')))

import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt
import seaborn as sns
import pingouin as pg
import panel as pn

from skbio.diversity import alpha_diversity, beta_diversity
from skbio.stats.ordination import pcoa, pcoa_biplot
from scipy.stats import levene, shapiro, ttest_ind

import momics as mo
from momics.loader import load_parquet_files
# %load_ext autoreload
# %autoreload 2

## Loading

### methods

In [35]:
@pn.cache()
def get_data(folder):
    return load_parquet_files(folder)

@pn.cache()
def get_metadata(folder):
    # Load metadata
    sample_metadata = pd.read_csv(
        os.path.join(folder, "Batch1and2_combined_logsheets_2024-09-11.csv")
    )

    observatory_metadata = pd.read_csv(
        os.path.join(folder, "Observatory_combined_logsheets_validated.csv")
    )

    # Merge metadata
    full_metadata = pd.merge(
        sample_metadata,
        observatory_metadata,
        on=["obs_id", "env_package"],  # Matching conditions
        how="inner"  # Inner join
    )

    # Sort the merged dataframe by 'ref_code' column in ascending order
    full_metadata = full_metadata.sort_values(by="ref_code", ascending=True)
    
    return full_metadata

In [36]:
# parquet files
root_folder = os.path.abspath(os.path.join('../'))
data_folder = os.path.join(root_folder, 'parquet_files')
assets_folder = os.path.join(root_folder, 'assets')

mgf_parquet_dfs = get_data(data_folder)

In [37]:
# Load and merge metadata
full_metadata = get_metadata(data_folder)

# select categorical columns from metadata
categorical_columns = sorted(full_metadata.select_dtypes(include=['object']).columns)

# select numerical columns from metadata
numerical_columns = sorted(full_metadata.select_dtypes(include=['int64', 'float64']).columns)

assert len(full_metadata.columns) == len(numerical_columns) + len(categorical_columns)

# print(f"Data table names are:\n{mgf_parquet_dfs.keys()}")
# print(f"Categorical metadata columns are:\n{categorical_columns}")
# print(f"Numerical metadata columns are:\n{numerical_columns}")

In [38]:
# mgf_parquet_dfs['SSU'].sort_values(by='abundance', ascending=False)

In [39]:
df = mgf_parquet_dfs['SSU'].copy()
df.ref_code.nunique(), df.reads_name.nunique()

(54, 54)

`ref_code` and `reads_name` are the same length so I will use the `ref_code`

## Methods

### Panel
- TODO: put this into some function in the end

### Taxonomy

### Diversity

In [40]:
# I think this is only useful for beta, not alpha diversity
def diversity_input(df, kind='alpha', taxon_col="ncbi_tax_id"):
    """
    Prepare input for diversity analysis.

    Args:
        df (pd.DataFrame): The input dataframe.
        kind (str): The type of diversity analysis. Either 'alpha' or 'beta'.
        taxon_col (str): The column name containing the taxon IDs.

    Returns:
        pd.DataFrame: The input for diversity analysis.
    """
    # Convert DF
    out = pd.pivot_table(
        df,
        index="ref_code",
        columns=taxon_col,
        values="abundance",
        fill_value=0,
    )

    # Normalize rows
    if kind == 'beta':
        out = out.div(out.sum(axis=1), axis=0)

    assert df.ncbi_tax_id.nunique(), out.shape[1]
    return out


# Function to get the appropriate column based on the selected table
# Example tables: ['go', 'go_slim', 'ips', 'ko', 'pfam']
def get_key_column(table_name):
    if table_name in ["go", "go_slim"]:
        return "id"
    elif table_name == "ips":
        return "accession"
    elif table_name in ["ko", "pfam"]:
        return "entry"
    else:
        raise ValueError(f"Unknown table: {table_name}")
    

def alpha_input(table_name):
    key_column = get_key_column(table_name)
    print("Key column:", key_column)

    # select distinct ref_codes from the dataframe
    ref_codes = mgf_parquet_dfs[table_name]['ref_code'].unique()
    print('length of the ref_codes:', len(ref_codes))
    out = pd.pivot_table(mgf_parquet_dfs[table_name],
                           values='abundance',
                           index=[key_column],
                           columns=['ref_code'],
                           aggfunc='sum',
                           fill_value=0,
                           )
    print('table shape:', out.shape)
    return out
# Example usage
# alpha_input = diversity_input(df, king='alpha')
# beta_input = diversity_input(df, king='beta')


def shannon_index(row):
    row = pd.to_numeric(row, errors="coerce")
    total_abundance = row.sum()
    if total_abundance == 0:
        return np.nan
    relative_abundance = row / total_abundance
    ln_relative_abundance = np.log(relative_abundance)
    ln_relative_abundance[relative_abundance == 0] = 0
    multi = relative_abundance * ln_relative_abundance * -1
    return multi.sum()  # Shannon entropy


def calculate_shannon_index(df):
    return df.apply(shannon_index, axis=1)


def calculate_alpha_diversity(df, factors):
    # Select columns that start with the appropriate prefix
    numeric_columns = [
        col
        for col in df.columns
        if col.startswith("GO:")
        or col.startswith("IPR")
        or col.startswith("K")
        or col.startswith("PF")
    ]

    # Calculate Shannon index only from the selected columns
    shannon_values = calculate_shannon_index(df[numeric_columns])

    # Create DataFrame with Shannon index and ref_code
    alpha_diversity_df = pd.DataFrame(
        {"ref_code": df["ref_code"], "Shannon": shannon_values}
    )

    # Merge with factors
    alpha_diversity_df = alpha_diversity_df.merge(factors, on="ref_code")

    return alpha_diversity_df




### Plotting

In [54]:
# Plot the PCoA with optional coloring
# TODO: color_by does not work for categorical data
def plot_pcoa_black(pcoa_df, color_by=None):
    
    plt.figure(figsize=(10, 6))

    if color_by is not None:
        scatter = plt.scatter(
            pcoa_df["PC1"], pcoa_df["PC2"], c=color_by, cmap="RdYlGn", edgecolor="k"
        )
        plt.colorbar(scatter, label=color_by.name)
    else:
        plt.scatter(pcoa_df["PC1"], pcoa_df["PC2"], color="black")

    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.title("PCoA Plot")
    # plt.grid(True)
    plt.show()


def plot_heatmap(df):
    plt.figure(figsize=(10, 6))
    sns.heatmap(df, cmap="viridis")
    plt.title(f"Heatmap of beta diversity")
    plt.show()


def plot_alpha_diversity(series, factor=None):
    plt.figure(figsize=(10, 6))
    sns.barplot(x=series.index, y=series.values, hue=factor)
    plt.title(f"Alpha diversity")
    plt.show()


def mpl_alpha_diversity(df, factor=None, show_legend=True):

    plot = plt.figure(figsize=(10, 6), facecolor=(0, 0, 0, 0))
    ax = plot.add_subplot(111)
    ax.set_title(f"Alpha diversity")
    ax.set_xlabel(factor)
    # ax.set_ylabel(YLABEL)
    # ax.set_ylim(YLIM)
    # ax.set_xlim(xlim)

    # for continent, df in data.groupby('continent'):
    #     ax.scatter(df.gdpPercap, y=df.lifeExp, s=df['size']*5,
    #                edgecolor='black', label=continent)
    
    sns.barplot(x=df.index, y=df['Shannon'], hue=df[factor])

    if show_legend:
        ax.legend(loc=4)

    plt.close(plot)
    return plot


## Data setup

In [65]:
# alpha diversity
def alpha_diversity_parametrized(table_name, metadata):

    df_alpha_input = alpha_input(table_name).T.sort_values(by="ref_code")
    df_alpha_input = pd.merge(df_alpha_input, metadata, left_index=True, right_on='ref_code')
    alpha = calculate_alpha_diversity(df_alpha_input, metadata)
    return alpha


# beta diversity
df_beta_input = diversity_input(df, kind='beta', taxon_col="ncbi_tax_id")

beta = beta_diversity("braycurtis", df_beta_input)

# merge metadata
df_beta = pd.merge(beta.to_data_frame(), full_metadata, left_index=True, right_on='ref_code')

## App setup

In [44]:
# dfa = diversity_input(df, kind='alpha')

In [None]:
pn.extension("tabulator")
ACCENT = "teal"

styles = {
    "box-shadow": "rgba(50, 50, 93, 0.25) 0px 6px 12px -2px, rgba(0, 0, 0, 0.3) 0px 3px 7px -3px",
    "border-radius": "4px",
    "padding": "10px",
}

# TODO: there is a bug in the panel library that does not allow to open png files, renamed to jpg
image = pn.pane.JPG(os.path.join(assets_folder, "figs/metaGOflow_logo_italics.jpg"),
                    width=200, height=100)

## Widgets
select_table = pn.widgets.Select(
    name="Alpha diversity source table",
    value="go",
    options=["go", "go_slim", "ips", "ko", "pfam"],
    description="Select a table for alpha diversity analysis",
)

# TODO: probably not all of these make sense
select_categorical_factor = pn.widgets.Select(
    name="Factor",
    value=categorical_columns[0],
    options=categorical_columns,
    description="Categorical columns to compare alpha diversities",
)


indicators = pn.FlexBox(
    pn.indicators.Number(
        value=10, name="Not implemented", format="{value:,.0f}", styles=styles
    ),
)


def alpha_plot(table_name, factor):
    alpha = alpha_diversity_parametrized(table_name, full_metadata)
    fig = pn.pane.Matplotlib(
        mpl_alpha_diversity(alpha, factor=factor),
                            sizing_mode="stretch_both",
                            name="Plot",
                            )
    return fig

bplot = pn.bind(alpha_plot, table_name=select_table, factor=select_categorical_factor)
table = pn.widgets.Tabulator(df, sizing_mode="stretch_both", name="Table")

tabs = pn.Tabs(
    bplot, table, styles=styles, sizing_mode="stretch_width", height=500, margin=10
)

template = pn.template.FastListTemplate(
    title="Diversity Analysis",
    sidebar=[image, select_table, select_categorical_factor],
    main=[pn.Column(indicators,
                    tabs,
                    sizing_mode="stretch_both",
                   )],
    main_layout=None,
    accent=ACCENT,
)

template.servable()