# Reproduction of Ml Hype or Hope Paper
- Goal: Reproduce Regression, XGBoost and Random Forest (RF) stats
- Step 1: Get data + merge
- Step 2: Train models (Regression, XGBoost, RF)
- Step 3: Figure out how to model w/ 

## Note: the variable names are freely available on the srtr data dictionary

# Data
- Location: 
- Time range: January 1, 2005 to December 31, 2017
- Who are we looking at?:
  - Kidney Transplant pts --> KI files
- What are we looking for?:
  - **Delayed graft function** = Need for dialysis within 1st week of transplant (post tx) --> Got it
  - **One-Year acute rejection** = All acute rejection episodes reported UP TO 1-Yr Follow Up (post tx) --> Binary outcome --> *Means count if it happend*
        - Got it: Filter by followup period + check rejection episode (Y/N)
  - **Death-Censored graft failure** = $\Delta t$ from first transplant to EITHER reinit dialysis OR another-transplant --> *Death-censoring?*
        - 
  - **All-Cause Graft failure** = $\Delta t$ from first transplant to EITHER reinit dialysis OR another transplant OR death
  - **Death** = $\Delta t$ first transplant to death???
- Start w/TX_KI and TXF_KI
- Demographics:
  - 18+: `rec_age_at_listing`
- Need to find:
  - Date of graft failure: `rec_prev_graft*_dt`, `rec_fail_dt`
  - Reason for graft failure:, `rec_fail_cause_ty` --> code 102, `tfl_dial_ty` --> not codes 1 or 998, `rec_resum_maint_dial` & `rec_resum_maint_dial_dt` (Might not use this since maitainence dialysis = not transplant yet according to google)
  - Delayed graft function: `rec_first_week_dial` (resumed dialysis after first week of transplant)
  - Deceased Donor (Y/N): `don_ty` --> C = deceased donor
  - Tx center: `rec_ctr_cd`
  - Tx date: `rec_tx_dt`, `rec_tx_org_ty` 
  - Acute rejection episodes during followup (Y/N): `tfl_acute_reg_episode`, `rec_acute_reg_episode` --> Not code 3
    - Followup period: `tfl_fol_cd` --> codes 1 (discharge) ,3 (3 months) ,6 (6 months),10 (1yr), 999 (death)
  - Death: `rec_px_stat`, `tfl_px_stat` --> code D & `rec_px_stat_dt`, `tfl_px_stat_dt`
  - Retransplant: `rec_px_stat`,`tfl_px_stat` --> code R
  - Might need: `CAN_PREV_KI`, `CAN_PREV_KI_TX_FUNCTN`, `CAN_PREV_TX`
  - Linking variables
     - candidate file + tx file: `px_id`
     - tx file to follow up file: `tx_id`
- **Variables by file**
- tx_ki: `["rec_prev_graft1_dt", "rec_fail_dt", "rec_fail_cause_ty", "rec_resum_maint_dial", "rec_resum_maint_dial_dt", "rec_first_week_dial", "don_ty", "rec_ctr_cd", "rec_ctr_ty", "rec_tx_dt", "rec_tx_org_ty", "rec_acute_rej_episode", "rec_px_stat", "rec_px_stat_dt", "can_prev_ki","can_prev_ki_tx_functn","tfl_death_dt"]`
- txf_ki:`["tfl_dial_ty", "tfl_acute_reg_episode","tfl_fol_cd","tfl_px_stat","tfl_px_stat_dt"]`

In [None]:
import pandas as pd
from datetime import datetime
#Removed paths for Release
donor_df = pd.read_stata()
tx_ki_raw = pd.read_stata()
txf_ki_raw = pd.read_stata()

# Covariates
- Donor variables: Need: donation after cardiac death
- Recipient variables: Need: pre-emptive transplant, time on dialysis

In [None]:
donor_vars = ["don_hist_hyperten", "don_hist_diab", "don_anti_hcv", "don_anti_cmv", "don_creat", "don_cod_don_stroke", "don_wgt_kg", "don_hgt_cm", "don_abo", "don_race", "don_gender", "don_age", "don_ethnicity_srtr"]
tx_ki_raw[donor_vars]

### Donation after cardiac death

In [None]:
donor_df

In [None]:
donor_df = donor_df[['donor_id','don_dcd_progress_to_brain_death']]
donor_df['DCD'] = donor_df['don_dcd_progress_to_brain_death'].apply(lambda x: 1 if x == '' else 0)

In [None]:
dcd_donor_ids = set(donor_df[donor_df['DCD'] == 1].donor_id)
tx_ki_raw['DCD'] = 1
tx_ki_raw['DCD'] = tx_ki_raw.donor_id.apply(lambda x: 1 if x in dcd_donor_ids else 0)

In [None]:
recipient_vars = ["rec_dr_mm_equiv_tx","rec_b_mm_equiv_tx","rec_a_mm_equiv_tx","rec_cold_isch_tm","can_education","can_tot_albumin","can_periph_vasc","rec_malig","can_drug_treat_hyperten","can_diab_ty","rec_pra_most_recent","rec_dial_dt","rec_prev_ki","rec_ebv_stat","rec_hcv_stat","rec_hbv_antibody","rec_cmv_stat","rec_hiv_stat","rec_bmi","rec_primary_pay","can_abo","rec_dgn","can_race","can_gender","rec_age_at_tx","can_ethnicity_srtr"]
tx_ki_raw[recipient_vars]

### Pre-emptive transplant

In [None]:
tx_ki_raw['rec_preemptive_tx'] = 0
tx_ki_raw['rec_preemptive_tx'] = tx_ki_raw.apply(
    lambda x: 1 if pd.isna(x['rec_dial_dt']) or x['rec_dial_dt'] > x['rec_tx_dt'] else 0,
    axis=1
)

### Time on Dialysis

