# Alpha diversity on functional information

The **Shannon, or Shannon-Wiener, Diversity Index** measures the diversity of species in a given community.
More details [here](https://www.statology.org/shannon-diversity-index/):

    
What is α (alpha) diversity? Check basic explanation [here](https://eco-intelligent.com/2016/10/14/alpha-beta-gamma-diversity/)

DP: For simplification, I removed all the SQL and duckDB, and use pure pandas
- TODO: benchmark both
  - for this create a benchmark suite in the momics package
- TODO: implement this in polars and benchmark too
- TODO: IF pandas is preferable, do it with pipelines, which should be faster
  - benchmark too

### Installing and importing required modules
DP: most of this is not needed, especially for the pure python approach

In [1]:
# # Install
# !pip install duckdb
# !pip install pandasql
# !pip install pingouin
# !pip install fastparquet
# !pip install nbQA # A linter for Python Jupyter notebooks.

# import contextlib
# import io
# import math
# # Import
# import os
# import warnings

# import duckdb
# import fastparquet
# import ipywidgets as widgets
# import matplotlib
# import matplotlib.pyplot as plt
# import numpy as np
# import pandas as pd
# import pingouin as pg
# import pyarrow
# import scipy
# import seaborn as sns
# from duckdb import BinderException, CatalogException
# from IPython.display import display
# from pandasql import sqldf
# from scipy.special import entr
# from scipy.stats import levene, shapiro, ttest_ind

## Local imports

In [2]:
import sys
import os
import platform

if platform.system() == 'Linux':
    sys.path.append(os.path.abspath(os.path.join('..')))
elif platform.system() == 'Windows':
    # I do not install the package via pip install -e, I rather add the path to the package to the sys.path
    # faster prototyping of the momics package
    sys.path.append("C:/Users/David Palecek/Documents/Python_projects/marine_omics/marine-omics")

import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt
import seaborn as sns
import pingouin as pg

from scipy.stats import levene, shapiro, ttest_ind

import momics as mo
from momics.loader import load_parquet_files

%load_ext autoreload
%autoreload 2

In [3]:
root_folder = os.path.abspath(os.path.join('../parquet_files'))

### Load parquet files

In [4]:
# mgf_parquet_dfs = load_parquet_files("../parquet_files")
mgf_parquet_dfs = load_parquet_files(root_folder)

In [5]:
# display dicts if necessary
# [display(val) for _, val in mgf_parquet_dfs.items()]

In [6]:
sample_metadata = pd.read_csv(
    os.path.join(root_folder, "Batch1and2_combined_logsheets_2024-09-11.csv")
)

# Observatory metadata - from the GoogleSheets
observatory_metadata = pd.read_csv(
    os.path.join(root_folder, "Observatory_combined_logsheets_validated.csv")
)

# This should be pipeline I guess
# merge metadata
full_metadata = pd.merge(
    sample_metadata,
    observatory_metadata,
    on=["obs_id", "env_package"],  # Matching conditions
    how="inner"  # Inner join
)

# Sort the merged dataframe by 'ref_code' column in ascending order
full_metadata = full_metadata.sort_values(by="ref_code", ascending=True)
# full_metadata

In [7]:
print(mgf_parquet_dfs.keys())
# mgf_parquet_dfs['go']
# mgf_parquet_dfs['go_slim']
# mgf_parquet_dfs['ips']
# mgf_parquet_dfs['LSU']
# mgf_parquet_dfs['ko']
# mgf_parquet_dfs['SSU']
# mgf_parquet_dfs['pfam']

dict_keys(['go', 'go_slim', 'ips', 'ko', 'LSU', 'pfam', 'SSU'])


## Methods


### Common methods

In [8]:
# Function to get the appropriate column based on the selected table
# Example tables: ['go', 'go_slim', 'ips', 'ko', 'pfam']
def get_key_column(table_name):
    if table_name in ["go", "go_slim"]:
        return "id"
    elif table_name == "ips":
        return "accession"
    elif table_name in ["ko", "pfam"]:
        return "entry"
    else:
        raise ValueError(f"Unknown table: {table_name}")


def merge_data(pivoted_table_transposed_sorted, full_metadata):
    return pd.merge(
        full_metadata, pivoted_table_transposed_sorted, on="ref_code", how="inner"
    )

### Statistical methods

In [9]:
def shannon_index(row):
    row = pd.to_numeric(row, errors="coerce")
    total_abundance = row.sum()
    if total_abundance == 0:
        return np.nan
    relative_abundance = row / total_abundance
    ln_relative_abundance = np.log(relative_abundance)
    ln_relative_abundance[relative_abundance == 0] = 0
    multi = relative_abundance * ln_relative_abundance * -1
    return multi.sum()  # Shannon entropy


def calculate_shannon_index(df):
    return df.apply(shannon_index, axis=1)


# Function to check normality using the Shapiro-Wilk test
def check_normality(data):
    stat, p_value = shapiro(data)
    return p_value > 0.05  # If p-value > 0.05, the data is normally distributed


# Function to check homogeneity of variances using Levene's test
def check_homogeneity_of_variances(groups):
    stat, p_value = levene(*groups)
    return p_value > 0.05  # If p-value > 0.05, variances are homogeneous


# Function to run t-test if there are only two conditions
def run_ttest(alpha_diversity_df, selected_factor, homogeneity):
    grouped_data = alpha_diversity_df.groupby(selected_factor)["Shannon"]
    groups = [group for _, group in grouped_data]

    # If variances are equal, run a standard t-test; otherwise, Welch's t-test
    equal_var = homogeneity
    t_stat, p_value = ttest_ind(groups[0], groups[1], equal_var=equal_var)

    return t_stat, p_value


# Modified ANOVA and post-hoc tests function with t-test option
def run_anova_and_posthoc(alpha_diversity_df, selected_factor):
    grouped_data = alpha_diversity_df.groupby(selected_factor)["Shannon"]
    groups = [group for _, group in grouped_data]

    if len(groups) == 1:
        return "Only one condition is present. No statistical test can be performed."

    normality = check_normality(alpha_diversity_df["Shannon"])

    # Handle cases with exactly two groups
    if len(groups) == 2:
        homogeneity = check_homogeneity_of_variances(groups)
        t_stat, p_value = run_ttest(alpha_diversity_df, selected_factor, homogeneity)
        result = f"\nT-test Results:\nT-statistic: {t_stat}, P-value: {p_value}\n"
        result += f"Normality: {'Pass' if normality else 'Fail'}\nHomogeneity: {'Pass' if homogeneity else 'Fail'}"
        return result

    # Handle cases with more than two groups
    try:
        homogeneity = check_homogeneity_of_variances(groups)
        anova_results = pg.anova(
            data=alpha_diversity_df, dv="Shannon", between=selected_factor
        )
        # print(anova_results.columns)  # Debugging line
        p_value_col = "p-unc" if "p-unc" in anova_results.columns else "p_val"
        result = f"\nANOVA Results:\n{anova_results}\nNormality: {'Pass' if normality else 'Fail'}\nHomogeneity: {'Pass' if homogeneity else 'Fail'}"

        if anova_results[p_value_col].values[0] < 0.05:
            if normality and homogeneity:
                posthoc = pg.pairwise_ttests(
                    data=alpha_diversity_df,
                    dv="Shannon",
                    between=selected_factor,
                    padjust="holm",
                )
                result += f"\nPost-hoc: Tukey's HSD\n{posthoc}"
            else:
                posthoc = pg.pairwise_gameshowell(
                    data=alpha_diversity_df, dv="Shannon", between=selected_factor
                )
                result += f"\nPost-hoc: Games-Howell\n{posthoc}"
        else:
            result += f"\nNo significant difference in {selected_factor}. No post-hoc test needed."

        return result
    except ValueError as e:
        return f"Error in statistical tests: {str(e)}"
    


def calculate_alpha_diversity(df, factors):
    # Select columns that start with the appropriate prefix
    numeric_columns = [
        col
        for col in df.columns
        if col.startswith("GO:")
        or col.startswith("IPR")
        or col.startswith("K")
        or col.startswith("PF")
    ]

    # Calculate Shannon index only from the selected columns
    shannon_values = calculate_shannon_index(df[numeric_columns])

    # Create DataFrame with Shannon index and ref_code
    alpha_diversity_df = pd.DataFrame(
        {"ref_code": df["ref_code"], "Shannon": shannon_values}
    )

    # Merge with factors
    alpha_diversity_df = alpha_diversity_df.merge(factors, on="ref_code")

    return alpha_diversity_df


### Plotting methods

In [10]:
def plot_shannon_index(alpha_diversity_df, selected_factor):
    alpha_diversity_df = alpha_diversity_df.sort_values(by=selected_factor)
    alpha_diversity_df["ref_code"] = pd.Categorical(
        alpha_diversity_df["ref_code"],
        categories=alpha_diversity_df["ref_code"],
        ordered=True,
    )
    plt.figure(figsize=(12, 6))
    # sns.barplot(x='ref_code', y='Shannon', hue=selected_factor, data=alpha_diversity_df, dodge=False, palette="coolwarm", errorbar=None)
    sns.barplot(
        x="ref_code",
        y="Shannon",
        hue=selected_factor,
        data=alpha_diversity_df,
        dodge=False,
        palette="coolwarm",
    )
    plt.xlabel("Sample")
    plt.ylabel("Shannon Index")
    plt.title(f"Shannon Index Grouped by {selected_factor}")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show()


def plot_average_shannon_per_condition(alpha_diversity_df, selected_factor):
    grouped_data = alpha_diversity_df.groupby(selected_factor)["Shannon"]
    means = grouped_data.mean()
    errors = grouped_data.sem()
    means.plot(
        kind="bar",
        yerr=errors,
        capsize=5,
        figsize=(10, 6),
        color=sns.color_palette("coolwarm", len(means)),
    )
    plt.xlabel(selected_factor)
    plt.ylabel("Average Shannon Index")
    plt.title(f"Average Shannon Index by {selected_factor}")
    plt.xticks(rotation=45, ha="right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

## Pure pandas

### Methods

In [11]:
def generate_pivot_table(table_name):
    key_column = get_key_column(table_name)
    print("Key column:", key_column)

    # select distinct ref_codes from the dataframe
    ref_codes = mgf_parquet_dfs[table_name]['ref_code'].unique()
    print('length of the ref_codes:', len(ref_codes))
    table = pd.pivot_table(mgf_parquet_dfs[table_name],
                           values='abundance',
                           index=[key_column],
                           columns=['ref_code'],
                           aggfunc='sum',
                           fill_value=0,
                           )
    print('table shape:', table.shape)
    return table

### Tests

In [12]:
## testing the function
table_name = 'go'
df = generate_pivot_table(table_name)

Key column: id
length of the ref_codes: 54
table shape: (2637, 54)


In [13]:
# df.columns

In [14]:
print([k for k, v in mgf_parquet_dfs.items()])

['go', 'go_slim', 'ips', 'ko', 'LSU', 'pfam', 'SSU']


### Widget pure pandas

In [None]:
# metadata already merged
# Widget Definition for Table Selection (SSU or LSU)
table_selection_dropdown = widgets.Dropdown(
    options=["go", "go_slim", "ips", "ko", "pfam"],  # Options for table selection
    value="go",  # Default selection
    description="Select Table:",
)

# Widget Definition for Color Factor Selection
color_factor_dropdown = widgets.Dropdown(
    options=["Please select"]
    + [
        col for col in full_metadata.columns if full_metadata[col].dtype == "object"
    ],  # Non-numeric columns from factors
    value="Please select",  # Default value
    description="Color by:",
)

# Output widget for displaying plots
output_plot = widgets.Output()

Key column: id
length of the ref_codes: 54
table shape: (2637, 54)
Merged data:                       source_mat_id                 source_mat_id_orig  \
0      EMOBON_OOB_So_210608_micro_1    EMO BON OOB So 210608 micro (1)   
1     EMOBON_BPNS_So_210726_micro_1  EMOBON-210726-330-SoSOP3-Micro-R1   
2     EMOBON_BPNS_So_210726_micro_2  EMOBON-210726-330-SoSOP3-Micro-R2   
3  EMOBON_ROSKOGO_So_210826_micro_1    EMOBON ROSKOGO So210826 micro 1   
4  EMOBON_ROSKOGO_So_210826_micro_2    EMOBON ROSKOGO So210826 micro 2   

                                    samp_description  tax_id  \
0  EMOBON metagenome sediment sample from station...  412755   
1  EMOBON metagenome sediment sample from station...  412755   
2  EMOBON metagenome sediment sample from station...  412755   
3  EMOBON metagenome sediment sample from station...  412755   
4  EMOBON metagenome sediment sample from station...  412755   

              scientific_name investigation_type              env_material  \
0  marine se

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **k

Key column: id
length of the ref_codes: 54
table shape: (2637, 54)
Merged data:                       source_mat_id                 source_mat_id_orig  \
0      EMOBON_OOB_So_210608_micro_1    EMO BON OOB So 210608 micro (1)   
1     EMOBON_BPNS_So_210726_micro_1  EMOBON-210726-330-SoSOP3-Micro-R1   
2     EMOBON_BPNS_So_210726_micro_2  EMOBON-210726-330-SoSOP3-Micro-R2   
3  EMOBON_ROSKOGO_So_210826_micro_1    EMOBON ROSKOGO So210826 micro 1   
4  EMOBON_ROSKOGO_So_210826_micro_2    EMOBON ROSKOGO So210826 micro 2   

                                    samp_description  tax_id  \
0  EMOBON metagenome sediment sample from station...  412755   
1  EMOBON metagenome sediment sample from station...  412755   
2  EMOBON metagenome sediment sample from station...  412755   
3  EMOBON metagenome sediment sample from station...  412755   
4  EMOBON metagenome sediment sample from station...  412755   

              scientific_name investigation_type              env_material  \
0  marine se

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **k

Key column: id
length of the ref_codes: 54
table shape: (116, 54)
Merged data:                       source_mat_id                 source_mat_id_orig  \
0      EMOBON_OOB_So_210608_micro_1    EMO BON OOB So 210608 micro (1)   
1     EMOBON_BPNS_So_210726_micro_1  EMOBON-210726-330-SoSOP3-Micro-R1   
2     EMOBON_BPNS_So_210726_micro_2  EMOBON-210726-330-SoSOP3-Micro-R2   
3  EMOBON_ROSKOGO_So_210826_micro_1    EMOBON ROSKOGO So210826 micro 1   
4  EMOBON_ROSKOGO_So_210826_micro_2    EMOBON ROSKOGO So210826 micro 2   

                                    samp_description  tax_id  \
0  EMOBON metagenome sediment sample from station...  412755   
1  EMOBON metagenome sediment sample from station...  412755   
2  EMOBON metagenome sediment sample from station...  412755   
3  EMOBON metagenome sediment sample from station...  412755   
4  EMOBON metagenome sediment sample from station...  412755   

              scientific_name investigation_type              env_material  \
0  marine sed

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **k

In [16]:
### Updating plots - Function to update the plot based on user selections
def update_plot(change):
    # Fetch the selected table and factor from the dropdowns
    selected_table = table_selection_dropdown.value
    selected_factor = color_factor_dropdown.value

    if selected_table == "Please select":
        return
    
    # Generate the pivoted table
    pivoted_table = generate_pivot_table(selected_table)

    # Transpose the DataFrame
    pivoted_table_transposed = pivoted_table.T

    # Sort the DataFrame by 'ref_code'
    pivoted_table_transposed_sorted = pivoted_table_transposed.sort_values(
        by="ref_code"
    )

    # Merge the full metadata with the pivoted table
    merged_data = merge_data(pivoted_table_transposed_sorted, full_metadata)
    print(f"Merged data: {merged_data.head()}")

    # Calculate Shannon index and plot
    alpha_diversity_df = calculate_alpha_diversity(merged_data, full_metadata)
    print(f"Alpha diversity DF: {alpha_diversity_df.head()}")

    # Clear previous plot and display the new one
    with output_plot:
        output_plot.clear_output(wait=True)
        plot_shannon_index(alpha_diversity_df, selected_factor)
        plot_average_shannon_per_condition(alpha_diversity_df, selected_factor)
        anova_output = run_anova_and_posthoc(alpha_diversity_df, selected_factor)
        print(anova_output)


# Observe changes in dropdowns
color_factor_dropdown.observe(update_plot, names="value")
table_selection_dropdown.observe(update_plot, names="value")

# Display dropdowns and output
display(table_selection_dropdown, color_factor_dropdown, output_plot)

Dropdown(description='Select Table:', options=('go', 'go_slim', 'ips', 'ko', 'pfam'), value='go')

Dropdown(description='Color by:', options=('Please select', 'source_mat_id', 'source_mat_id_orig', 'samp_descr…

Output()

### Benchmarks

In [17]:
# benchmark creation of the pivot tables from mgf_parquet_dfs
from time import perf_counter

start = perf_counter()
for table_name in ['go', 'go_slim', 'ips', 'ko', 'pfam']:
    df = generate_pivot_table(table_name)
end = perf_counter()

print(f"Time taken: {end - start} seconds")

Key column: id
length of the ref_codes: 54
table shape: (2637, 54)
Key column: id
length of the ref_codes: 54
table shape: (116, 54)
Key column: accession
length of the ref_codes: 54
table shape: (18733, 54)
Key column: entry
length of the ref_codes: 54
table shape: (4076, 54)
Key column: entry
length of the ref_codes: 54
table shape: (17442, 54)
Time taken: 1.1082419000003938 seconds


## duckDB way

In [18]:
import duckdb
from duckdb import BinderException, CatalogException

%load_ext autoreload
%autoreload 2

ModuleNotFoundError: No module named 'duckdb'

### DB methods

In [None]:
def dbQuery2df(q):
    return duckdb.sql(q).df()


# this is basically pivot table generation
def generate_sql_query(table_name):
    key_column = get_key_column(table_name)

    # Step 2: Get distinct ref_code (EMOBON IDs)
    ref_codes = duckdb.sql(f"SELECT DISTINCT ref_code FROM {table_name};").fetchdf()

    # Step 3: Generate the dynamic pivot query with COALESCE to replace NULL with 0
    sql_query = f"SELECT {key_column}"
    for ref_code in ref_codes["ref_code"]:
        sql_query += f", COALESCE(MAX(CASE WHEN ref_code = '{ref_code}' THEN abundance END), 0) AS {ref_code}"
    sql_query += f" FROM {table_name} GROUP BY {key_column};"

    return sql_query


def createDuckDB(df_tables, sample_metadata, observatory_metadata):
    if isinstance(df_tables, list):
        # TODO: convert to dict
        raise NotImplementedError
    
    # dictionary
    duckdb.sql("CREATE TABLE SAMPLE_METADATA AS SELECT * FROM sample_metadata")
    duckdb.sql("SELECT COUNT(*) FROM SAMPLE_METADATA")
    duckdb.sql("CREATE TABLE OBS_METADATA AS SELECT * FROM observatory_metadata")
    duckdb.sql("SELECT COUNT(*) FROM OBS_METADATA")
    for table_name in df_tables:
        df = df_tables[table_name]
        cmd = f"CREATE TABLE {table_name} AS SELECT * FROM df"
        duckdb.sql(cmd)

    print(duckdb.sql("SHOW TABLES"))

### Code

### Create the data tables
- find out how to manage decision between creation of the DB vs CPU storage

In [None]:
# duckdb.execute("SHOW TABLES").df()
# duckdb.execute("SHOW ALL TABLES").df()

In [None]:
# delete tables if they exist, should the user be asked?
try:
    for table_name in duckdb.execute("SHOW TABLES").df()['name']:
        duckdb.sql(f"DROP TABLE {table_name}")
except CatalogException:
    pass

# data to duckDB
createDuckDB(mgf_parquet_dfs, sample_metadata, observatory_metadata)

In [81]:
# duckdb.execute("SHOW TABLES").df()

In [None]:
# display(duckdb.execute(f"SELECT * FROM go LIMIT 5").df())
# display(duckdb.execute(f"SELECT * FROM go_pivoted LIMIT 5").df())
# display(duckdb.execute(f"SELECT COUNT(*) FROM go_pivoted").df())

In [None]:
# show the first few rows of the tables

# for table_name in duckdb.execute("SHOW TABLES").df()['name']:
#     print(table_name)
#     display(duckdb.execute(f"SELECT * FROM {table_name} LIMIT 5").df())

In [77]:
query_metadata = """
    SELECT OBS_METADATA.*,
           SAMPLE_METADATA.*
    FROM SAMPLE_METADATA
    INNER JOIN OBS_METADATA 
    ON SAMPLE_METADATA.obs_id = OBS_METADATA.obs_id
    AND SAMPLE_METADATA.env_package = OBS_METADATA.env_package
    ORDER BY SAMPLE_METADATA.ref_code ASC
    """

# DP, do not understand, these two queries are the same
# Get the full_metadata DataFrame
full_metadata = dbQuery2df(query_metadata)

factors = duckdb.sql(query_metadata).df()

### Widget using DB

In [None]:
# Widget Definition for Table Selection (SSU or LSU)
table_selection_dropdown = widgets.Dropdown(
    options=["go", "go_slim", "ips", "ko", "pfam"],  # Options for table selection
    value="go",  # Default selection
    description="Select Table:",
)

# Widget Definition for Color Factor Selection
color_factor_dropdown = widgets.Dropdown(
    options=["Please select"]
    + [
        col for col in factors.columns if factors[col].dtype == "object"
    ],  # Non-numeric columns from factors
    value="Please select",  # Default value
    description="Color by:",
)

# Output widget for displaying plots
output_plot = widgets.Output()


#### Updating plots - Function to update the plot based on user selections
def update_plot(change):
    # Fetch the selected table and factor from the dropdowns
    selected_table = table_selection_dropdown.value
    selected_factor = color_factor_dropdown.value

    # Print debug information (you can remove these after debugging)
    print(f"Selected table: {selected_table}, Selected factor: {selected_factor}")

    if selected_factor != "Please select":
        # Generate the SQL query for the selected table
        query = generate_sql_query(selected_table)
        print(f"Generated SQL query: {query}")

        # Create or replace the pivoted table
        duckdb.sql(f"CREATE OR REPLACE TABLE {selected_table}_pivoted AS {query};")

        # Fetch the pivoted table
        pivoted_table = duckdb.sql(f"SELECT * FROM {selected_table}_pivoted;").fetchdf()
        print(f"Pivoted table: {pivoted_table.head()}")

        # Set key column ('id', 'accession', or 'entry') as index
        key_column = get_key_column(selected_table)
        pivoted_table.set_index(key_column, inplace=True)

        # Transpose the DataFrame
        pivoted_table_transposed = pivoted_table.T

        # Reset the index to get 'ref_code' as a column
        pivoted_table_transposed.reset_index(inplace=True)
        pivoted_table_transposed.rename(columns={"index": "ref_code"}, inplace=True)

        # Sort the DataFrame by 'ref_code'
        pivoted_table_transposed_sorted = pivoted_table_transposed.sort_values(
            by="ref_code"
        )

        # Merge the full metadata with the pivoted table
        merged_data = merge_data(pivoted_table_transposed_sorted, full_metadata)
        print(f"Merged data: {merged_data.head()}")

        # Calculate Shannon index and plot
        alpha_diversity_df = calculate_alpha_diversity(merged_data, full_metadata)
        print(f"Alpha diversity DF: {alpha_diversity_df.head()}")

        # Clear previous plot and display the new one
        with output_plot:
            output_plot.clear_output(wait=True)
            plot_shannon_index(alpha_diversity_df, selected_factor)
            plot_average_shannon_per_condition(alpha_diversity_df, selected_factor)
            anova_output = run_anova_and_posthoc(alpha_diversity_df, selected_factor)
            print(anova_output)


# Observe changes in dropdowns
color_factor_dropdown.observe(update_plot, names="value")
table_selection_dropdown.observe(update_plot, names="value")

# Display dropdowns and output
display(table_selection_dropdown, color_factor_dropdown, output_plot)

Dropdown(description='Select Table:', options=('go', 'go_slim', 'ips', 'ko', 'pfam'), value='go')

Dropdown(description='Color by:', options=('Please select', 'obs_id', 'project_name', 'geo_loc_name', 'loc_bro…

Output()