In [None]:
import pandas as pd
import numpy as np
from lifelines import CoxTimeVaryingFitter

from sksurv.ensemble import RandomSurvivalForest
from sksurv.preprocessing import OneHotEncoder

import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, accuracy_score



In [None]:
def get_first_target_times(df, patient_col="ID", time_col="Month", targets=["event_vl50", "event_vl200", "event_cd4"]):
    """
    Finds the first time each patient reaches the given targets.

    Parameters:
    - df (pd.DataFrame): Original DataFrame with patient data.
    - patient_col (str): Name of the patient ID column.
    - time_col (str): Name of the time column.
    - targets (list): List of target event columns to check.

    Returns:
    - pd.DataFrame: A DataFrame containing Patient ID and first time they hit each target.
    """
    result_df = pd.DataFrame()

    # Loop through each target and find first occurrence
    for target in targets:
        first_times = df[df[target] == 1].groupby(patient_col)[time_col].min().reset_index()
        first_times.rename(columns={time_col: f"first_{target}"}, inplace=True)
        
        # Merge into result DataFrame
        if result_df.empty:
            result_df = first_times
        else:
            result_df = result_df.merge(first_times, on=patient_col, how="outer")

    return result_df

def map_regimen(row):
    """Combine regimen components into a standardized name."""
    drug_mappings = {
    'Base_Drug_Combo': {
        0: 'FTC + TDF',
        1: '3TC + ABC',
        2: 'FTC + TAF',
        3: 'DRV + FTC + TDF',
        4: 'FTC + RTVB + TDF',
        5: 'Other'
    },
    'Comp_INI': {0: 'DTG', 1: 'RAL', 2: 'EVG', 3: 'N/A'},
    'Comp_NNRTI': {0: 'NVP', 1: 'EFV', 2: 'RPV', 3: 'N/A'},
    'ExtraPI': {0: 'DRV', 1: 'RTVB', 2: 'LPV', 3: 'RTV', 4: 'ATV', 5: 'N/A'},
}

    base = drug_mappings['Base_Drug_Combo'].get(row['Base_Drug_Combo'], 'Unknown')
    ini = drug_mappings['Comp_INI'].get(row['Comp_INI'], '')
    nnrti = drug_mappings['Comp_NNRTI'].get(row['Comp_NNRTI'], '')
    pi = drug_mappings['ExtraPI'].get(row['ExtraPI'], '')
    enhancer = ' + ExtraPK' if row['ExtraPk_En'] == 1 else ''
    components = [base, ini, nnrti, pi]
    regimen = ' + '.join(filter(None, components)) + enhancer
    return regimen

In [None]:
df = pd.read_csv('BDHSC_SCC_2025_synth_data.csv')

In [None]:
df['Regimen'] = df.apply(map_regimen, axis=1)
df['event_vl50'] = df['VL'] < 50
df['event_vl250'] = df['VL'] < 250
df['event_cd4'] = df['CD4'] > 500
targets = ["event_vl50", "event_vl250", "event_cd4"]
first_target_times_df = get_first_target_times(df, targets=targets)
df = pd.merge(df, first_target_times_df, on='ID', how='inner')

In [None]:
df_sorted = df.sort_values(['ID', 'Month'])
df_sorted['Regimen_Shift'] = (df_sorted.groupby('ID')['Regimen'].shift() != df_sorted['Regimen']).astype(int)
df_sorted['Regimen_Start'] = df_sorted.groupby('ID')['Regimen_Shift'].cumsum()
intervals = df_sorted.groupby(['ID', 'Regimen_Start']).agg(
    Start=('Month', 'min'),
    End=('Month', 'max'),
    Regimen=('Regimen', 'first'),
    VL=('VL', lambda x: list(x)),
    CD4=('CD4', lambda x: list(x)),
    VL_50_time=('first_event_vl50', 'first'),
    VL_250_time=('first_event_vl250', 'first'),
    CD4_500_time=('first_event_cd4', 'first'),
    Gender=('Gender', 'first'),
    Ethnicity=('Ethnic', 'first'),
    VL_50_baseline=('VL', 'first'),
    CD4_500_baseline=('CD4', 'first'),
    CD4_Percent_baseline=('RelCD4', 'first')
).reset_index()
intervals

In [None]:
intervals_censored = intervals.copy()
intervals_censored['VL_50_Censored'] = np.where(
    ~((intervals_censored['Start'] <= intervals_censored['VL_50_time']) & 
      (intervals_censored['VL_50_time'] <= intervals_censored['End'])),
    True,  # If the condition is not met, mark as censored (True)
    False  # Otherwise, mark as not censored (False)
)
intervals_censored['VL_250_Censored'] = np.where(
    ~((intervals_censored['Start'] <= intervals_censored['VL_250_time']) & 
      (intervals_censored['VL_250_time'] <= intervals_censored['End'])),
    True,  # If the condition is not met, mark as censored (True)
    False  # Otherwise, mark as not censored (False)
)
intervals_censored['CD4_500_Censored'] = np.where(
    ~((intervals_censored['Start'] <= intervals_censored['CD4_500_time']) & 
      (intervals_censored['CD4_500_time'] <= intervals_censored['End'])),
    True,  # If the condition is not met, mark as censored (True)
    False  # Otherwise, mark as not censored (False)
)
intervals_censored

In [None]:
intervals_censored.to_csv('intervals_censored.csv')

In [135]:
def filter_intervals(df, mode):
    """Filter intervals where start time < target time (or target not achieved)."""

    if mode == 'vl50':
        mask = (df['Start'] < df['VL_50_time'])
    elif mode == 'vl250':
        mask = (df['Start'] < df['VL_250_time'])
    else:
        mask = (df['Start'] < df['CD4_500_time'])
    df = df[mask]
    return df[df['End'] - df['Start'] >= 3]

