In [1]:
!pip install git+https://github.com/scikit-bio/scikit-bio.git

Collecting git+https://github.com/scikit-bio/scikit-bio.git
  Cloning https://github.com/scikit-bio/scikit-bio.git to /tmp/pip-req-build-70ls5h38
  Running command git clone --filter=blob:none --quiet https://github.com/scikit-bio/scikit-bio.git /tmp/pip-req-build-70ls5h38
  Resolved https://github.com/scikit-bio/scikit-bio.git to commit 32b17ade580b0a7583ca3d66ef247f4f497f6dbb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting biom-format>=2.1.16 (from scikit-bio==0.6.3.dev0)
  Downloading biom-format-2.1.16.tar.gz (11.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.7/11.7 MB[0m [31m48.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages:

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# @title Data Reading
import pandas as pd
import os
from pandas import ExcelFile
import matplotlib.pyplot as plt
import seaborn as sns
from random import sample
import numpy as np
from statsmodels.tsa.stattools import adfuller
from statsmodels.stats.multitest import multipletests
import warnings
from scipy.spatial.distance import pdist, squareform
from scipy.stats import entropy
from scipy.signal import savgol_filter
from scipy.spatial.distance import euclidean
from scipy.spatial.distance import euclidean
from sklearn.preprocessing import StandardScaler
from skbio.stats.composition import clr

from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.patches as mpatches
from scipy.stats import gmean

from statsmodels.stats.outliers_influence import variance_inflation_factor

gene_count_path ='/content/drive/MyDrive/MITResearch/IBD_13/gene_count_preprocessed.csv'
metaphlan_path = '/content/drive/MyDrive/MITResearch/IBD_13/metaphlan_preprocessed.csv'

gene_count_df = pd.read_csv(gene_count_path)
metaphlan_df = pd.read_csv(metaphlan_path)

print(gene_count_df.shape)
print(metaphlan_df.shape)

(605, 547)
(479, 1454)


In [4]:
pd.set_option('display.max_rows', None)  # or set a specific large number if 'None' is too extensive
pd.set_option('display.max_columns', None)  # Ensure all columns are shown
pd.set_option('display.width', None)  # Use the maximum width necessary to display the DataFrame without wrapping

In [5]:
# @title count columns with non zero
def count_columns_with_nonzero(df, excluded_columns):
    # Exclude the specified columns
    df_filtered = df.drop(columns=excluded_columns)

    # Count non-zero values in each column
    non_zero_counts = df_filtered.astype(bool).sum(axis=0)

    # Get the number of samples
    num_samples = len(df_filtered)

    # Find columns where the count of non-zero values is greater than 10% of the number of samples
    more_than_10_percent = non_zero_counts > (0.2 * num_samples)

    # Count these columns
    count = more_than_10_percent.sum()

    return count

excluded_columns = ['patient_id', 'week', 'Flare_status']
result = count_columns_with_nonzero(gene_count_df, excluded_columns)
print(f'Number of columns with more than 20% non-zero values in gene_count_df: {result}')

result = count_columns_with_nonzero(metaphlan_df, excluded_columns)
print(f'Number of columns with more than 20% non-zero values in metaphlan_df: {result}')

Number of columns with more than 20% non-zero values in gene_count_df: 369
Number of columns with more than 20% non-zero values in metaphlan_df: 192


In [6]:
# @title flag_first_flare_weeks function

def flag_first_flare_weeks(df):
    # Identify the rows where flare starts
    df['is_flare'] = (df['Flare_status'] == 'During_flare') | (df['Flare_status'] == 'During_flare_2')

    # Sort by patient and week to ensure the chronological order
    df.sort_values(by=['patient_id', 'week'], inplace=True)

    # Mark each flare start for each patient
    df['Flare_start'] = (df['is_flare']) & (df['is_flare'] != df['is_flare'].shift(1))

    # Convert boolean to integer (1 for True, 0 for False)
    df['Flare_start'] = df['Flare_start'].astype(int)

    flare_start = df.pop('Flare_start')
    # Now, insert it at the desired position
    df.insert(3, 'Flare_start', flare_start)

    # Drop helper columns if they are no longer needed
    df.drop('is_flare', axis=1, inplace=True)

    return df

In [7]:
# @title align_flare_start_function
def align_flare_start_week(df, flare_column='Flare_start', week_column='week', patient_column='patient_id', target_week=0):
    """
    Adjusts the week numbers for each patient in the dataframe such that the first occurrence
    of Flare_start=1 (that is not the first sample for the patient) occurs in the same target week for all patients who have a flare start.

    Args:
    df (pd.DataFrame): DataFrame containing the data.
    flare_column (str): Column name indicating the flare start.
    week_column (str): Column name indicating the week number.
    patient_column (str): Column name indicating the patient ID.
    target_week (int): The week number to which the first flare start should be aligned.

    Returns:
    pd.DataFrame: Modified DataFrame with adjusted week numbers.
    """
    df = df.copy()
    adjusted_dfs = []

    # Add a column indicating if the patient had a flare or not
    df['had_flare'] = df.groupby(patient_column)[flare_column].transform('max')

    for patient in df[patient_column].unique():
        patient_data = df[df[patient_column] == patient].copy()
        first_week = patient_data[week_column].min()

        if patient_data['had_flare'].iloc[0] == 1:
            # Identify flare weeks and remove the first week if it's a flare
            flare_weeks = patient_data[patient_data[flare_column] == 1][week_column]
            flare_weeks = flare_weeks[flare_weeks > first_week]  # ignore flare at the first recorded week

            if not flare_weeks.empty:
                flare_week = flare_weeks.min()  # first flare that's not at the first week
                shift = target_week - flare_week
                patient_data[week_column] += shift

        adjusted_dfs.append(patient_data)

    aligned_df = pd.concat(adjusted_dfs, ignore_index=True)

    # Normalize to ensure there are no negative weeks
    min_week = aligned_df[week_column].min()
    if min_week < 0:
       aligned_df.loc[aligned_df['had_flare'] == 1, week_column] -= min_week

    # Remove the temporary 'had_flare' column before returning
    aligned_df.drop('had_flare', axis=1, inplace=True)

    return aligned_df

In [8]:
# @title analyze_representation_and_variation function

def analyze_representation_and_variation(df, start_col, category_name):
    """
    Function to calculate mean and standard deviation for categories in a DataFrame,
    and plot the top 10 most represented and most variable categories.

    Args:
    df (pd.DataFrame): DataFrame containing the data.
    start_col (int): Column index from which to start the analysis.
    category_name (str): Name of the category (e.g., 'Bacteria', 'Genes') for labeling purposes.
    """

    # Calculate mean and standard deviation for each category
    category_mean = df.iloc[:, start_col:].mean()
    category_std = df.iloc[:, start_col:].std()

    # Identify the most represented categories (highest mean)
    most_represented = category_mean.sort_values(ascending=False).head(10)

    # Identify categories with the most variation (highest standard deviation)
    most_variable = category_std.sort_values(ascending=False).head(10)

    # Plotting
    plt.figure(figsize=(20, 10))

    plt.subplot(1, 2, 1)
    sns.barplot(x=most_represented.values, y=most_represented.index)
    plt.title(f'Top 10 Most Represented {category_name}')
    plt.xlabel(f'Mean Relative Abundance for {category_name}')

    plt.subplot(1, 2, 2)
    sns.barplot(x=most_variable.values, y=most_variable.index)
    plt.title(f'Top 10 {category_name} with Most Variation')
    plt.xlabel(f'Standard Deviation of Relative Abundance for {category_name}')

    plt.tight_layout()
    plt.show()


In [9]:
# @title plot_species_abundance_over_time
def plot_species_abundance_over_time(df, cols_to_plot, title, type_, selection_mode='user'):
    """
    Plots the abundance over time for the specified species for each patient in the dataframe in a single figure with subplots.

    Args:
    df (pd.DataFrame): DataFrame containing microbial abundance data.
    cols_to_plot (list or None): List of species names to plot, or None if to be determined by selection_mode.
    title (str): Title for the plots and y-axis label.
    type_ (str): Label for the plot legend.
    selection_mode (str): Mode of selecting columns ('user', 'variation', 'mean'). 'user' expects cols_to_plot to be provided.
    """
    unique_patients = df['patient_id'].unique()
    num_patients = len(unique_patients)
    cols = 2  # Number of columns for subplots
    rows = (num_patients // cols) + (num_patients % cols > 0)  # Calculate rows needed

    fig, axs = plt.subplots(rows, cols, figsize=(20, 5 * rows))  # Adjust figsize as needed
    axs = axs.flatten()  # Flatten in case of 2D array of axes

    for i, patient in enumerate(unique_patients):
        ax = axs[i]
        patient_data = df[df['patient_id'] == patient].copy()
        patient_data.sort_values(by='week', inplace=True)  # Ensure data is sorted by week

        # Exclude non-abundance columns to focus on species data
        species_data = patient_data.drop(exclude_columns, axis=1)

        # Determine columns to plot based on selection_mode if cols_to_plot is None or selection_mode is not 'user'
        if cols_to_plot is None or selection_mode != 'user':
            if selection_mode == 'variation':
                # Find the x species with the highest standard deviation
                cols_to_plot = species_data.std().sort_values(ascending=False).head(6).index.tolist()
            elif selection_mode == 'mean':
                top_species = set()  # Use a set to avoid duplicates
                for index, row in species_data.iterrows():
                    top_species.update(row.nlargest(1).index)
                cols_to_plot = list(top_species)  # Convert set to list

        # Plot the abundance over time for the specified species
        for species in cols_to_plot:
            if species in species_data.columns:
                sns.lineplot(data=patient_data, x='week', y=species, label=species, ax=ax)

        # Annotate flare status changes (if applicable)
        flare_changes = patient_data.drop_duplicates('Flare_status', keep='first')[['week', 'Flare_status']]
        for _, (week, status) in flare_changes.iterrows():
            ax.axvline(x=week, color='grey', linestyle='--')
            ax.text(week, ax.get_ylim()[1], f' {status}', verticalalignment='top', fontsize=8)

        # Setting x-axis ticks for every week, adjusting label rotation for clarity
        patient_data["week"] = patient_data["week"].astype(int)
        ax.set_xticks(patient_data['week'].unique())
        ax.set_xticklabels(patient_data['week'].unique(), rotation=45)

        ax.set_title(f'{title} Over Time for Patient {patient}')
        ax.set_xlabel('Week')
        ax.set_ylabel(title)
        ax.legend(title=type_, bbox_to_anchor=(1.05, 1), loc='upper left')

    # Hide any unused subplots
    for j in range(i + 1, len(axs)):
        fig.delaxes(axs[j])

    plt.tight_layout()
    plt.show()


In [10]:
# @title plot_species_abundance_over_time_2 function
def plot_species_abundance_over_time_2(df, cols_to_plot, title, type_):
    """
    Plots the abundance over time for the specified species for all patients in the dataframe,
    with separate plots for patients with and without a flare start, all in a single figure with subplots.

    Args:
    df (pd.DataFrame): DataFrame containing microbial abundance data.
    cols_to_plot (list): List of species names to plot.
    title (str): Title for the plots and y-axis label.
    type_ (str): Label for the plot legend.
    """
    if not cols_to_plot:
        raise ValueError("cols_to_plot must be provided when selection_mode is 'user'")

    num_species = len(cols_to_plot)
    cols = 2  # Number of columns for subplots (one for with flare, one for without)
    rows = num_species  # Each species will occupy one row with two subplots

    fig, axs = plt.subplots(rows, cols, figsize=(20, 5 * rows))  # Adjust figsize as needed

    # If only one species, axs might not be a 2D array, ensure it is
    if num_species == 1:
        axs = [axs]

    # Determine patients where the first recorded week is a flare start
    df['first_week'] = df.groupby('patient_id')['week'].transform('min')  # Find first week for each patient
    first_flare_patients = df[(df['Flare_start'] == 1) & (df['week'] == df['first_week'])]['patient_id'].unique()
    patients_with_flare = df[(df['Flare_start'] == 1) & (~df['patient_id'].isin(first_flare_patients))]['patient_id'].unique()
    patients_without_flare = df[~df['patient_id'].isin(patients_with_flare)]['patient_id'].unique()


    for i, species in enumerate(cols_to_plot):
        # Plot for patients with flare
        ax = axs[i][0]

        for patient in patients_with_flare:
            patient_data = df[df['patient_id'] == patient].copy()
            patient_data.sort_values(by='week', inplace=True)  # Ensure data is sorted by week

            if species in patient_data.columns:
                sns.lineplot(data=patient_data, x='week', y=species, ax=ax, label=f'Patient {patient}')

        # Annotate flare start changes
        flare_changes = df[(df["Flare_start"] == 1)&(df['patient_id'].isin(patients_with_flare ))][['week', 'Flare_start']].drop_duplicates()
        for _, (week, status) in flare_changes.iterrows():
            ax.axvline(x=week, color='grey', linestyle='--')
            ax.text(week, ax.get_ylim()[1], 'Flare_start', verticalalignment='top', fontsize=8)

        ax.set_title(f'{title} ({species}) Over Time for Patients with Flare')
        ax.set_xlabel('Week')
        ax.set_ylabel(title)
        ax.legend(title=type_, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_xticks(df['week'].unique())
        ax.set_xticklabels(df['week'].unique(), rotation=45)

        # Plot for patients without flare
        ax = axs[i][1]
        for patient in patients_without_flare:
            patient_data = df[df['patient_id'] == patient].copy()
            patient_data.sort_values(by='week', inplace=True)  # Ensure data is sorted by week

            if species in patient_data.columns:
                sns.lineplot(data=patient_data, x='week', y=species, ax=ax, label=f'Patient {patient}')

        ax.set_title(f'{title} ({species}) Over Time for Patients without Flare')
        ax.set_xlabel('Week')
        ax.set_ylabel(title)
        ax.legend(title=type_, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_xticks(df['week'].unique())
        ax.set_xticklabels(df['week'].unique(), rotation=45)

    plt.tight_layout()
    plt.show()


In [11]:
# @title segment_time_series
def segment_time_series(df, length=6):
    # Sort by patient and week to ensure chronological order
    df.sort_values(by=['patient_id', 'week'], inplace=True)

    # Create 'ts' column initialized to NaN
    df['ts'] = np.nan
    df['RBF'] = False

    # Process each patient individually
    for patient in df['patient_id'].unique():
        patient_data = df[df['patient_id'] == patient]

        # Check for the presence of a flare start
        flare_start_indices = patient_data[patient_data['Flare_start'] == 1].index.tolist()

        # Initial segment number
        ts_value = 1

        if flare_start_indices:
            # There is at least one flare start
            flare_start_index = flare_start_indices[0]  # Assume first flare start is the one to consider
            # Calculate the position of the first flare start
            flare_position = patient_data.index.get_loc(flare_start_index)

            # Calculate the length of the first segment to align flare start with a new segment start
            if flare_position % length == 0:
                first_segment_length = 0  # Flare start is already aligned with a new segment
            else:
                first_segment_length = flare_position % length

            # Assign time series IDs
            # Segment before flare start if it has any length
            if first_segment_length > 0:
                df.loc[patient_data.index[:first_segment_length], 'ts'] = ts_value
                ts_value += 1

            # Remaining data after the initial adjustment
            for i in range(first_segment_length, len(patient_data), length):
                df.loc[patient_data.index[i:i+length], 'ts'] = ts_value
                ts_value += 1

            df.loc[patient_data.index[flare_position - length:flare_position], 'RBF'] = True

        else:
            # No flare, segment normally every x points
            for i in range(0, len(patient_data), length):
                df.loc[patient_data.index[i:i+length], 'ts'] = ts_value
                ts_value += 1

    # Convert 'ts' column to integer
    df['ts'] = df['ts'].astype(int)

    # Create a new column combining 'ts' and 'patient_id'
    df['ts_patient_id'] = df['patient_id'] + '_TS' + df['ts'].astype(str)

    return df


In [12]:
# @title smooth_patient_data
def smooth_patient_data(df, value_columns, patient_column='patient_id', week_column='week'):
    """
    Smooths the values of the specified columns in the dataframe using a rolling window average.

    Args:
    df (pd.DataFrame): DataFrame containing the data to be smoothed.
    value_columns (list): List of columns to smooth.
    patient_column (str): Column name for patient ID.
    week_column (str): Column name for week number.

    Returns:
    pd.DataFrame: Smoothed DataFrame.
    """
    df = df.copy()
    smoothed_dfs = []

    for patient in df[patient_column].unique():
        patient_data = df[df[patient_column] == patient].copy()
        patient_data = patient_data.sort_values(by=week_column)

        for col in value_columns:
            smoothed_values = patient_data[col].rolling(window=3, min_periods=1, center=True).mean()
            # Handle first and last weeks separately
            if len(smoothed_values) > 1:
                smoothed_values.iloc[0] = patient_data[col].iloc[:2].mean()
                smoothed_values.iloc[-1] = patient_data[col].iloc[-2:].mean()
            patient_data[col] = smoothed_values

        smoothed_dfs.append(patient_data)

    smoothed_df = pd.concat(smoothed_dfs)
    return smoothed_df


In [13]:
species_columns = [col for col in metaphlan_df.columns if col.startswith('s__')]
print(species_columns)
exclude_columns = metaphlan_df.columns.difference(species_columns)
gene_columns = gene_count_df.columns.difference(exclude_columns)

['s__GGB39918_SGB47522', 's__Actinobaculum_sp_oral_taxon_183', 's__Actinomyces_SGB17154', 's__Actinomyces_SGB17168', 's__Actinomyces_bouchesdurhonensis', 's__Actinomyces_dentalis', 's__Actinomyces_gerencseriae', 's__Actinomyces_graevenitzii', 's__Actinomyces_israelii', 's__Actinomyces_johnsonii', 's__Actinomyces_massiliensis', 's__Actinomyces_naeslundii', 's__Actinomyces_oris', 's__Actinomyces_sp_ICM47', 's__Actinomyces_sp_ICM58', 's__Actinomyces_sp_S6_Spd3', 's__Pauljensenia_hongkongensis', 's__Peptidiphaga_gingivicola', 's__Schaalia_turicensis', 's__Trueperella_pyogenes', 's__Alloscardovia_omnicolens', 's__Bifidobacterium_adolescentis', 's__Bifidobacterium_animalis', 's__Bifidobacterium_bifidum', 's__Bifidobacterium_breve', 's__Bifidobacterium_catenulatum', 's__Bifidobacterium_dentium', 's__Bifidobacterium_longum', 's__Bifidobacterium_pseudocatenulatum', 's__Bifidobacterium_pullorum', 's__Parascardovia_denticolens', 's__Scardovia_wiggsiae', 's__Corynebacterium_accolens', 's__Coryneba

In [14]:
metaphlan_df['count_above_0_3'] = (metaphlan_df.loc[:, species_columns] > 0.3).sum(axis=1)
metaphlan_df['count_above_0_03'] = (metaphlan_df.loc[:, species_columns] > 0.03).sum(axis=1)
metaphlan_df['count_above_0_01'] = (metaphlan_df.loc[:, species_columns] > 0.01).sum(axis=1)
metaphlan_df['count_above_0_1'] = (metaphlan_df.loc[:, species_columns] > 0.1).sum(axis=1)
metaphlan_df['count_above_0_05'] = (metaphlan_df.loc[:, species_columns] > 0.05).sum(axis=1)
#print(metaphlan_df[['patient_id', 'week', 'Flare_status', 'distance_to_flare', 'count_above_0_3']])

In [15]:
species_data= metaphlan_df[species_columns].iloc[1]
species_data.nlargest(10)

Unnamed: 0,1
s__Blautia_wexlerae,0.172019
s__Blautia_sp_MSK_21_1,0.08173
s__Fusicatenibacter_saccharivorans,0.078127
s__Clostridiaceae_bacterium,0.069694
s__Phocaeicola_vulgatus,0.067152
s__Ruminococcus_sp_JE7A12,0.057036
s__Candidatus_Cibiobacter_qucibialis,0.044256
s__Ruminococcus_torques,0.03553
s__Anaerostipes_hadrus,0.031832
s__Roseburia_faecis,0.024341


In [16]:
# Ensure the DataFrame is sorted by patient_id and week
metaphlan_df = metaphlan_df.sort_values(by=['patient_id', 'week'])

# Function to find top N species for a given row
def find_top_species(row, top_n=20):
    # Select species columns
    species_data = row[species_columns].apply(pd.to_numeric)
    # Sort by abundance and get top N species names
    top_species = species_data.nlargest(top_n).index.tolist()
    return top_species

# Apply the function to find the top 10 species for each row
metaphlan_df['top_10_species'] = metaphlan_df.apply(find_top_species, axis=1)

# Function to compare top species with previous week
def compare_with_previous_week(group):
    group = group.sort_values(by='week')
    group['common_species_with_last_week'] = np.nan
    for i in range(1, len(group)):
        current_week_species = set(group.iloc[i]['top_10_species'])
        previous_week_species = set(group.iloc[i-1]['top_10_species'])
        common_species = current_week_species.intersection(previous_week_species)
        group.at[group.index[i], 'common_species_with_last_week'] = len(common_species)
    return group


# Apply the function to each patient group and store results
metaphlan_df = metaphlan_df.groupby('patient_id').apply(compare_with_previous_week)

# Reset index if needed
metaphlan_df.reset_index(drop=True, inplace=True)

metaphlan_df['common_species_with_last_week'].fillna(metaphlan_df['common_species_with_last_week'].median(), inplace=True)

In [17]:
print(metaphlan_df[['patient_id', 'week', 'top_10_species', 'Flare_status']].head())

  patient_id  week                                     top_10_species Flare_status
0    TR_2101     1  [s__Blautia_wexlerae, s__Eubacterium_rectale, ...    Pre_flare
1    TR_2101     2  [s__Blautia_wexlerae, s__Blautia_sp_MSK_21_1, ...    Pre_flare
2    TR_2101     3  [s__Blautia_wexlerae, s__Eubacterium_rectale, ...    Pre_flare
3    TR_2101     5  [s__Lachnospiraceae_bacterium, s__Bacteroides_...    Pre_flare
4    TR_2101     6  [s__Eubacterium_rectale, s__Blautia_wexlerae, ...    Pre_flare


In [18]:

def compute_distance_to_flare(metaphlan_df):
    # Initialize distance_to_flare with NaN
    metaphlan_df['distance_to_flare'] = np.nan

    # Iterate through each unique patient
    for patient in metaphlan_df['patient_id'].unique():
        # Filter and sort data for the current patient
        patient_data = metaphlan_df[metaphlan_df['patient_id'] == patient].sort_values('week')

        # Identify weeks during flare
        during_flare_weeks = patient_data[patient_data['Flare_status'] == 'During_flare']['week']

        # Calculate distance to the next flare for each non-flare row
        for i, row in patient_data.iterrows():
            if row['Flare_status'] not in ['During_flare', 'Post_flare']:
                future_flares = during_flare_weeks[during_flare_weeks > row['week']]
                if not future_flares.empty:
                    metaphlan_df.at[i, 'distance_to_flare'] = future_flares.iloc[0] - row['week']

    # Manually adjust for specific cases
    adjustments = {
        39: 4, 40: 3, 41: 2, 42: 1
    }
    for week, distance in adjustments.items():
        metaphlan_df.loc[(metaphlan_df['week'] == week) & (metaphlan_df['patient_id'] == 'TR_2107'), 'distance_to_flare'] = distance

    median_flare_distance = metaphlan_df[metaphlan_df['distance_to_flare'].notna()]['distance_to_flare'].median()

    metaphlan_df.loc[
        (metaphlan_df['distance_to_flare'].isna()) &
        (metaphlan_df['Flare_status'] != 'During_flare'),
        'distance_to_flare'
    ] = median_flare_distance

    # Fill NaN distances based on the nearest non-NaN entries
    metaphlan_df = fill_distance(metaphlan_df)
    metaphlan_df.loc[metaphlan_df['Flare_status']=='During_flare', 'distance_to_flare'] = np.nan

    return metaphlan_df

def fill_distance(df):
    for patient_id, group in df.groupby('patient_id'):
        nan_indices = group[(group['distance_to_flare'].isna()) & (group['Flare_status'] != 'During_flare')].index
        last_index = group['week'].idxmax()
        last_week = group['week'].max()
        last_distance = group.at[last_index, 'distance_to_flare']

        for nan_index in nan_indices:
            current_week = group.at[nan_index, 'week']
            week_difference = np.abs(current_week - last_week)
            new_distance = last_distance + week_difference
            df.at[nan_index, 'distance_to_flare'] = new_distance

    return df

metaphlan_df = compute_distance_to_flare(metaphlan_df)
gene_count_df = compute_distance_to_flare(gene_count_df)

In [19]:
gene_count_df = compute_distance_to_flare(gene_count_df)

In [20]:


# Calculate the sum of the three largest values in each row excluding specified columns
metaphlan_df['dominant_species'] = metaphlan_df[species_columns].apply(lambda row: row.nlargest(4).sum(), axis=1)

print(metaphlan_df[['patient_id', 'week', 'dominant_species','Flare_status', 'distance_to_flare']][metaphlan_df['dominant_species']>0.6][metaphlan_df['Flare_status']!="During_flare"])

    patient_id  week  dominant_species Flare_status  distance_to_flare
18     TR_2101    21          0.602143    Pre_flare                4.0
109    TR_2103    19          0.761224    Pre_flare                8.0
110    TR_2103    20          0.672748    Pre_flare                7.0
111    TR_2103    21          0.934282    Pre_flare                6.0
112    TR_2103    22          0.656717    Pre_flare                5.0
148    TR_2104    14          0.608547    Pre_flare               29.0
294    TR_2107     8          0.661791    Pre_flare                2.0
310    TR_2107    26          0.668560    Pre_flare                7.0
360    TR_2201     8          0.621029     No_flare               12.0
365    TR_2201    13          0.667868     No_flare               12.0
373    TR_2201    23          0.627021     No_flare               12.0
436    TR_2202    41          0.620448     No_flare               12.0


  print(metaphlan_df[['patient_id', 'week', 'dominant_species','Flare_status', 'distance_to_flare']][metaphlan_df['dominant_species']>0.6][metaphlan_df['Flare_status']!="During_flare"])


In [21]:
# @title Adding a Flare_start column
metaphlan_df = flag_first_flare_weeks(metaphlan_df)
gene_count_df = flag_first_flare_weeks(gene_count_df)
#exclude_columns+= ['Flare_start']

In [22]:
#  @title  calculate aitchison distance for consecutive samples
def calculate_aitchison_distance_consecutive(df, species_cols):
    results = []

    for patient in df['patient_id'].unique():
        patient_data = df[df['patient_id'] == patient].copy()
        patient_data = patient_data.sort_values(by='week')

        patient_data['clr_transformed'] = patient_data[species_cols].apply(lambda row: clr(row), axis=1)

        for i in range(len(patient_data) - 1):
            week_current = patient_data.iloc[i]['week']
            week_next = patient_data.iloc[i + 1]['week']

            clr_current = patient_data.iloc[i]['clr_transformed']
            clr_next = patient_data.iloc[i + 1]['clr_transformed']

            aitchison_dist = euclidean(clr_current, clr_next)
            results.append({
                'patient_id': patient,
                'week_current': week_current,
                'week_next': week_next,
                'aitchison_distance': aitchison_dist,
                'Flare_start': patient_data.iloc[i]['Flare_start']
            })

    return pd.DataFrame(results)

In [23]:
# Create a copy of the DataFrame to avoid modifying the original
aitchison_df = metaphlan_df.copy()

# Select columns with species data
exclude_cols = ['patient_id', 'week', 'Flare_status', 'Flare_start', 'ts', 'RBF', 'ts_patient_id', 'PatientID_Weeknr']

# Find the minimum positive value in the species data
min_positive_value = aitchison_df[species_columns][aitchison_df[species_columns] > 0].min().min()

# Replace zeros with the minimum positive value divided by two
replacement_value = min_positive_value / 2
aitchison_df[species_columns] = aitchison_df[species_columns].replace(0, replacement_value)

to_plot_aitchison_df = calculate_aitchison_distance_consecutive(aitchison_df, species_columns)
to_plot_aitchison_df.drop('week_current', axis= 1, inplace=True)
to_plot_aitchison_df.rename(columns={'week_next': 'week'}, inplace=True)

In [24]:
# Function to calculate the number of top values needed to reach a cumulative sum of 0.8
def calculate_columns_to_reach_threshold(row, threshold=0.7):
    sorted_values = row.sort_values(ascending=False)  # Sort the values in descending order
    cumulative_sum = 0
    count = 0
    for value in sorted_values:
        cumulative_sum += value
        count += 1
        if cumulative_sum >= threshold:
            return count
    return count  # If threshold is not met, return the total count

# Apply the function across the rows
metaphlan_df['dominant_species_2'] = metaphlan_df[species_columns].apply(calculate_columns_to_reach_threshold, axis=1)


print(metaphlan_df[['patient_id', 'week', 'dominant_species', 'dominant_species_2' , 'Flare_status', 'distance_to_flare']][metaphlan_df['dominant_species_2']<8])

    patient_id  week  dominant_species  dominant_species_2  Flare_status  distance_to_flare
18     TR_2101    21          0.602143                   7     Pre_flare                4.0
22     TR_2101    25          0.560689                   7  During_flare                NaN
27     TR_2101    31          0.573689                   7  During_flare                NaN
30     TR_2101    35          0.762722                   4  During_flare                NaN
109    TR_2103    19          0.761224                   4     Pre_flare                8.0
110    TR_2103    20          0.672748                   5     Pre_flare                7.0
111    TR_2103    21          0.934282                   2     Pre_flare                6.0
112    TR_2103    22          0.656717                   5     Pre_flare                5.0
113    TR_2103    23          0.539810                   7     Pre_flare                4.0
116    TR_2103    27          0.727503                   4  During_flare        

In [25]:
# @title


non_zero_counts = metaphlan_df.groupby('patient_id')[species_columns].apply(lambda x: (x != 0).sum())
non_zero_counts = non_zero_counts.apply(lambda x: x[x > 0])
species_count_per_patient = non_zero_counts.count(axis=1)

# Print the result
print("Number of species with at least one non-zero value per patient:")
print(species_count_per_patient)

Number of species with at least one non-zero value per patient:
patient_id
TR_2101    226
TR_2102    229
TR_2103    330
TR_2104    323
TR_2105    289
TR_2106    265
TR_2107    250
TR_2108    191
TR_2201    281
TR_2202    236
TR_2203    188
TR_2205    290
dtype: int64


In [26]:
# @title
species_columns = [col for col in metaphlan_df.columns if col.startswith('s__')]

# Group by 'patient_id' and calculate the mean of each species
mean_abundances = metaphlan_df.groupby('patient_id')[species_columns].mean()


# and print the results for each patient
for patient_id, abundances in mean_abundances.iterrows():
    print(f"Patient {patient_id} has the following species with an average relative abundance above 0.07:")
    filtered_species = abundances[abundances > 0.07 ]
    if not filtered_species.empty:
        print(filtered_species)
    else:
        print("No species above the threshold.")
    print()


Patient TR_2101 has the following species with an average relative abundance above 0.07:
s__Blautia_sp_MSK_21_1    0.088540
s__Blautia_wexlerae       0.180948
s__Eubacterium_rectale    0.086417
Name: TR_2101, dtype: float64

Patient TR_2102 has the following species with an average relative abundance above 0.07:
s__Bacteroides_uniformis    0.089266
s__Phocaeicola_dorei        0.114793
Name: TR_2102, dtype: float64

Patient TR_2103 has the following species with an average relative abundance above 0.07:
s__Bifidobacterium_pullorum    0.085024
Name: TR_2103, dtype: float64

Patient TR_2104 has the following species with an average relative abundance above 0.07:
s__Phocaeicola_vulgatus    0.124184
Name: TR_2104, dtype: float64

Patient TR_2105 has the following species with an average relative abundance above 0.07:
s__Phocaeicola_vulgatus    0.070913
s__Ruminococcus_bromii     0.090685
Name: TR_2105, dtype: float64

Patient TR_2106 has the following species with an average relative abunda

In [27]:
metaphlan_df.columns

Index(['PatientID_Weeknr', 'patient_id', 'week', 'Flare_start', 'Flare_status',
       's__GGB39918_SGB47522', 's__Actinobaculum_sp_oral_taxon_183', 's__Actinomyces_SGB17154',
       's__Actinomyces_SGB17168', 's__Actinomyces_bouchesdurhonensis',
       ...
       'count_above_0_3', 'count_above_0_03', 'count_above_0_01', 'count_above_0_1',
       'count_above_0_05', 'top_10_species', 'common_species_with_last_week', 'distance_to_flare',
       'dominant_species', 'dominant_species_2'],
      dtype='object', length=1465)

In [28]:
# @title calculate_alpha_diversity function

def calculate_alpha_diversity(df, diversity_index='shannon'):
    """Calculate alpha diversity based on the specified diversity index."""
    if diversity_index == 'shannon':
        return df.apply(lambda x: entropy(x[x > 0]), axis=1)
    elif diversity_index == 'simpson':
        def simpson_index(x):
            x = x[x > 0]
            if len(x) == 0:
                return np.nan
            p = x / x.sum()
            return 1 - np.sum(p**2)
        return df.apply(simpson_index, axis=1)
    elif diversity_index == 'richness':
        return df.apply(lambda x: (x > 0).sum(), axis=1)
    elif diversity_index == 'evenness':
        # Calculate evenness using Pielou's evenness index
        def pielou_evenness(x):
            x = x[x > 0]
            if len(x) == 0:
                return np.nan
            shannon = entropy(x)
            richness = np.log(len(x))
            if richness == 0:
                return np.nan  # Avoid division by zero
            return shannon / richness
        return df.apply(pielou_evenness, axis=1)
    else:
        raise ValueError("Unsupported diversity index specified.")

In [29]:
diversity_indices = ['shannon', 'simpson', 'richness', 'evenness']
alpha_div = metaphlan_df[['patient_id', 'week', 'Flare_status', 'Flare_start']].copy()

# Calculate all types of alpha diversities
for index in diversity_indices:
    print(index)
    metaphlan_df[f'alpha_diversity_{index}'] = calculate_alpha_diversity(metaphlan_df[species_columns], index  )


# Separate patients based on whether they had a flare

shannon
simpson
richness
evenness


In [30]:
metaphlan_df.columns

Index(['PatientID_Weeknr', 'patient_id', 'week', 'Flare_start', 'Flare_status',
       's__GGB39918_SGB47522', 's__Actinobaculum_sp_oral_taxon_183', 's__Actinomyces_SGB17154',
       's__Actinomyces_SGB17168', 's__Actinomyces_bouchesdurhonensis',
       ...
       'count_above_0_05', 'top_10_species', 'common_species_with_last_week', 'distance_to_flare',
       'dominant_species', 'dominant_species_2', 'alpha_diversity_shannon',
       'alpha_diversity_simpson', 'alpha_diversity_richness', 'alpha_diversity_evenness'],
      dtype='object', length=1469)

In [31]:
metaphlan_df['log_Bwex']= np.log(0.01+metaphlan_df['s__Blautia_wexlerae'])

In [32]:
metaphlan_df['group']=metaphlan_df['patient_id']+ (metaphlan_df['week']>25).astype('str')
mask_2203 = metaphlan_df['patient_id'] == "TR_2203"
metaphlan_df.loc[mask_2203, 'group'] = metaphlan_df.loc[mask_2203, 'patient_id'] + (metaphlan_df.loc[mask_2203, 'week'] > 11).astype(str)

mask_2205 = metaphlan_df['patient_id'] == "TR_2205"
metaphlan_df.loc[mask_2205, 'group'] = metaphlan_df.loc[mask_2205, 'patient_id'] + (metaphlan_df.loc[mask_2205, 'week'] > 45).astype(str)

In [33]:
metaphlan_df.columns

Index(['PatientID_Weeknr', 'patient_id', 'week', 'Flare_start', 'Flare_status',
       's__GGB39918_SGB47522', 's__Actinobaculum_sp_oral_taxon_183', 's__Actinomyces_SGB17154',
       's__Actinomyces_SGB17168', 's__Actinomyces_bouchesdurhonensis',
       ...
       'common_species_with_last_week', 'distance_to_flare', 'dominant_species',
       'dominant_species_2', 'alpha_diversity_shannon', 'alpha_diversity_simpson',
       'alpha_diversity_richness', 'alpha_diversity_evenness', 'log_Bwex', 'group'],
      dtype='object', length=1471)

In [34]:
metaphlan_df=pd.merge(metaphlan_df, to_plot_aitchison_df, on = ['week','patient_id','Flare_start'], how ='left')

In [35]:
metaphlan_df['aitchison_distance'].fillna(metaphlan_df['aitchison_distance'].median(), inplace = True)

In [36]:
metaphlan_df['distance_to_flare_sqrt']=np.sqrt(metaphlan_df['distance_to_flare'])


In [37]:
all_features = ['alpha_diversity_evenness', 'alpha_diversity_richness',
       'alpha_diversity_shannon', 'alpha_diversity_simpson', 'common_species_with_last_week',
       'count_above_0_01', 'count_above_0_03', 'count_above_0_05',
       'count_above_0_1', 'count_above_0_3',  'dominant_species',
       'dominant_species_2', 'log_Bwex',  'aitchison_distance',
       'log_aitchison_distance']

In [38]:
import pandas as pd
import numpy as np
from sklearn.linear_model import Lasso
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from statsmodels.stats.outliers_influence import variance_inflation_factor

alpha= 0.01
# Filter and prepare the data
data = metaphlan_df[(metaphlan_df['patient_id'] != 'TR_2103') & (metaphlan_df['Flare_status'] != 'During_flare')]
patient_cols = [col for col in data.columns if col.startswith('patient_')]
data = pd.get_dummies(data, columns=['patient_id'], prefix='patient')
patient_cols = [col for col in data.columns if col.startswith('patient_')]
data[patient_cols] = data[patient_cols].astype('int')


# Specify the features to include in the model
features = all_features + patient_cols
X = data[features]
target = "distance_to_flare_sqrt"
y = data[target]

# Checking for multicollinearity using Variance Inflation Factor (VIF)
vif_data = pd.DataFrame()
vif_data["feature"] = X.columns
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]

print("Features with high collinearity (VIF > 8):")
print(vif_data[vif_data['VIF'] > 8])

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Initialize KFold for 10 folds
kf = KFold(n_splits=10, shuffle=True, random_state=42)

# Initialize lists to store performance metrics and predictions
mse_scores = []
r2_scores = []
predictions = pd.DataFrame(columns=['week',  'actual', 'predicted'])

# Perform 10-fold Cross-Validation
for train_index, test_index in kf.split(data):
    train_X, test_X = X_scaled[train_index], X_scaled[test_index]
    train_y, test_y = y.iloc[train_index], y.iloc[test_index]

    # Create and fit the Lasso model
    lasso = Lasso(alpha=alpha)  # Adjust alpha as needed
    lasso.fit(train_X, train_y)

    # Make predictions on the test set
    y_pred = lasso.predict(test_X)

    # Collect predictions
    fold_predictions = pd.DataFrame({
        'week': data.iloc[test_index].index,
        #'patient_id': data.iloc[test_index]['patient_id'],
        'actual': test_y**2,
        'predicted': y_pred**2
    })
    predictions = pd.concat([predictions, fold_predictions], ignore_index=True)

    # Calculate and collect MSE and R^2
    mse = mean_squared_error(test_y**2, y_pred**2)
    r2 = r2_score(test_y**2, y_pred**2)
    mse_scores.append(mse)
    r2_scores.append(r2)

# Calculate average MSE and R^2 over all folds
average_mse = np.mean(mse_scores)
average_r2 = np.mean(r2_scores)

print(f"Average Mean Squared Error: {average_mse}")
print(f"Average R-squared: {average_r2}")

# Fit the final Lasso model on the entire dataset
final_lasso = Lasso(alpha=alpha)  # Adjust alpha as needed
final_lasso.fit(X_scaled, y)

print("\nCoefficients from the final Lasso model:")
for feature, coef in zip(features, final_lasso.coef_):
    print(f"{feature}: {coef}")

# Add intercept (if the model has an intercept)
print(f"Intercept: {final_lasso.intercept_}")

# Print the predictions dataframe
print("\nPredictions DataFrame:")
print(predictions)


KeyError: "['log_aitchison_distance'] not in index"

In [None]:
131071/(60*60)

In [None]:
# import pandas as pd
# import numpy as np
# import statsmodels.api as sm
# from sklearn.model_selection import LeaveOneGroupOut
# from sklearn.metrics import r2_score
# from itertools import combinations

# # Assuming metaphlan_df is predefined and includes necessary data
# data = metaphlan_df[metaphlan_df['patient_id'] != 'TR_2103'][metaphlan_df['Flare_status'] != 'During_flare']
# data = pd.get_dummies(data, columns=['patient_id'], prefix='patient')

# # Define your initial setup
# target = "distance_to_flare_sqrt"
# y = data[target]
# groups = data['group']
# data = sm.add_constant(data)  # Add a constant to the model

# # Selecting potential predictor variables, excluding the target and grouping variables
# predictors = ['const','alpha_diversity_evenness', 'alpha_diversity_richness',
#        'alpha_diversity_shannon', 'alpha_diversity_simpson', 'common_species_with_last_week',
#        'count_above_0_01', 'count_above_0_03', 'count_above_0_05',
#        'count_above_0_1', 'count_above_0_3',  'dominant_species',
#        'dominant_species_2', 'log_Bwex',  'aitchison_distance',
#        'log_aitchison_distance']

# # Forward feature selection
# def forward_feature_selection(data, predictors, target, groups):
#     remaining_predictors = predictors[:]
#     selected_predictors = []
#     best_r2 = -np.inf

#     logo = LeaveOneGroupOut()

#     while remaining_predictors:
#         r2_with_candidates = []
#         for candidate in remaining_predictors:
#             trial_predictors = selected_predictors + [candidate]
#             X = data[trial_predictors]
#             group_r2_scores = []

#             # Perform Leave-One-Group-Out Cross-Validation
#             for train_index, test_index in logo.split(X, groups=groups):
#                 X_train, X_test = X.iloc[train_index], X.iloc[test_index]
#                 y_train, y_test = y.iloc[train_index], y.iloc[test_index]

#                 model = sm.OLS(y_train, X_train).fit()
#                 y_pred = model.predict(X_test)
#                 r2 = r2_score(y_test, y_pred)
#                 group_r2_scores.append(r2)

#             average_r2 = np.mean(group_r2_scores)
#             print(f"Average R-squared for {candidate}: {average_r2}")
#             r2_with_candidates.append((average_r2, candidate))

#         # Select the best candidate that maximizes R-squared
#         r2_with_candidates.sort(reverse=True)
#         best_new_r2, best_candidate = r2_with_candidates[0]

#         # Check if there's an improvement
#         if best_new_r2 > best_r2:
#             selected_predictors.append(best_candidate)
#             print(selected_predictors)
#             remaining_predictors.remove(best_candidate)
#             best_r2 = best_new_r2
#         else:
#             break

#     return selected_predictors

# selected_features = forward_feature_selection(data, predictors, target, groups)

# # Print selected features
# print("Selected features:", selected_features)

# # Fit the final model using the selected features
# final_X = data[selected_features]
# final_model = sm.OLS(y, final_X).fit()
# print("\nCoefficients and p-values from the final model:")
# print(final_model.summary())


In [None]:
data.head()

In [None]:
#chek multicolinearity
corr_matrix = data[features].corr()

# Plot correlation matrix
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap='coolwarm')
plt.title('Correlation Matrix')
plt.show()

In [None]:
metaphlan_df.columns[1400:1470]

In [None]:
import matplotlib.pyplot as plt

# Fit the model
final_model = sm.OLS(y, X).fit()

# Calculate residuals
residuals = final_model.resid

# Plot residuals
plt.scatter(final_model.predict(X), residuals)
plt.xlabel('Predicted Values')
plt.ylabel('Residuals')
plt.axhline(y=0, color='r', linestyle='--')
plt.title('Residual Plot')
plt.show()


In [None]:
from statsmodels.stats.outliers_influence import OLSInfluence

# Calculate influence measures
influence = OLSInfluence(final_model)
cooks = influence.cooks_distance[0]

# Plot Cook's Distance
plt.figure(figsize=(10, 6))
plt.stem(cooks, use_line_collection=True)
plt.title('Cook\'s Distance Outlier Detection')
plt.xlabel('Observation Index')
plt.ylabel('Cook\'s Distance')
plt.show()

# Threshold for Cook's Distance might typically be 4/n or 3 times the mean
cooks_threshold = 4 / len(X)
print(f"Suggested Cook's distance threshold: {cooks_threshold}")

outlier_indices = np.where(cooks > cooks_threshold)[0]
print("Indices of outliers based on Cook's distance:", outlier_indices)

In [None]:
X.iloc[cooks.argmax()]

In [None]:
metaphlan_df[metaphlan_df['log_Bwex']< -2.844918][metaphlan_df['log_Bwex']>-2.844920]

In [None]:
metaphlan_df[metaphlan_df['Flare_status']!='During_flare'].iloc[outlier_indices][['patient_id','week','Flare_status','distance_to_flare']]

In [None]:
studentized_residuals = influence.resid_studentized_external

# Find large studentized residuals
outliers = np.abs(studentized_residuals) > 2
print(f"Outliers based on studentized residuals: {data[outliers][['week']]}")


In [None]:
# Calculate the correlation matrix
columns_of_interest = ['distance_to_flare', 'count_above_0_3', 'log_Bwex', 'aitchison_distance', 'dominant_species','dominant_species_2', 'alpha_diversity_shannon', 'alpha_diversity_simpson', 'alpha_diversity_richness',
                       'alpha_diversity_evenness', 'count_above_0_01', 'count_above_0_03', 'count_above_0_05', 'count_above_0_1', 'common_species_with_last_week']
correlation_matrix = metaphlan_df[columns_of_interest][metaphlan_df['Flare_status'] != 'During_flare'].corr()

# Create a mask for the upper triangle
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))

# Plotting the correlation matrix
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt=".2f", cmap='coolwarm', cbar=True)
plt.title('Correlation Matrix')
plt.show()


In [None]:
metaphlan_df['log_distance_to_Flare']= np.log(metaphlan_df['distance_to_flare'])
metaphlan_df['log_aitchison_distance']= np.log(metaphlan_df['aitchison_distance'])

In [None]:
# Assuming 'metaphlan_df' has been filtered to exclude 'During_flare' status
data = metaphlan_df[metaphlan_df['Flare_status']!='During_flare']

# Define the list of columns you want to plot against 'distance to flare'
columns_to_plot = ['count_above_0_3',  'log_Bwex', 'aitchison_distance',
                   'dominant_species', 'dominant_species_2', 'alpha_diversity_shannon',
                   'alpha_diversity_simpson', 'alpha_diversity_richness', 'alpha_diversity_evenness', 'log_aitchison_distance', 's__GGB9512_SGB14909','s__Isoptericola_variabilis', 's__GGB4682_SGB6472', 's__Blautia_sp_MSK_20_85' , 's__Bacteroides_stercoris']

# Set up the matplotlib figure
plt.figure(figsize=(15, 30))  # Adjust the size according to your needs

# Create a scatter plot for each variable against 'distance to flare'
for i, column in enumerate(columns_to_plot, 1):
    plt.subplot(6, 3, i)  # Adjust subplot grid parameters as needed
    sns.scatterplot(x=column, y='distance_to_flare', data=data)
    plt.title(f'Scatter Plot of Distance to Flare vs. {column}')
    plt.xlabel(column)
    plt.ylabel('Distance to Flare')

plt.tight_layout()
plt.show()


In [None]:
columns_of_interest = ['distance_to_flare'] + species_columns

# Filter the DataFrame to include only the relevant columns
filtered_data = metaphlan_df[columns_of_interest]

# Calculate the correlation matrix
correlation_matrix = filtered_data.corr()

# Focus on the 'distance_to_flare' correlations
distance_correlations = correlation_matrix['distance_to_flare']

# Drop the correlation of 'distance_to_flare' with itself to avoid trivial results
distance_correlations = distance_correlations.drop(index='distance_to_flare')

# Sort by the absolute values of correlations to find the top species
top_species = distance_correlations.abs().sort_values(ascending=False).head(10)

print("Top 10 species with the highest absolute correlation with distance to flare:")
print(top_species)

In [None]:
columns_of_interest = ['distance_to_flare'] + gene_columns.tolist()

# Filter the DataFrame to include only the relevant columns
filtered_data = gene_count_df[columns_of_interest]

# Calculate the correlation matrix
correlation_matrix = filtered_data.corr()

# Focus on the 'distance_to_flare' correlations
distance_correlations = correlation_matrix['distance_to_flare']

# Drop the correlation of 'distance_to_flare' with itself to avoid trivial results
distance_correlations = distance_correlations.drop(index='distance_to_flare')

# Sort by the absolute values of correlations to find the top species
top_genes = distance_correlations.abs().sort_values(ascending=False).head(10)

print("Top 10 genes with the highest absolute correlation with distance to flare:")
print(top_genes)


In [None]:
# prompt: make a histogram of 's__Blautia_wexlerae'

import matplotlib.pyplot as plt
plt.hist(np.log(0.01+metaphlan_df['s__Blautia_wexlerae']))
plt.xlabel('s__Blautia_wexlerae Abundance')
plt.ylabel('Frequency')
plt.title('Histogram of s__Blautia_wexlerae Abundance')
plt.show()


In [None]:
metaphlan_df['distance_to_flare_transf']= np.sqrt(1/(metaphlan_df['distance_to_flare']))

In [None]:
metaphlan_df[['distance_to_flare_transf','distance_to_flare']]

In [None]:
# Sample code to plot a histogram of the 'distance_to_flare' column in the metaphlan_df DataFrame


# Sample code to plot a histogram of the 'distance_to_flare' column in the metaphlan_df DataFrame
def plot_distance_to_flare_histogram(metaphlan_df):
    plt.figure(figsize=(10, 6))
    plt.hist(metaphlan_df['distance_to_flare'], bins=50, edgecolor='black', alpha=0.7)
    plt.title('Histogram of Distance to Flare')
    plt.xlabel('Distance to Flare')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()


plot_distance_to_flare_histogram(metaphlan_df)


def plot_distance_to_flare_histogram(metaphlan_df):
    plt.figure(figsize=(10, 6))
    plt.hist(metaphlan_df['distance_to_flare_sqrt'], bins=50, edgecolor='black', alpha=0.7)
    plt.title('Histogram of Distance to Flare')
    plt.xlabel('Distance to Flare')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()


plot_distance_to_flare_histogram(metaphlan_df)


In [None]:
metaphlan_df['distance_to_flare_sqrt'].max()

In [None]:
import matplotlib.pyplot as plt
import scipy.stats as stats

# Q-Q Plot
def qq_plot(metaphlan_df):
    stats.probplot(metaphlan_df['distance_to_flare_transf'], dist="norm", plot=plt)
    plt.title('Q-Q Plot')
    plt.grid(True)
    plt.show()

# Call the function to display the Q-Q plot
qq_plot(metaphlan_df)


In [None]:
import matplotlib.pyplot as plt
import scipy.stats as stats

# Q-Q Plot
def qq_plot(metaphlan_df):
    stats.probplot(metaphlan_df['distance_to_flare_sqrt'], dist="norm", plot=plt)
    plt.title('Q-Q Plot')
    plt.grid(True)
    plt.show()

# Call the function to display the Q-Q plot
qq_plot(metaphlan_df)


In [None]:
['top_10_species', 'patient_id','week','Flare_status', 'Flare_start', 'PatientID_Weeknr','distance_to_flare',
                                                                                                     'last_week','distance_to_flare','log_distace_to_flare', 'distance_to_flare_transf', 'group', 'distance_to_flare_sqrt']

In [None]:


# Define columns to plot, excluding specific columns and species columns
columns_to_plot = [col for col in metaphlan_df.columns if col not in species_columns and col not in ['top_10_species', 'patient_id','week','Flare_status', 'Flare_start', 'PatientID_Weeknr','distance_to_flare',
                                                                                                     'last_week','distance_to_flare','log_distace_to_flare', 'distance_to_flare_transf', 'group', 'distance_to_flare_sqrt']]

# Plotting for each patient and each feature
for column in columns_to_plot:
    for patient_id, group in metaphlan_df.groupby('patient_id'):
        plt.figure(figsize=(12, 6))
        plt.plot(group['week'], group[column], marker='.', label=f'Patient {patient_id}')
        plt.title(f'{column} Across Weeks for Patient {patient_id}')
        plt.xlabel('Week')
        plt.ylabel(column)
        plt.grid(True)
        plt.legend()
        plt.show()

In [None]:
to_plot = metaphlan_df.drop(species_columns, axis=1)

In [None]:
to_plot.columns

In [None]:
to_plot.drop(['PatientID_Weeknr','group'], axis= 1, inplace=True)

In [None]:
to_plot.to_csv('metaphlan_df.csv')

In [None]:
# @title align_flare_start_function
def align_flare_start_week(df, flare_column='Flare_start', week_column='week', patient_column='patient_id', target_week=0):
    """
    Adjusts the week numbers for each patient in the dataframe such that the first occurrence
    of Flare_start=1 (that is not the first sample for the patient) occurs in the same target week for all patients who have a flare start.

    Args:
    df (pd.DataFrame): DataFrame containing the data.
    flare_column (str): Column name indicating the flare start.
    week_column (str): Column name indicating the week number.
    patient_column (str): Column name indicating the patient ID.
    target_week (int): The week number to which the first flare start should be aligned.

    Returns:
    pd.DataFrame: Modified DataFrame with adjusted week numbers.
    """
    df = df.copy()
    adjusted_dfs = []

    # Add a column indicating if the patient had a flare or not
    df['had_flare'] = df.groupby(patient_column)[flare_column].transform('max')

    for patient in df[patient_column].unique():
        patient_data = df[df[patient_column] == patient].copy()
        first_week = patient_data[week_column].min()

        if patient_data['had_flare'].iloc[0] == 1:
            # Identify flare weeks and remove the first week if it's a flare
            flare_weeks = patient_data[patient_data[flare_column] == 1][week_column]
            flare_weeks = flare_weeks[flare_weeks > first_week]  # ignore flare at the first recorded week

            if not flare_weeks.empty:
                flare_week = flare_weeks.min()  # first flare that's not at the first week
                shift = target_week - flare_week
                patient_data[week_column] += shift

        adjusted_dfs.append(patient_data)

    aligned_df = pd.concat(adjusted_dfs, ignore_index=True)

    # Normalize to ensure there are no negative weeks
    min_week = aligned_df[week_column].min()
    if min_week < 0:
       aligned_df.loc[aligned_df['had_flare'] == 1, week_column] -= min_week

    # Remove the temporary 'had_flare' column before returning
    aligned_df.drop('had_flare', axis=1, inplace=True)

    return aligned_df

In [None]:
# @title shifting the week numbers so that the three patients have the beginning of their flare at the same time, which makes it easier to compare visually
to_plot = align_flare_start_week(to_plot)

In [None]:
# @title plot_species_abundance_over_time_2 function
def plot_species_abundance_over_time_2(df, cols_to_plot, title, type_):
    """
    Plots the abundance over time for the specified species for all patients in the dataframe,
    with separate plots for patients with and without a flare start, all in a single figure with subplots.

    Args:
    df (pd.DataFrame): DataFrame containing microbial abundance data.
    cols_to_plot (list): List of species names to plot.
    title (str): Title for the plots and y-axis label.
    type_ (str): Label for the plot legend.
    """
    if not cols_to_plot:
        raise ValueError("cols_to_plot must be provided when selection_mode is 'user'")

    num_species = len(cols_to_plot)
    cols = 2  # Number of columns for subplots (one for with flare, one for without)
    rows = num_species  # Each species will occupy one row with two subplots

    fig, axs = plt.subplots(rows, cols, figsize=(20, 5 * rows))  # Adjust figsize as needed

    # If only one species, axs might not be a 2D array, ensure it is
    if num_species == 1:
        axs = [axs]

    # Determine patients where the first recorded week is a flare start
    df['first_week'] = df.groupby('patient_id')['week'].transform('min')  # Find first week for each patient
    first_flare_patients = df[(df['Flare_start'] == 1) & (df['week'] == df['first_week'])]['patient_id'].unique()
    patients_with_flare = df[(df['Flare_start'] == 1) & (~df['patient_id'].isin(first_flare_patients))]['patient_id'].unique()
    patients_without_flare = df[~df['patient_id'].isin(patients_with_flare)]['patient_id'].unique()


    for i, species in enumerate(cols_to_plot):
        # Plot for patients with flare
        ax = axs[i][0]

        for patient in patients_with_flare:
            patient_data = df[df['patient_id'] == patient].copy()
            patient_data.sort_values(by='week', inplace=True)  # Ensure data is sorted by week

            if species in patient_data.columns:
                sns.lineplot(data=patient_data, x='week', y=species, ax=ax, label=f'Patient {patient}')

        # Annotate flare start changes
        flare_changes = df[(df["Flare_start"] == 1)&(df['patient_id'].isin(patients_with_flare ))][['week', 'Flare_start']].drop_duplicates()
        for _, (week, status) in flare_changes.iterrows():
            ax.axvline(x=week, color='grey', linestyle='--')
            ax.text(week, ax.get_ylim()[1], 'Flare_start', verticalalignment='top', fontsize=8)

        ax.set_title(f'{title} ({species}) Over Time for Patients with Flare')
        ax.set_xlabel('Week')
        ax.set_ylabel(title)
        ax.legend(title=type_, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_xticks(df['week'].unique())
        ax.set_xticklabels(df['week'].unique(), rotation=45)

        # Plot for patients without flare
        ax = axs[i][1]
        for patient in patients_without_flare:
            patient_data = df[df['patient_id'] == patient].copy()
            patient_data.sort_values(by='week', inplace=True)  # Ensure data is sorted by week

            if species in patient_data.columns:
                sns.lineplot(data=patient_data, x='week', y=species, ax=ax, label=f'Patient {patient}')

        ax.set_title(f'{title} ({species}) Over Time for Patients without Flare')
        ax.set_xlabel('Week')
        ax.set_ylabel(title)
        ax.legend(title=type_, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.set_xticks(df['week'].unique())
        ax.set_xticklabels(df['week'].unique(), rotation=45)

    plt.tight_layout()
    plt.show()


In [None]:
plot_species_abundance_over_time_2(to_plot, ['alpha_diversity_evenness', 'alpha_diversity_richness',
       'alpha_diversity_shannon', 'alpha_diversity_simpson', 'common_species_with_last_week',
  'count_above_0_01', 'count_above_0_03', 'count_above_0_05',
       'count_above_0_1', 'count_above_0_3', 'distance_to_flare', 'dominant_species',
       'dominant_species_2', 'log_Bwex',
       'aitchison_distance',
       'log_aitchison_distance'], '..', '..')

In [None]:
['alpha_diversity_evenness', 'alpha_diversity_richness',
       'alpha_diversity_shannon', 'alpha_diversity_simpson', 'common_species_with_last_week',
 'count_above_0_01', 'count_above_0_03', 'count_above_0_05',
       'count_above_0_1', 'count_above_0_3',  'dominant_species',
       'dominant_species_2', 'log_Bwex',
       'aitchison_distance',
   ]

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor
from statsmodels.stats.outliers_influence import variance_inflation_factor

# Filter and prepare the data
data = metaphlan_df[(metaphlan_df['patient_id'] != 'TR_2103') & (metaphlan_df['Flare_status'] != 'During_flare')]
data['original_patient_id'] = data['patient_id']  # Retain original patient_id
data = pd.get_dummies(data, columns=['patient_id'], prefix='patient')
patient_cols = [col for col in data.columns if col.startswith('patient_')]
data[patient_cols] = data[patient_cols].astype('int')


# Specify the features to include in the model
features = ['alpha_diversity_evenness', 'alpha_diversity_richness',
       'alpha_diversity_shannon', 'alpha_diversity_simpson', 'common_species_with_last_week',
'count_above_0_01', 'count_above_0_03', 'count_above_0_05',
       'count_above_0_1', 'count_above_0_3',  'dominant_species',
       'dominant_species_2', 'log_Bwex',
       'aitchison_distance',
   ] + patient_cols

X = data[features]
target = "distance_to_flare_sqrt"
y = data[target]

# Checking for multicollinearity using Variance Inflation Factor (VIF)
vif_data = pd.DataFrame()
vif_data["feature"] = X.columns
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]

print("Features with high collinearity (VIF > 8):")
print(vif_data[vif_data['VIF'] > 8])

# Initialize KFold for 10 folds
kf = KFold(n_splits=10, shuffle=True, random_state=42)

# Initialize lists to store performance metrics and predictions
mse_scores = []
r2_scores = []
all_actuals = []
all_predictions = []
all_patient_ids = []
all_weeks = []

# Perform 10-fold Cross-Validation
for train_index, test_index in kf.split(data):
    train_data = data.iloc[train_index]
    test_data = data.iloc[test_index]

    # Create and fit the Random Forest model
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(train_data[features], train_data[target])

    # Make predictions on the test set
    y_pred = model.predict(test_data[features])

    # Collect actual values, predictions, patient IDs, and weeks
    all_actuals.extend(test_data[target].values)
    all_predictions.extend(y_pred)
    all_patient_ids.extend(test_data['original_patient_id'].values)
    all_weeks.extend(test_data['week'].values)  # Assuming there is a 'week' column

    # Calculate and collect MSE and R^2
    mse = mean_squared_error(test_data[target], y_pred)
    r2 = r2_score(test_data[target], y_pred)
    mse_scores.append(mse)
    r2_scores.append(r2)

# Calculate average MSE and R^2 over all folds
average_mse = np.mean(mse_scores)
average_r2 = np.mean(r2_scores)

print(f"Average Mean Squared Error: {average_mse}")
print(f"Average R-squared: {average_r2}")

# Fit the final model on the entire dataset
final_model = RandomForestRegressor(n_estimators=100, random_state=42)
final_model.fit(X, y)

# Output feature importances
importances = final_model.feature_importances_
feature_importances = pd.DataFrame({'feature': features, 'importance': importances}).sort_values(by='importance', ascending=False)
print("\nFeature importances from the final model:")
print(feature_importances)

# Show actual vs. predicted values along with patient ID and week
predictions_df = pd.DataFrame({
    'patient_id': all_patient_ids,
    'week': all_weeks,
    'Actual': all_actuals,
    'Predicted': all_predictions
})

In [None]:
print("\nActual vs. Predicted values with Patient ID and Week:")
predictions_df['Actual']=predictions_df['Actual']**2
predictions_df['Predicted']=predictions_df['Predicted']**2
print(predictions_df.sort_values(by=['patient_id','week']))