In [1]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [None]:
# @title Data Reading
import pandas as pd
import os
from pandas import ExcelFile
import matplotlib.pyplot as plt
import seaborn as sns
from random import sample
import numpy as np
from statsmodels.tsa.stattools import adfuller
from statsmodels.stats.multitest import multipletests
import warnings
from scipy.spatial.distance import pdist, squareform
from scipy.stats import entropy
from scipy.signal import savgol_filter
from scipy.spatial.distance import euclidean
from sklearn.preprocessing import StandardScaler

from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.patches as mpatches
from scipy.stats import gmean
import math

gene_count_path ='/content/drive/MyDrive/MITResearch/IBD_13/gene_count_preprocessed.csv'
metaphlan_path = '/content/drive/MyDrive/MITResearch/IBD_13/metaphlan_preprocessed.csv'

gene_count_df = pd.read_csv(gene_count_path)
metaphlan_df = pd.read_csv(metaphlan_path)

print(gene_count_df.shape)
print(metaphlan_df.shape)

In [None]:

extended_colors = [
    '#66c2a5',  # Teal
    '#ff7f00',  # Orange
    '#8da0cb',  # Light Blue
    '#e78ac3',  # Pink
    '#a6d854',  # Light Green
    '#e5c494',  # Tan
    '#b3b3b3',  # Gray
    '#80b1d3',  # Sky Blue
    '#bebada',  # Lavender
    '#fb8072',  # Light Red
    '#7fc97f',  # Medium Green
    '#d9d9d9',  # Light Gray
    '#ffed6f',  # Pale Yellow
    '#b15928',  # Brown
    '#fccde5',  # Pale Pink
    '#8dd3c7',  # Pale Teal
    '#ffffb3',  # Very Pale Yellow
    '#bc80bd'   # Purple
]

# Set default color cycle
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=extended_colors)

# Set other default styles
#plt.rcParams['figure.facecolor'] = '#f0f0f0'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['grid.color'] = '#cccccc'
plt.rcParams['grid.linestyle'] = '--'
plt.rcParams['grid.alpha'] = 0.5

# Adjust grid frequency
plt.rcParams['axes.grid'] = False
plt.rcParams['axes.grid.which'] = 'major'
plt.rcParams['axes.grid.axis'] = 'both'


# Set default line width
plt.rcParams['lines.linewidth'] = 1

In [None]:
# @title count columns with non zero
def count_columns_with_nonzero(df, excluded_columns):
    # Exclude the specified columns
    df_filtered = df.drop(columns=excluded_columns)

    # Count non-zero values in each column
    non_zero_counts = df_filtered.astype(bool).sum(axis=0)

    # Get the number of samples
    num_samples = len(df_filtered)

    # Find columns where the count of non-zero values is greater than 10% of the number of samples
    more_than_10_percent = non_zero_counts > (0.2 * num_samples)

    # Count these columns
    count = more_than_10_percent.sum()

    return count

excluded_columns = ['patient_id', 'week', 'Flare_status']
result = count_columns_with_nonzero(gene_count_df, excluded_columns)
print(f'Number of columns with more than 20% non-zero values in gene_count_df: {result}')

result = count_columns_with_nonzero(metaphlan_df, excluded_columns)
print(f'Number of columns with more than 20% non-zero values in metaphlan_df: {result}')

In [None]:
# @title flag_first_flare_weeks function

def flag_first_flare_weeks(df):
    # Identify the rows where flare starts
    df['is_flare'] = (df['Flare_status'] == 'During_flare') | (df['Flare_status'] == 'During_flare_2')

    # Sort by patient and week to ensure the chronological order
    df.sort_values(by=['patient_id', 'week'], inplace=True)

    # Mark each flare start for each patient
    df['Flare_start'] = (df['is_flare']) & (df['is_flare'] != df['is_flare'].shift(1))

    # Convert boolean to integer (1 for True, 0 for False)
    df['Flare_start'] = df['Flare_start'].astype(int)

    flare_start = df.pop('Flare_start')
    # Now, insert it at the desired position
    df.insert(3, 'Flare_start', flare_start)

    # Drop helper columns if they are no longer needed
    df.drop('is_flare', axis=1, inplace=True)

    return df

In [None]:
# @title align_flare_start_function
def align_flare_start_week(df, flare_column='Flare_start', week_column='week', patient_column='patient_id', target_week=0):
    """
    Adjusts the week numbers for each patient in the dataframe such that the first occurrence
    of Flare_start=1 (that is not the first sample for the patient) occurs in the same target week for all patients who have a flare start.

    Args:
    df (pd.DataFrame): DataFrame containing the data.
    flare_column (str): Column name indicating the flare start.
    week_column (str): Column name indicating the week number.
    patient_column (str): Column name indicating the patient ID.
    target_week (int): The week number to which the first flare start should be aligned.

    Returns:
    pd.DataFrame: Modified DataFrame with adjusted week numbers.
    """
    df = df.copy()
    adjusted_dfs = []

    # Add a column indicating if the patient had a flare or not
    df['had_flare'] = df.groupby(patient_column)[flare_column].transform('max')

    for patient in df[patient_column].unique():
        patient_data = df[df[patient_column] == patient].copy()
        first_week = patient_data[week_column].min()

        if patient_data['had_flare'].iloc[0] == 1:
            # Identify flare weeks and remove the first week if it's a flare
            flare_weeks = patient_data[patient_data[flare_column] == 1][week_column]
            flare_weeks = flare_weeks[flare_weeks > first_week]  # ignore flare at the first recorded week

            if not flare_weeks.empty:
                flare_week = flare_weeks.min()  # first flare that's not at the first week
                shift = target_week - flare_week
                patient_data[week_column] += shift

        adjusted_dfs.append(patient_data)

    aligned_df = pd.concat(adjusted_dfs, ignore_index=True)

    # Normalize to ensure there are no negative weeks
    min_week = aligned_df[week_column].min()
    if min_week < 0:
       aligned_df.loc[aligned_df['had_flare'] == 1, week_column] -= min_week

    # Remove the temporary 'had_flare' column before returning
    aligned_df.drop('had_flare', axis=1, inplace=True)

    return aligned_df

In [None]:
# @title analyze representation and variation

def analyze_representation_and_variation(df, start_col, category_name):
    """
    Function to calculate mean and standard deviation for categories in a DataFrame,
    and plot the top 10 most represented and most variable categories.

    Args:
    df (pd.DataFrame): DataFrame containing the data.
    start_col (int): Column index from which to start the analysis.
    category_name (str): Name of the category (e.g., 'Bacteria', 'Genes') for labeling purposes.
    """

    # Calculate mean and standard deviation for each category
    category_mean = df.iloc[:, start_col:].mean()
    category_std = df.iloc[:, start_col:].std()

    # Identify the most represented categories (highest mean)
    most_represented = category_mean.sort_values(ascending=False).head(10)

    # Identify categories with the most variation (highest standard deviation)
    most_variable = category_std.sort_values(ascending=False).head(10)

    # Plotting
    plt.figure(figsize=(20, 10))
    sns.set(style="whitegrid", palette="muted")

    # Top 10 Most Represented Categories
    plt.subplot(1, 2, 1)
    ax1 = sns.barplot(x=most_represented.values, y=most_represented.index, palette="rocket")
    plt.title(f'Top 10 Most Represented {category_name}', fontsize=18)
    plt.xlabel(f'Mean Relative Abundance for {category_name}', fontsize=14)
    plt.ylabel('')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    # Remove axes spines
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)

    # Remove legend border if exists
    if ax1.get_legend():
        leg1 = ax1.get_legend()
        leg1.get_frame().set_linewidth(0.0)

    # Top 10 Categories with Most Variation
    plt.subplot(1, 2, 2)
    ax2 = sns.barplot(x=most_variable.values, y=most_variable.index, palette="rocket")
    plt.title(f'Top 10 {category_name} with Most Variation', fontsize=18)
    plt.xlabel(f'Standard Deviation of Relative Abundance for {category_name}', fontsize=14)
    plt.ylabel('')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    # Remove axes spines
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.spines['left'].set_visible(False)
    ax2.spines['bottom'].set_visible(False)

    # Remove legend border if exists
    if ax2.get_legend():
        leg2 = ax2.get_legend()
        leg2.get_frame().set_linewidth(0.0)

    plt.tight_layout()
    plt.show()


