# Indicator Species Analysis

## Calulating ISA values and running 4999 Monte Carlo Permutations to Determine Significance

In [None]:
import pandas as pd
import numpy as np
import os
import re
from collections import defaultdict

def calculate_indicator_value_stats(species_abundances_in_group, 
                                    species_abundances_all_samples,
                                    group_sample_indices, 
                                    all_sample_group_assignments):
    """
    Calculates specificity (A), fidelity (B), and IndVal for a single species in a target group.

    Args:
        species_abundances_in_group (pd.Series): Abundances of the species in samples belonging to the target group.
        species_abundances_all_samples (pd.Series): Abundances of the species across all samples.
        group_sample_indices (list): List of sample indices belonging to the target group.
        all_sample_group_assignments (pd.Series): Series mapping all sample indices to their group labels.
    
    Returns:
        tuple: (specificity_A, fidelity_B, indval)
    """
    # Specificity (A_ij)
    # Mean abundance of species i in sites of group j / sum of mean abundances of species i in all groups
    mean_abundance_in_target_group = species_abundances_in_group.mean()
    
    sum_of_mean_abundances_across_groups = 0
    unique_groups = all_sample_group_assignments.unique()
    for grp in unique_groups:
        grp_samples_mask = (all_sample_group_assignments == grp)
        mean_abund_in_grp = species_abundances_all_samples[grp_samples_mask].mean()
        if pd.isna(mean_abund_in_grp): # Handle case where a group might have no samples after filtering
            mean_abund_in_grp = 0
        sum_of_mean_abundances_across_groups += mean_abund_in_grp
        
    specificity_A = 0
    if sum_of_mean_abundances_across_groups > 0:
        specificity_A = mean_abundance_in_target_group / sum_of_mean_abundances_across_groups
    if pd.isna(specificity_A): specificity_A = 0


    # Fidelity (B_ij)
    # Number of sites in group j where species i is present / total number of sites in group j
    num_sites_in_group_with_species = (species_abundances_in_group > 0).sum()
    total_sites_in_group = len(group_sample_indices)
    
    fidelity_B = 0
    if total_sites_in_group > 0:
        fidelity_B = num_sites_in_group_with_species / total_sites_in_group
    if pd.isna(fidelity_B): fidelity_B = 0

    indval = specificity_A * fidelity_B
    return specificity_A, fidelity_B, indval


