# Data Preparation

This notebook explores and summarizes environmental and pathogen data for high and medium priority fungal pathogens in Northeastern Thailand. 

- Loads and processes abundance and metadata for both training and test datasets.
- Focuses on high and medium priority pathogens, excluding critical priority due to lack of samples.
- Identifies and saves sample IDs associated with each pathogen group. This data is used for training and test the decision tree models
- Merges pathogen presence/absence with environmental properties for each sample.
- Summarizes the number of samples per class (present/absent) for each pathogen group.
- Visualizes data distributions and relationships using pairplots and boxplots for environmental features, grouped by pathogen presence.

**Note**: 
- Results from `match_taxonomy_to_key_names.ipynb` are required to run this notebook
- In order to run `analysis/model/automated_decision_tree.ipynb`, you need to run this notebook for both `data_set='training'` and `data_set='test'`

In [None]:
import os
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

## Data
- This analysis focuses on high and medium priority fungal pathogens.
- Critical priority pathogens are excluded because the training and test datasets do not contain samples with critical priority pathogens.

In [None]:
# get path for folders
project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))

# folders
data_folder = os.path.join(project_root, "data")

results_folder = os.path.join(project_root, "results")
os.makedirs(results_folder, exist_ok=True)

group_names_folder = os.path.join(results_folder, "group_names")
os.makedirs(group_names_folder, exist_ok=True)

cross_plots_folder = os.path.join(results_folder, "cross_plots")
os.makedirs(cross_plots_folder, exist_ok=True)

box_plots_folder = os.path.join(results_folder, "box_plots")
os.makedirs(box_plots_folder, exist_ok=True)

In [None]:
# data set to use
data_set = "test"  # options: "training" or "test"

In [None]:
# Read data
if data_set == "training":
    data_df = pd.read_csv(os.path.join(data_folder,"DroughtITS.final.txt"), sep = "\t")
elif data_set == "test":
    data_df = pd.read_csv(os.path.join(data_folder,"SakhonNakhonApril2025.final.txt"), sep = "\t")
else:
    raise ValueError("Invalid data_set value. Choose 'training' or 'test'.")   

# rename the column
data_df = data_df.rename(columns={'#OTU ID': 'SampleID'})

# transpose the dataframe 
data_df = data_df.transpose().reset_index()

# rename the column names 
data_df.columns = data_df.iloc[0]

# remove the first row (column names)
data_df = data_df[1:]

### High Priority

In [None]:
high_priority_samples_df = pd.read_csv(f"{data_folder}/high_priority_sample_names_{data_set}_data.csv")

In [None]:
high_priority_names_df = high_priority_samples_df.groupby('classification_group').agg({'Name': set}).reset_index()
high_priority_names_df['Name'] = high_priority_names_df['Name'].apply(lambda x: list(x))

In [None]:
high_priority_group_names = high_priority_names_df['classification_group'].tolist()

In [None]:
for group_name in high_priority_group_names:
    names = high_priority_names_df[high_priority_names_df.classification_group == group_name].iloc[0].Name
    condition = (data_df[names] > 0).any(axis=1)
    filtered_df = data_df[condition]
    filtered_df.to_csv(f"{data_folder}/high_priority/{group_name}_sampleIDs_{data_set}_data.csv", index=False)

In [None]:
# save name of the high priority samples
with open(f"{results_folder}/group_names/high_priority_pathogen_group_names_{data_set}_data.txt", "w") as file:
    for group_name in high_priority_group_names:
        file.write(f"{group_name}\n")

### Medium priority

In [None]:
medium_priority_samples_df = pd.read_csv(f"{data_folder}/medium_priority_sample_names_{data_set}_data.csv")

In [None]:
medium_priority_names_df = medium_priority_samples_df.groupby('classification_group').agg({'Name': set}).reset_index()
medium_priority_names_df['Name'] = medium_priority_names_df['Name'].apply(lambda x: list(x))

In [None]:
medium_priority_group_names = medium_priority_names_df['classification_group'].tolist()

In [None]:
for group_name in medium_priority_group_names:
    names = medium_priority_names_df[medium_priority_names_df.classification_group == group_name].iloc[0].Name
    condition = (data_df[names] > 0).any(axis=1)
    filtered_df = data_df[condition]
    filtered_df.to_csv(f"{data_folder}/medium_priority/{group_name}_sampleIDs_{data_set}_data.csv", index=False)

In [None]:
# save name of the medium priority samples
with open(f"{results_folder}/group_names/medium_priority_pathogen_group_names_{data_set}_data.txt", "w") as file:
    for group_name in medium_priority_group_names:
        file.write(f"{group_name}\n")

## Environmental properties data

In [None]:
if data_set == 'training':
    properties_df = pd.read_csv(os.path.join(data_folder,"DroughtITS.mapping_file.fix.txt"), sep = "\t")

    # remove unnecessary columns
    drop_columns = ['BarcodeSequence', 'LinkerPrimerSequence',
       'RevBarcodeSequence', 'ReversePrimer', 'phinchID', 'DemuxReads',
       'Treatment']
    
    properties_df = properties_df.drop(columns=drop_columns)

elif data_set == 'test':
    properties_df = pd.read_csv(os.path.join(data_folder,"SakhonNakhonApril2025.mapping_file.txt"), sep = "\t")

    # remove _n and _p from 'plant' column
    properties_df['plant'] = properties_df['plant'].str.replace('_n', '')
    properties_df['plant'] = properties_df['plant'].str.replace('_p', '')
    properties_df['plant'] = properties_df['plant'].str.replace('_', ' ')