In [None]:
# @title plot species abundance over time

def plot_species_abundance_over_time(df, cols_to_plot, title, type_, selection_mode='user'):
    """
    Plots the abundance over time for the specified species for each patient in the dataframe in a single figure with subplots.

    Args:
    df (pd.DataFrame): DataFrame containing microbial abundance data.
    cols_to_plot (list or None): List of species names to plot, or None if to be determined by selection_mode.
    title (str): Title for the plots and y-axis label.
    type_ (str): Label for the plot legend.
    selection_mode (str): Mode of selecting columns ('user', 'variation', 'mean'). 'user' expects cols_to_plot to be provided.
    """
    unique_patients = df['patient_id'].unique()
    num_patients = len(unique_patients)
    cols = 2  # Number of columns for subplots
    rows = (num_patients // cols) + (num_patients % cols > 0)  # Calculate rows needed

    fig, axs = plt.subplots(rows, cols, figsize=(20, 5 * rows))  # Adjust figsize as needed
    axs = axs.flatten()  # Flatten in case of 2D array of axes

    for i, patient in enumerate(unique_patients):
        ax = axs[i]
        patient_data = df[df['patient_id'] == patient].copy()
        patient_data.sort_values(by='week', inplace=True)  # Ensure data is sorted by week

        # Exclude non-abundance columns to focus on species data
        species_data = patient_data.drop(exclude_columns, axis=1)

        # Determine columns to plot based on selection_mode if cols_to_plot is None or selection_mode is not 'user'
        if cols_to_plot is None or selection_mode != 'user':
            if selection_mode == 'variation':
                # Find the x species with the highest standard deviation
                cols_to_plot = species_data.std().sort_values(ascending=False).head(6).index.tolist()
            elif selection_mode == 'top1':
                top_species = set()  # Use a set to avoid duplicates
                for index, row in species_data.iterrows():
                    top_species.update(row.nlargest(1).index)
                cols_to_plot = list(top_species)  # Convert set to list
            elif selection_mode == 'mean':
                # Find the x species with the highest mean
                cols_to_plot = species_data.mean().sort_values(ascending=False).head(6).index.tolist()

        # Plot the abundance over time for the specified species
        for species in cols_to_plot:
            if species in species_data.columns:
                sns.lineplot(data=patient_data, x='week', y=species, label=species, ax=ax)

        # Annotate flare status changes (if applicable)
        flare_changes = patient_data.drop_duplicates('Flare_status', keep='first')[['week', 'Flare_status']]
        for _, (week, status) in flare_changes.iterrows():
            ax.axvline(x=week, color='grey', linestyle='--')
            ax.text(week, ax.get_ylim()[1], f' {status}', verticalalignment='top', fontsize=8)

        # Setting x-axis ticks for every week, adjusting label rotation for clarity
        major_ticks = range(55)[::5]  # Get unique weeks
        ax.set_xticks(major_ticks)
        ax.set_xticklabels(major_ticks, rotation=45)

        ax.set_title(f'{title} Over Time for Patient {patient}')
        ax.set_xlabel('Week')
        ax.set_ylabel(title)
        ax.legend(title=type_, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.grid(False)

        # Remove axes spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(True)
        ax.spines['bottom'].set_visible(True)

    # Hide any unused subplots
    for j in range(i + 1, len(axs)):
        fig.delaxes(axs[j])

    # Remove legend border
    for leg in ax.get_legend().get_frame().axes.figure.legends:
        leg.get_frame().set_linewidth(0.0)

    plt.tight_layout()
    plt.show()


In [None]:
# @title plot species abundance over time 2

def plot_species_abundance_over_time_2(df, cols_to_plot, title, type_, poly_degree=3):
    """
    Plots the abundance over time for the specified species for all patients in the dataframe,
    with separate plots for patients with and without a flare start, all in a single figure with subplots.
    Includes a trend line for each plot using polynomial regression.

    Args:
    df (pd.DataFrame): DataFrame containing microbial abundance data.
    cols_to_plot (list): List of species names to plot.
    title (str): Title for the plots and y-axis label.
    type_ (str): Label for the plot legend.
    poly_degree (int): Degree of the polynomial for the trend line. Default is 3.
    """
    if not cols_to_plot:
        raise ValueError("cols_to_plot must be provided when selection_mode is 'user'")

    num_species = len(cols_to_plot)
    cols = 2  # Number of columns for subplots (one for with flare, one for without)
    rows = num_species  # Each species will occupy one row with two subplots

    fig, axs = plt.subplots(rows, cols, figsize=(20, 5 * rows))  # Adjust figsize as needed

    # If only one species, axs might not be a 2D array, ensure it is
    if num_species == 1:
        axs = [axs]

    # Determine patients where the first recorded week is a flare start
    df['first_week'] = df.groupby('patient_id')['week'].transform('min')  # Find first week for each patient
    patients_with_flare = df[df['Flare_start'] == 1]['patient_id'].unique()
    patients_without_flare = df[~df['patient_id'].isin(patients_with_flare)]['patient_id'].unique()

    for i, species in enumerate(cols_to_plot):
        for j, (patients, title_suffix) in enumerate([(patients_with_flare, "with Flare"), (patients_without_flare, "without Flare")]):
            ax = axs[i][j]

            all_weeks = []
            all_abundances = []

            for patient in sorted(patients):  # Sort patients for consistent order
                patient_data = df[df['patient_id'] == patient].copy()
                patient_data.sort_values(by='week', inplace=True)  # Ensure data is sorted by week

                if species in patient_data.columns:
                    sns.lineplot(data=patient_data, x='week', y=species, ax=ax, label=f'Patient {patient}')

                    all_weeks.extend(patient_data['week'])
                    all_abundances.extend(patient_data[species])

                if j == 0:  # Only for patients with flare
                    flare_starts = patient_data[patient_data['Flare_start'] == 1]
                    ax.scatter(flare_starts['week'], flare_starts[species],
                               color='red', s=10, zorder=5, label='Flare Start' if patient == patients_with_flare[0] else "")

            # # Add trend line using polynomial regression
            # if all_weeks and all_abundances:
            #     z = np.polyfit(all_weeks, all_abundances, poly_degree)
            #     p = np.poly1d(z)
            #     x_range = np.linspace(min(all_weeks), max(all_weeks), 100)
            #     ax.plot(x_range, p(x_range),  label='Trend line', color= "red", linestyle ="--")

            ax.set_title(f'{title} ({species}) Over Time for Patients {title_suffix}')
            ax.set_xlabel('Week')
            ax.set_ylabel(title)

            # Remove axes spines
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(True)
            ax.spines['bottom'].set_visible(True)

            # Create legend with no box and sorted alphabetically
            handles, labels = ax.get_legend_handles_labels()
            by_label = dict(zip(labels, handles))
            sorted_labels = sorted(by_label.keys())
            ax.legend(title=type_, bbox_to_anchor=(1.05, 1), loc='upper left',
                      handles=[by_label[l] for l in sorted_labels],
                      labels=sorted_labels, frameon=False)

            major_ticks = range(0, max(all_weeks) + 1, 5)  # Get unique weeks
            ax.set_xticks(major_ticks)
            ax.set_xticklabels(major_ticks, rotation=45)
            ax.grid(False)

    plt.tight_layout()
    plt.show()

In [None]:
# @title segment_time_series
def segment_time_series(df, length=6):
    # Sort by patient and week to ensure chronological order
    df.sort_values(by=['patient_id', 'week'], inplace=True)

    # Create 'ts' column initialized to NaN
    df['ts'] = np.nan

    # Process each patient individually
    for patient in df['patient_id'].unique():
        patient_data = df[df['patient_id'] == patient]

        # Check for the presence of a flare start
        flare_start_indices = patient_data[patient_data['Flare_start'] == 1].index.tolist()

        # Initial segment number
        ts_value = 1

        if flare_start_indices:
            # There is at least one flare start
            flare_start_index = flare_start_indices[0]  # Assume first flare start is the one to consider
            # Calculate the position of the first flare start
            flare_position = patient_data.index.get_loc(flare_start_index)

            # Calculate the length of the first segment to align flare start with a new segment start
            if flare_position % length == 0:
                first_segment_length = 0  # Flare start is already aligned with a new segment
            else:
                first_segment_length = flare_position % length

            # Assign time series IDs
            # Segment before flare start if it has any length
            if first_segment_length > 0:
                df.loc[patient_data.index[:first_segment_length], 'ts'] = ts_value
                ts_value += 1

            # Remaining data after the initial adjustment
            for i in range(first_segment_length, len(patient_data), length):
                df.loc[patient_data.index[i:i+length], 'ts'] = ts_value
                ts_value += 1

        else:
            # No flare, segment normally every x points
            for i in range(0, len(patient_data), length):
                df.loc[patient_data.index[i:i+length], 'ts'] = ts_value
                ts_value += 1

    # Convert 'ts' column to integer
    df['ts'] = df['ts'].astype(int)

    # Create a new column combining 'ts' and 'patient_id'
    df['ts_patient_id'] = df['patient_id'] + '_TS' + df['ts'].astype(str)

    return df


In [None]:
# @title Augmented Dickey-Fuller (ADF) test
def analyze_time_series(data, significance_level=0.05, mht="bonferroni", id_col="ts_patient_id"):
    #flare_patients = data[data['Flare_status']=="During_flare"]['patient_id'].unique()
    target_ids = data[data["RBF"]==True]['ts_patient_id'].unique()
    print(f'Target ids : {target_ids}')

    # Suppress runtime warnings from statsmodels or other calculations
    warnings.filterwarnings('ignore', category=RuntimeWarning)

    filtered_df = data.groupby('ts_patient_id').filter(lambda x: len(x) > 4)
    print(exclude_columns)

    # Try to concatenate directly
    col_to_exclude = exclude_columns.tolist() + ['ts', 'RBF', 'ts_patient_id']


    # Define columns to analyze (excluding metadata columns)
    analyze_columns = [col for col in filtered_df.columns if col not in col_to_exclude]

    # Prepare to store results
    adf_results = []

    # Perform ADF tests
    for id in filtered_df[id_col].unique():
        patient_data = filtered_df[filtered_df[id_col] == id]
        for column in analyze_columns:
            time_series = patient_data[column]
            if time_series.notna().all() and len(time_series) > 1 and time_series.var() > 0:
                try:
                    result = adfuller(time_series)
                    adf_results.append({
                        'id': id,
                        'variable': column,
                        'ADF Statistic': result[0],
                        'p-value': result[1],
                        'Used lag': result[2],
                        'Number of observations': result[3],
                        'Critical values': result[4],
                        'IC used': result[5]
                    })
                except Exception as e:
                    adf_results.append({
                        'id': id,
                        'variable': column,
                        'Error': str(e)
                    })
            else:
                adf_results.append({
                    'id': id,
                    'variable': column,
                    'Error': 'Insufficient data or variance'
                })

    # Convert results to DataFrame
    results_df = pd.DataFrame(adf_results)
    results_df['p-value'] = pd.to_numeric(results_df['p-value'], errors='coerce')

    # Apply multiple testing corrections
    _, bh_adjusted, _, _ = multipletests(results_df['p-value'].dropna(), method='fdr_bh')
    _, bonf_adjusted, _, _ = multipletests(results_df['p-value'].dropna(), method='bonferroni')

    results_df['bh_adjusted_p-value'] = pd.Series(bh_adjusted, index=results_df['p-value'].dropna().index)
    results_df['bonferroni_adjusted_p-value'] = pd.Series(bonf_adjusted, index=results_df['p-value'].dropna().index)

    # Filter significant results
    if mht == "bonferroni":
      significant_results = results_df[results_df['bonferroni_adjusted_p-value'] < significance_level]
    elif mht == "bh" :
      significant_results = results_df[results_df['bh_adjusted_p-value'] < significance_level]

    # Print results
    significant_counts = significant_results.groupby('id').size()
    print("Summary of significant adjusted p-value counts by sub-timeseries:")
    print(significant_counts.sort_values(ascending=False))

    for id, count in significant_counts.items():
        print(f"Patient ID: {id}")
        print(f"Number of variables with significant adjusted p-values: {count}")
        patient_vars = significant_results[significant_results['id'] == id]
        #for index, row in patient_vars.iterrows():
         # if mht == "bonferroni":
          #  print(f"  Variable: {row['variable']}, Adjusted p-value: {row['bonferroni_adjusted_p-value']:.5f}")
          #elif mht == "bh" :
           # print(f"  Variable: {row['variable']}, Adjusted p-value: {row['bh_adjusted_p-value']:.5f}")
        #print("\n")  # Add a new line for readability


    #target_ids= ['TR_2101_TS4', 'TR_2102_TS2']

    # Find and print variables significant across specific target patient IDs
    target_significant = significant_results[significant_results['id'].isin(target_ids)]
    common_significant_vars = target_significant.groupby('variable').filter(lambda x: len(x) >= 0.5*len(target_ids))

    if not common_significant_vars.empty:
        print("\nVariables significant across at least 50% of target patient IDs:")
        for variable in common_significant_vars['variable'].unique():
            print(f"Variable: {variable}")
            relevant_patients = significant_results[significant_results['variable'] == variable]
            other_patients = relevant_patients[~relevant_patients['id'].isin(target_ids)]
            print(f"Significant in target IDs: {', '.join(relevant_patients['id'].unique())}")
            if not other_patients.empty:
                print(f"Also significant in other {len(other_patients['id'].unique())} portions")
            else:
                print("Not significant in any other patient IDs.")
    else:
        print("No variables are significant across all specified target patient IDs.")

    print("\n")  # Add a new line for readability

In [None]:
# @title smooth_patient_data
def smooth_patient_data(df, value_columns, patient_column='patient_id', week_column='week'):
    """
    Smooths the values of the specified columns in the dataframe using a rolling window average.

    Args:
    df (pd.DataFrame): DataFrame containing the data to be smoothed.
    value_columns (list): List of columns to smooth.
    patient_column (str): Column name for patient ID.
    week_column (str): Column name for week number.

    Returns:
    pd.DataFrame: Smoothed DataFrame.
    """
    df = df.copy()
    smoothed_dfs = []

    for patient in df[patient_column].unique():
        patient_data = df[df[patient_column] == patient].copy()
        patient_data = patient_data.sort_values(by=week_column)

        for col in value_columns:
            smoothed_values = patient_data[col].rolling(window=3, min_periods=1, center=True).mean()
            # Handle first and last weeks separately
            if len(smoothed_values) > 1:
                smoothed_values.iloc[0] = patient_data[col].iloc[:2].mean()
                smoothed_values.iloc[-1] = patient_data[col].iloc[-2:].mean()
            patient_data[col] = smoothed_values

        smoothed_dfs.append(patient_data)

    smoothed_df = pd.concat(smoothed_dfs)
    return smoothed_df


In [None]:
# @title
species_columns = [col for col in metaphlan_df.columns if col.startswith('s__')]

# Group by 'patient_id' and calculate the mean of each species
mean_abundances = metaphlan_df.groupby('patient_id')[species_columns].mean()


# and print the results for each patient
for patient_id, abundances in mean_abundances.iterrows():
    print(f"Patient {patient_id} has the following species with an average relative abundance above 0.07:")
    filtered_species = abundances[abundances > 0.07 ]
    if not filtered_species.empty:
        print(filtered_species)
    else:
        print("No species above the threshold.")
    print()

In [None]:
# @title Plot Microbial abundance profiles
combinations = []

# Loop until we have 10 good combinations
while len(combinations) < 16:
    # Randomly select a patient_id and column name
    patient_id = metaphlan_df['patient_id'].sample().values[0]
    column_name = metaphlan_df.columns[np.random.randint(1, len(metaphlan_df.columns))]

    # Check if the values for the selected patient_id and column are not all zeros
    if not all(metaphlan_df[metaphlan_df['patient_id'] == patient_id][column_name] == 0):
        # Add the combination to the list
        combinations.append((patient_id, column_name))

# Print the list of good combinations
print(combinations)


# Define the number of rows and columns for the subplot grid
nrows = 4
ncols = 4

# Create the figure and subplots
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 12))
sns.set(style="whitegrid")
# Iterate through the list of combinations
for i, (patient_id, column_name) in enumerate(combinations):
    # Select the current subplot
    ax = axes.flat[i]

    # Plot the histogram of values
    sns.histplot(data=metaphlan_df[metaphlan_df['patient_id'] == patient_id][column_name], ax=ax, bins = 30, color='skyblue', kde=True)

    # Set the title of the subplot
    ax.set_title(f'{patient_id} - {column_name}')


