In [6]:
import pandas as pd
import numpy as np
from lifelines import CoxTimeVaryingFitter

from sksurv.ensemble import RandomSurvivalForest
from sksurv.preprocessing import OneHotEncoder

import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter

from sklearn.model_selection import train_test_split

In [8]:
df = pd.read_csv('BDHSC_SCC_2025_synth_data.csv')
regimen_columns = ["Base_Drug_Combo", "Comp_INI", "Comp_NNRTI", "ExtraPI", "ExtraPk_En"]
df["Regimen"] = df[regimen_columns].astype(str).agg("_".join, axis=1)

unique_ids = df["ID"].unique()
train_ids, test_ids = train_test_split(unique_ids, test_size=0.1, random_state=42)
train_df = df[df["ID"].isin(train_ids)]
test_df = df[df["ID"].isin(test_ids)]

In [9]:
df_sorted = train_df.sort_values(['ID', 'Month'])
df_sorted['Regimen_Shift'] = (df_sorted.groupby('ID')['Regimen'].shift() != df_sorted['Regimen']).astype(int)
df_sorted['Regimen_Start'] = df_sorted.groupby('ID')['Regimen_Shift'].cumsum()

# Split into regimen intervals
intervals = df_sorted.groupby(['ID', 'Regimen_Start']).agg(
    Start=('Month', 'min'),
    End=('Month', 'max'),
    Regimen=('Regimen', 'first'),
    VL=('VL', lambda x: list(x)),
    CD4=('CD4', lambda x: list(x))
).reset_index()

In [None]:
def track_outcomes(group):
    """Track first suppression/CD4 recovery and retain covariates."""
    vl = group['VL']
    cd4 = group['CD4']
    time_points = group['Month']
    
    # Viral load outcomes
    vl_250_time = next((t for t, val in zip(time_points, vl) if val <= 250), None)
    vl_50_time = next((t for t, val in zip(time_points, vl) if val <= 50), None)
    
    # CD4 recovery
    cd4_500_time = next((t for t, val in zip(time_points, cd4) if val > 500), None)
    
    return pd.Series({
        'VL_250_time': vl_250_time,
        'VL_50_time': vl_50_time,
        'CD4_500_time': cd4_500_time,
        'VL_250_Censored': 1 if vl_250_time is None else 0,
        'VL_50_Censored': 1 if vl_50_time is None else 0,
        'CD4_500_Censored': 1 if cd4_500_time is None else 0,
        'Gender': group['Gender'].iloc[0],  # Add baseline covariates
        'Ethnicity': group['Ethnic'].iloc[0],
        'Baseline_VL': group['VL'].iloc[0],
        'Baseline_CD4': group['CD4'].iloc[0],
        'Baseline_CD4_percent': group['RelCD4']
    })

# Apply outcome tracking
outcomes = df_sorted.groupby(['ID', 'Regimen_Start']).apply(track_outcomes).reset_index()
intervals_outcomes = pd.merge(intervals, outcomes, on=['ID', 'Regimen_Start'])

In [None]:
def filter_intervals(df, mode):
    """Filter intervals where start time < target time (or target not achieved)."""

    if mode == 'vl50':
        mask = (df['Start'] < df['VL_50_time']) | (df['VL_50_Censored'] == 1)
    elif mode == 'vl250':
        mask = (df['Start'] < df['VL_250_time']) | (df['VL_250_Censored'] == 1)
    else:
        mask = (df['Start'] < df['CD4_500_time']) | (df['CD4_500_Censored'] == 1)
    df = df[mask]
    return df[df['End'] - df['Start'] >= 6]

def cox(mode, verbose=False):
    filtered_intervals_outcomes = filter_intervals(intervals_outcomes, mode)

    tv_data = []
    for _, row in filtered_intervals_outcomes.iterrows():
        start = row['Start']
        if mode == 'vl50':
            end = row['End'] if row['VL_50_Censored'] == 1 else row['VL_50_time']
        elif mode == 'vl250':
            end = row['End'] if row['VL_250_Censored'] == 1 else row['VL_250_time']
        else:
            end = row['End'] if row['CD4_500_Censored'] == 1 else row['CD4_500_time']
        
        tv_data.append({
            'Patient ID': row['ID'],
            'start': start,
            'stop': end,
            'gender': row['Gender'],
            'ethnicity': row['Ethnicity'],
            'censor': row['VL_50_Censored'],
            'Regimen': row['Regimen'],
            'Baseline_VL': row['Baseline_VL'],
            'Baseline_CD4': row['Baseline_CD4']
        })

    tv_df = pd.DataFrame(tv_data)
    counts = tv_df["Regimen"].value_counts()
    common_regimens = counts[counts > 1000].index

    tv_df["Regimen_lumped"] = tv_df["Regimen"].apply(
        lambda x: x if x in common_regimens else "Other"
    )

    tv_df["gender"] = tv_df["gender"].astype("category")
    tv_df["ethnicity"] = tv_df["ethnicity"].astype("category")
    tv_df = pd.get_dummies(tv_df, columns=["Regimen_lumped", "gender", "ethnicity"], drop_first=True)

    tv_df.drop(['Regimen'], axis=1, inplace=True)

    ctv = CoxTimeVaryingFitter()
    ctv.fit(
        tv_df,
        id_col="Patient ID",
        event_col="censor",
        start_col="start",
        stop_col="stop",
        show_progress=True
    )
    if verbose:
        ctv.print_summary()
    return ctv