In [None]:
tx_ki_raw['rec_time_on_dialysis'] = 0

tx_ki_raw['rec_time_on_dialysis'] = tx_ki_raw.apply(
    lambda x: 0 if pd.isna(x['rec_dial_dt']) or x['rec_dial_dt'] > x['rec_tx_dt'] else (x['rec_tx_dt']-x['rec_dial_dt']).days/365.25,
    axis=1
)

In [None]:
donor_vars.append('DCD')
recipient_vars.append('rec_preemptive_tx')
recipient_vars.append('rec_time_on_dialysis')

In [None]:
tx_ki_raw[donor_vars + recipient_vars]

# Outcomes

In [None]:
#Date boundaries
start_date = datetime(2005,1,1) #January 1, 2005
end_date = datetime(2017, 12, 31) # December 31, 2017

In [None]:
# Perfect replication of population
tx_var = ["px_id","tx_id","trr_id","rec_age_at_tx","pers_retx","rec_prev_graft1_dt", "rec_fail_dt", "rec_fail_cause_ty", "rec_resum_maint_dial", "rec_resum_maint_dial_dt", "rec_first_week_dial", "don_ty", "rec_ctr_cd", "rec_ctr_ty", "rec_tx_dt", "rec_tx_org_ty", "rec_acute_rej_episode", "rec_px_stat", "rec_px_stat_dt", "can_prev_ki","can_prev_ki_tx_functn","tfl_death_dt","pers_ssa_death_dt","pers_retx_trr_id","pers_optn_death_dt"] + donor_vars + recipient_vars
tx_ki_df = tx_ki_raw[tx_var][(start_date <= tx_ki_raw.rec_tx_dt) & (tx_ki_raw.rec_tx_dt <= end_date) & (tx_ki_raw.rec_age_at_tx >= '18-34') & (tx_ki_raw.don_ty == 'C') & (tx_ki_raw.rec_tx_org_ty == 'KI')]

In [None]:
txf_var = ["px_id","tx_id","trr_id","tfl_dial_ty", "tfl_acute_rej_episode","tfl_fol_cd","tfl_px_stat","tfl_px_stat_dt", "tfl_resum_maint_dial_dt"]
txf_ki_df = txf_ki_raw[txf_var][txf_ki_raw.tfl_fol_cd <= '          10: 1 YEAR']

## Delayed graft function
- rec_first_week_dial --> Y/N?

In [None]:
tx_ki_df['DGF'] = tx_ki_df.rec_first_week_dial.apply(lambda x: 1 if x == 'Y' else 0)

## One-Year Acute Rejection
- 'all acute rejection episodes reported up to one-year follow up.'
- 'treated as a binary outcome'
- 

In [None]:
txf_ki_df.tfl_acute_rej_episode.dtypes.categories

In [None]:
txf_ki_df.tfl_fol_cd.dtypes.categories

In [None]:
oya_rej_cats = txf_ki_df.tfl_acute_rej_episode.dtypes.categories
txf_ki_df[(txf_ki_df.tfl_acute_rej_episode == oya_rej_cats[0]) | (txf_ki_df.tfl_acute_rej_episode == oya_rej_cats[1])]

In [None]:
len(tx_ki_df.trr_id.unique())

In [None]:
def _oya_aggregation(s):
    return s.isin(oya_rej_cats[0:2]).any()

oya_test = txf_ki_df.groupby('trr_id')['tfl_acute_rej_episode'].agg(_oya_aggregation)

In [None]:
oya_test

In [None]:
def _ar_helper(x):
    try:
        if oya_test[x]:
            return 1
        else:
            return 0
    except:
        return 0

tx_ki_df['AR'] = tx_ki_df.trr_id.apply(lambda x: _ar_helper(x))

In [None]:
tx_ki_df

## DCGF
- "defined as the time from KT to graft failure (re-initiation of dialysis or re-KT), censoring for death"
- Death censoring: Suppose time of kt is at some $k > 0$ and death time is at some $d > k > 0$. Then we look at time of graft failure $g$ only if $g \in [k,d]$

In [None]:
tx_ki_raw[['pers_id','tx_id','trr_id','rec_tx_dt','tfl_death_dt','rec_fail_dt','pers_retx',"pers_ssa_death_dt","pers_retx_trr_id"]]#.dropna(subset=['tfl_death_dt','pers_ssa_death_dt'])

In [None]:
tx_ki_df['custom_death'] = tx_ki_df[['tfl_death_dt','pers_ssa_death_dt','pers_optn_death_dt']].min(axis=1)

In [None]:
tx_ki_df[tx_ki_df.custom_death.isna()]#.iloc[7]#.apply(lambda x: min(x.tfl_death_dt, x.pers_ssa_death_dt, x.pers_optn_death_dt ))

In [None]:
extra_deaths = txf_ki_raw[txf_ki_raw.trr_id.isin(tx_ki_df.trr_id) & (txf_ki_raw.tfl_px_stat == 'D')][['trr_id','tfl_px_stat_dt']]

In [None]:
id_list = set(tx_ki_df.trr_id)

In [None]:
tx_ki_df['extra_deaths'] = pd.NaT
for _,row in extra_deaths.iterrows():
    if row.trr_id in id_list:
        idx = tx_ki_df[tx_ki_df.trr_id == row.trr_id]
        tx_ki_df.loc[idx.index.values,'extra_deaths'] = row.tfl_px_stat_dt

In [None]:
tx_ki_df[tx_ki_df.extra_deaths.notna()]

In [None]:
tx_ki_df['custom_death'] = tx_ki_df[['tfl_death_dt','pers_ssa_death_dt','pers_optn_death_dt','extra_deaths']].min(axis=1)

In [None]:
extra_resume_times = txf_ki_raw[txf_ki_raw.trr_id.isin(tx_ki_df.trr_id)][['trr_id','tfl_resum_maint_dial_dt']].dropna()