# Adjust the spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
exclude_columns = metaphlan_df.columns.difference(species_columns)

In [None]:
# # @title CLR transformation for Microbial Composition
# def clr_transform(df, metadata_columns):
#     species_data = df.drop(columns=metadata_columns)
#     species_data += 1e-5
#     geometric_means = gmean(species_data, axis=1)
#     clr_transformed_data = np.log(species_data.divide(geometric_means, axis=0))
#     clr_transformed_df = pd.concat([df[metadata_columns], clr_transformed_data], axis=1)
#     return clr_transformed_df

# metaphlan_df = clr_transform(metaphlan_df, exclude_columns)

In [None]:
# @title Plot Microbial Abundance Profiles after CLR transformation
# Create the figure and subplots
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 12))
sns.set(style="whitegrid")

# Iterate through the list of combinations
for i, (patient_id, column_name) in enumerate(combinations):
    # Select the current subplot
    ax = axes.flat[i]

    # Plot the histogram of values
    sns.histplot(data=metaphlan_df[metaphlan_df['patient_id'] == patient_id][column_name], ax=ax, bins = 30,  color='skyblue', kde=True)

    # Set the title of the subplot
    ax.set_title(f'{patient_id} - {column_name}')


# Adjust the spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
# @title Adding a Flare_start column
metaphlan_df = flag_first_flare_weeks(metaphlan_df)
gene_count_df = flag_first_flare_weeks(gene_count_df)
exclude_columns+= ['Flare_start']

