In [39]:
import numpy as np
import pandas as pd

In [258]:
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

class CASE(BaseEstimator, TransformerMixin):
    def __init__(self, study_period, period_type='year'):
        """
        Initialize the CASE Transformer with period settings.
        
        :param study_period: Maximum number of periods to consider for oversampling.
        :param period_type: Type of period to use ('year', 'month', '6-month', '3-month').
        """
        self.study_period = study_period
        self.period_type = period_type
        self.period_lengths = {
            'year': 365,
            'month': 30,
            '6-month': 182,
            '3-month': 91
        }
        if period_type not in self.period_lengths:
            raise ValueError("Invalid period_type. Choose from 'year', 'month', '6-month', '3-month'.")
        self.period_length = self.period_lengths[period_type]
        self.event_field = None
        self.time_field = None
        self.transformation_map = None
        self.classifier = None
        self.regressor = None
        

    def _detect_fields(self, y):
        """
        Detect the event and time fields in the structured array y.
        
        :param y: Structured array with event and time fields.
        """
        for name, dtype in y.dtype.fields.items():
            if np.issubdtype(dtype[0], np.bool_) or np.issubdtype(dtype[0], np.integer):
                self.event_field = name
            elif np.issubdtype(dtype[0], np.floating):
                self.time_field = name
        
        if not self.event_field or not self.time_field:
            raise ValueError("Unable to detect the event or time field in y.")

    def transform(self, X, y, is_test_data = False):
        """
        Transform survival data into classification data by oversampling for each period.
        
        :param X: Feature matrix (array-like, shape = (n_samples, n_features)).
        :param y: Structured array with event and time fields.
        :return: Augmented feature matrix (X_aug), augmented target vector (y_aug).
        """
        # Detect event and time fields if not already set
        if self.event_field is None or self.time_field is None:
            self._detect_fields(y)
        
        if self.transformation_map is None:
            self.transformation_map = {}
            self.transformation_map['test'] = {}
            self.transformation_map['train'] = {}

        # Convert X to a NumPy array if it's a DataFrame
        if not isinstance(X, pd.DataFrame):
            raise ValueError("X must be a Pandas dataframe.")

        # Initialize lists to store augmented data and targets
        X_aug, y_aug = [], []

        # Iterate over the records
        for idx, (row,row_y) in enumerate(zip(X.iterrows(), y)):
            
            # Extract event and survival time for the current record
            event = row_y[self.event_field]
            
            surv_period = int(np.floor(row_y[self.time_field] / self.period_length))
            
            # Determine the range of periods for oversampling
            if event or is_test_data:
                periods = np.arange(0, self.study_period + 2)
            else:
                periods = np.arange(0, surv_period + 2)
            
            # Track the indices of the augmented data for each original record
            aug_indices = []

            # Duplicate the record for each period
            
            for period in periods:
                row_copy = row[1].copy()
                row_copy['period'] = period

                # Add to augmented data
                X_aug.append(row_copy.values)
                target = 1 if period <= surv_period else 0
                y_aug.append(target)

                # Track the index for mapping
                aug_indices.append(len(X_aug) - 1)

            # Update the augmentation map
            if is_test_data:
                self.transformation_map['test'][row[0]] = aug_indices
            else:
                self.transformation_map['train'][row[0]] = aug_indices

        # Convert the augmented data to Pandas
        X_aug_df = pd.DataFrame(X_aug, columns=X.columns.tolist()+['period'])
        
        # Return Augmented data
        return X_aug_df, y_aug


    def inverse_transform(self, preds, is_test_data= False):
        """
        De-augment the predicted survival probabilities to reconstruct individual survival curves.

        :param preds: Predicted probabilities for the augmented data.
        :return: De-augmented survival curves for original records.
        """
        if self.transformation_map is None:
            raise ValueError("No augmentation map found. Ensure that the model has been fitted.")

        survival_curves = {}

        # Iterate over the original record indices
        current_map = self.transformation_map['test'] if is_test_data else self.transformation_map['train']
        for idx, aug_indices in current_map.items():
            # Extract the predicted probabilities for the current original record
            pred_probs = preds[aug_indices]

            # Calculate the cumulative survival probability for each period
            survival_curve = [np.prod(pred_probs[:period]) for period in range(1, len(pred_probs) + 1)]

            # Store the survival curve for the original record
            survival_curves[idx] = survival_curve
            
        return np.array(survival_curves)
    
    def fit_classifier(self, X_aug, y_aug, classifier):
        """
        Fit the CASE model by augmenting data and training a classification model.
        
        :param X: Feature matrix.
        :param T: Time-to-event vector.
        :param E: Event indicator vector (1 if event occurred, 0 if censored).
        """
        self.classifier = classifier
        self.fitted_clf = self.classifier.fit(X_aug, y_aug)
        return self
    
    def predict_survival_function(self, X):
        """
        Predict survival probabilities for new data.

        :param X: Feature matrix.
        :return: Predicted survival probabilities for each time point.
        """
        survival_probs = []

        for x in X:
            probs = []
            for tau in range(1, self.study_period + 1):
                x_aug = np.append(x, tau).reshape(1, -1)
                prob = self.fitted_classifier.predict_proba(x_aug)[0, 1]  # Probability of survival
                probs.append(prob)
            survival_probs.append(probs)

        return np.array(survival_probs)


    def fit_regression(self, X, survival_probs):
        """
        Fit a regression model on the combined dataset to predict survival times.
        
        :param X: Original feature matrix.
        :param survival_probs: Predicted survival probabilities.
        :return: Fitted regression model.
        """
        y_reg = np.argmax(survival_probs < 0.5, axis=1)  # Survival time estimation
        self.regressor = RandomForestRegressor().fit(X, y_reg)
        return self

    def predict_survival_times(self, X):
        """
        Predict exact survival times using the fitted regression model.

        :param X: Feature matrix.
        :return: Predicted survival times.
        """
        if self.regressor is None:
            raise ValueError("Regression model is not fitted. Call fit_regression() first.")

        return self.regressor.predict(X)

    def predict_survival_function_for_censored_training_data(self, X_tr, record_probs):
        """
        Handle records with incomplete probability lists by constructing new augmented samples and predicting.
        
        :param record_probs: Dictionary with records as keys and probability lists as values.
        :param X_test: Original test set (DataFrame).
        :param study_period: Total number of periods to be considered.
        :return: Updated record probabilities.
        """
        # Iterate over the records in the probability dictionary
        for record, probs in record_probs.items():
            current_len = len(probs)

            # Check if the current list size is less than the study period
            if current_len < self.study_period:
                # Calculate the number of missing periods
                missing_periods = self.study_period - current_len

                # Get the original test record
                test_record = X_tr.iloc[record].copy()

                # Construct new augmented samples for the missing periods
                new_aug_samples = []
                for period in range(current_len + 1, self.study_period + 1):
                    test_record['period'] = period
                    new_aug_samples.append(test_record.values)

                # Convert new augmented samples to a NumPy array
                new_aug_samples = np.array(new_aug_samples)

                # Query the model to get probabilities for the missing periods
                new_probs = self.classifier.predict_proba(new_aug_samples)[:, 1]

                # Update the record's probability list with the new probabilities
                record_probs[record].extend(new_probs.tolist())

        return record_probs