In [None]:
tx_ki_df['tfl_resum_maint_dial_dt'] = pd.NaT
for _,row in extra_resume_times.iterrows():
    if row.trr_id in id_list:
        idx = tx_ki_df[tx_ki_df.trr_id == row.trr_id]
        tx_ki_df.loc[idx.index.values,'tfl_resum_maint_dial_dt'] = row.tfl_resum_maint_dial_dt

In [None]:
tx_ki_df['custom_resum_maint_dial'] = tx_ki_df[['rec_resum_maint_dial_dt','tfl_resum_maint_dial_dt']].min(axis=1)

In [None]:
tx_ki_df[tx_ki_df.tfl_resum_maint_dial_dt.notna()]

In [None]:
tx_ki_df[['tx_id','trr_id','rec_resum_maint_dial_dt','rec_resum_maint_dial','custom_death','pers_retx']]

In [None]:
import numpy as np
# For each patient:
# Check the tx file --> get the date of resuming maintainence dialysis, prior transplant date, death date, recall: end_date = datetime(2017, 12, 31) = December 31, 2017
# Check the followup file --> get the date of resuming maintanence dialysis, death date (should be same, but double check)
# Take min among maintanence dialysis from tx file, prior transplant date, and maintanence dialysis from followup file
# Take min among death date from tx and followup death dates

#function logic:
#Find boundary: min(maintanence, prior_tx, death)

#If arg 0 or 1 --> Graft fail happened, output: event, find graft fail - start time
#If arg 2 --> Death before graft fail, output: no event, find death - start time
#If arg 3 --> Alive at end, output: no event, find end time - start time
def _dcgf_helper(row):
    arg0, arg1, arg2 = row[['custom_resum_maint_dial','pers_retx','custom_death']]#.values[0]
    out = np.array([arg0,arg1,arg2,np.datetime64(end_date)])
    bound = np.min(out[~pd.isna(out)])

    if arg0 == bound or arg1 == bound:
        return 1, bound - row.rec_tx_dt
    elif arg2 == bound:
        return 0, bound - row.rec_tx_dt
    else:
        return 0, end_date - row.rec_tx_dt
        


tx_ki_df[["dcgf_event", "dcgf_time"]] = tx_ki_df.apply(lambda row: _dcgf_helper(row), axis=1, result_type='expand')

In [None]:
tx_ki_df[tx_ki_df.dcgf_event == 1][['rec_tx_dt','pers_retx','custom_death','rec_resum_maint_dial_dt','tfl_resum_maint_dial_dt','pers_optn_death_dt','dcgf_event','dcgf_time']]

## ACGF
- "defined as the time from KT to graft failure (defined above) or death"
- Study censored

In [None]:
def _acgf_helper(row):
    arg0, arg1, arg2 = row[['custom_resum_maint_dial','pers_retx','custom_death']]#.values[0]
    out = np.array([arg0,arg1,arg2,np.datetime64(end_date)])
    bound = np.min(out[~pd.isna(out)])

    if arg0 == bound or arg1 == bound:
        return 1, bound - row.rec_tx_dt
    elif arg2 == bound:
        return 1, bound - row.rec_tx_dt
    else:
        return 0, end_date - row.rec_tx_dt
        


tx_ki_df[["acgf_event", "acgf_time"]] = tx_ki_df.apply(lambda row: _acgf_helper(row), axis=1, result_type='expand')

## Death

In [None]:
def _death_time_helper(row):
    death_time = row[['custom_death']].values[0]
    if death_time <= end_date:
        return 1, death_time - row.rec_tx_dt
    else:
        return 0, end_date - row.rec_tx_dt
        

tx_ki_df[["death_event", "death_time"]] = tx_ki_df.apply(lambda row: _death_time_helper(row), axis=1, result_type='expand')

In [None]:
df = tx_ki_df[donor_vars+recipient_vars+['rec_ctr_cd','DGF','AR']+["dcgf_event", "dcgf_time"]+["acgf_event", "acgf_time"]+["death_event", "death_time"]]

In [None]:
df = df.loc[:, ~df.columns.duplicated()]

# Regression

In [None]:
reg_df = df.copy()

## Impute

### Clean categoricals

In [None]:
def _bin_to_numeric_helper(age_bin):
    if age_bin == '50-64':
        return 57.0
    elif age_bin == '35-49':
        return 42.0
    elif age_bin == '65+':
        return 80.5
    elif age_bin == '18-34':
        return 26.0
    
reg_df.rec_age_at_tx = reg_df.rec_age_at_tx.apply(lambda x: x if type(x) == float else _bin_to_numeric_helper(x))

In [None]:
def _map_to_binary_helper(val):
    val = str(val).strip()
    if val.startswith("1: NO"):
        return 0
    elif val.startswith("998"):
        return np.nan
    elif val.startswith(("2:", "3:", "4:", "5:")):
        return 1
    else:
        return np.nan  # catch anything unexpected

reg_df.don_hist_hyperten = reg_df.don_hist_hyperten.apply(_map_to_binary_helper)
reg_df.don_hist_diab = reg_df.don_hist_diab.apply(_map_to_binary_helper)

In [None]:
def _serology_to_binary(val):
    val = str(val).strip().upper()
    if val in ['P', 'PD']:
        return 1
    elif val == 'N':
        return 0
    elif val in ['I', 'ND', 'U', '']:
        return np.nan
    else:
        return np.nan  # catch unexpected values

reg_df.don_anti_cmv = reg_df.don_anti_cmv.apply(_serology_to_binary)
reg_df.don_anti_hcv = reg_df.don_anti_hcv.apply(_serology_to_binary)
reg_df.rec_ebv_stat = reg_df.rec_ebv_stat.apply(_serology_to_binary)
reg_df.rec_hcv_stat = reg_df.rec_hcv_stat.apply(_serology_to_binary)
reg_df.rec_hbv_antibody = reg_df.rec_hbv_antibody.apply(_serology_to_binary)
reg_df.rec_cmv_stat = reg_df.rec_cmv_stat.apply(_serology_to_binary)
reg_df.rec_hiv_stat = reg_df.rec_hiv_stat.apply(_serology_to_binary)