In [None]:
analyze_representation_and_variation(metaphlan_df, 5, 'Bacteria')
analyze_representation_and_variation(gene_count_df, 5, 'Genes')

Patient 2102 had a flair from week 16 to 24
Patient 2101 had a flair from week 25 to 43
Patient 2103 had a flair from week 27 to 39

In [None]:
# @title PCOA functions
def perform_pcoa(data, columns_to_exclude, patient_id_column='patient_id', week_column='week', flare_status_column='Flare_status'):
    """
    Performs PCoA on the given data, excluding specified columns.

    Parameters:
    data (pd.DataFrame): The dataframe containing the data.
    columns_to_exclude (list of str): List of columns to exclude from the analysis.
    patient_id_column (str): The column name for patient ID. Default is 'patient_id'.
    week_column (str): The column name for week. Default is 'week'.
    flare_status_column (str): The column name for flare status. Default is 'Flare_status'.

    Returns:
    pd.DataFrame: A dataframe containing the PCoA results with patient ID, week, and flare status.
    """
    data_filtered = data.drop(columns=columns_to_exclude)

    pca = PCA(n_components=2)
    pcoa_results = pca.fit_transform(data_filtered)

    pcoa_df = pd.DataFrame(pcoa_results, columns=['PC1', 'PC2'])
    pcoa_df[patient_id_column] = data[patient_id_column].values
    pcoa_df[week_column] = data[week_column].values
    pcoa_df[flare_status_column] = data[flare_status_column].values

    return pcoa_df


def create_combined_pcoa_plots(microbial_data, gene_expression_data, patient_ids):
    """
    Creates 3D PCoA plots for microbial abundances and gene expression for all patients in a single figure,
    with rows of 4 plots.

    Parameters:
    microbial_data (pd.DataFrame): The dataframe containing microbial abundances.
    gene_expression_data (pd.DataFrame): The dataframe containing gene expression data.
    patient_ids (list of str): List of patient IDs to create plots for.

    Returns:
    None
    """
    n_patients = len(patient_ids)
    n_rows = (n_patients + 1) // 2  # Ceil division to get number of rows
    fig = plt.figure(figsize=(20, 3*n_rows))

    for i, patient_id in enumerate(patient_ids):
        # Filter data for the current patient
        microbial_patient_data = microbial_data[microbial_data['patient_id'] == patient_id]
        gene_expression_patient_data = gene_expression_data[gene_expression_data['patient_id'] == patient_id]

        microbial_pcoa_df = perform_pcoa(microbial_patient_data, exclude_columns)
        gene_expression_pcoa_df = perform_pcoa(gene_expression_patient_data, exclude_columns)

        # Plot 3D PCoA for microbial data
        ax1 = fig.add_subplot(n_rows, 4, 2*i + 1, projection='3d')
        plot_pcoa_3d_subplot(microbial_pcoa_df, patient_id, 'Microbial Abundances', ax1)

        # Plot 3D PCoA for gene expression data
        ax2 = fig.add_subplot(n_rows, 4, 2*i + 2, projection='3d')
        plot_pcoa_3d_subplot(gene_expression_pcoa_df, patient_id, 'Gene Expression', ax2)

    plt.tight_layout()
    plt.show()

def plot_pcoa_3d_subplot(pcoa_df, patient_id, title, ax, flare_status_column='Flare_status'):
    """
    Plots 3D PCoA results on a given subplot.

    Parameters:
    pcoa_df (pd.DataFrame): The dataframe containing the PCoA results.
    patient_id (str): The patient ID for the plot title.
    title (str): The title for the plot.
    ax (matplotlib.axes.Axes): The subplot to plot on.
    flare_status_column (str): The column name for flare status. Default is 'Flare_status'.

    Returns:
    None
    """
    flare_status_mapping = {
        'Pre_flare': '#e78ac3',
        'Pre_flare_two': '#e78ac3',
        'During_flare': '#ff7f00',
        'During_flare_2': '#ff7f00',
        'Post_flare': '#8da0cb',
        'Post_flare_2': '#8da0cb',
        'No_flare': '#66c2a5',
        'Before_infusion': 'pink',
        'After_infusion': 'orange'
    }

    pcoa_df['Flare_status_color'] = pcoa_df[flare_status_column].map(flare_status_mapping)

    scatter = ax.scatter(pcoa_df['week'], pcoa_df['PC1'], pcoa_df['PC2'], c=pcoa_df['Flare_status_color'], s=20, alpha=0.7)

    ax.set_xlabel('Week')
    ax.set_ylabel('PC1')
    ax.set_zlabel('PC2')
    ax.set_title(f'Patient {patient_id} - {title}')

    # Create a custom legend for elements present in the plot
    unique_flare_statuses = pcoa_df[flare_status_column].unique()
    legend_handles = [mpatches.Patch(color=flare_status_mapping[status], label=status) for status in unique_flare_statuses if status in flare_status_mapping]
    ax.legend(handles=legend_handles, title='Flare Status', bbox_to_anchor=(1.05, 1), loc='upper left')



In [None]:
exclude_columns = metaphlan_df.columns.difference(species_columns)

In [None]:
# Use the function
patient_ids = metaphlan_df['patient_id'].unique()
create_combined_pcoa_plots(metaphlan_df, gene_count_df, patient_ids)

In [None]:

extended_colors = [
    '#66c2a5',  # Teal
    '#ff7f00',  # Orange
    '#8da0cb',  # Light Blue
    '#e78ac3',  # Pink
    '#a6d854',  # Light Green
    '#e5c494',  # Tan
    '#b3b3b3',  # Gray
    '#80b1d3',  # Sky Blue
    '#bebada',  # Lavender
    '#fb8072',  # Light Red
    '#7fc97f',  # Medium Green
    '#d9d9d9',  # Light Gray
    '#ffed6f',  # Pale Yellow
    '#b15928',  # Brown
    '#fccde5',  # Pale Pink
    '#8dd3c7',  # Pale Teal
    '#ffffb3',  # Very Pale Yellow
    '#bc80bd'   # Purple
]

# Set default color cycle
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=extended_colors)

# Set other default styles
#plt.rcParams['figure.facecolor'] = '#f0f0f0'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['grid.color'] = '#cccccc'
plt.rcParams['grid.linestyle'] = '--'
plt.rcParams['grid.alpha'] = 0.5

# Adjust grid frequency
plt.rcParams['axes.grid'] = False
plt.rcParams['axes.grid.which'] = 'major'
plt.rcParams['axes.grid.axis'] = 'both'


# Set default line width
plt.rcParams['lines.linewidth'] = 1

In [None]:
plot_species_abundance_over_time(metaphlan_df, None, 'Microbial Abundance', 'Bacteria', 'mean')

In [None]:
plot_species_abundance_over_time(metaphlan_df, None, 'Microbial Abundance', 'Bacteria', 'variation')

In [None]:
plot_species_abundance_over_time(gene_count_df, None, 'Gene Expression', 'Genes', 'mean')

In [None]:
plot_species_abundance_over_time(gene_count_df, None, 'Gene Expression', 'Genes', 'variation')

In [None]:
x = ['s__Blautia_SGB4805', 's__Blautia_SGB4815', 's__Blautia_SGB4831',
       's__Blautia_argi', 's__Blautia_caecimuris', 's__Blautia_faecicola',
       's__Blautia_faecis', 's__Blautia_glucerasea', 's__Blautia_hansenii',
       's__Blautia_hydrogenotrophica', 's__Blautia_massiliensis',
       's__Blautia_obeum', 's__Blautia_producta', 's__Blautia_sp_AF19_10LB',
       's__Blautia_sp_An249', 's__Blautia_sp_An81', 's__Blautia_sp_MSK_20_85',
       's__Blautia_sp_MSK_21_1', 's__Blautia_sp_OF03_15BH',
       's__Blautia_stercoris', 's__Blautia_wexlerae']+['s__Phocaeicola_dorei','s__Bacteroides_uniformis','s__Faecalibacterium_prausnitzii' , 's__Bifidobacterium_pullorum',
  's__Bacteroides_ovatus', 's__Clostridia_bacterium', 's__Eubacterium_rectale', 's__Blautia_sp_MSK_21_1',
  's__GGB4569_SGB6310', 's__Blautia_wexlerae']

x.sort()
x = list(dict.fromkeys(x))
x

In [None]:
# @title Remove dominant species
# # Step 1: Compute mean relative abundances for each species per patient
# species_columns = [col for col in metaphlan_df.columns if col.startswith('s__')]
# mean_abundances = metaphlan_df.groupby('patient_id')[species_columns].mean()

# # Step 2: Identify species to zero out
# species_to_zero = mean_abundances.apply(lambda x: x > 0.07)

# # Step 3: Set values to zero for identified species
# for patient in species_to_zero.index:
#     for species in species_columns:
#         if species_to_zero.at[patient, species]:
#             metaphlan_df.loc[metaphlan_df['patient_id'] == patient, species] = 0

# # Step 4: Normalize the DataFrame so that each row sums to 1
# metaphlan_df[species_columns] = metaphlan_df[species_columns].div(metaphlan_df[species_columns].sum(axis=1), axis=0)

microbes to investigate

In [None]:
# @title assign_rbf
def assign_rbf(df, samples_before_flare, strict=False):
    # Initial setup
    df = df.sort_values(by=['patient_id', 'week'])
    df['RBF'] = 0  # Initialize the RBF column

    # Identify patients and all weeks where flares start
    flare_starts = df[df['Flare_start'] == 1].groupby('patient_id')['week'].apply(list)

    for patient, flare_weeks in flare_starts.items():
        # Handle each flare start week
        for flare_week in flare_weeks:
            # Get all weeks for the current patient
            patient_weeks = df.loc[df['patient_id'] == patient, 'week']

            if strict:
                # Find indexes for weeks strictly before this flare start and within the limit
                weeks_before_flare = patient_weeks[(patient_weeks < flare_week) & (patient_weeks >= flare_week - samples_before_flare)]
            else:
                # Find indexes for weeks before this flare start
                weeks_before_flare = patient_weeks[patient_weeks < flare_week].nlargest(samples_before_flare)


            # Set RBF to 1 for these weeks
            df.loc[(df['patient_id'] == patient) & (df['week'].isin(weeks_before_flare)), 'RBF'] = 1

    return df

In [None]:
metaphlan_df = assign_rbf(metaphlan_df, 6, strict=False)
gene_count_df = assign_rbf(gene_count_df, 6, strict=False)

In [None]:
# @title Segment time series
# Apply the function to both dataframes
metaphlan_df = segment_time_series(metaphlan_df, 6)
gene_count_df = segment_time_series(gene_count_df, 6)

# Check the results
print(metaphlan_df[['patient_id', 'week', 'Flare_status', 'Flare_start', 'ts', 'ts_patient_id', 'RBF']])
print(gene_count_df[['patient_id', 'week', 'Flare_status', 'Flare_start', 'ts', 'ts_patient_id','RBF']])

In [None]:
analyze_time_series(metaphlan_df, 0.1, "bh")

In [None]:
# @title copy paste output here to process it
import re

