In [3]:
pip install ruptures



In [4]:
# @title Data Reading
import pandas as pd
import random
import seaborn as sns
import numpy as np
import ruptures as rpt
import time
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score
from scipy.stats import gmean






gene_count_path ='/content/drive/MyDrive/MITResearch/IBD_13/gene_count_preprocessed.csv'
metaphlan_path = '/content/drive/MyDrive/MITResearch/IBD_13/metaphlan_preprocessed.csv'

gene_count_df = pd.read_csv(gene_count_path)
metaphlan_df = pd.read_csv(metaphlan_path)

print(gene_count_df.shape)
print(metaphlan_df.shape)

(605, 546)
(479, 1453)


In [5]:
def clr_transform(df, metadata_columns):
    species_data = df.drop(columns=metadata_columns)
    species_data += 1e-5
    geometric_means = gmean(species_data, axis=1)
    clr_transformed_data = np.log(species_data.divide(geometric_means, axis=0))
    clr_transformed_df = pd.concat([df[metadata_columns], clr_transformed_data], axis=1)
    return clr_transformed_df

metaphlan_df = clr_transform(metaphlan_df, ['patient_id', 'week', 'Flare_status'])

In [6]:
# @title Adding a Flare_start column function

def flag_first_flare_weeks(df):
    # Identify the rows where flare starts
    df['is_flare'] = (df['Flare_status'] == 'During_flare') | (df['Flare_status'] == 'During_flare_2')

    # Sort by patient and week to ensure the chronological order
    df.sort_values(by=['patient_id', 'week'], inplace=True)

    # Mark each flare start for each patient
    df['Flare_start'] = (df['is_flare']) & (df['is_flare'] != df['is_flare'].shift(1))

    # Convert boolean to integer (1 for True, 0 for False)
    df['Flare_start'] = df['Flare_start'].astype(int)

    # Drop helper columns if they are no longer needed
    df.drop('is_flare', axis=1, inplace=True)

    return df

In [7]:
gene_count_df.dtypes

Flare_status                               object
patient_id                                 object
week                                        int64
ABCA1_ENSG00000165029_ENST00000678995     float64
ABCA8_ENSG00000141338_ENST00000586539     float64
                                           ...   
LCN2_ENSG00000148346_ENST00000373017      float64
MUC19_ENSG00000205592_ENST00000454784     float64
NLRP7_ENSG00000167634_ENST00000588756     float64
ORMDL3_ENSG00000172057_ENST00000394169    float64
TSLP_ENSG00000145777_ENST00000344895      float64
Length: 546, dtype: object

In [8]:
metaphlan_df = metaphlan_df.sort_values(by=['patient_id', 'week'])

# Function to find weeks with Flare_status change
def find_flare_changes(group):
    # Detect changes by comparing each element with the previous one
    changes = group['Flare_status'].ne(group['Flare_status'].shift(1))
    if changes.any():  # If there are any changes
        change_weeks = group.loc[changes, 'week']
        print(f"Patient {group['patient_id'].iloc[0]} has Flare_status changes in weeks: {list(change_weeks)}")

# Group by patient_id and apply the function
metaphlan_df.groupby('patient_id').apply(find_flare_changes)

Patient TR_2101 has Flare_status changes in weeks: [1, 25, 44]
Patient TR_2102 has Flare_status changes in weeks: [1, 16, 29]
Patient TR_2103 has Flare_status changes in weeks: [1, 27, 42, 51, 52]
Patient TR_2104 has Flare_status changes in weeks: [1, 23, 36]
Patient TR_2105 has Flare_status changes in weeks: [1, 23, 36]
Patient TR_2106 has Flare_status changes in weeks: [0, 23, 36]
Patient TR_2107 has Flare_status changes in weeks: [0, 23, 36]
Patient TR_2108 has Flare_status changes in weeks: [16, 23, 36]
Patient TR_2201 has Flare_status changes in weeks: [1, 7, 8, 9, 15, 16, 18, 22, 23, 24, 31, 32, 33, 47, 48, 49]
Patient TR_2202 has Flare_status changes in weeks: [1]
Patient TR_2203 has Flare_status changes in weeks: [0]
Patient TR_2205 has Flare_status changes in weeks: [39]