In [None]:
# Map labels exactly as in your data
race_map = {
    '8: White': 'White',
    '16: Black or African American': 'Black/African American',
    '1024: Unknown (for Donor Referral only)': np.nan
}

def map_race_label(val):
    if not isinstance(val, str):
        return np.nan
    val_clean = val.strip()  # remove leading/trailing spaces

    if val_clean in race_map:
        return race_map[val_clean]
    elif val_clean.startswith(('8:', '16:', '1024:')):
        # covers unmapped code strings in case
        return race_map.get(val_clean, 'Other')
    else:
        return 'Other'

def override_with_ethnicity(race, ethnicity):
    if str(ethnicity).strip() in ['LATINO']:
        return 'Hispanic/Latino'
    return race

reg_df.don_race = reg_df.don_race.apply(map_race_label)
reg_df.can_race = reg_df.can_race.apply(map_race_label)
reg_df.don_race = reg_df.apply(lambda row: override_with_ethnicity(row.don_race, row.don_ethnicity_srtr), axis=1)
reg_df.can_race = reg_df.apply(lambda row: override_with_ethnicity(row.can_race, row.can_ethnicity_srtr), axis=1)

In [None]:
def map_secondary_education(val):
    if not isinstance(val, str):
        return np.nan

    val = val.strip()

    if val in [
        '3: HIGH SCHOOL (9-12) or GED',
        '4: ATTENDED COLLEGE/TECHNICAL SCHOOL',
        '5: ASSOCIATE/BACHELOR DEGREE',
        '6: POST-COLLEGE GRADUATE DEGREE'
    ]:
        return 1
    elif val in [
        '1: NONE',
        '2: GRADE SCHOOL (0-8)'
    ]:
        return 0
    elif val in [
        '998: UNKNOWN',
        '996: N/A (< 5 YRS OLD)'
    ]:
        return np.nan
    else:
        return np.nan

reg_df.can_education = reg_df.can_education.apply(map_secondary_education)

In [None]:
def map_binary_outcome_general(val):
    val = str(val).strip().upper()
    if val == 'Y':
        return 1
    elif val == 'N':
        return 0
    else:  # includes 'U', '', or any unknown codes
        return np.nan

reg_df.can_periph_vasc = reg_df.can_periph_vasc.apply(map_binary_outcome_general)
reg_df.rec_malig = reg_df.rec_malig.apply(map_binary_outcome_general)
reg_df.can_drug_treat_hyperten = reg_df.can_drug_treat_hyperten.apply(map_binary_outcome_general)

In [None]:
def map_diabetes_status(val):
    if not isinstance(val, str):
        return np.nan

    val = val.strip()

    if val in [
        '2: Type I',
        '3: Type II',
        '4: Type Other',
        '5: Type Unknown'
    ]:
        return 1
    elif val == '1: No':
        return 0
    elif val == '998: Diabetes Status Unknown':
        return np.nan
    else:
        return np.nan

reg_df.can_diab_ty = reg_df.can_diab_ty.apply(map_diabetes_status)


In [None]:
def map_medicare_primary(val):
    if not isinstance(val, str):
        return np.nan

    val = val.strip()

    if val in [
        '3: Public insurance - Medicare FFS (Fee for Service)',
        '4: Public insurance - Medicare & Choice'
    ]:
        return 1
    else:
        return 0

reg_df.rec_primary_pay = reg_df.rec_primary_pay.apply(map_medicare_primary)

In [None]:
def map_esrd_cause(val):
    if not isinstance(val, str):
        return np.nan

    val = val.strip().upper()

    if '999' in val or val == '':
        return np.nan
    elif any(keyword in val for keyword in [
        'GLOMERULONEPHRITIS', 'GLOMERULOSCLEROSIS', 'IGA NEPHROPATHY',
        'FSG', 'ANTI-GBM', 'GOODPASTURE', 'RPGN', 'MESANGIO', 'MEMBRANOUS'
    ]):
        return 'Glomerulonephritis'
    elif any(keyword in val for keyword in [
        'DIABETES', 'INSULIN DEP', 'PANCREATITIS'
    ]):
        return 'Diabetes'
    elif 'HYPERTENSION' in val or 'HYPERTENSIVE' in val:
        return 'Hypertension'
    else:
        return 'Others'

reg_df.rec_dgn = reg_df.rec_dgn.apply(map_esrd_cause)

In [None]:
blood_map = {
    'A': 'A',
    'A1': 'A',
    'A2': 'A',
    'B': 'B',
    'AB': 'AB',
    'A1B': 'AB',
    'A2B': 'AB',
    'O': 'O'
}

# Example usage:
reg_df.don_abo = reg_df.don_abo.map(blood_map)
reg_df.can_abo = reg_df.can_abo.map(blood_map)

In [None]:
reg_df

### Impute missing vars

In [None]:
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer

covars = ["don_hist_hyperten", 
          "don_hist_diab", 
          "don_anti_hcv", 
          "don_anti_cmv", 
          "don_creat", 
          "don_cod_don_stroke", 
          "don_wgt_kg", 
          "don_hgt_cm", 
          "don_abo", 
          "don_race", 
          "don_gender", 
          "don_age"] + \
         ["rec_dr_mm_equiv_tx",
          "rec_b_mm_equiv_tx",
          "rec_a_mm_equiv_tx",
          "rec_cold_isch_tm",
          "can_education",
          "can_tot_albumin",
          "can_periph_vasc",
          "rec_malig",
          "can_drug_treat_hyperten",
          "can_diab_ty",
          "rec_pra_most_recent",
          "rec_prev_ki",
          "rec_ebv_stat",
          "rec_hcv_stat",
          "rec_hbv_antibody",
          "rec_cmv_stat",
          "rec_hiv_stat",
          "rec_bmi",
          "rec_primary_pay",
          "can_abo",
          "rec_dgn",
          "can_race",
          "can_gender",
          "rec_age_at_tx"]