output = """
Patient ID: TR_2101_TS2
Number of variables with significant adjusted p-values: 44
Patient ID: TR_2101_TS3
Number of variables with significant adjusted p-values: 61
Patient ID: TR_2101_TS4
Number of variables with significant adjusted p-values: 52
Patient ID: TR_2101_TS5
Number of variables with significant adjusted p-values: 51
Patient ID: TR_2101_TS6
Number of variables with significant adjusted p-values: 40
Patient ID: TR_2101_TS7
Number of variables with significant adjusted p-values: 33
Patient ID: TR_2101_TS8
Number of variables with significant adjusted p-values: 50
Patient ID: TR_2102_TS1
Number of variables with significant adjusted p-values: 56
Patient ID: TR_2102_TS2
Number of variables with significant adjusted p-values: 58
Patient ID: TR_2102_TS3
Number of variables with significant adjusted p-values: 69
Patient ID: TR_2102_TS4
Number of variables with significant adjusted p-values: 62
Patient ID: TR_2102_TS5
Number of variables with significant adjusted p-values: 75
Patient ID: TR_2102_TS6
Number of variables with significant adjusted p-values: 37
Patient ID: TR_2102_TS7
Number of variables with significant adjusted p-values: 62
Patient ID: TR_2103_TS1
Number of variables with significant adjusted p-values: 83
Patient ID: TR_2103_TS2
Number of variables with significant adjusted p-values: 90
Patient ID: TR_2103_TS3
Number of variables with significant adjusted p-values: 48
Patient ID: TR_2103_TS4
Number of variables with significant adjusted p-values: 45
Patient ID: TR_2103_TS5
Number of variables with significant adjusted p-values: 26
Patient ID: TR_2103_TS6
Number of variables with significant adjusted p-values: 49
Patient ID: TR_2103_TS7
Number of variables with significant adjusted p-values: 42
Patient ID: TR_2104_TS1
Number of variables with significant adjusted p-values: 55
Patient ID: TR_2104_TS2
Number of variables with significant adjusted p-values: 127
Patient ID: TR_2104_TS3
Number of variables with significant adjusted p-values: 127
Patient ID: TR_2104_TS4
Number of variables with significant adjusted p-values: 75
Patient ID: TR_2104_TS5
Number of variables with significant adjusted p-values: 81
Patient ID: TR_2104_TS6
Number of variables with significant adjusted p-values: 62
Patient ID: TR_2104_TS7
Number of variables with significant adjusted p-values: 75
Patient ID: TR_2104_TS8
Number of variables with significant adjusted p-values: 63
Patient ID: TR_2105_TS1
Number of variables with significant adjusted p-values: 61
Patient ID: TR_2105_TS2
Number of variables with significant adjusted p-values: 50
Patient ID: TR_2105_TS3
Number of variables with significant adjusted p-values: 75
Patient ID: TR_2105_TS4
Number of variables with significant adjusted p-values: 54
Patient ID: TR_2105_TS5
Number of variables with significant adjusted p-values: 53
Patient ID: TR_2105_TS6
Number of variables with significant adjusted p-values: 102
Patient ID: TR_2105_TS7
Number of variables with significant adjusted p-values: 83
Patient ID: TR_2105_TS8
Number of variables with significant adjusted p-values: 71
Patient ID: TR_2106_TS1
Number of variables with significant adjusted p-values: 40
Patient ID: TR_2106_TS2
Number of variables with significant adjusted p-values: 58
Patient ID: TR_2106_TS3
Number of variables with significant adjusted p-values: 57
Patient ID: TR_2106_TS4
Number of variables with significant adjusted p-values: 49
Patient ID: TR_2106_TS5
Number of variables with significant adjusted p-values: 57
Patient ID: TR_2106_TS6
Number of variables with significant adjusted p-values: 47
Patient ID: TR_2106_TS7
Number of variables with significant adjusted p-values: 83
Patient ID: TR_2106_TS8
Number of variables with significant adjusted p-values: 47
Patient ID: TR_2107_TS2
Number of variables with significant adjusted p-values: 66
Patient ID: TR_2107_TS3
Number of variables with significant adjusted p-values: 79
Patient ID: TR_2107_TS4
Number of variables with significant adjusted p-values: 46
Patient ID: TR_2107_TS5
Number of variables with significant adjusted p-values: 63
Patient ID: TR_2107_TS6
Number of variables with significant adjusted p-values: 83
Patient ID: TR_2107_TS7
Number of variables with significant adjusted p-values: 77
Patient ID: TR_2108_TS1
Number of variables with significant adjusted p-values: 44
Patient ID: TR_2108_TS2
Number of variables with significant adjusted p-values: 48
Patient ID: TR_2108_TS3
Number of variables with significant adjusted p-values: 52
Patient ID: TR_2108_TS4
Number of variables with significant adjusted p-values: 64
Patient ID: TR_2201_TS1
Number of variables with significant adjusted p-values: 56
Patient ID: TR_2201_TS2
Number of variables with significant adjusted p-values: 69
Patient ID: TR_2201_TS3
Number of variables with significant adjusted p-values: 74
Patient ID: TR_2201_TS4
Number of variables with significant adjusted p-values: 73
Patient ID: TR_2201_TS5
Number of variables with significant adjusted p-values: 49
Patient ID: TR_2201_TS6
Number of variables with significant adjusted p-values: 49
Patient ID: TR_2201_TS7
Number of variables with significant adjusted p-values: 69
Patient ID: TR_2202_TS1
Number of variables with significant adjusted p-values: 39
Patient ID: TR_2202_TS2
Number of variables with significant adjusted p-values: 36
Patient ID: TR_2202_TS3
Number of variables with significant adjusted p-values: 22
Patient ID: TR_2202_TS4
Number of variables with significant adjusted p-values: 40
Patient ID: TR_2202_TS5
Number of variables with significant adjusted p-values: 51
Patient ID: TR_2202_TS6
Number of variables with significant adjusted p-values: 38
Patient ID: TR_2202_TS7
Number of variables with significant adjusted p-values: 31
Patient ID: TR_2202_TS8
Number of variables with significant adjusted p-values: 36
Patient ID: TR_2203_TS1
Number of variables with significant adjusted p-values: 33
Patient ID: TR_2203_TS2
Number of variables with significant adjusted p-values: 49
Patient ID: TR_2203_TS3
Number of variables with significant adjusted p-values: 58
Patient ID: TR_2203_TS4
Number of variables with significant adjusted p-values: 27
Patient ID: TR_2205_TS1
Number of variables with significant adjusted p-values: 99
"""


In [None]:
patients  = []
ts  = []
significant_p_values = []

# Use regular expressions to extract the information
pattern = re.compile(r"Patient ID: TR_(\d+)_TS(\d+)\s+Number of variables with significant adjusted p-values: (\d+)")

matches = pattern.findall(output)

for match in matches:
    patients_, ts_, significant_p_values_ = match
    patients.append(int(patients_))
    ts.append(int(ts_))
    significant_p_values.append(int(significant_p_values_))

In [None]:
highlight_set = set(metaphlan_df[metaphlan_df["RBF"]==True]['ts_patient_id'].unique())

# Generate combinations of Patient and TS
combinations = ['TR_' + str(p) + '_TS' + str(t) for p, t in zip(patients, ts)]

# Get unique patients
unique_patients = sorted(set(patients))

# Calculate the number of rows and columns for subplots
n = len(unique_patients)
cols = 3
rows = math.ceil(n / cols)

# Create the plot
fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))

for i, patient in enumerate(unique_patients):
    row = i // cols
    col = i % cols
    ax = axes[row, col] if rows > 1 else axes[col]

    # Filter data for the current patient
    patient_indices = [j for j, p in enumerate(patients) if p == patient]
    patient_ts = [ts[j] for j in patient_indices]
    patient_significant_p_values = [significant_p_values[j] for j in patient_indices]
    patient_combinations = [combinations[j] for j in patient_indices]

    # Plot the data
    ax.plot(patient_ts, patient_significant_p_values, label=f'Patient {patient}')

    # Highlight specific combinations in red
    for t, sp, comb in zip(patient_ts, patient_significant_p_values, patient_combinations):
        if comb in highlight_set:
            ax.plot(t, sp, 'ro')

    ax.set_xlabel('Batch')
    ax.set_ylabel('Number of non stationnary species')
    ax.set_title(f'Patient {patient}')
    ax.grid(True, linestyle='--', alpha=0.7)

# Remove empty subplots
for i in range(n, rows*cols):
    row = i // cols
    col = i % cols
    fig.delaxes(axes[row, col] if rows > 1 else axes[col])

plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.show()


In [None]:
analyze_time_series(gene_count_df, 0.1, "bh")

In [None]:
# @title copy paste output here