In [9]:
def print_flare_start_indices(df):
  df_sorted = df.sort_values(by=['patient_id', 'week'])

  df_sorted = flag_first_flare_weeks(df_sorted)

  def find_flare_start_indices(group):
      indices = group.index[group['Flare_start'] == 1].tolist()
      return [group.index.get_loc(idx) for idx in indices]

  # Apply the function to each group
  flare_indices = df_sorted.groupby('patient_id').apply(find_flare_start_indices)

  print(flare_indices)

print("Flare start indices for metaphlan_df")
print(print_flare_start_indices(metaphlan_df))
print("Flare start indices for gene_count_df")
print(print_flare_start_indices(gene_count_df))

Flare start indices for metaphlan_df
patient_id
TR_2101        [22]
TR_2102        [12]
TR_2103    [24, 42]
TR_2104        [22]
TR_2105        [22]
TR_2106        [22]
TR_2107        [23]
TR_2108         [6]
TR_2201          []
TR_2202          []
TR_2203          []
TR_2205          []
dtype: object
None
Flare start indices for gene_count_df
patient_id
TR_2101        [25]
TR_2102        [11]
TR_2103    [25, 46]
TR_2104        [22]
TR_2105        [21]
TR_2106        [22]
TR_2107        [22]
TR_2108        [20]
TR_2201          []
TR_2202          []
TR_2203          []
TR_2204          []
TR_2205          []
dtype: object
None


# New approach

In [11]:
def detect_change_points(data, pen=50):
    """Detect change points in a given time series using the PELT method."""
    model = "l2"
    algo = rpt.Pelt(model=model, min_size=1, jump=1).fit(data)
    result = algo.predict(pen=pen)
    if result[-1] == len(data):
        result = result[:-1]
    return result

def process_patient_data(data, patient_id, pen=50):
    """Process data for a single patient and perform clustering on change points, returning species in each cluster."""
    print(f"Processing data for patient ID: {patient_id}")

    # Filter and prepare data
    filtered_df = data[data['patient_id'] == patient_id].drop(columns=['patient_id', 'Flare_status', 'week'])
    filtered_df = 10 * filtered_df.loc[:, (filtered_df != 0).any(axis=0)]
    species_names = filtered_df.columns.tolist()
    signal = filtered_df.to_numpy()
    print(f"Data prepared with {len(species_names)} species/genes")

    # Detect change points
    change_points = [detect_change_points(signal[:, i], pen) for i in range(signal.shape[1])]
    print("Change point detection complete")

    # Create feature vectors for clustering
    feature_vectors = np.zeros((signal.shape[1], signal.shape[0]))
    for idx, cps in enumerate(change_points):
        if cps:
            feature_vectors[idx, cps] = 1
    print(f"Feature vectors created with non-zero entries: {np.sum(feature_vectors)}")

    # Clustering process
    if np.any(feature_vectors.sum(axis=1)):
        silhouette_scores = []
        print(feature_vectors.shape)
        print(int(np.sum(feature_vectors.sum(axis=1) > 0)) - 1)
        unique_counts = [len(np.unique(feature_vectors[:, i])) for i in range(feature_vectors.shape[1])]
        range_n_clusters = range(2, int(sum(np.array(unique_counts) > 1) / 3))
        print(range_n_clusters)
        for num_clusters in range_n_clusters:
            kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
            cluster_labels = kmeans.fit_predict(feature_vectors)
            silhouette_avg = silhouette_score(feature_vectors, cluster_labels)
            silhouette_scores.append(silhouette_avg)
            #print(f"Evaluating {num_clusters} clusters: Silhouette Score={silhouette_avg}")

        optimal_n_clusters = range_n_clusters[np.argmax(silhouette_scores)]
        kmeans = KMeans(n_clusters=optimal_n_clusters, random_state=42, n_init=10)
        labels = kmeans.fit_predict(feature_vectors)

        # Organize species by cluster
        clusters = {i: [] for i in range(optimal_n_clusters)}
        for label, species in zip(labels, species_names):
            clusters[label].append(species)
        print(f"Optimal number of clusters: {optimal_n_clusters}")

        for cluster_id, species_list in clusters.items():
          print(f"Cluster {cluster_id} has {len(species_list)} elements.")

        print(clusters)
        return clusters
    else:
        print("No change points detected across all features; clustering not applicable.")
        return {}
