# Survival Tree

Here we show an implementation of a Survival Tree. This is the scaffold which can be :

* upscaled for a 'package' level code  
* used for building survival forests

In [1]:
from __future__ import print_function
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import collections
import operator

import lifelines
from lifelines.statistics import logrank_test
from lifelines.utils import concordance_index

### Example Dataset

This is a Hodgkin's Lymphoma Dataset with the following variable definitions:

* **trtgiven**: 
    - RT = radiation 
    - CMT = chemotherapy and radiation
    
* **medwidsi mediastinum involvement** : 
    - N = no
    - S = small  
    - L = large
    
* **extranod extranodal disease**: 
    - Y = extranodal disease
    - N = nodal disease
    
* **clinstg clinical stage** : 
    - 1 = stage I
    - 2 = stage II
    
* **status** (after transformation below):
    - 1 - Death (Event)
    - 0 - Censored

In [3]:
df = pd.read_csv('../datasets/hd.csv')
df= df.iloc[:,1:]
print(df.shape)
display(df.head())

(865, 8)


Unnamed: 0,age,sex,trtgiven,medwidsi,extranod,clinstg,time,status
0,64.0,F,RT,N,N,1,3.1,2
1,63.0,M,RT,N,N,1,15.9,2
2,17.0,M,RT,N,N,2,0.9,1
3,63.0,M,RT,N,N,2,13.1,2
4,21.0,M,RT,L,N,2,35.9,0


In [26]:
for col in ['trtgiven', 'medwidsi', 'extranod', 'clinstg','status']:
    print(df[col].value_counts())
    print("\n")

RT     616
CMT    249
Name: trtgiven, dtype: int64


N    464
S    288
L    113
Name: medwidsi, dtype: int64


N    786
Y     79
Name: extranod, dtype: int64


2    569
1    296
Name: clinstg, dtype: int64


0    439
1    291
2    135
Name: status, dtype: int64




In [30]:
df = df[df.status!=1]
print(df.shape)
df['status'] = df.status.apply(lambda x: int(x==2))
display(df.head(3))

(574, 8)


Unnamed: 0,age,sex,trtgiven,medwidsi,extranod,clinstg,time,status
0,64.0,F,RT,N,N,1,3.1,1
1,63.0,M,RT,N,N,1,15.9,1
3,63.0,M,RT,N,N,2,13.1,1


In [31]:
print(df['status'].value_counts())
print("\n")

0    439
1    135
Name: status, dtype: int64




# Survival Tree : Implementation

## Class and Function Definitions

In [20]:
def class_counts(df, column):
    """Counts the number of each type of example in a dataset."""
    counts = collections.Counter(df[column])
    return counts

In [21]:
def isNumeric(value):
    return isinstance(value, int) or isinstance(value, float)

def isCategorical(value):
    return isinstance(value, object) and not isNumeric(value)

In [22]:
class Splitter(object):
    """A splitting criterion for dividing the dataset based on the outcome
    
    This class forms a new node of a tree where a two outcome Question is asked
    depending on the response the current data is divided into two parts which form the
    remaining data for the left and right subtrees.
    """
    def __init__(self, attribute, operation, target):
        self.attribute = attribute
        self.target = target
        self.operation = operation
        
    def details(self):
        return 'The splitter condition is: '+ self.attribute + " " + self.operation + " " + str(self.target)

In [23]:
def partition(df, splitter):
    """Partitions a dataset.
    
    @input: DataFrame <pandas.DataFrame>, splitter <object>
    @returns: positive, negative dataframes <pd.DataFrame>
    
    Check whether the row value matches the splitter condition. If it does,
    add it to the matched rows else to the unmatched rows
    """
    comparison = splitter.operation
    attr = splitter.attribute
    target = splitter.target
    
    # For handling Numeric Data
    if isNumeric(target):
        if comparison == '<':
            true_rows = df[df[attr] < target]
            false_rows = df[df[attr] >= target]
        elif comparison == '>':
            true_rows = df[df[attr] > target]
            false_rows = df[df[attr] <= target]
        elif comparison == '==':
            true_rows = df[df[attr] == target]
            false_rows = df[df[attr] != target]
        elif comparison == '!=':
            true_rows = df[df[attr] != target]
            false_rows = df[df[attr] == target]
        elif comparison == '<=':
            true_rows = df[df[attr] <= target]
            false_rows = df[df[attr] > target]
        elif comparison == '>=':
            true_rows = df[df[attr] >= target]
            false_rows = df[df[attr] < target]
        else:
            raise SyntaxError
        return true_rows, false_rows
    
    # For handling Categorical Data
    elif isCategorical(target):
        if comparison == '==':
            true_rows = df[df[attr] == target]
            false_rows = df[df[attr] != target]
        elif comparison == '!=':
            true_rows = df[df[attr] != target]
            false_rows = df[df[attr] == target]
        else:
            raise TypeError
        return true_rows, false_rows
    
    # Erroneous Datatype
    else:
        raise TypeError

