In [23]:
pip install pytorch-tabnet



In [24]:
pip install torch torchvision



In [25]:
# General imports
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from random import sample
import warnings

# Data processing
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Model training and evaluation
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score, make_scorer, roc_auc_score, precision_score, recall_score, confusion_matrix

# PyTorch and TabNet
import torch
from pytorch_tabnet.tab_model import TabNetClassifier

gene_count_path = '/content/drive/MyDrive/MITResearch/sparsecca/target.gene.count.matrix.csv'
metaphlan_path = '/content/drive/MyDrive/MITResearch/sparsecca/UMCGPilotdata_metaphlan.xlsx'

# Load gene count matrix
gene_count_df = pd.read_csv(gene_count_path)
metaphlan_df = pd.read_excel(metaphlan_path)
print(gene_count_df.shape)
print(metaphlan_df.shape)

(543, 213)
(228, 597)


In [26]:
# @title

# ---- Data Transformation Section ----
# Transpose gene count matrix to align with metaphlan_df data format
gene_count_df = gene_count_df.T

# Set the first row as column names and convert the matrix to a dataframe
new_col_names = gene_count_df.iloc[0]  # First row for column names
gene_count_df = gene_count_df[1:]  # Remove the first row
gene_count_df.columns = new_col_names  # Set new column names
gene_count_df.reset_index(drop=False,inplace=True)
gene_count_df.rename(columns={'index': 'PatientID_Weeknr'}, inplace=True)

# Align metaphlan_df PatientID_Weeknr format with gene_count_df
metaphlan_df['PatientID_Weeknr'] = metaphlan_df['PatientID_Weeknr'].str.replace("Weekly_Feces_", "")
metaphlan_df['PatientID_Weeknr'] = metaphlan_df['PatientID_Weeknr'].str.replace("_Week_", "__Week_")
metaphlan_df['PatientID_Weeknr'] = metaphlan_df['PatientID_Weeknr'].str.replace("_(\\d)$", r"_0\1", regex=True)

pattern = r'^TR_\d+__Week_\d{2}$'

# Identify and print rows with incorrect formats
invalid_df = gene_count_df[~gene_count_df['PatientID_Weeknr'].str.match(pattern)]
if not invalid_df.empty:
    print("Rows removed (incorrect format):")
    print(invalid_df['PatientID_Weeknr'])
else:
    print("No rows removed.")

gene_count_df = gene_count_df[gene_count_df['PatientID_Weeknr'].str.match(pattern)]

Rows removed (incorrect format):
15    TR_2101__Week_17_2
18    TR_2101__Week_19_2
Name: PatientID_Weeknr, dtype: object


In [27]:
# @title split and insert function
def split_and_insert(df, column_name, split_str):
    df[column_name] = df[column_name].astype(str)
    patient_id, week = df[column_name].str.split(split_str, expand=True)[0], df[column_name].str.extract('(\d+)$')[0].astype(int)
    df.insert(0, 'week', week)
    df.insert(0, 'patient_id', patient_id)
    df.drop(column_name, axis=1, inplace=True)
    return df

In [28]:
# @title flag_first_flare_weeks function

def flag_first_flare_weeks(df):
    # Identify the rows where flare starts
    df['is_flare'] = (df['Flare_status'] == 'During_flare') | (df['Flare_status'] == 'During_flare_2')

    # Sort by patient and week to ensure the chronological order
    df.sort_values(by=['patient_id', 'week'], inplace=True)

    # Mark each flare start for each patient
    df['Flare_start'] = (df['is_flare']) & (df['is_flare'] != df['is_flare'].shift(1))

    # Convert boolean to integer (1 for True, 0 for False)
    df['Flare_start'] = df['Flare_start'].astype(int)

    # Drop helper columns if they are no longer needed
    df.drop('is_flare', axis=1, inplace=True)

    return df

In [29]:
# @title normalize_to_housekeeping_genes function
def normalize_to_housekeeping_genes(df, housekeeping_genes, exclude_columns):
    # Convert data in DataFrame to numeric types, except for excluded columns
    for column in df.columns:
        if column not in exclude_columns:
            df[column] = pd.to_numeric(df[column], errors='coerce')

    # Find columns that start with any of the housekeeping gene symbols and are not in the excluded list
    hk_gene_cols = [col for col in df.columns
                    if any(col.startswith(hk + '_') for hk in housekeeping_genes) and col not in exclude_columns]

    # Selecting housekeeping genes data
    hk_genes_data = df[hk_gene_cols]

    # Calculate the geometric mean of the housekeeping genes for each sample
    # Adding a small value (e.g., 1e-9) to avoid taking log(0)
    geometric_mean_hk = np.exp(np.log(hk_genes_data + 1e-9).mean(axis=1))

    # Normalize the entire gene count matrix by the geometric mean of the housekeeping genes
    # Apply normalization only to numeric columns
    for column in df.columns:
        if column not in exclude_columns:
            df[column] = df[column].div(geometric_mean_hk, axis=0)

    return df

In [30]:
# @title assign_rbf