In [246]:
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.datasets import load_breast_cancer

data_x, data_y = load_veterans_lung_cancer()
data_y[:2]

array([( True,  72.), ( True, 411.)],
      dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

In [247]:
from sklearn.preprocessing import OneHotEncoder
enc = OneHotEncoder(handle_unknown='ignore')
enc.fit(data_x.select_dtypes('category'))
data_x_encoded = pd.concat([data_x.select_dtypes(exclude='category'), pd.DataFrame(enc.transform(data_x.select_dtypes('category')).toarray(),
                              columns=enc.get_feature_names_out())],axis=1)


In [248]:
# Set parameters for the CASE model
study_period = 3
period_type = 'year'

# Create an instance of the CASE class
case_model = CASE(
    study_period=study_period, 
    period_type=period_type
)

In [249]:
data_x_tr = data_x_encoded[:100]
data_y_tr = data_y[:100]
data_x_test = data_x_encoded[100:]
data_y_test = data_y[100:]

In [250]:
data_y_test

array([( True,  99.), ( True,  61.), ( True,  25.), ( True,  95.),
       ( True,  80.), ( True,  51.), ( True,  29.), ( True,  24.),
       ( True,  18.), (False,  83.), ( True,  31.), ( True,  51.),
       ( True,  90.), ( True,  52.), ( True,  73.), ( True,   8.),
       ( True,  36.), ( True,  48.), ( True,   7.), ( True, 140.),
       ( True, 186.), ( True,  84.), ( True,  19.), ( True,  45.),
       ( True,  80.), ( True,  52.), ( True, 164.), ( True,  19.),
       ( True,  53.), ( True,  15.), ( True,  43.), ( True, 340.),
       ( True, 133.), ( True, 111.), ( True, 231.), ( True, 378.),
       ( True,  49.)],
      dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

In [251]:
x_case_tr,y_case_tr = case_model.transform(data_x_tr, data_y_tr)
x_case_test,y_case_test = case_model.transform(data_x_test, data_y_test, is_test_data=True)

In [252]:
case_model.fit(x_case_tr,y_case_tr, RandomForestClassifier())

CASE(study_period=3)

In [253]:
x_preds = case_model.fitted_classifier.predict_proba(x_case_test)[:,1]

In [254]:
case_model.inverse_transform(x_preds, is_test_data=True)

array({100: [0.91, 0.0364, 0.0, 0.0, 0.0], 101: [0.91, 0.0637, 0.0006370000000000001, 0.0, 0.0], 102: [0.93, 0.0651, 0.0006510000000000001, 0.0, 0.0], 103: [0.95, 0.08549999999999999, 0.00171, 3.42e-05, 6.839999999999999e-07], 104: [0.93, 0.0279, 0.0, 0.0, 0.0], 105: [0.86, 0.043000000000000003, 0.00043000000000000004, 4.3e-06, 4.3e-08], 106: [0.97, 0.097, 0.00388, 0.0, 0.0], 107: [0.92, 0.1748, 0.005244, 0.00010488000000000001, 2.0976000000000005e-06], 108: [0.91, 0.1274, 0.007644000000000001, 0.00022932, 2.2932e-06], 109: [0.95, 0.418, 0.02926, 0.0002926, 2.926e-06], 110: [0.9, 0.06300000000000001, 0.0012600000000000003, 2.5200000000000006e-05, 2.520000000000001e-07], 111: [0.97, 0.0776, 0.0, 0.0, 0.0], 112: [0.85, 0.0935, 0.0056099999999999995, 0.00011219999999999999, 2.244e-06], 113: [0.92, 0.1748, 0.0034960000000000004, 0.0, 0.0], 114: [0.93, 0.1023, 0.0, 0.0, 0.0], 115: [0.95, 0.08549999999999999, 0.00171, 0.0, 0.0], 116: [0.93, 0.037200000000000004, 0.0007440000000000001, 0.0, 0

In [257]:
data_x_test

Unnamed: 0,Age_in_years,Karnofsky_score,Months_from_Diagnosis,Celltype_adeno,Celltype_large,Celltype_smallcell,Celltype_squamous,Prior_therapy_no,Prior_therapy_yes,Treatment_standard,Treatment_test
100,62.0,85.0,4.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0
101,71.0,70.0,2.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0
102,70.0,70.0,2.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0
103,61.0,70.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0
104,71.0,50.0,17.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0
105,59.0,30.0,87.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0
106,67.0,40.0,8.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0
107,60.0,40.0,2.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0
108,69.0,40.0,5.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
109,57.0,99.0,3.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0