In [57]:
def find_best_split(df):
    """Find the best question to ask by iterating over every feature / value
    and calculating the information gain.
    @input: dataframe <pd.DataFrame>, class Column name <string>
    @returns: best_gain <float>, best_splitter <object>
    """
    
    best_sepr = 0  # Keep track of the best information gain
    best_splitter = None  # Keep train of the feature / value that produced it
    #current_uncertainty = gini(df, column) # Current gini index of the (parent) node
    
    attributes = list(df.columns)
    attributes.remove('time')
    attributes.remove('status')
    
    for attr in attributes:    # For each attribute in the dataset
        values = df[attr].unique()  # List of unique values for each attribute
        if isNumeric(values[0]):
            setOfOperations = ('>', '>=', '<', '<=', '==', '!=')
        if isCategorical(values[0]):
            setOfOperations = ('==', '!=')
            
        # For each unqiue value in list
        for val in values:        
            for operation in setOfOperations:
                
                # Creating new splitter condition
                splitter = Splitter(attr, operation, val) 
                #print("Checking Split Condition: ", splitter.attribute, splitter.operation, splitter.target)
                # Partitioning dataset using splitter
                true_branch, false_branch = partition(df, splitter)
                
                # Skip this split if it doesn't divide the dataset.
                if len(true_branch) == 0 or len(false_branch) == 0: 
                    continue

                # Calculate the survival separation from this split
                # Here we have used log rank test (based on KM curves). 
                # This test has been implemented in Lifelines, calling the function from there

                logRankValue = logrank_test(event_times_A =  true_branch.loc[:,'time'],
                                    event_times_B = false_branch.loc[:,'time'],
                                    event_observed_A = true_branch.loc[:,'status'], 
                                    event_observed_B = false_branch.loc[:,'status'])
                
                if logRankValue.p_value < 0.05:
                    sepr = logRankValue.test_statistic
                else:
                    sepr = 0.0
                
                # Selecting the best gain
                if sepr >= best_sepr:
                    best_sepr, best_splitter = sepr, splitter

    return best_sepr, best_splitter

In [41]:
class Leaf:
    """A Leaf node classifies data.

    This holds a dictionary of class (e.g., "Apple") -> number of times
    it appears in the rows from the training data that reach this leaf.
    """

    def __init__(self, df, column):
        self.predictions = class_counts(df, column)

In [42]:
class Decision_Node:
    """A Decision Node asks a question.
    This holds a reference to the splitter object, and to the two child nodes.
    """

    def __init__(self,
                 splitter,
                 true_branch,
                 false_branch):
        self.splitter = splitter
        self.true_branch = true_branch
        self.false_branch = false_branch

In [93]:
def build_tree(df, column, stop_val):
    """Builds the tree.
    @input: dataframe <pd.DataFrame>, class Column name <string>
    @returns: Decision Node <object>
    """

    # Try partitioing the dataset on each of the unique attribute,
    # calculate the survival separation using some criteria
    sepr, splitter = find_best_split(df, column)

    # Base case: no further separation (acc to given stop value)
    # Since we can ask no further questions,
    # we'll return a leaf.
    if sepr <= stop_val:
        return Leaf(df, column)

    # If we reach here, we have found a useful feature / value
    # to partition on.
    true_rows, false_rows = partition(df, splitter)

    # Recursively build the true branch.
    true_branch = build_tree(true_rows, column, stop_val)

    # Recursively build the false branch.
    false_branch = build_tree(false_rows, column, stop_val)

    # Return a Splitter node.
    # This records the best feature / value to ask at this point,
    # as well as the branches to follow
    # dependingo on the answer.
    return Decision_Node(splitter, true_branch, false_branch)