def assign_rbf(df, samples_before_flare=6):
    # Initial setup
    df = df.sort_values(by=['patient_id', 'week'])
    df['RBF'] = 0  # Initialize the RBF column

    # Identify patients and the weeks where flares start
    flare_starts = df[df['Flare_start'] == 1].groupby('patient_id')['week'].min()

    for patient, flare_week in flare_starts.items():
        # Get all weeks for the current patient
        patient_weeks = df.loc[df['patient_id'] == patient, 'week']

        # Find the indexes for weeks before the flare start
        weeks_before_flare = patient_weeks[patient_weeks < flare_week].nlargest(samples_before_flare)

        # Set RBF to 1 for these weeks
        df.loc[(df['patient_id'] == patient) & (df['week'].isin(weeks_before_flare)), 'RBF'] = 1

    return df

In [31]:
# @title Split 'PatientID_Weeknr' in patient id and week
metaphlan_df = split_and_insert(metaphlan_df, 'PatientID_Weeknr', '__')
gene_count_df = split_and_insert(gene_count_df, 'PatientID_Weeknr', '__')
metaphlan_df = metaphlan_df.drop(['SampleName', 'Weeknumber', 'Exacerbation_judgement_treatingphysician'], axis=1)

In [32]:
print(gene_count_df.shape)
print(metaphlan_df.shape)

(210, 545)
(228, 595)


In [33]:
# @title merge metaphlan with gene_count to get the Flare_satus
gene_count_df = pd.merge(gene_count_df, metaphlan_df[['patient_id', 'week', 'Flare_status']], on=['patient_id', 'week'], how='left')
cols = ['Flare_status'] + [col for col in gene_count_df.columns if col != 'Flare_status']
gene_count_df = gene_count_df[cols]


gene_count_df.sort_values(by=['patient_id', 'week'], inplace=True)

# Step 3: Fill missing 'Flare_status' values using forward fill
gene_count_df['Flare_status'] = gene_count_df['Flare_status'].fillna(method='ffill')


In [34]:
print(gene_count_df.shape)
print(metaphlan_df.shape)

(210, 546)
(228, 595)


In [35]:
# @title Adding a Flare_start column
metaphlan_df = flag_first_flare_weeks(metaphlan_df)
gene_count_df = flag_first_flare_weeks(gene_count_df)

In [36]:
# @title normalize to housekeeping genes
house_keeping_genes = ["ACTB", "ATP5F1", "B2M", "GAPDH", "GUSB",
                       "HPRT", "PGK1", "PPIA", "RPS18", "TBP",
                       "TFRC", "YWHAZ"]

# Columns to exclude from normalization
exclude_columns = ['Flare_status', 'patient_id', 'week', 'Flare_start']

# Normalizing the DataFrame
normalized_df = normalize_to_housekeeping_genes(gene_count_df, house_keeping_genes, exclude_columns)

In [37]:
# @title Assign RBF flags with 6 samples before flare as default or any other number you wish
metaphlan_df = assign_rbf(metaphlan_df, samples_before_flare=6)
gene_count_df = assign_rbf(gene_count_df, samples_before_flare=6)

In [38]:
# Merge datasets on 'patient_id' and 'week'
data = pd.merge(metaphlan_df, gene_count_df, on=['patient_id', 'week', 'Flare_start',  'Flare_status', 'RBF'])
data = data[(data["Flare_status"] != "Post_flare") &
            (data["Flare_status"] != "During_flare") &
            (data["Flare_status"] != "During_flare_2")]
# Drop unwanted columns
cols_to_drop = ['Flare_start', 'Flare_status']
data = data.drop(cols_to_drop, axis=1)

# Fill missing values if necessary
#data.fillna(data.mean(), inplace=True)  # Using mean to fill missing values, adjust according to your data


# Transformer

In [39]:

# Convert 'patient_id' into a categorical variable suitable for modeling
data['patient_id'] = LabelEncoder().fit_transform(data['patient_id'])

# Split data into features and target
X = data.drop('RBF', axis=1)
y = data['RBF']

# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize TabNetClassifier
model = TabNetClassifier(optimizer_fn=torch.optim.Adam,
                         optimizer_params=dict(lr=2e-2),
                         scheduler_params={"step_size":10, "gamma":0.9},
                         scheduler_fn=torch.optim.lr_scheduler.StepLR,
                         verbose=1)

# Fit model
model.fit(
    X_train.values, y_train.values,
    eval_set=[(X_train.values, y_train.values), (X_test.values, y_test.values)],
    eval_name=['train', 'valid'],
    eval_metric=['accuracy'],
    max_epochs=1000,
    patience=50,  # Early stopping patience
    batch_size=1024,
    virtual_batch_size=128,
    num_workers=0,
    drop_last=False
)

# Predictions
preds = model.predict(X_test.values)
print(classification_report(y_test, preds))

# Feature Importance
feature_importances = model.feature_importances_
print("Feature importances:\n", feature_importances)




epoch 0  | loss: 0.6974  | train_accuracy: 0.88393 | valid_accuracy: 0.86207 |  0:00:00s
epoch 1  | loss: 0.59382 | train_accuracy: 0.875   | valid_accuracy: 0.86207 |  0:00:00s
epoch 2  | loss: 0.56243 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:00s
epoch 3  | loss: 0.52191 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:00s
epoch 4  | loss: 0.4135  | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:01s
epoch 5  | loss: 0.43707 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:01s
epoch 6  | loss: 0.35756 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:01s
epoch 7  | loss: 0.33195 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:01s
epoch 8  | loss: 0.43734 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:02s
epoch 9  | loss: 0.39767 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:02s
epoch 10 | loss: 0.43677 | train_accuracy: 0.89286 | valid_accuracy: 0.86207 |  0:00:02s
epoch 11 | loss: 0.40