def cox(mode, intervals_censored, verbose=False):
    filtered_intervals_outcomes = filter_intervals(intervals_censored, mode)

    tv_data = []
    for _, row in filtered_intervals_outcomes.iterrows():
        start = row['Start']
        if mode == 'vl50':
            end = row['End'] if row['VL_50_Censored'] == 1 else row['VL_50_time']
        elif mode == 'vl250':
            end = row['End'] if row['VL_250_Censored'] == 1 else row['VL_250_time']
        else:
            end = row['End'] if row['CD4_500_Censored'] == 1 else row['CD4_500_time']
        
        tv_data.append({
            'Patient ID': row['ID'],
            'start': start,
            'stop': end,
            'gender': row['Gender'],
            'ethnicity': row['Ethnicity'],
            'censor': row['VL_50_Censored'] if mode == 'vl50' else row['VL_250_Censored'] if mode == 'vl250' else row['CD4_500_Censored'],
            'Regimen': row['Regimen'],
            'Baseline_VL': row['VL_50_baseline'],
            'Baseline_CD4': row['CD4_500_baseline'],
            'Baseline_CD4_Percent': row['CD4_Percent_baseline']
        })

    
    tv_df = pd.DataFrame(tv_data)
    counts = tv_df["Regimen"].value_counts()
    common_regimens = counts[counts > len(filtered_intervals_outcomes)/40].index

    tv_df["Regimen"] = tv_df["Regimen"].apply(
        lambda x: x if x in common_regimens else "Other"
    )

    # tv_df["gender"] = tv_df["gender"].astype("category")
    # tv_df["ethnicity"] = tv_df["ethnicity"].astype("category")
    tv_df = pd.get_dummies(tv_df, columns=["Regimen"], drop_first=True)

    tv_df.drop(["gender", 'ethnicity'], axis=1, inplace=True)
    if 'Regimen_Other' in tv_df.columns:
        tv_df.drop(['Regimen_Other'], axis=1, inplace=True)

    ctv = CoxTimeVaryingFitter()
    ctv.fit(
        tv_df,
        id_col="Patient ID",
        event_col="censor",
        start_col="start",
        stop_col="stop",
        show_progress=True,
        fit_options={'precision': 1e-4, 'r_precision': 1e-4}
    )
    if verbose:
        ctv.print_summary()

    coeffs = ctv.summary
    # coeffs = coeffs[coeffs['p'] < 0.05]
    for baseline in ['Baseline_VL', 'Baseline_CD4', 'Baseline_CD4_Percent']:
        if baseline in coeffs.index:
            coeffs = coeffs.drop(index=[baseline])

    coeffs = coeffs.reset_index()
    coeffs = coeffs[['covariate', 'exp(coef)']]
    coeffs.columns = ['covariate', f'effect_{mode}']
    
    return coeffs

def optimal_regimen(intervals_censored):
    coeffs_tv50 = cox('vl50', intervals_censored)
    coeffs_tv250 = cox('vl250', intervals_censored)
    coeffs_cd500 = cox('cd500', intervals_censored)

    df = pd.merge(coeffs_tv50, coeffs_tv250, on='covariate', how='inner')
    df = pd.merge(df, coeffs_cd500, on='covariate', how='inner')
    # df['multiplicative_effect'] = np.power(df['effect_vl50'], weights[0]/denom) * np.power(df['effect_vl250'], weights[0]/denom) * np.power(df['effect_cd500'], weights[1]/denom)
    # df = df[df['multiplicative_effect'] >= 1]

    # df = df.sort_values(['multiplicative_effect'], ascending=False)
    df.columns = ['regimen', 'effect_vl50', 'effect_vl250', 'effect_cd500']
    df['regimen'] = df['regimen'].apply(lambda x:x.split('Regimen_')[1])
    return df

In [124]:
intervals_censored = pd.read_csv('intervals_censored.csv')
intervals_censored["id"] = range(len(intervals_censored))
black_male_cluster = pd.read_csv('black_male_cluster_labels.csv')
intervals_censored_black_male = pd.merge(intervals_censored, black_male_cluster, on='id', how='inner')


In [137]:
for cluster_label in range(3):
    interval_cluster = intervals_censored_black_male[intervals_censored_black_male['cluster_label'] == cluster_label]
    res = optimal_regimen(interval_cluster)
    res.to_csv(f'black_male/label={cluster_label}.csv', index=False)

Iteration 1: norm_delta = 6.13e-01, step_size = 0.9500, log_lik = -95465.37163, newton_decrement = 1.19e+03, seconds_since_start = 0.2
Iteration 2: norm_delta = 1.94e-01, step_size = 0.9500, log_lik = -94228.66555, newton_decrement = 7.93e+01, seconds_since_start = 0.4
Iteration 3: norm_delta = 2.17e-02, step_size = 0.9500, log_lik = -94147.33886, newton_decrement = 7.97e-01, seconds_since_start = 0.5
Iteration 4: norm_delta = 1.30e-03, step_size = 1.0000, log_lik = -94146.54097, newton_decrement = 2.52e-03, seconds_since_start = 0.7
Iteration 5: norm_delta = 5.39e-07, step_size = 1.0000, log_lik = -94146.53844, newton_decrement = 4.16e-10, seconds_since_start = 0.9
Convergence completed after 5 iterations.
Iteration 1: norm_delta = 5.54e-01, step_size = 0.9500, log_lik = -40672.60550, newton_decrement = 4.57e+02, seconds_since_start = 0.1
Iteration 2: norm_delta = 1.35e-01, step_size = 0.9500, log_lik = -40246.09027, newton_decrement = 1.97e+01, seconds_since_start = 0.2
Iteration 3: 