In [None]:
import pandas as pd
import numpy as np
import os
import openpyxl
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

## Data

In [None]:
# get path for folders
project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
data_folder = os.path.join(project_root, "data")
results_folder = os.path.join(project_root, "results")

summary_stats_folder = os.path.join(results_folder, "summary_statistics")
os.makedirs(summary_stats_folder, exist_ok=True)

In [None]:
df = pd.read_csv(f"{data_folder}/DroughtITS_mapping_w_labels_training_data.csv")

## Mean and SD
- calculates and summarizes the mean and standard deviation of selected environmental features for each pathogen, separately for samples where the pathogen is absent (Class 0) and present (Class 1)

In [None]:
pathogens = ['Acremonium spp.', 'Candida tropicalis',
       'Curvularia lunata', 'Falciformispora senegalensis',
       'Lichtheimia spp.', 'Mucor spp.', 'Rhizopus spp.', 'Scedosporium spp.',
       'Talaromyces marneffei']
env_features = ['drought', 'water_content', 'organic_matter',
       'nitrogen', 'phosphorus', 'potassium', 'temp_soil', 'pH']

summary_rows = []

for pathogen in pathogens:
    for label in [0, 1]:
        group = df[df[pathogen] == label]
        row = {
            'Pathogen': pathogen,
            'Class': label
        }
        for feat in env_features:
            mean = group[feat].mean()
            std = group[feat].std()
            row[f"{feat} Mean ± SD"] = f"{mean:.2f} ± {std:.2f}"
        summary_rows.append(row)

# Create the summary DataFrame
summary_df = pd.DataFrame(summary_rows)

# Optional: rearrange columns for readability
column_order = ['Pathogen', 'Class'] + [f"{feat} Mean ± SD" for feat in env_features]
summary_df = summary_df[column_order]

In [None]:
summary_df.to_excel(f"{summary_stats_folder}/training_data_summary_statistics.xlsx", index=False)

### Box Plots

In [None]:
pathogens = ['Acremonium spp.', 'Candida tropicalis',
       'Curvularia lunata', 'Falciformispora senegalensis',
       'Lichtheimia spp.', 'Mucor spp.', 'Rhizopus spp.', 'Scedosporium spp.',
       'Talaromyces marneffei']
env_features = ['drought', 'water_content', 'organic_matter',
       'nitrogen', 'phosphorus', 'potassium', 'temp_soil', 'pH']

# Convert from wide to long format
plot_df = []

for pathogen in pathogens:
    for feature in env_features:
        temp = df[[feature, pathogen]].copy()
        temp.columns = ['Value', 'Class']
        temp['Pathogen'] = pathogen
        temp['Feature'] = feature
        plot_df.append(temp)

long_df = pd.concat(plot_df, ignore_index=True)

# Plot: Grid of boxplots (one feature per facet)
g = sns.catplot(
    data=long_df,
    x="Class", y="Value",
    col="Feature", row="Pathogen",
    kind="box", 
    height=2.2, aspect=1.2,
    sharey=False, sharex=True,
    margin_titles=True,
    palette="Set2"
)

g.set_titles(col_template="{col_name}", row_template="{row_name}")
g.set_axis_labels("Presence (0 = Absent, 1 = Present)", "Value")
g.tight_layout()
plt.subplots_adjust(top=0.95)
g.fig.suptitle("Boxplot of Environmental Features by Pathogen Presence/Absence", fontsize=14)

plt.show()

## Standardized mean differences

Helpers

In [None]:
# Function to compute Cohen's d
def cohen_d(x0, x1):
    nx0, nx1 = len(x0), len(x1)
    mean_diff = np.mean(x1) - np.mean(x0)
    pooled_std = np.sqrt(
        ((nx0 - 1)*np.var(x0, ddof=1) + (nx1 - 1)*np.var(x1, ddof=1)) / (nx0 + nx1 - 2)
    )
    return mean_diff / pooled_std

# Function to calculate Cohen's d across pathogens and features
def calculate_cohens_d(df, pathogens, features):
    results = []

    for pathogen in pathogens:
        class0 = df[df[pathogen] == 0]
        class1 = df[df[pathogen] == 1]

        for feature in features:
            d = cohen_d(class0[feature].dropna(), class1[feature].dropna())
            results.append({
                'Pathogen': pathogen,
                'Feature': feature,
                'Cohens_d': d
            })

    return pd.DataFrame(results)