# Subset the data
# remove    "rec_dial_dt",
df_subset = reg_df[covars]



In [None]:
for column in df_subset.select_dtypes(include=['object']).columns:
    print(column, list(df_subset[column].unique()))

In [None]:
mapping_dicts = {}

cat_mappings = {
    'don_abo': ['A', 'O', 'AB', 'B'],  # Reference 'A'
    'don_race': ['White', 'Black/African American', 'Hispanic/Latino', 'Other'],  # Reference 'White'
    'don_gender': ['M', 'F'],
    'can_abo': ['A', 'O', 'AB', 'B'],  # Reference 'A'
    'rec_dgn': ['Glomerulonephritis', 'Diabetes', 'Hypertension', 'Others'],  # Reference 'Glomerulonephritis'
    'can_race': ['White', 'Black/African American', 'Hispanic/Latino', 'Other'],  # Reference 'White'
    'can_gender': ['M', 'F']
}

def convert_to_nominal_cat(df, col, categories):
    df[col] = pd.Categorical(df[col], categories=categories, ordered=False)
    mapping_dicts[col] = dict(enumerate(df[col].cat.categories))
    df[col] = df[col].cat.codes.replace(-1, np.nan)
    
for col, cats in cat_mappings.items():
    convert_to_nominal_cat(df_subset, col, cats)


In [None]:
for column in df_subset.select_dtypes(include=['object']).columns:
    print(column, list(df_subset[column].unique()))

In [None]:
categorical_covars = ["don_hist_hyperten",
"don_hist_diab",
"don_anti_cmv",
"don_anti_hcv",
"rec_ebv_stat",
"rec_hcv_stat",
"rec_hbv_antibody",
"rec_cmv_stat",
"rec_hiv_stat",
"can_education",
"can_periph_vasc",
"rec_malig",
"can_drug_treat_hyperten",
"can_diab_ty",
"rec_primary_pay",
"rec_dgn",
"don_abo",
"don_gender",
"don_race",
"can_abo",
"can_gender",
"can_race",
"rec_dr_mm_equiv_tx",
"rec_a_mm_equiv_tx",
"rec_b_mm_equiv_tx",
"don_cod_don_stroke"]

In [None]:
random_state=42
# Set up the imputer
imputer = IterativeImputer(max_iter=10, random_state=random_state)

# Fit and transform the subset
imputed_array = imputer.fit_transform(df_subset)

# Convert back to DataFrame
df_imputed_subset = pd.DataFrame(imputed_array, columns=covars, index=reg_df.index)

for col in categorical_covars:
    df_imputed_subset[col] = df_imputed_subset[col].round().astype(int)

for col in df_imputed_subset.columns:
    df_subset[col] = df_imputed_subset[col]

for col in df_subset.columns:
    reg_df[col] = df_subset[col]

In [None]:
reg_df = reg_df.drop(columns=['don_ethnicity_srtr','can_ethnicity_srtr'])

In [None]:
reg_df.dtypes

## Splines

In [None]:
def linear_spline(df, var, knots):
    """
    Create linear spline columns for var in df with given knots list.
    Returns DataFrame with spline columns.
    """
    df_splines = pd.DataFrame()
    df_splines[var + '_spline_1'] = df[var]
    for i, k in enumerate(knots):
        df_splines[f'{var}_spline_{i+2}'] = (df[var] - k).clip(lower=0)
    return df_splines

# Example usage
don_age_splines = linear_spline(reg_df, 'don_age', knots=[30, 60])
don_hgt_cm_splines = linear_spline(reg_df, 'don_hgt_cm', knots=[170])
don_wgt_kg_splines = linear_spline(reg_df, 'don_wgt_kg', knots=[80])
don_creat_splines = linear_spline(reg_df, 'don_creat', knots=[0.8])
rec_age_splines = linear_spline(reg_df, 'rec_age_at_tx', knots=[30, 60])
rec_bmi_splines = linear_spline(reg_df, 'rec_bmi', knots=[30])
rec_serum_albumin_splines = linear_spline(reg_df, 'can_tot_albumin', knots=[3])
rec_pra_splines = linear_spline(reg_df, 'rec_pra_most_recent', knots=[80,90])
rec_cold_splines = linear_spline(reg_df, 'rec_cold_isch_tm', knots=[6,36])
rec_dial_splines = linear_spline(reg_df.loc[reg_df['rec_time_on_dialysis'] > 0], 'rec_time_on_dialysis', [2,6])

In [None]:
reg_df.loc[reg_df['rec_time_on_dialysis'] > 0]

In [None]:
reg_df[reg_df.rec_time_on_dialysis.isna()]

## Reference Categories

In [None]:
mapping_dicts.pop('don_gender')
mapping_dicts.pop('can_gender')

In [None]:
mapping_dicts

In [None]:
for col, cats in mapping_dicts.items():
    reg_df[col + '_cat'] = reg_df[col].map(mapping_dicts[col])

In [None]:
reg_df.rec_ctr_cd

In [None]:
def one_hot_encode(df, var, ref):
    dummies = pd.get_dummies(df[var], prefix=var)
    ref_col = f"{var}_{ref}"
    if ref_col in dummies.columns:
        dummies = dummies.drop(columns=[f'{var}_{ref}'])
    return dummies