def run_indicator_species_analysis(
        otu_table_path, 
        metadata_path, 
        output_path,
        metadata_sample_id_col="#SampleID", 
        grouping_variable_col="DiseaseStatus",
        num_permutations=4999,
        taxonomy_cols_count=9 # Number of taxonomy columns after OTU_ID
    ):
    """
    Performs Indicator Species Analysis using IndVal and Monte Carlo permutations.
    """
    print(f"Loading OTU table from: {otu_table_path}")
    try:
        df_otu_full = pd.read_csv(otu_table_path)
    except FileNotFoundError:
        print(f"Error: OTU table file '{otu_table_path}' not found.")
        return
    except Exception as e:
        print(f"Error loading OTU table: {e}")
        return

    print(f"Loading metadata from: {metadata_path}")
    try:
        df_metadata = pd.read_csv(metadata_path)
    except FileNotFoundError:
        print(f"Error: Metadata file '{metadata_path}' not found.")
        return
    except Exception as e:
        print(f"Error loading metadata: {e}")
        return

    if df_otu_full.empty or df_metadata.empty:
        print("OTU table or metadata is empty. Aborting.")
        return

    # --- Data Preprocessing ---
    # OTU Table
    otu_id_col = df_otu_full.columns[0]
    print(f"Using '{otu_id_col}' as OTU/Taxon identifier column.")
    
    # Identify sample columns (assumed to be after OTU_ID and taxonomy columns)
    # Using regex for A##, B## pattern, robust to other non-matching columns
    sample_col_pattern = re.compile(r'^[A-Z]\d{2}$')
    all_columns = df_otu_full.columns.tolist()
    
    # Determine start of sample columns: first column after OTU_ID + taxonomy_cols_count
    # Or, more robustly, filter by pattern if taxonomy columns are not strictly fixed
    potential_sample_cols_start_index = 1 + taxonomy_cols_count 
    sample_cols = [col for col in all_columns[potential_sample_cols_start_index:] if sample_col_pattern.match(col)]
    
    if not sample_cols:
        # Fallback if the index-based method fails, try to find sample columns anywhere
        sample_cols = [col for col in all_columns if sample_col_pattern.match(col)]
        if not sample_cols:
            print("Error: No sample columns found matching the pattern (e.g., A01, B12). Please check OTU table format.")
            return
    print(f"Identified {len(sample_cols)} sample columns in OTU table.")

    df_otu_abundances = df_otu_full.set_index(otu_id_col)[sample_cols]
    df_otu_abundances = df_otu_abundances.apply(pd.to_numeric, errors='coerce').fillna(0)
    
    # Keep taxonomy for final output
    taxonomy_map = df_otu_full.set_index(otu_id_col).iloc[:, 0:taxonomy_cols_count]


    # Metadata
    if metadata_sample_id_col not in df_metadata.columns:
        print(f"Error: Metadata sample ID column '{metadata_sample_id_col}' not found.")
        return
    if grouping_variable_col not in df_metadata.columns:
        print(f"Error: Grouping variable column '{grouping_variable_col}' not found in metadata.")
        return
    
    df_metadata[metadata_sample_id_col] = df_metadata[metadata_sample_id_col].astype(str)
    df_metadata = df_metadata.set_index(metadata_sample_id_col)
    
    # Align samples
    common_samples = list(set(df_otu_abundances.columns) & set(df_metadata.index))
    if not common_samples:
        print("Error: No common samples found between OTU table and metadata. Check sample IDs.")
        return
    
    df_otu_aligned = df_otu_abundances[common_samples]
    df_metadata_aligned = df_metadata.loc[common_samples]
    
    # Ensure group assignments are a Series with sample IDs as index
    group_assignments = df_metadata_aligned[grouping_variable_col].astype('category')
    unique_groups = group_assignments.cat.categories.tolist()
    print(f"Found {len(unique_groups)} unique groups: {unique_groups}")
    if len(unique_groups) < 2:
        print("Error: Less than 2 groups found. Indicator Species Analysis requires at least 2 groups.")
        return

    # --- Calculate Observed Indicator Values ---
    print("\nCalculating observed indicator values...")
    observed_indvals = pd.DataFrame(index=df_otu_aligned.index, columns=unique_groups, dtype=float)
    
    for species_idx, species_id in enumerate(df_otu_aligned.index):
        if (species_idx + 1) % 100 == 0:
            print(f"  Processing observed IndVal for species {species_idx + 1}/{len(df_otu_aligned.index)}")
        species_data = df_otu_aligned.loc[species_id]
        for group_label in unique_groups:
            group_samples_mask = (group_assignments == group_label)
            group_sample_ids = group_assignments[group_samples_mask].index.tolist() # Get actual sample IDs for the group
            
            if not group_sample_ids: # If a group has no samples after alignment
                observed_indvals.loc[species_id, group_label] = 0.0
                continue

            species_abund_in_grp = species_data[group_sample_ids]
            
            _, _, indval = calculate_indicator_value_stats(
                species_abund_in_grp,
                species_data, # Pass all samples for this species
                group_sample_ids, # Pass actual sample IDs for the group
                group_assignments # Pass all group assignments
            )
            observed_indvals.loc[species_id, group_label] = indval
            
    observed_max_indvals = observed_indvals.max(axis=1)
    observed_associated_group = observed_indvals.idxmax(axis=1)

    # --- Permutation Test ---
    print(f"\nStarting permutation test ({num_permutations} permutations)...")
    # Stores how many times permuted max IndVal was >= observed max IndVal for each species
    perm_greater_counts = pd.Series(0, index=df_otu_aligned.index, dtype=int)

    original_group_assignments_array = group_assignments.values # For shuffling
    
    for i_perm in range(num_permutations):
        if (i_perm + 1) % 100 == 0:
            print(f"  Permutation {i_perm + 1}/{num_permutations}")
        
        permuted_group_assignments_array = np.random.permutation(original_group_assignments_array)
        permuted_group_assignments = pd.Series(permuted_group_assignments_array, index=group_assignments.index)
        
        permuted_indvals_for_iter = pd.DataFrame(index=df_otu_aligned.index, columns=unique_groups, dtype=float)

        for species_id in df_otu_aligned.index:
            species_data = df_otu_aligned.loc[species_id]
            for group_label in unique_groups:
                group_samples_mask_perm = (permuted_group_assignments == group_label)
                group_sample_ids_perm = permuted_group_assignments[group_samples_mask_perm].index.tolist()

                if not group_sample_ids_perm:
                    permuted_indvals_for_iter.loc[species_id, group_label] = 0.0
                    continue
                
                species_abund_in_grp_perm = species_data[group_sample_ids_perm]
                
                _, _, indval_perm = calculate_indicator_value_stats(
                    species_abund_in_grp_perm,
                    species_data,
                    group_sample_ids_perm,
                    permuted_group_assignments # Use permuted assignments here
                )
                permuted_indvals_for_iter.loc[species_id, group_label] = indval_perm
        
        permuted_max_indvals_for_iter = permuted_indvals_for_iter.max(axis=1)
        perm_greater_counts += (permuted_max_indvals_for_iter >= observed_max_indvals).astype(int)

    # Calculate p-values
    p_values = (perm_greater_counts + 1) / (num_permutations + 1)

    # --- Prepare Results ---
    results_df = pd.DataFrame({
        'Taxon_ID': observed_max_indvals.index,
        'Associated_Group': observed_associated_group,
        'Max_IndVal': observed_max_indvals,
        'Permutation_P_Value': p_values
    })
    
    # Add original taxonomy information
    results_df = results_df.join(taxonomy_map, on='Taxon_ID')
    
    # Reorder columns for better readability
    cols_order = ['Taxon_ID'] + taxonomy_map.columns.tolist() + ['Associated_Group', 'Max_IndVal', 'Permutation_P_Value']
    results_df = results_df[cols_order]
    
    results_df = results_df.sort_values(by=['Permutation_P_Value', 'Max_IndVal'], ascending=[True, False])

    print(f"\nSaving results to: {output_path}")
    try:
        results_df.to_csv(output_path, index=False)
        print("Indicator Species Analysis complete.")
    except Exception as e:
        print(f"Error saving results: {e}")
    
    return results_df