# Custom formatting function for pathogen labels
def format_pathogen_label(name):
    if 'spp.' in name:
        genus = name.replace(' spp.', '')
        return f"$\\it{{{genus}}}$ spp."
    else:
        parts = name.split()
        if len(parts) == 2:
            genus, species = parts
            return f"$\\it{{{genus}}}$ $\\it{{{species}}}$"
        else:
            return f"$\\it{{{name}}}$"

In [None]:
pathogens = ['Acremonium spp.', 'Candida tropicalis',
       'Curvularia lunata', 'Falciformispora senegalensis',
       'Lichtheimia spp.', 'Mucor spp.', 'Rhizopus spp.', 'Scedosporium spp.',
       'Talaromyces marneffei']
features = ['drought', 'water_content', 'organic_matter',
       'nitrogen', 'phosphorus', 'potassium', 'temp_soil', 'pH']

# One-hot encoding plant type
plant_dummies = pd.get_dummies(df['plant'], prefix='plant')

# Combine original dataframe with plant dummies
df_extended = pd.concat([df, plant_dummies], axis=1)

# Updated features list including plant types
extended_features = ['drought', 'water_content', 'organic_matter',
                     'nitrogen', 'phosphorus', 'potassium', 'temp_soil', 'pH'] \
                    + plant_dummies.columns.tolist()

# Calculate Cohen's d including categorical variables
cohens_d_df = calculate_cohens_d(df_extended, pathogens, extended_features)

# Pivot to create heatmap dataframe
heatmap_df = cohens_d_df.pivot(index='Pathogen', columns='Feature', values='Cohens_d')

# Explicit feature ordering
ordered_features = ['drought', 'water_content', 'organic_matter',
                    'nitrogen', 'phosphorus', 'potassium', 'temp_soil', 'pH'] \
                   + plant_dummies.columns.tolist()

heatmap_df = heatmap_df[ordered_features]

# Apply formatting to pathogen labels
formatted_pathogens = [format_pathogen_label(name) for name in heatmap_df.index]

# Custom formatting function for feature labels
    
def format_feature_label(name):
    special_cases = {
        'pH': 'Soil pH',
        'temp_soil': 'Soil Temp',
        'drought': 'Drought Level',
        'water_content': 'Soil Water Content',
    }
    if name in special_cases:
        return special_cases[name]
    elif 'plant_' in name:
        return name.replace('plant_', '').title()
    else:
        return name.replace('_', ' ').title()


# Apply formatting to feature labels
formatted_features = [format_feature_label(name) for name in heatmap_df.columns]

# Plot heatmap with customized labels
plt.figure(figsize=(12, 8))
ax = sns.heatmap(heatmap_df, 
                 annot=True, 
                 fmt=".2f",
                 cmap= 'RdYlBu', #'viridis', #'coolwarm', 
                 center=0,
                 xticklabels=formatted_features, 
                 yticklabels=formatted_pathogens,
                 linecolor = 'black', 
                 linewidths=0.5,
                 annot_kws={"size": 12}  # Default size, not bold
)

# Make annotation bold only when abs(value) > 1
for text in ax.texts:
    try:
        value = float(text.get_text())
        if abs(value) > 1:
            text.set_fontweight('bold')
    except ValueError:
        pass  # In case annotation is not a number


ax.add_patch(
    Rectangle(
        (0, 0),  # (x, y) lower left corner
        heatmap_df.shape[1],  # width
        heatmap_df.shape[0],  # height
        fill=False,
        edgecolor='black',
        linewidth=1,
        clip_on=False
    )
)


# Add black border around the colorbar
cbar = ax.collections[0].colorbar
cbar.outline.set_edgecolor('black')
cbar.outline.set_linewidth(1)

plt.title("Standardized Mean Differences", pad=20, fontsize=20)
plt.ylabel('Pathogen', fontsize=16)
plt.xlabel('Feature', fontsize=16)
plt.xticks(fontsize=14)  # Make x-axis tick labels larger
plt.yticks(fontsize=14)  # Make y-axis tick labels larger
plt.tight_layout()
cbar.ax.tick_params(labelsize=12)  # Make tick labels on colorbar bigger

# Save figure at 1000 dpi, full page width (7480 pixels)
plt.savefig(
    f"{results_folder}/summary_statistics/standard_mean_diff_w_plants_1000dpi.png",
    dpi=1000,
    bbox_inches='tight',
    format='png'
)
plt.show()