don_race_dummies = one_hot_encode(reg_df, 'don_race_cat', 'White')
don_abo_dummies = one_hot_encode(reg_df, 'don_abo_cat', 'A')
can_race_dummies = one_hot_encode(reg_df, 'can_race_cat', 'White')
can_abo_dummies = one_hot_encode(reg_df, 'can_abo_cat', 'A')
can_dgn_dummies = one_hot_encode(reg_df, 'rec_dgn_cat', 'Glomerulonephritis')
can_a_mm_dummies = one_hot_encode(reg_df, 'rec_a_mm_equiv_tx',0)
can_b_mm_dummies = one_hot_encode(reg_df, 'rec_b_mm_equiv_tx',0)
can_dr_mm_dummies = one_hot_encode(reg_df, 'rec_dr_mm_equiv_tx',0)

In [None]:
binary_vars=['don_hist_hyperten', 
 'don_hist_diab', 
 'don_anti_hcv', 
 'don_anti_cmv',
 'don_cod_don_stroke',
 'don_gender', 
 'DCD',
 'can_education', 
 'can_periph_vasc', 
 'rec_malig',
 'can_drug_treat_hyperten', 
 'can_diab_ty', 
 'rec_prev_ki',
 'rec_ebv_stat', 
 'rec_hcv_stat', 
 'rec_hbv_antibody', 
 'rec_cmv_stat',
 'rec_hiv_stat', 
 'rec_primary_pay', 
 'can_gender', 
 'DGF', 
 'AR',
 'dcgf_event', 
 'dcgf_time',
 'acgf_event',
 'acgf_time',
 'death_event',
 'death_time',
 'rec_ctr_cd']

reg_df[binary_vars]

In [None]:
# Get list of spline columns
spline_cols = [col for col in reg_df.columns if col.startswith('rec_time_on_dialysis_spline')]

# Mask splines where preemptive transplant
reg_df[spline_cols] = reg_df[spline_cols].multiply(1 - reg_df['preemptive_tx'], axis=0)


In [None]:
df_model = pd.concat([
don_age_splines,
don_hgt_cm_splines,
don_wgt_kg_splines,
don_creat_splines,
rec_age_splines,
rec_bmi_splines,
rec_serum_albumin_splines,
rec_pra_splines,
rec_cold_splines,
rec_dial_splines,
don_race_dummies,
don_abo_dummies,
can_race_dummies,
can_abo_dummies,
can_dgn_dummies,
can_a_mm_dummies,
can_b_mm_dummies,
can_dr_mm_dummies,
reg_df[binary_vars]
], axis=1)
df_model.fillna(0,inplace=True)

# Split

In [None]:
random_state = 42

from sklearn.model_selection import train_test_split

centers = df_model.rec_ctr_cd.unique()

train_centers, val_centers = train_test_split(centers, test_size=0.3,random_state=random_state)

train_df = df_model[df_model.rec_ctr_cd.isin(train_centers)].copy()
val_df = df_model[df_model.rec_ctr_cd.isin(val_centers)].copy()

print(train_df.shape)
print(val_df.shape)

## Logistic Regression (AR and DGF)

In [None]:
non_covars = ['DGF', 'AR', 'dcgf_event', 'dcgf_time', 'acgf_event', 'acgf_time', 'death_event', 'death_time', 'rec_ctr_cd']

In [None]:
predictors = [col for col in train_df.columns if col not in non_covars]
predictors

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, roc_auc_score

X_train = train_df[predictors]
X_valid = val_df[predictors]

y_train_dgf = train_df['DGF']
y_valid_dgf = val_df['DGF']

# Fit model
model_dgf = LogisticRegression(max_iter=10000, class_weight='balanced')  # optional: class_weight
model_dgf.fit(X_train, y_train_dgf)

# Predict and evaluate
y_pred_dgf = model_dgf.predict(X_valid)
y_prob_dgf = model_dgf.predict_proba(X_valid)[:, 1]

print("DGF Classification Report:")
print(classification_report(y_valid_dgf, y_pred_dgf))
print("DGF ROC AUC:", roc_auc_score(y_valid_dgf, y_prob_dgf))


In [None]:
y_train_dgf = train_df['AR']
y_valid_dgf = val_df['AR']

# Fit model
model_dgf = LogisticRegression(max_iter=10000, class_weight='balanced')  # optional: class_weight
model_dgf.fit(X_train, y_train_dgf)

# Predict and evaluate
y_pred_dgf = model_dgf.predict(X_valid)
y_prob_dgf = model_dgf.predict_proba(X_valid)[:, 1]
y_train_dgf = train_df['AR']
y_valid_dgf = val_df['AR']

print("AR Classification Report:")
print(classification_report(y_valid_dgf, y_pred_dgf))
print("AR ROC AUC:", roc_auc_score(y_valid_dgf, y_prob_dgf))


# TabPFN

In [None]:
categorical_covars = [
"don_hist_diab",
"don_anti_cmv",
"don_anti_hcv",
"rec_ebv_stat",
"rec_hcv_stat",
"rec_hbv_antibody",
"rec_cmv_stat",
"rec_hiv_stat",
"can_education",
"can_periph_vasc",
"rec_malig",
"can_drug_treat_hyperten",
"can_diab_ty",
"rec_primary_pay",
"rec_dr_mm_equiv_tx",
"rec_a_mm_equiv_tx",
"rec_b_mm_equiv_tx",
"don_cod_don_stroke"]
for col in categorical_covars:
    reg_df[col] = reg_df[col].astype(int)