else:
    raise ValueError("Invalid data_set value. Choose 'training' or 'test'.")

In [None]:
all_group_names = high_priority_group_names + medium_priority_group_names

In [None]:
all_possible_group_names = ['Acremonium spp.', 'Candida tropicalis',
       'Curvularia lunata', 'Falciformispora senegalensis', 'Fusarium spp.',
       'Lichtheimia spp.', 'Mucor spp.', 'Rhizopus spp.', 'Scedosporium spp.',
       'Talaromyces marneffei']

for group_name in all_possible_group_names:
    if group_name in all_group_names:
        if group_name in high_priority_group_names:
            df = pd.read_csv(f"{data_folder}/high_priority/{group_name}_sampleIDs_{data_set}_data.csv")
        else:
            df = pd.read_csv(f"{data_folder}/medium_priority/{group_name}_sampleIDs_{data_set}_data.csv")
        if data_set == 'training':
            SampleIDs = df['SampleID'].to_list()
        elif data_set == 'test':
            SampleIDs = df['OTU ID'].to_list()
        SampleIDs = [x.replace("-","_") for x in SampleIDs]

        properties_df[group_name] = properties_df['SampleID'].apply(lambda x: 1 if x in SampleIDs else 0)

        # Sanity check
        if properties_df[properties_df[group_name] == 1].shape[0] != len(SampleIDs):
            print(f"Error: {group_name} sampleIDs do not match")
    else:
        properties_df[group_name] = 0

In [None]:
# save name of samples
with open(f"{results_folder}/group_names/pathogen_group_names_{data_set}_data.txt", "w") as file:
    for group_name in all_group_names:
        file.write(f"{group_name}\n")

In [None]:
# save environmental properties data with labels
properties_df.to_csv(f"{data_folder}/DroughtITS_mapping_w_labels_{data_set}_data.csv", index=False)

## Summary for number of samples
- number of smaples for each class for each pathogen

In [None]:
group_name_sizes = []
for group_name in all_group_names:
    n_class0 = properties_df[properties_df[group_name] == 0].shape[0]
    n_class1 = properties_df[properties_df[group_name] == 1].shape[0]
    group_name_sizes.append({
        "name": group_name,
        "class_0": n_class0,
        "class_1": n_class1
    })

In [None]:
group_name_sizes_df = pd.DataFrame(group_name_sizes)

In [None]:
group_name_sizes_df.to_csv(f"{results_folder}/n_class_samples_per_pathogen_{data_set}_data.csv", index=False)

# Data Correlation

## Crossplot

In [None]:
numerical_features = ['lat', 'lon', 'drought',
                      'water_content', 'organic_matter', 
                      'nitrogen', 'phosphorus', 'potassium',
                      'temp_soil', 'pH']

In [None]:
for group_name in all_group_names:
    # Count the number of samples for class 0 and 1
    class_counts = properties_df[group_name].value_counts()
    class_0_count = class_counts.get(0, 0)  # Default to 0 if class 0 is not present
    class_1_count = class_counts.get(1, 0)  # Default to 0 if class 1 is not present

    # Create the pairplot
    pairplot = sns.pairplot(properties_df[numerical_features + [group_name]], hue=group_name, palette='husl')

    # Add a title with group_name and sample counts
    pairplot.fig.suptitle(f"{group_name} (Class 0: {class_0_count}, Class 1: {class_1_count})", 
                          y=1.02)  # Adjust y to position the title above the plot

    # Save the plot to a file
    pairplot.savefig(f"{cross_plots_folder}/{group_name}_pairplot_{data_set}_data.png")

    # Close the plot to free memory
    plt.close(pairplot.fig)

## Boxplots

In [None]:
for group_name in all_group_names:
    # Separate features and label
    label_column = group_name  # Replace with the actual name of your label column
    # features = numerical_features
    features =  ['drought',
                        'water_content', 'organic_matter', 
                        'nitrogen', 'phosphorus', 'potassium',
                        'temp_soil', 'pH']
    y_labels = ['Drought Level', 'Water Content', 'Organic Matter', 'Nitrogen', 'Phosphorus', 'Potassium', 'Soil Temperature', 'Soil pH']

    # Create subplots
    fig, axes = plt.subplots(4, 2, figsize=(15, 16))  # 4 rows, 2 columns
    axes = axes.flatten()  # Flatten to 1D array for easy indexing

    # Define colors for the classes
    palette = {'0': 'blue', '1': 'green'}

    # Loop through each feature and create a subplot
    for i, feature in enumerate(features):
        sns.boxplot(data=properties_df, x=label_column, y=feature, ax=axes[i], hue = label_column, legend = False)
        axes[i].set_title(f'Box Plot of {y_labels[i]}')
        axes[i].set_xlabel('Class')
        # axes[i].set_ylabel(feature)
        axes[i].set_ylabel(y_labels[i])

    class_counts = properties_df[group_name].value_counts()
    class_0_count = class_counts.get(0, 0)  # Default to 0 if class 0 is not present
    class_1_count = class_counts.get(1, 0)  # Default to 0 if class 1 is not present

    # Add a title for the entire figure
    fig.suptitle(f'Box Plots for {group_name} (Class 0: {class_0_count}, Class 1: {class_1_count})', fontsize=16)

    # Adjust layout and leave space for the title
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust rect to leave space for the suptitle

    # Save the figure to a file
    output_path = f"{box_plots_folder}/{group_name}_boxplot_{data_set}_data.png"
    plt.savefig(output_path)
    plt.close(fig)  # Close the figure to free memory

    print(f"Saved box plot for {group_name} to {output_path}")