output = """
Length: 97, dtype: int64
Patient ID: TR_2101_TS2
Number of variables with significant adjusted p-values: 162
Patient ID: TR_2101_TS3
Number of variables with significant adjusted p-values: 119
Patient ID: TR_2101_TS4
Number of variables with significant adjusted p-values: 120
Patient ID: TR_2101_TS5
Number of variables with significant adjusted p-values: 125
Patient ID: TR_2101_TS6
Number of variables with significant adjusted p-values: 147
Patient ID: TR_2101_TS7
Number of variables with significant adjusted p-values: 66
Patient ID: TR_2101_TS8
Number of variables with significant adjusted p-values: 98
Patient ID: TR_2102_TS1
Number of variables with significant adjusted p-values: 173
Patient ID: TR_2102_TS2
Number of variables with significant adjusted p-values: 124
Patient ID: TR_2102_TS3
Number of variables with significant adjusted p-values: 102
Patient ID: TR_2102_TS4
Number of variables with significant adjusted p-values: 127
Patient ID: TR_2102_TS5
Number of variables with significant adjusted p-values: 133
Patient ID: TR_2102_TS6
Number of variables with significant adjusted p-values: 100
Patient ID: TR_2102_TS7
Number of variables with significant adjusted p-values: 46
Patient ID: TR_2103_TS2
Number of variables with significant adjusted p-values: 107
Patient ID: TR_2103_TS3
Number of variables with significant adjusted p-values: 107
Patient ID: TR_2103_TS4
Number of variables with significant adjusted p-values: 91
Patient ID: TR_2103_TS5
Number of variables with significant adjusted p-values: 165
Patient ID: TR_2103_TS6
Number of variables with significant adjusted p-values: 126
Patient ID: TR_2103_TS7
Number of variables with significant adjusted p-values: 135
Patient ID: TR_2103_TS8
Number of variables with significant adjusted p-values: 165
Patient ID: TR_2104_TS1
Number of variables with significant adjusted p-values: 73
Patient ID: TR_2104_TS2
Number of variables with significant adjusted p-values: 150
Patient ID: TR_2104_TS3
Number of variables with significant adjusted p-values: 78
Patient ID: TR_2104_TS4
Number of variables with significant adjusted p-values: 101
Patient ID: TR_2104_TS5
Number of variables with significant adjusted p-values: 119
Patient ID: TR_2104_TS6
Number of variables with significant adjusted p-values: 147
Patient ID: TR_2104_TS7
Number of variables with significant adjusted p-values: 122
Patient ID: TR_2104_TS8
Number of variables with significant adjusted p-values: 142
Patient ID: TR_2105_TS1
Number of variables with significant adjusted p-values: 84
Patient ID: TR_2105_TS2
Number of variables with significant adjusted p-values: 111
Patient ID: TR_2105_TS3
Number of variables with significant adjusted p-values: 126
Patient ID: TR_2105_TS4
Number of variables with significant adjusted p-values: 89
Patient ID: TR_2105_TS5
Number of variables with significant adjusted p-values: 99
Patient ID: TR_2105_TS6
Number of variables with significant adjusted p-values: 115
Patient ID: TR_2105_TS7
Number of variables with significant adjusted p-values: 150
Patient ID: TR_2105_TS8
Number of variables with significant adjusted p-values: 103
Patient ID: TR_2106_TS1
Number of variables with significant adjusted p-values: 165
Patient ID: TR_2106_TS2
Number of variables with significant adjusted p-values: 165
Patient ID: TR_2106_TS3
Number of variables with significant adjusted p-values: 152
Patient ID: TR_2106_TS4
Number of variables with significant adjusted p-values: 173
Patient ID: TR_2106_TS5
Number of variables with significant adjusted p-values: 115
Patient ID: TR_2106_TS6
Number of variables with significant adjusted p-values: 145
Patient ID: TR_2106_TS7
Number of variables with significant adjusted p-values: 126
Patient ID: TR_2106_TS8
Number of variables with significant adjusted p-values: 161
Patient ID: TR_2107_TS2
Number of variables with significant adjusted p-values: 182
Patient ID: TR_2107_TS3
Number of variables with significant adjusted p-values: 180
Patient ID: TR_2107_TS4
Number of variables with significant adjusted p-values: 103
Patient ID: TR_2107_TS5
Number of variables with significant adjusted p-values: 108
Patient ID: TR_2107_TS6
Number of variables with significant adjusted p-values: 130
Patient ID: TR_2107_TS7
Number of variables with significant adjusted p-values: 145
Patient ID: TR_2107_TS8
Number of variables with significant adjusted p-values: 173
Patient ID: TR_2108_TS1
Number of variables with significant adjusted p-values: 159
Patient ID: TR_2108_TS2
Number of variables with significant adjusted p-values: 196
Patient ID: TR_2108_TS3
Number of variables with significant adjusted p-values: 161
Patient ID: TR_2108_TS4
Number of variables with significant adjusted p-values: 128
Patient ID: TR_2108_TS5
Number of variables with significant adjusted p-values: 177
Patient ID: TR_2108_TS6
Number of variables with significant adjusted p-values: 98
Patient ID: TR_2201_TS1
Number of variables with significant adjusted p-values: 94
Patient ID: TR_2201_TS2
Number of variables with significant adjusted p-values: 136
Patient ID: TR_2201_TS3
Number of variables with significant adjusted p-values: 142
Patient ID: TR_2201_TS4
Number of variables with significant adjusted p-values: 140
Patient ID: TR_2201_TS5
Number of variables with significant adjusted p-values: 145
Patient ID: TR_2201_TS6
Number of variables with significant adjusted p-values: 76
Patient ID: TR_2201_TS7
Number of variables with significant adjusted p-values: 111
Patient ID: TR_2201_TS8
Number of variables with significant adjusted p-values: 61
Patient ID: TR_2202_TS1
Number of variables with significant adjusted p-values: 108
Patient ID: TR_2202_TS2
Number of variables with significant adjusted p-values: 104
Patient ID: TR_2202_TS3
Number of variables with significant adjusted p-values: 94
Patient ID: TR_2202_TS4
Number of variables with significant adjusted p-values: 138
Patient ID: TR_2202_TS5
Number of variables with significant adjusted p-values: 101
Patient ID: TR_2202_TS6
Number of variables with significant adjusted p-values: 145
Patient ID: TR_2202_TS7
Number of variables with significant adjusted p-values: 93
Patient ID: TR_2202_TS8
Number of variables with significant adjusted p-values: 94
Patient ID: TR_2203_TS1
Number of variables with significant adjusted p-values: 56
Patient ID: TR_2203_TS2
Number of variables with significant adjusted p-values: 136
Patient ID: TR_2203_TS3
Number of variables with significant adjusted p-values: 100
Patient ID: TR_2203_TS4
Number of variables with significant adjusted p-values: 94
Patient ID: TR_2203_TS5
Number of variables with significant adjusted p-values: 89
Patient ID: TR_2203_TS6
Number of variables with significant adjusted p-values: 102
Patient ID: TR_2203_TS7
Number of variables with significant adjusted p-values: 82
Patient ID: TR_2203_TS8
Number of variables with significant adjusted p-values: 72
Patient ID: TR_2204_TS1
Number of variables with significant adjusted p-values: 196
Patient ID: TR_2204_TS2
Number of variables with significant adjusted p-values: 104
Patient ID: TR_2204_TS3
Number of variables with significant adjusted p-values: 116
Patient ID: TR_2204_TS4
Number of variables with significant adjusted p-values: 124
Patient ID: TR_2204_TS5
Number of variables with significant adjusted p-values: 167
Patient ID: TR_2204_TS6
Number of variables with significant adjusted p-values: 143
Patient ID: TR_2204_TS7
Number of variables with significant adjusted p-values: 149
Patient ID: TR_2205_TS1
Number of variables with significant adjusted p-values: 84
Patient ID: TR_2205_TS2
Number of variables with significant adjusted p-values: 140
Patient ID: TR_2205_TS3
Number of variables with significant adjusted p-values: 165
Patient ID: TR_2205_TS4
Number of variables with significant adjusted p-values: 149
Patient ID: TR_2205_TS5
Number of variables with significant adjusted p-values: 138
Patient ID: TR_2205_TS6
Number of variables with significant adjusted p-values: 159
Patient ID: TR_2205_TS7
Number of variables with significant adjusted p-values: 129
Patient ID: TR_2205_TS8
Number of variables with significant adjusted p-values: 213
"""



In [None]:
patients  = []
ts  = []
significant_p_values = []

# Use regular expressions to extract the information
pattern = re.compile(r"Patient ID: TR_(\d+)_TS(\d+)\s+Number of variables with significant adjusted p-values: (\d+)")

matches = pattern.findall(output)

for match in matches:
    patients_, ts_, significant_p_values_ = match
    patients.append(int(patients_))
    ts.append(int(ts_))
    significant_p_values.append(int(significant_p_values_))

In [None]:
highlight_set = set(metaphlan_df[metaphlan_df["RBF"]==True]['ts_patient_id'].unique())

# Generate combinations of Patient and TS
combinations = ['TR_' + str(p) + '_TS' + str(t) for p, t in zip(patients, ts)]

# Get unique patients
unique_patients = sorted(set(patients))

# Calculate the number of rows and columns for subplots
n = len(unique_patients)
cols = 3
rows = math.ceil(n / cols)

# Create the plot
fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))