In [None]:
tab_pfn_df = reg_df[['don_age',
 'don_hgt_cm',
 'don_wgt_kg',
 'don_creat',
 'rec_age_at_tx',
 'rec_bmi',
 'can_tot_albumin',
 'rec_pra_most_recent',
 'rec_cold_isch_tm',
 'rec_time_on_dialysis',
 'don_race',
 'don_abo',
 'can_race',
 'can_abo',
 'rec_dgn',
 'rec_a_mm_equiv_tx',
 'rec_b_mm_equiv_tx',
 'rec_dr_mm_equiv_tx',
 'don_hist_hyperten',
 'don_hist_diab',
 'don_anti_hcv',
 'don_anti_cmv',
 'don_cod_don_stroke',
 'don_gender',
 'DCD',
 'can_education',
 'can_periph_vasc',
 'rec_malig',
 'can_drug_treat_hyperten',
 'can_diab_ty',
 'rec_prev_ki',
 'rec_ebv_stat',
 'rec_hcv_stat',
 'rec_hbv_antibody',
 'rec_cmv_stat',
 'rec_hiv_stat',
 'rec_primary_pay',
 'can_gender',
 'DGF','AR','rec_ctr_cd']]

In [None]:
predictors = ['don_age',
 'don_hgt_cm',
 'don_wgt_kg',
 'don_creat',
 'rec_age_at_tx',
 'rec_bmi',
 'can_tot_albumin',
 'rec_pra_most_recent',
 'rec_cold_isch_tm',
 'rec_time_on_dialysis',
 'don_race',
 'don_abo',
 'can_race',
 'can_abo',
 'rec_dgn',
 'rec_a_mm_equiv_tx',
 'rec_b_mm_equiv_tx',
 'rec_dr_mm_equiv_tx',
 'don_hist_hyperten',
 'don_hist_diab',
 'don_anti_hcv',
 'don_anti_cmv',
 'don_cod_don_stroke',
 'don_gender',
 'DCD',
 'can_education',
 'can_periph_vasc',
 'rec_malig',
 'can_drug_treat_hyperten',
 'can_diab_ty',
 'rec_prev_ki',
 'rec_ebv_stat',
 'rec_hcv_stat',
 'rec_hbv_antibody',
 'rec_cmv_stat',
 'rec_hiv_stat',
 'rec_primary_pay',
 'can_gender']

## Subsampling

In [None]:
random_state = 42

from sklearn.model_selection import train_test_split

centers = tab_pfn_df.rec_ctr_cd.unique()

train_centers, val_centers = train_test_split(centers, test_size=0.3,random_state=random_state)

train_df_tab = tab_pfn_df[tab_pfn_df.rec_ctr_cd.isin(train_centers)].copy()
val_df_tab = tab_pfn_df[tab_pfn_df.rec_ctr_cd.isin(val_centers)].copy()

print(train_df_tab.shape)
print(val_df_tab.shape)

X_train = train_df_tab[predictors]
X_valid = val_df_tab[predictors]

y_train_dgf = train_df_tab['DGF']
y_valid_dgf = val_df_tab['DGF']

In [None]:
indices = [train_df_tab.columns.get_loc(col) for col in categorical_covars]
print(indices)

In [None]:
from tabpfn import TabPFNClassifier
from sklearn.metrics import roc_auc_score, accuracy_score

# Initialize and train classifier
clf_base = TabPFNClassifier(device='cuda',
                             categorical_features_indices=indices,
                             ignore_pretraining_limits=True,
                             inference_config = {"SUBSAMPLE_SAMPLES":6000})

  
clf_base.fit(X_train, y_train_dgf)

preds = clf_base.predict_proba(X_valid)
y_eval = np.argmax(preds, axis=1)


print('ROC AUC: ',  roc_auc_score(y_valid_dgf, preds[:,1]), 'Accuracy', accuracy_score(y_valid_dgf, y_eval))

In [None]:
from tabpfn import TabPFNClassifier
y_train_ar = train_df_tab['AR']
y_valid_ar = val_df_tab['AR']

# Initialize and train classifier
clf = TabPFNClassifier(device='cuda',
                             categorical_features_indices=indices,
                             ignore_pretraining_limits=True,
                             inference_config = {"SUBSAMPLE_SAMPLES": 15000})
clf.fit(X_train, y_train_ar)

preds = clf.predict_proba(X_valid)
y_eval = np.argmax(preds, axis=1)

from sklearn.metrics import roc_auc_score, accuracy_score
print('ROC AUC: ',  roc_auc_score(y_valid_ar, preds[:,1]), 'Accuracy', accuracy_score(y_valid_ar, y_eval))

In [None]:
X_train.shape[0]

## TabPFN-DT

In [None]:
random_state = 42

from sklearn.model_selection import train_test_split

centers = tab_pfn_df.rec_ctr_cd.unique()

train_centers, val_centers = train_test_split(centers, test_size=0.3,random_state=random_state)

train_df_tab = tab_pfn_df[tab_pfn_df.rec_ctr_cd.isin(train_centers)].copy()
val_df_tab = tab_pfn_df[tab_pfn_df.rec_ctr_cd.isin(val_centers)].copy()

print(train_df_tab.shape)
print(val_df_tab.shape)

X_train = train_df_tab[predictors]
X_valid = val_df_tab[predictors]

y_train_dgf = train_df_tab['DGF']
y_valid_dgf = val_df_tab['DGF']

In [None]:
from sklearn.tree import DecisionTreeClassifier
from collections import defaultdict

def train_tree_tabpfn(X_train, y_train):
    n_samples = X_train.shape[0]
    max_depth = 3 if n_samples < 100_000 else 10

    # Step 1: Train shallow decision tree
    tree = DecisionTreeClassifier(max_depth=max_depth)
    tree.fit(X_train, y_train)

    # Step 2: Find leaf indices for all training data
    leaf_indices = tree.apply(X_train)

    # Step 3: Train a  model per leaf
    tabpfn_models = {}
    for leaf_id in set(leaf_indices):
        idx = (leaf_indices == leaf_id)
        X_leaf, y_leaf = X_train[idx], y_train[idx]

        model = TabPFNClassifier(device='cuda', categorical_features_indices=indices,ignore_pretraining_limits=True, inference_config = {"SUBSAMPLE_SAMPLES": 6000})
        model.fit(X_leaf, y_leaf)
        tabpfn_models[leaf_id] = model

    return tree, tabpfn_models