In [48]:
def print_tree(node, spacing=""):
    """World's most elegant tree printing function."""

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    # Print the splitter at this node
    print (spacing + str(node.splitter.details()))

    # Call this function recursively on the true branch
    print (spacing + '|--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '|--> False:')
    print_tree(node.false_branch, spacing + "  ")

### Training on Hodgkin's data

We train on 100 samples from the data.

In [95]:
%%time
myTree = build_tree(df.iloc[:100,:], 'time', 0.1)

CPU times: user 5min 24s, sys: 11.5 s, total: 5min 35s
Wall time: 1min 29s


In [101]:
print_tree(myTree)

The splitter condition is: age >= 58.43
|--> True:
  Predict Counter({13.1: 2, 0.6: 1, 2.7: 1, 3.1: 1, 4.3: 1, 5.3: 1, 12.0: 1, 9.5: 1, 10.2: 1, 11.1: 1, 13.0: 1, 14.7: 1, 15.9: 1, 3.0: 1, 9.0: 1})
|--> False:
  The splitter condition is: age == 37.0
  |--> True:
    Predict Counter({10.6: 1})
  |--> False:
    The splitter condition is: age == 49.0
    |--> True:
      Predict Counter({12.6: 1})
    |--> False:
      The splitter condition is: age == 47.3
      |--> True:
        Predict Counter({17.9: 1})
      |--> False:
        The splitter condition is: age == 42.0
        |--> True:
          Predict Counter({19.2: 1, 21.0: 1})
        |--> False:
          The splitter condition is: age <= 39.0
          |--> True:
            The splitter condition is: age == 30.0
            |--> True:
              Predict Counter({27.5: 1, 20.3: 1, 21.1: 1})
            |--> False:
              The splitter condition is: age == 29.62
              |--> True:
                Predict Counter

# Prediction

In [97]:
operatorMap = {'>': operator.gt,
               '>=': operator.ge,
               '<': operator.lt,
               '<=': operator.le,
               '==': operator.eq,
               '!=': operator.ne}

In [98]:
def predict(observation, node):
    """See the 'rules of recursion' above."""
    
    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        return node.predictions

    if operatorMap[node.splitter.operation](observation[node.splitter.attribute], node.splitter.target):
        return predict(observation, node.true_branch)
    else:
        return predict(observation, node.false_branch)

In [99]:
def print_leaf(counts):
    """A nicer way to print the predictions at a leaf."""
    total = sum(counts.values()) * 1.0
    probs = {}
    for lbl in counts.keys():
        probs[lbl] = int(counts[lbl] / total * 100)
    return probs

In [100]:
def getMean(myDict):
    """outputs the mean of values from an input dictinonary 'print leaf' above"""
    val = 0
    
    for k,v in myDict.items():
        #print(k,v)
        val += float(k)*float(v)
        
    return val/sum(myDict.values())

** Training performance **

In [102]:
y_train_pred = []

for index, obs in df.iloc[:100,:].iterrows():
    
    mydict = print_leaf(classify(obs, myTree))
    val = getMean(mydict)
    y_train_pred.append(val)
    
concordance_index(df.iloc[:100,-2].values, y_train_pred, df.iloc[:100,-1].values)

0.895826025955805

### Testing on Hodgkin's data

We test on 50 samples from the data.

In [103]:
y_test_pred = []

for index, obs in df.iloc[100:150,:].iterrows():
    
    mydict = print_leaf(classify(obs, myTree))
    val = getMean(mydict)
    y_test_pred.append(val)
    
concordance_index(df.iloc[100:150,-2].values,y_test_pred,df.iloc[100:150,-1].values)

0.7759674134419552

This is actually a very good value of concordance index from a survival task point of view. However performance wasn't a goal in this analysis, just a proof-of-concept of running the algorithm.

## Key features of the implementation


1. Can be scaled up to write the first package for Survival Tree and Survival Random Forest in Python
2. This implementation can handle categorical variables (as-is) like R rather than having to 1-Hot encode them as in implementation of Scikit-learn

## What next ?

1. Tear apart Scikit's Tree impementation and borrow maximal code  
2. Introduce our way of handling categorical variables  
3. Add *criteria* to handle survival, i.e. calculate survival difference  