In [None]:
ctv_vl50 = cox('vl50')
ctv_vl250 = cox('vl250')
ctv_cd4_500 = cox('cd4_500')

In [None]:
test_df_sorted = test_df.sort_values(['ID', 'Month'])
test_df_sorted['Regimen_Shift'] = (test_df_sorted.groupby('ID')['Regimen'].shift() != df_sorted['Regimen']).astype(int)
test_df_sorted['Regimen_Start'] = test_df_sorted.groupby('ID')['Regimen_Shift'].cumsum()

# Split into regimen intervals
test_intervals = test_df_sorted.groupby(['ID', 'Regimen_Start']).agg(
    Start=('Month', 'min'),
    End=('Month', 'max'),
    Regimen=('Regimen', 'first'),
    VL=('VL', lambda x: list(x)),
    CD4=('CD4', lambda x: list(x))
).reset_index()

test_outcomes = test_df_sorted.groupby(['ID', 'Regimen_Start']).apply(track_outcomes).reset_index()
test_intervals_outcomes = pd.merge(test_intervals, test_outcomes, on=['ID', 'Regimen_Start'])

In [None]:
def predict_probability_vl50(new_sample, ctv, baseline_survival, time_point):
    """
    Predicts the probability of reaching VL ≤ 50 for a new patient sample at a given time.

    Parameters:
    - new_sample (pd.Series): A single row of patient data with same features as ctv model.
    - ctv (CoxTimeVaryingFitter): Trained Cox PH model.
    - baseline_survival (pd.Series): The baseline survival function from ctv.
    - time_point (float): The time at which to predict the probability.

    Returns:
    - Probability of reaching VL ≤ 50 by time_point.
    """
    # Compute the linear predictor (risk score)
    X_beta = np.dot(new_sample, ctv.params_)

    # Extract baseline survival at time_point
    S0_t = baseline_survival.loc[time_point]

    # Compute adjusted survival probability for new sample
    S_new_t = S0_t ** np.exp(X_beta)

    # Probability of VL ≤ 50 by time_point
    probability_vl50 = 1 - S_new_t
    return probability_vl50


In [None]:
def compute_combined_probability(test_df, ctv, baseline_survival):
    """
    Computes the combined probability of reaching VL ≤ 50 for each patient in test_df,
    by multiplying predicted probabilities for each regimen interval.

    Parameters:
    - test_df (pd.DataFrame): Test dataset with patient intervals.
    - ctv (CoxTimeVaryingFitter): Trained Cox PH model.
    - baseline_survival (pd.Series): The baseline survival function from ctv.
    - time_point (float): The time at which to predict the probability.

    Returns:
    - A dictionary mapping Patient ID to combined probability of VL ≤ 50.
    """
    patient_probs = {}

    # Group by Patient ID
    for patient_id, patient_data in test_df.groupby("Patient ID"):
        combined_prob = 1.0  # Start with 1, since we multiply probabilities

        for _, interval in patient_data.iterrows():
            # Drop non-feature columns (assuming the same feature structure as Cox model)
            duration = interval['stop'] - interval['start']
            features = interval.drop(["Patient ID", "stop", "censor"])

            # Compute probability for this interval
            prob_vl50 = predict_probability_vl50(features, ctv, baseline_survival, duration)
            
            # Multiply probabilities (to reflect all regimen intervals)
            combined_prob *= prob_vl50

        # Store the final probability per patient
        patient_probs[patient_id] = combined_prob

    return patient_probs



In [None]:
mode = 'vl50'
filtered_test_intervals_outcomes = filter_intervals(test_intervals_outcomes, mode)

tv_data = []
for _, row in filtered_test_intervals_outcomes.iterrows():
    start = row['Start']
    if mode == 'vl50':
        end = row['End'] if row['VL_50_Censored'] == 1 else row['VL_50_time']
    elif mode == 'vl250':
        end = row['End'] if row['VL_250_Censored'] == 1 else row['VL_250_time']
    else:
        end = row['End'] if row['CD4_500_Censored'] == 1 else row['CD4_500_time']
    
    tv_data.append({
        'Patient ID': row['ID'],
        'start': start,
        'stop': end,
        'gender': row['Gender'],
        'ethnicity': row['Ethnicity'],
        'censor': row['VL_50_Censored'],
        'Regimen': row['Regimen'],
        'Baseline_VL': row['Baseline_VL'],
        'Baseline_CD4': row['Baseline_CD4']
    })

tv_df = pd.DataFrame(tv_data)
counts = tv_df["Regimen"].value_counts()
common_regimens = counts[counts > 1000].index

tv_df["Regimen_lumped"] = tv_df["Regimen"].apply(
    lambda x: x if x in common_regimens else "Other"
)

tv_df["gender"] = tv_df["gender"].astype("category")
tv_df["ethnicity"] = tv_df["ethnicity"].astype("category")
tv_df = pd.get_dummies(tv_df, columns=["Regimen_lumped", "gender", "ethnicity"], drop_first=True)

tv_df.drop(['Regimen'], axis=1, inplace=True)

In [None]:
baseline_survival = ctv_vl50.baseline_survival_

# Compute the combined probabilities for all patients in the test dataset
patient_combined_probs = compute_combined_probability(tv_df, ctv_vl50, baseline_survival)

# Convert results to a DataFrame and display
patient_prob_df = pd.DataFrame(list(patient_combined_probs.items()), columns=["Patient ID", "Combined Probability VL ≤ 50"])
import ace_tools as tools
tools.display_dataframe_to_user(name="Patient VL ≤ 50 Probabilities", dataframe=patient_prob_df)