if __name__ == "__main__":
    # --- Configuration ---
    otu_table_file = "otu_table_uclust_with_updated_taxonomy_05172025_78bats_singletsRemoved_78bats_filtered_domains_combined_by_genus_keep_blanks_species_column_removed.csv"
    metadata_file = "otu_metadata_uclust_updated_taxonomy_05162025.csv" # Make sure this is the correct metadata file
    
    # Column names in your metadata file
    meta_sample_id_col = "SampleID" # Or "SampleID", "sample_id", etc.
    meta_grouping_col = "DiseaseStatus"
    
    # Number of initial columns in OTU table that are taxonomy (after the first OTU_ID column)
    # Domain, Super Kingdom, Kingdom, Phylum, Class, Order, Family, Genus = 8 columns
    num_taxonomy_cols = 8 
    
    num_perms = 4999 # As requested
    
    output_results_file = "indicator_species_analysis_results.csv"
    
    # --- Run Analysis ---
    if not os.path.exists(otu_table_file):
        print(f"CRITICAL ERROR: OTU table file '{otu_table_file}' not found.")
    elif not os.path.exists(metadata_file):
        print(f"CRITICAL ERROR: Metadata file '{metadata_file}' not found.")
    else:
        run_indicator_species_analysis(
            otu_table_path=otu_table_file,
            metadata_path=metadata_file,
            output_path=output_results_file,
            metadata_sample_id_col=meta_sample_id_col,
            grouping_variable_col=meta_grouping_col,
            num_permutations=num_perms,
            taxonomy_cols_count=num_taxonomy_cols
        )