def predict_tree_tabpfn(tree, tabpfn_models, X_test):
    leaf_indices = tree.apply(X_test)
    y_pred = []

    for i, leaf_id in enumerate(leaf_indices):
        model = tabpfn_models.get(leaf_id)

        if model is None:
            raise ValueError(f"No TabPFN_v2 model found for leaf {leaf_id}")

        pred = model.predict_proba(X_test.iloc[[i]])
        y_pred.append(pred[0])

    return y_pred

def predict_tree_tabpfn_batch(tree, tabpfn_models, X_test):
    # Step 1: Get the leaf node for each test sample
    leaf_ids = tree.apply(X_test)

    # Step 2: Group test indices by leaf ID
    leaf_to_test_indices = defaultdict(list)
    for i, leaf_id in enumerate(leaf_ids):
        leaf_to_test_indices[leaf_id].append(i)

    # Step 3: Run inference in batch per leaf
    y_pred = np.zeros((len(X_test),2))
    for leaf_id, indices in leaf_to_test_indices.items():
        model = tabpfn_models.get(leaf_id)
        if model is None:
            raise ValueError(f"No TabPFN model found for leaf {leaf_id}")

        # Batch inference for all test samples in this leaf
        X_leaf = X_test.iloc[indices]
        preds = model.predict_proba(X_leaf)  # or model.predict() if returning class only

        # Store predictions in the right positions
        y_pred[indices] = preds

    return y_pred


In [None]:
tree, tabpfn_models = train_tree_tabpfn(X_train, y_train_dgf)

preds = predict_tree_tabpfn_batch(tree, tabpfn_models, X_valid)
y_eval = np.argmax(preds, axis=1)

from sklearn.metrics import roc_auc_score, accuracy_score
print('ROC AUC: ',  roc_auc_score(y_valid_dgf, preds[:,1]), 'Accuracy', accuracy_score(y_valid_dgf, y_eval))

In [None]:
tree, tabpfn_models = train_tree_tabpfn(X_train, y_train_ar)

preds = predict_tree_tabpfn_batch(tree, tabpfn_models, X_valid)
y_eval = np.argmax(preds, axis=1)

from sklearn.metrics import roc_auc_score, accuracy_score
print('ROC AUC: ',  roc_auc_score(y_valid_ar, preds[:,1]), 'Accuracy', accuracy_score(y_valid_ar, y_eval))

## TabPFN-SQ

In [None]:
import numpy as np
from tabpfn_extensions import TabPFNClassifier
from tabpfn_extensions.embedding import TabPFNEmbedding
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split


n_repeats = 4
predictions_list = []
X_train_full = tab_pfn_df[predictors]
y_train_full = tab_pfn_df['DGF']

clf = TabPFNClassifier(device='cuda')
tabpfn_model = TabPFNEmbedding(tabpfn_clf=clf, n_fold=0)
for seed in range(n_repeats):
    np.random.seed(seed)

# Sample 10,000 support points 
    support_indices = np.random.choice(len(X_train_full), size=10000, replace=False)
    query_indices = np.setdiff1d(np.arange(len(X_train_full)), support_indices)

    X_support = X_train_full.iloc[support_indices]
    y_support = y_train_full.iloc[support_indices]

    X_query = X_train_full.iloc[query_indices]
    y_query = y_train_full.iloc[query_indices]


    query_embeddings = tabpfn_model.get_embeddings(
        X_train=X_support,
        y_train=y_support,
        X=X_query,
        data_source='train'
    )

    test_embeddings = tabpfn_model.get_embeddings(
        X_train=X_support,
        y_train=y_support,
        query_features=X_valid,
        data_source='test'
    )

    clf = LogisticRegression(max_iter=500)
    clf.fit(query_embeddings, y_query)


    y_pred = clf.predict_proba(test_embeddings)
    predictions_list.append(y_pred)

final_probs = np.mean(predictions_list, axis=0)
final_preds = np.argmax(final_probs, axis=1)

roc = roc_auc_score(y_valid_dgf, final_preds)
accuracy = accuracy_score(y_valid_dgf, final_preds)
print(f"TabPFN v2*-SQ accuracy: {accuracy:.4f}")
print(f"TabPFN v2*-SQ auroc: {roc:.4f}")

In [None]:
# Too lazy to wrap in function...
n_repeats = 4
predictions_list = []
X_train_full = tab_pfn_df[predictors]
y_train_full = tab_pfn_df['AR']


for seed in range(n_repeats):
    np.random.seed(random_state)

    support_indices = np.random.choice(len(X_train_full), size=10_000, replace=False)
    query_indices = np.setdiff1d(np.arange(len(X_train_full)), support_indices)

    X_support = X_train_full[support_indices]
    y_support = y_train_full[support_indices]

    X_query = X_train_full[query_indices]
    y_query = y_train_full[query_indices]

    query_embeddings = tabpfn_model.get_embeddings(
        X_train=X_support,
        y_train=y_support,
        X=X_query
    )

    test_embeddings = tabpfn_model.get_embeddings(
        X_train=X_support,
        y_train=y_support,
        query_features=X_valid
    )

    clf = LogisticRegression(max_iter=500)
    clf.fit(query_embeddings, y_query)

    y_pred = clf.predict_proba(test_embeddings)
    predictions_list.append(y_pred)

final_probs = np.mean(predictions_list, axis=0)
final_preds = np.argmax(final_probs, axis=1)

roc = roc_auc_score(y_valid_dgf, final_preds)
accuracy = accuracy_score(y_valid_dgf, final_preds)
print(f"TabPFN v2*-SQ accuracy: {accuracy:.4f}")
print(f"TabPFN v2*-SQ auroc: {roc:.4f}")