for i, patient in enumerate(unique_patients):
    row = i // cols
    col = i % cols
    ax = axes[row, col] if rows > 1 else axes[col]

    # Filter data for the current patient
    patient_indices = [j for j, p in enumerate(patients) if p == patient]
    patient_ts = [ts[j] for j in patient_indices]
    patient_significant_p_values = [significant_p_values[j] for j in patient_indices]
    patient_combinations = [combinations[j] for j in patient_indices]

    # Plot the data
    ax.plot(patient_ts, patient_significant_p_values, label=f'Patient {patient}')

    # Highlight specific combinations in red
    for t, sp, comb in zip(patient_ts, patient_significant_p_values, patient_combinations):
        if comb in highlight_set:
            ax.plot(t, sp, 'ro')

    ax.set_xlabel('Batch')
    ax.set_ylabel('Number of non stationnary genes')
    ax.set_title(f'Patient {patient}')
    ax.grid(True, linestyle='--', alpha=0.7)

# Remove empty subplots
for i in range(n, rows*cols):
    row = i // cols
    col = i % cols
    fig.delaxes(axes[row, col] if rows > 1 else axes[col])

plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.show()


In [None]:
# @title shifting the week numbers so that the three patients have the beginning of their flare at the same time, which makes it easier to compare visually
metaphlan_df = align_flare_start_week(metaphlan_df)
gene_count_df = align_flare_start_week(gene_count_df)

In [None]:
species_signif_in_adf = ["s__Bacteroides_ovatus", "s__Barnesiella_intestinihominis", "s__Alistipes_onderdonkii", "s__Clostridiaceae_bacterium", "s__Anaerostipes_hadrus", "s__Blautia_wexlerae", "s__Dorea_formicigenerans", "s__Dorea_longicatena", "s__Eubacterium_rectale", "s__Lachnospiraceae_bacterium", "s__Roseburia_inulinivorans", "s__Flavonifractor_plautii", "s__Ruminococcus_bicirculans", "s__Faecalibacillus_intestinalis", "s__Bacteroides_uniformis", "s__Parabacteroides_distasonis", "s__Agathobaculum_butyriciproducens", "s__Faecalibacterium_prausnitzii"]


In [None]:
len(species_signif_in_adf)

In [None]:
species_signif_in_adf[9:17]

In [None]:
plot_species_abundance_over_time_2(metaphlan_df, cols_to_plot=species_signif_in_adf[], title='Abundance', type_='Species')

In [None]:
genes_signif_in_adf = ["ACVR1B_ENSG00000135503_ENST00000257963", "APAF1_ENSG00000120868_ENST00000551964", "BMP2_ENSG00000125845_ENST00000378827", "CASP3_ENSG00000164305_ENST00000308394", "CCL20_ENSG00000115009_ENST00000358813", "CLDN8_ENSG00000156284_ENST00000399899", "CLEC7A_ENSG00000172243_ENST00000304084", "CXCL1_ENSG00000163739_ENST00000395761", "FCGBP_ENSG00000275395_ENST00000616721", "FFAR2_ENSG00000126262_ENST00000599180", "HMGB1_ENSG00000189403_ENST00000341423", "ITGB1_ENSG00000150093_ENST00000437302", "MMP12_ENSG00000262406_ENST00000571244", "NR1D2_ENSG00000174738_ENST00000312521", "REL_ENSG00000162924_ENST00000394479", "RIPK2_ENSG00000104312_ENST00000220751", "S100A9_ENSG00000163220_ENST00000368738", "SELENBP1_ENSG00000143416_ENST00000368868", "TNFRSF25_ENSG00000215788_ENST00000356876", "XBP1_ENSG00000100219_ENST00000344347", "ATG5_ENSG00000057663_ENST00000369076", "CD68_ENSG00000129226_ENST00000250092", "CXCR5_ENSG00000160683_ENST00000292174", "DEFB1_ENSG00000164825_ENST00000297439", "IL4R_ENSG00000077238_ENST00000395762", "IRAG2_ENSG00000118308_ENST00000636465", "NFKB1_ENSG00000109320_ENST00000226574", "PIK3CA_ENSG00000121879_ENST00000263967", "TAGAP_ENSG00000164691_ENST00000367066", "BCAR1_ENSG00000050820_ENST00000162330", "CD44_ENSG00000026508_ENST00000428726", "IL1R2_ENSG00000115590_ENST00000332549", "NCF2_ENSG00000116701_ENST00000367535", "PTEN_ENSG00000171862_ENST00000371953", "SATB2_ENSG00000119042_ENST00000417098", "TNFRSF14_ENSG00000157873_ENST00000355716", "NOD2_ENSG00000167207_ENST00000647318"]


In [None]:
len(genes_signif_in_adf)

In [None]:
18+37

In [None]:
# @ title species to investigate
plot_species_abundance_over_time_2(gene_count_df,   genes_signif_in_adf, 'Expression', 'Genes')

In [None]:
# @ title species to investigate
plot_species_abundance_over_time_2(gene_count_df,   ['IL17A_ENSG00000112115_ENST00000648244', 'MKI67_ENSG00000148773_ENST00000368654', 'PRLR_ENSG00000113494_ENST00000618457'], 'Expression', 'Genes')

In [None]:
plot_species_abundance_over_time_2(metaphlan_df, cols_to_plot=['s__Flavonifractor_plautii'], title='Abundance', type_='Species')

In [None]:
plot_species_abundance_over_time_2(metaphlan_df, cols_to_plot=['s__Barnesiella_sp_An22','s__Gordonibacter_pamelaeae','s__Firmicutes_bacterium_AF16_15'], title='Abundance', type_='Species')

In [None]:
plot_species_abundance_over_time(metaphlan_df, ['s__Flavonifractor_plautii', 's__Blautia_wexlerae'], 'Microbial Abundance', 'Taxa')

In [None]:
# @title Aggregate species

# Extract the genus part from the species column names
metaphlan_agg_col = {col: col.split('_')[2] for col in metaphlan_df.columns if 's__' in col}

# Map species to genus
metaphlan_agg = metaphlan_df.rename(columns=metaphlan_agg_col)

metaphlan_agg = metaphlan_agg.T.groupby(level=0).sum().T

exclude_columns  = [ 'PatientID_Weeknr',
 'patient_id',
 'week',
 'Flare_status',
 'Flare_start',
 'ts',
 'RBF',
 'ts_patient_id']

new_order =  exclude_columns + [col for col in metaphlan_agg.columns if col not in exclude_columns]
metaphlan_agg = metaphlan_agg[new_order]

# Your DataFrame now has the microbial data aggregated by genus
print(metaphlan_agg.head())


In [None]:
analyze_representation_and_variation(metaphlan_agg, 8, 'Taxa')

In [None]:
plot_species_abundance_over_time(metaphlan_agg, ['Bacteroides', 'Christensenellaceae', 'Gemella', 'Sutterella'], 'Microbial Abundance', 'Taxa')

In [None]:
plot_species_abundance_over_time_2(gene_count_df,   ['XCR1_ENSG00000173578_ENST00000683768'], 'Expression', 'Genes')

In [None]:
metaphlan_df.columns

In [None]:
gene_count_df.columns



# Smoothed Profiles Analysis



In [None]:
species_columns = [col for col in metaphlan_df.columns if col not in exclude_columns]
gene_columns = [col for col in gene_count_df.columns if col not in exclude_columns]

In [None]:
smoothed_metaphlan_df = smooth_patient_data(metaphlan_df, species_columns)

In [None]:
plot_species_abundance_over_time_2(smoothed_metaphlan_df , cols_to_plot=['s__Phocaeicola_dorei','s__Bacteroides_uniformis','s__Faecalibacterium_prausnitzii' , 's__Bifidobacterium_pullorum',
  's__Bacteroides_ovatus', 's__Clostridia_bacterium', 's__Eubacterium_rectale', 's__Blautia_sp_MSK_21_1',
  's__GGB4569_SGB6310', 's__Blautia_wexlerae'], title='Abundance', type_='Species')