## Creating Dot Plot from ISA output.
- Including at least one significant taxa from each DiseaseStatus Group
- Remaining plotted taxa need at least a ISA of 0.2 and be significant

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np
from matplotlib.lines import Line2D 

def get_most_specific_taxon(row, tax_levels, uninformative_strings, fallback_id_str):
    """
    Finds the most specific non-blank taxonomic assignment from a list of levels.
    Returns the name and the level string.
    """
    for level in tax_levels: 
        value = row[level]
        if pd.notna(value):
            value_str = str(value).strip()
            if value_str.lower() not in uninformative_strings and value_str != "":
                return value_str, level 
    return fallback_id_str, "ID" 

def plot_significant_indicator_species(
    results_csv_path="indicator_species_analysis_results.csv",
    output_plot_path="significant_indicator_species_dot_plot.png",
    p_value_threshold=0.05,
    indval_threshold=0.2,
    no_disease_group_name="No Disease Present" 
):
    """
    Generates a dot plot of significant indicator species.
    Special filtering for 'no_disease_group_name'.
    Custom p-value size legend with example dots, positioned correctly.
    """
    try:
        df_results = pd.read_csv(results_csv_path)
        print(f"Successfully loaded indicator species results: '{results_csv_path}' (Shape: {df_results.shape})")
    except FileNotFoundError:
        print(f"Error: Results file '{results_csv_path}' not found.")
        return
    except Exception as e:
        print(f"Error loading results CSV: {e}")
        return

    if df_results.empty:
        print("Results DataFrame is empty. Cannot generate plot.")
        return

    required_cols = ['Taxon_ID', 'Domain', 'Genus', 'Family', 'Order', 'Class', 'Phylum', 
                     'Kingdom', 'Super Kingdom', 'Associated_Group', 'Max_IndVal', 'Permutation_P_Value']
    missing_cols = [col for col in required_cols if col not in df_results.columns]
    if missing_cols:
        print(f"Error: The following required columns are missing from '{results_csv_path}': {', '.join(missing_cols)}")
        return

    df_significant_p = df_results[df_results['Permutation_P_Value'] < p_value_threshold].copy()

    if df_significant_p.empty:
        print(f"No significant indicators found with P-value < {p_value_threshold}. No plot will be generated.")
        return
    
    df_no_disease_hits = df_significant_p[df_significant_p['Associated_Group'] == no_disease_group_name]
    df_other_groups_hits = df_significant_p[
        (df_significant_p['Associated_Group'] != no_disease_group_name) &
        (df_significant_p['Max_IndVal'] > indval_threshold)
    ]
    
    df_significant = pd.concat([df_no_disease_hits, df_other_groups_hits])
    
    if df_significant.empty:
        print(f"No significant indicators found meeting the specified P-value and IndVal criteria (with exception for '{no_disease_group_name}'). No plot will be generated.")
        return
    
    num_no_disease_included_conditionally = 0
    if no_disease_group_name in df_significant_p['Associated_Group'].unique():
        num_no_disease_included_conditionally = len(
            df_significant_p[
                (df_significant_p['Associated_Group'] == no_disease_group_name) &
                (df_significant_p['Max_IndVal'] <= indval_threshold) 
            ]
        )
    print(f"Found {len(df_significant)} significant indicators to plot.")
    if num_no_disease_included_conditionally > 0:
         print(f"  (Including {num_no_disease_included_conditionally} significant hits from '{no_disease_group_name}' with IndVal <= {indval_threshold})")

    epsilon = 1e-300 
    df_significant['P_Value_Size_Metric'] = -np.log10(df_significant['Permutation_P_Value'] + epsilon)

    taxonomic_hierarchy = ['Genus', 'Family', 'Order', 'Class', 'Phylum', 'Kingdom', 'Super Kingdom'] ## because all like genera were combined earlier species assignment doesn't mean anything now.
    uninformative_tax_strings = [
        '', 'nan', 'none', 'na', '<na>', 'unassigned', 'unclassified', 
        'unknown', 'no blast hit', 'incertae sedis', 'metazoa', 
        'no classification', 'unknown_domain' 
    ]
    uninformative_tax_strings = [s.lower() for s in uninformative_tax_strings]

    display_labels = []
    for index, row in df_significant.iterrows():
        domain_val_raw = row['Domain']
        domain_val_str = str(domain_val_raw).strip()
        if pd.isna(domain_val_raw) or domain_val_str.lower() in uninformative_tax_strings or domain_val_str == "":
            domain_prefix = "Unknown_Domain"
        else:
            domain_prefix = domain_val_str
        
        most_specific_name, specific_level_name = get_most_specific_taxon(row, taxonomic_hierarchy, uninformative_tax_strings, str(row['Taxon_ID']))
        
        if specific_level_name == "ID":
            if domain_prefix == "Unknown_Domain":
                final_label = f"{most_specific_name} (ID)" 
            else:
                final_label = f"{domain_prefix}: {most_specific_name} (ID)"
        else:
            final_label = f"{domain_prefix}: {most_specific_name} ({specific_level_name})"
        display_labels.append(final_label)
        
    df_significant['Display_Label'] = display_labels

    df_significant_sorted = df_significant.sort_values(
        by=['Associated_Group', 'Max_IndVal'],
        ascending=[True, True] 
    ).reset_index(drop=True) 

    if df_significant_sorted.empty:
        print("No data to plot after processing labels and sorting.")
        return

    plt.style.use('seaborn-v0_8-pastel') 
    
    num_indicators = len(df_significant_sorted)
    fig_height = max(8, num_indicators * 0.4) 
    fig_width = 12 

    fig, ax = plt.subplots(figsize=(fig_width, fig_height)) 

    min_dot_size_plot, max_dot_size_plot = 50, 450
    
    dot_plot = sns.scatterplot(
        y=df_significant_sorted.index, 
        x='Max_IndVal',
        hue='Associated_Group',
        size='P_Value_Size_Metric',  
        sizes=(min_dot_size_plot, max_dot_size_plot),  
        data=df_significant_sorted,
        palette="pastel", 
        ax=ax, 
        legend=False 
    )

    for i in range(len(df_significant_sorted)):
        hue_categories = df_significant_sorted['Associated_Group'].astype('category').cat.categories
        palette_colors = sns.color_palette("pastel", n_colors=len(hue_categories)) 
        color_map = dict(zip(hue_categories, palette_colors))
        point_color = color_map[df_significant_sorted['Associated_Group'].iloc[i]]

        plt.hlines(y=i, xmin=0, xmax=df_significant_sorted['Max_IndVal'].iloc[i], 
                   color=point_color, 
                   alpha=0.7, 
                   linewidth=1.5)

    plt.yticks(ticks=df_significant_sorted.index, labels=df_significant_sorted['Display_Label'], fontsize=9)
    plt.xlabel('Max Indicator Value (Max_IndVal)', fontsize=12)
    plt.ylabel('Significant Taxa (Domain: Most Specific Level (Rank))', fontsize=12) 
    plt.title(f'Significant Indicator Taxa (P < {p_value_threshold}) by {df_significant_sorted["Associated_Group"].name}', fontsize=14)
    
    plt.xlim(0, 1.0) 
    
    # --- Custom Legend Handling ---
    # 1. Hue Legend (Associated Group)
    hue_handles, hue_labels_list = [], [] # Renamed to avoid conflict
    hue_order = df_significant_sorted['Associated_Group'].astype('category').cat.categories.tolist()
    palette_for_legend = sns.color_palette("pastel", n_colors=len(hue_order))
    hue_color_map_legend = dict(zip(hue_order, palette_for_legend))

    for group_name in hue_order:
        hue_handles.append(Line2D([0], [0], marker='o', color='w', 
                              markerfacecolor=hue_color_map_legend[group_name], markersize=8, label=group_name))
        hue_labels_list.append(group_name)
    
    # Place the hue legend using ax.legend, its bbox_to_anchor is relative to the axes.
    hue_legend = ax.legend(hue_handles, hue_labels_list, title='Associated Group', 
                           bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.,
                           frameon=True, facecolor='white', edgecolor='grey') 
    # ax.add_artist(hue_legend) # ax.legend() already adds it.

    # Force a draw of the canvas to ensure legend positions are finalized
    fig.canvas.draw_idle()

    # 2. P-value Size Legend (Custom with dots)
    # Get the bounding box of the hue_legend IN FIGURE COORDINATES
    hue_legend_bbox_fig = hue_legend.get_window_extent().transformed(fig.transFigure.inverted())
    
    min_p_val_plotted = df_significant_sorted['Permutation_P_Value'].min()
    max_p_val_plotted = df_significant_sorted['Permutation_P_Value'].max()
    
    min_p_size_metric = -np.log10(min_p_val_plotted + epsilon)
    max_p_size_metric = -np.log10(max_p_val_plotted + epsilon)
    
    overall_min_metric = df_significant_sorted['P_Value_Size_Metric'].min()
    overall_max_metric = df_significant_sorted['P_Value_Size_Metric'].max()

    def scale_p_metric_to_dot_size(p_metric, overall_min, overall_max, size_min, size_max):
        if overall_max == overall_min: 
            return (size_min + size_max) / 2
        p_metric_clipped = np.clip(p_metric, overall_min, overall_max)
        return size_min + ((p_metric_clipped - overall_min) / (overall_max - overall_min)) * (size_max - size_min)

    size_for_min_p = scale_p_metric_to_dot_size(min_p_size_metric, overall_min_metric, overall_max_metric, min_dot_size_plot, max_dot_size_plot)
    size_for_max_p = scale_p_metric_to_dot_size(max_p_size_metric, overall_min_metric, overall_max_metric, min_dot_size_plot, max_dot_size_plot)

    legend_p_handles = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=np.sqrt(size_for_min_p), label=f"p = {min_p_val_plotted:.2e}"),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=np.sqrt(size_for_max_p), label=f"p = {max_p_val_plotted:.2e}")
    ]
    
    # Position the p-value legend using fig.legend, its bbox_to_anchor is relative to the FIGURE.
    # x0 of hue_legend_bbox_fig is the left edge of the hue legend in figure coordinates.
    # y0 of hue_legend_bbox_fig is the bottom edge of the hue legend in figure coordinates.
    # We want the p_value_legend's 'upper left' corner to be at (hue_legend_left, below_hue_legend_bottom)
    p_value_legend_anchor_x = hue_legend_bbox_fig.x0 - 0.25
    p_value_legend_anchor_y = hue_legend_bbox_fig.y0 # Adjust this 0.02 for vertical spacing

    p_value_legend = fig.legend(handles=legend_p_handles, 
                                title="P-value (Dot Size)",
                                loc='upper left', # Anchor point of the p_value_legend box
                                bbox_to_anchor=(p_value_legend_anchor_x, p_value_legend_anchor_y), 
                                borderaxespad=0.,
                                frameon=True, facecolor='white', edgecolor='grey'
                               )

    # Adjust overall layout to make space for legends on the right
    plt.tight_layout(rect=[0, 0, 0.82, 1])

    try:
        plt.savefig(output_plot_path, bbox_inches='tight', dpi=600)
        print(f"\nDot plot of significant indicators saved to '{output_plot_path}'")
        plt.show()
    except Exception as e:
        print(f"Error saving plot: {e}")

if __name__ == "__main__":
    indicator_results_file = "indicator_species_analysis_results.csv" 
    plot_output_file = "significant_indicators_dot_plot_refined_legend_placement.png" 
    
    min_indval_to_plot = 0.2
    no_disease_group = "No Disease Present" 
    
    if not os.path.exists(indicator_results_file):
        print(f"CRITICAL ERROR: The indicator results file '{indicator_results_file}' was not found.")
    else:
        plot_significant_indicator_species(
            results_csv_path=indicator_results_file,
            output_plot_path=plot_output_file,
            p_value_threshold=0.05,
            indval_threshold=min_indval_to_plot,
            no_disease_group_name=no_disease_group
        )
        print("\nDot plot generation script finished.")
