# **Clinical Survival Convolutional Neural Network**

In [None]:
import tensorflow as tf 
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
import scipy
from typing import Dict, Iterable, Sequence, Tuple, Optional
import pathlib
from pathlib import Path 
pip install lifelines
from lifelines.utils import concordance_index

## **Loading Data and Preprocessing**

### **CT scans**

In [None]:
train_folder = pathlib.Path("/content/drive/My Drive/x_train/images")
all_image_paths = [str(img_path) for img_path in list(train_folder.glob("*"))]
all_image_paths = sorted(all_image_paths)
images=all_image_paths
len(images)

CT scans store raw voxel intensity in Hounsfield units (HU). This is defined as Air = −1000 HU, Lung ≈ −500 HU, Water = 0 HU, Soft tissue (& blood) ≈ +50 HU, Bone ≈ +1000 HU. A lung window from -1200 to +800 can be applied to view lung tumor and normalize CT scans.

In [127]:
def read_file(filepath):
    # Read file
    scan = np.load(filepath)
    # Get raw data
    scan = scan['scan']
    return scan

def normalize(volume):
    min = -1200
    max = 400
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume

def process_scan(path):
    # Read scan
    volume = read_file(path)
    # Normalize
    volume = normalize(volume)
    return volume

In [None]:
# Each scan is rescaled.
normal_scans = np.array([process_scan(path) for path in images])
normal_scans.shape

### **Clinical Data**

In [None]:
clinical=pd.read_csv("/content/drive/My Drive/x_train/features/clinical_data_train.csv", index_col='PatientID')

#Preprocessing
clinical['Tstage']=clinical.Tstage.apply(lambda x: 4 if x==5 else x)
clinical['Histology_cat']=clinical.Histology.apply(lambda x: 0 if x in ('Adenocarcinoma','adenocarcinoma') 
                                                               else 1 if x=='large cell' 
                                                               else 2 if x in('squamous cell carcinoma', 'Squamous cell carcinoma') 
                                                               else 3)
categories = ['Nstage','Tstage','Histology_cat']
clinical = pd.get_dummies(clinical, columns=categories, drop_first=True)
clinical=clinical.drop(columns=['Histology','Mstage'])
clinical=clinical.sort_index()
clinical.SourceDataset=clinical.SourceDataset.apply(lambda x: 0 if x=='l1' else 1)

from sklearn.impute import SimpleImputer

imp_mean = SimpleImputer(strategy='median')
clinical_=imp_mean.fit_transform(clinical)
clinical=pd.DataFrame(clinical_, index=clinical.index, columns=clinical.columns)
clinical.head()


### **Target and Censorship Variables**

In [131]:
y_train = pd.read_csv('/content/drive/My Drive/y_train.csv', index_col=0)
y_train=y_train.sort_index()

### **Splitting Train Data in Training and Validation**

In [132]:
survivaltime=[x for x in (y_train.SurvivalTime)]
event=[x for x in (y_train.Event)]

In [133]:
# Split data in the ratio 70-30 for training and validation.
x_train = normal_scans[:260]
time_train = np.array(survivaltime[:260])
event_train = np.array(event[:260])
clinical_train = clinical.iloc[:260,:].to_numpy(dtype="float32")

x_test = normal_scans[260:]
time_test = np.array(survivaltime[260:])
event_test = np.array(event[260:])
clinical_test = clinical.iloc[260:,:].to_numpy(dtype="float32")

### **Data Augmentation**

The number of train data being quite small and the Clinical Convolutional Neural Network having the tendency to overfit quite quickly, I tried several techniques of data augmentation. These techniques did not improve my results. However, I let the code which could be used in other cases.

In [None]:
pip install dltk
from dltk.io.augmentation import *
from dltk.io.preprocessing import *

## **Creation of Clinical Convolution Neural Network for Survival Analysis**

### **Train and Validation Data**

In [135]:
def _make_riskset(time: np.ndarray) -> np.ndarray:
    """Compute mask that represents each sample's risk set.

    Parameters
    ----------
    time : np.ndarray, shape=(n_samples,)
        Observed event time sorted in descending order.

    Returns
    -------
    risk_set : np.ndarray, shape=(n_samples, n_samples)
        Boolean matrix where the `i`-th row denotes the
        risk set of the `i`-th instance, i.e. the indices `j`
        for which the observer time `y_j >= y_i`.
    """
    assert time.ndim == 1, "expected 1D array"

    # sort in descending order
    o = np.argsort(-time, kind="mergesort")
    n_samples = len(time)
    risk_set = np.zeros((n_samples, n_samples), dtype=np.bool_)
    for i_org, i_sort in enumerate(o):
        ti = time[i_sort]
        k = i_org
        while k < n_samples and ti == time[o[k]]:
            k += 1
        risk_set[i_sort, o[:k]] = True
    return risk_set


def random_rotate3D(img_numpy, min_angle, max_angle):
  """
  3D Medical image rotation
  -----
  Returns a random rotated array in the same shape
  :param img_numpy: 3D numpy array
  :param min_angle: in degrees
  :param max_angle: in degrees
  """
  liste=[]
  for i in range(len(img_numpy)):
    assert img_numpy[i].ndim == 3, "provide a 3d numpy array"
    assert min_angle < max_angle, "min should be less than max val"
    assert min_angle > -360 or max_angle < 360
    all_axes = [(1, 0), (1, 2), (0, 2)]
    angle = np.random.randint(low=min_angle, high=max_angle+1)
    axes_random_id = np.random.randint(low=0, high=len(all_axes))
    axes = all_axes[axes_random_id]
    images = scipy.ndimage.rotate(img_numpy[i], angle, reshape=False)
    images[images < 0] = 0
    images[images > 1] = 1
    liste.append(images)
    rotated_images=(np.array(liste))
  return rotated_images

class InputFunction_train:
    """Callable input function that computes the risk set for each batch.
    
    Parameters
    ----------
    images : np.ndarray, shape=(n_samples, height, width)
        Image data.
    clinical: np.ndarray, shape=(n_samples, n_variables)
      Clinical data.
    time : np.ndarray, shape=(n_samples,)
        Observed time.
    event : np.ndarray, shape=(n_samples,)
        Event indicator.
    batch_size : int, optional, default=64
        Number of samples per batch.
    drop_last : int, optional, default=False
        Whether to drop the last incomplete batch.
    shuffle : bool, optional, default=False
        Whether to shuffle data.
    seed : int, optional, default=89
        Random number seed.
    """

    def __init__(self,
                 images: np.ndarray,
                 clinical: np.ndarray,
                 time: np.ndarray,
                 event: np.ndarray,
                 batch_size: int = 20,
                 drop_last: bool = False,
                 shuffle: bool = False,
                 seed: int = 89) -> None:

        """ To be used if we want to apply rotation to the images:
        -------
        rotated_images = random_rotate3D(images, -20, 20)

        if rotated_images.ndim == 4:
            images = rotated_images[..., np.newaxis]
        """

        if images.ndim == 4:
            images = images[..., np.newaxis]
        self.images = images
        self.clinical = clinical
        self.time = time
        self.event = event
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.shuffle = shuffle
        self.seed = seed

    def size(self) -> int:
        """Total number of samples."""
        return self.images.shape[0]

    def steps_per_epoch(self) -> int:
        """Number of batches for one epoch."""
        return int(np.floor(self.size() / self.batch_size))

    def _get_data_batch(self, index: np.ndarray) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
        """Compute risk set for samples in batch."""

        time = self.time[index]
        event = self.event[index]
        images = self.images[index]
        clinical = self.clinical[index]

        """ Data augmentation techniques from DLTK library
        -------
        # Randomly flip the image along axis 1
        images = flip(images.copy(), axis=1)
        # Add a Gaussian offset 
        images = add_gaussian_offset(images.copy(), sigma=0.5)
        # Add Gaussian noise
        images = add_gaussian_noise(images.copy(), sigma=0.15)
        """

        labels = {
            "label_event": event.astype(np.int32),
            "label_time": time.astype(np.float32),
            "label_riskset": _make_riskset(time)
        }
        return images, clinical, labels

    def _iter_data(self) -> Iterable[Tuple[np.ndarray, Dict[str, np.ndarray]]]:
        """Generator that yields one batch at a time."""
        index = np.arange(self.size())
        rnd = np.random.RandomState(self.seed)

        if self.shuffle:
            rnd.shuffle(index)
        for b in range(self.steps_per_epoch()):
            start = b * self.batch_size
            idx = index[start:(start + self.batch_size)]
            yield self._get_data_batch(idx)

        if not self.drop_last:
            start = self.steps_per_epoch() * self.batch_size
            idx = index[start:]
            yield self._get_data_batch(idx)

    def _get_shapes(self) -> Tuple[tf.TensorShape, Dict[str, tf.TensorShape]]:
        """Return shapes of data returned by `self._iter_data`."""
        batch_size = self.batch_size if self.drop_last else None
        h, w, d, c = self.images.shape[1:]
        images = tf.TensorShape([batch_size, h, w, d, c])
        clinical = tf.TensorShape([batch_size,self.clinical.shape[1]])

        labels = {k: tf.TensorShape((batch_size,))
                  for k in ("label_event", "label_time")}
        labels["label_riskset"] = tf.TensorShape((batch_size, batch_size))
        return images, clinical, labels

    def _get_dtypes(self) -> Tuple[tf.DType, Dict[str, tf.DType]]:
        """Return dtypes of data returned by `self._iter_data`."""
        labels = {"label_event": tf.int32,
                  "label_time": tf.float32,
                  "label_riskset": tf.bool}
        return tf.float32, tf.float32, labels

    def _make_dataset(self) -> tf.data.Dataset:
        """Create dataset from generator."""
        ds = tf.data.Dataset.from_generator(
            self._iter_data,
            self._get_dtypes(),
            self._get_shapes()
        )
        return ds

    def __call__(self) -> tf.data.Dataset:
        return self._make_dataset()

class InputFunction_test:
    """Callable input function that computes the risk set for each batch.
    
    Parameters
    ----------
    images : np.ndarray, shape=(n_samples, height, width)
        Image data.
    clinical : np.ndarray, shape=(n_samples, n_variables)
      Clinical data.
    time : np.ndarray, shape=(n_samples,)
        Observed time.
    event : np.ndarray, shape=(n_samples,)
        Event indicator.
    batch_size : int, optional, default=64
        Number of samples per batch.
    drop_last : int, optional, default=False
        Whether to drop the last incomplete batch.
    shuffle : bool, optional, default=False
        Whether to shuffle data.
    seed : int, optional, default=89
        Random number seed.
    """

    def __init__(self,
                 images: np.ndarray,
                 clinical: np.ndarray,
                 time: np.ndarray,
                 event: np.ndarray,
                 batch_size: int = 20,
                 drop_last: bool = False,
                 shuffle: bool = False,
                 seed: int = 89) -> None:
        if images.ndim == 4:
            images = images[..., np.newaxis]
        self.images = images
        self.clinical = clinical
        self.time = time
        self.event = event
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.shuffle = shuffle
        self.seed = seed

    def size(self) -> int:
        """Total number of samples."""
        return self.images.shape[0]

    def steps_per_epoch(self) -> int:
        """Number of batches for one epoch."""
        return int(np.floor(self.size() / self.batch_size))

    def _get_data_batch(self, index: np.ndarray) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
        """Compute risk set for samples in batch."""
        time = self.time[index]
        event = self.event[index]
        images = self.images[index]
        clinical = self.clinical[index]

        labels = {
            "label_event": event.astype(np.int32),
            "label_time": time.astype(np.float32),
            "label_riskset": _make_riskset(time)
        }
        return images, clinical, labels

    def _iter_data(self) -> Iterable[Tuple[np.ndarray, Dict[str, np.ndarray]]]:
        """Generator that yields one batch at a time."""
        index = np.arange(self.size())
        rnd = np.random.RandomState(self.seed)

        if self.shuffle:
            rnd.shuffle(index)
        for b in range(self.steps_per_epoch()):
            start = b * self.batch_size
            idx = index[start:(start + self.batch_size)]
            yield self._get_data_batch(idx)

        if not self.drop_last:
            start = self.steps_per_epoch() * self.batch_size
            idx = index[start:]
            yield self._get_data_batch(idx)

    def _get_shapes(self) -> Tuple[tf.TensorShape, Dict[str, tf.TensorShape]]:
        """Return shapes of data returned by `self._iter_data`."""
        batch_size = self.batch_size if self.drop_last else None
        h, w, d, c = self.images.shape[1:]
        images = tf.TensorShape([batch_size, h, w, d, c])
        clinical = tf.TensorShape([batch_size,self.clinical.shape[1]])

        labels = {k: tf.TensorShape((batch_size,))
                  for k in ("label_event", "label_time")}
        labels["label_riskset"] = tf.TensorShape((batch_size, batch_size))
        return images, clinical, labels

    def _get_dtypes(self) -> Tuple[tf.DType, Dict[str, tf.DType]]:
        """Return dtypes of data returned by `self._iter_data`."""
        labels = {"label_event": tf.int32,
                  "label_time": tf.float32,
                  "label_riskset": tf.bool}
        return tf.float32, tf.float32, labels

    def _make_dataset(self) -> tf.data.Dataset:
        """Create dataset from generator."""
        ds = tf.data.Dataset.from_generator(
            self._iter_data,
            self._get_dtypes(),
            self._get_shapes()
        )
        return ds

    def __call__(self) -> tf.data.Dataset:
        return self._make_dataset()

def safe_normalize(x: tf.Tensor) -> tf.Tensor:
    """Normalize risk scores to avoid exp underflowing.

    Note that only risk scores relative to each other matter.
    If minimum risk score is negative, we shift scores so minimum
    is at zero.
    """
    x_min = tf.reduce_min(x, axis=0)
    c = tf.zeros_like(x_min)
    norm = tf.where(x_min < 0, -x_min, c)
    return x + norm


def logsumexp_masked(risk_scores: tf.Tensor,
                     mask: tf.Tensor,
                     axis: int = 0,
                     keepdims: Optional[bool] = None) -> tf.Tensor:
    """Compute logsumexp across `axis` for entries where `mask` is true."""
    risk_scores.shape.assert_same_rank(mask.shape)

    with tf.name_scope("logsumexp_masked"):
        mask_f = tf.cast(mask, risk_scores.dtype)
        risk_scores_masked = tf.math.multiply(risk_scores, mask_f)
        # for numerical stability, substract the maximum value
        # before taking the exponential
        amax = tf.reduce_max(risk_scores_masked, axis=axis, keepdims=True)
        risk_scores_shift = risk_scores_masked - amax

        exp_masked = tf.math.multiply(tf.exp(risk_scores_shift), mask_f)
        exp_sum = tf.reduce_sum(exp_masked, axis=axis, keepdims=True)
        output = amax + tf.math.log(exp_sum)
        if not keepdims:
            output = tf.squeeze(output, axis=axis)
    return output

### **Computation of Cox PH loss function**

In [136]:
class CoxPHLoss(tf.keras.losses.Loss):
    """Negative partial log-likelihood of Cox's proportional hazards model."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)            

    def call(self,
             y_true: Sequence[tf.Tensor],
             y_pred: tf.Tensor) -> tf.Tensor:
        """Compute loss.

        Parameters
        ----------
        y_true : list|tuple of tf.Tensor
            The first element holds a binary vector where 1
            indicates an event 0 censoring.
            The second element holds the riskset, a
            boolean matrix where the `i`-th row denotes the
            risk set of the `i`-th instance, i.e. the indices `j`
            for which the observer time `y_j >= y_i`.
            Both must be rank 2 tensors.
        y_pred : tf.Tensor
            The predicted outputs. Must be a rank 2 tensor.

        Returns
        -------
        loss : tf.Tensor
            Loss for each instance in the batch.
        """
        event, riskset = y_true
        predictions = y_pred

        pred_shape = predictions.shape
        if pred_shape.ndims != 2:
            raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
                             "be 2." % pred_shape.ndims)

        if pred_shape[1] is None:
            raise ValueError("Last dimension of predictions must be known.")

        if pred_shape[1] != 1:
            raise ValueError("Dimension mismatch: Last dimension of predictions "
                             "(received %s) must be 1." % pred_shape[1])

        if event.shape.ndims != pred_shape.ndims:
            raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
                             "equal rank of event (received %s)" % (
                pred_shape.ndims, event.shape.ndims))

        if riskset.shape.ndims != 2:
            raise ValueError("Rank mismatch: Rank of riskset (received %s) should "
                             "be 2." % riskset.shape.ndims)

        event = tf.cast(event, predictions.dtype)
        predictions = safe_normalize(predictions)

        with tf.name_scope("assertions"):
            assertions = (
                tf.debugging.assert_less_equal(event, 1.),
                tf.debugging.assert_greater_equal(event, 0.),
                tf.debugging.assert_type(riskset, tf.bool)
            )

        # move batch dimension to the end so predictions get broadcast
        # row-wise when multiplying by riskset
        pred_t = tf.transpose(predictions)
        # compute log of sum over risk set for each row
        rr = logsumexp_masked(pred_t, riskset, axis=1, keepdims=True)
        assert rr.shape.as_list() == predictions.shape.as_list()

        losses = tf.math.multiply(event, rr - predictions)

        return losses

### **Computation of the Concordance Index** (on the validation data at each epoch)

In [137]:
class CindexMetric:
    """Computes concordance index across one epoch."""

    def reset_states(self) -> None:
        """Clear the buffer of collected values."""
        self._data = {
            "label_time": [],
            "label_event": [],
            "prediction": []
        }

    def update_state(self, y_true: Dict[str, tf.Tensor], y_pred: tf.Tensor) -> None:
        """Collect observed time, event indicator and predictions for a batch.

        Parameters
        ----------
        y_true : dict
            Must have two items:
            `label_time`, a tensor containing observed time for one batch,
            and `label_event`, a tensor containing event indicator for one batch.
        y_pred : tf.Tensor
            Tensor containing predicted risk score for one batch.
        """
        self._data["label_time"].append(y_true["label_time"].numpy())
        self._data["label_event"].append(y_true["label_event"].numpy())
        self._data["prediction"].append(tf.squeeze(y_pred).numpy())

    def result(self) -> Dict[str, float]:
        """Computes the concordance index across collected values.

        Returns
        ----------
        metrics : dict
            Computed metrics.
        """
        data = {}
        for k, v in self._data.items():
            data[k] = np.concatenate(v)

        results = concordance_index(
            data["label_time"],
            data["prediction"],
            data["label_event"] == 1,)

        result_data = {}
        names = "cindex"
        result_data[names] = 1-results

        return result_data

### **Training of Survival Convolutional Neural Network**

In [138]:
import tensorflow.compat.v2.summary as summary
from tensorflow.python.ops import summary_ops_v2

class TrainAndEvaluateModel:

    def __init__(self, model, model_dir, train_dataset, eval_dataset,
                 learning_rate, num_epochs):
        """ Note: 'train_dataset' to be removed from __init__ if  
        data augmentation is applied """

        self.num_epochs = num_epochs
        self.model_dir = model_dir

        self.model = model

        """ Note: the line below has to be removed if data augmentation is 
        applied """
        self.train_ds = train_dataset

        self.val_ds = eval_dataset

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.loss_fn = CoxPHLoss()

        self.train_loss_metric = tf.keras.metrics.Mean(name="train_loss")
        self.val_loss_metric = tf.keras.metrics.Mean(name="val_loss")
        self.val_cindex_metric = CindexMetric()

    @tf.function
    def train_one_step(self, x_image, x_clinical, y_event, y_riskset):
        y_event = tf.expand_dims(y_event, axis=1)

        with tf.GradientTape() as tape:
            logits = self.model([x_image, x_clinical], training=True)

            train_loss = self.loss_fn(y_true=[y_event, y_riskset], y_pred=logits)

        with tf.name_scope("gradients"):
            grads = tape.gradient(train_loss, self.model.trainable_weights)
            self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
        return train_loss, logits

    def train_and_evaluate(self):
        ckpt = tf.train.Checkpoint(
            step=tf.Variable(0, dtype=tf.int64),
            optimizer=self.optimizer,
            model=self.model)
        ckpt_manager = tf.train.CheckpointManager(
            ckpt, str(self.model_dir), max_to_keep=2)

        if ckpt_manager.latest_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint)
            print(f"Latest checkpoint restored from {ckpt_manager.latest_checkpoint}.")

        train_summary_writer = summary.create_file_writer(
            str(self.model_dir / "train"))
        val_summary_writer = summary.create_file_writer(
            str(self.model_dir / "valid"))

        for epoch in range(self.num_epochs):
            with train_summary_writer.as_default():
                self.train_one_epoch(ckpt.step)

            # Run a validation loop at the end of each epoch.
            with val_summary_writer.as_default():
                self.evaluate(ckpt.step)

        save_path = ckpt_manager.save()
        print(f"Saved checkpoint for step {ckpt.step.numpy()}: {save_path}")

    def train_one_epoch(self, step_counter):

          """ Note: the two lines below have to be used if data augmentation
              is applied:
          
              data = InputFunction_train(x_train, clinical_train, time_train, event_train, batch_size=10, shuffle=True, drop_last=True)
              for x_images, x_clinical, y in data():
                        
              Image preprocessing (can be used for data augmentation)
              -------
              #Adjust the brightness of images by a random factor.
              x_images = tf.image.random_brightness(x_images, 0.2)
              #Adjust the contrast of images by a random factor.
              x_images = tf.image.random_contrast(x_images, 0.2, 0.5)
         """

          for x_images, x_clinical, y in self.train_ds:
            train_loss, logits = self.train_one_step(
                x_images, x_clinical, y["label_event"], y["label_riskset"])

            step = int(step_counter)
            
            # Update training metric.
            self.train_loss_metric.update_state(train_loss)

            # Log every 200 batches.
            if step % 5 == 0:
                # Display metrics
                mean_loss = self.train_loss_metric.result()
                print(f"step {step}: mean loss = {mean_loss:.4f}")
                # save summaries
                summary.scalar("loss", mean_loss, step=step_counter)
                # Reset training metrics
                self.train_loss_metric.reset_states()

            step_counter.assign_add(1)

    @tf.function
    def evaluate_one_step(self, x, x_clinical, y_event, y_riskset):
        y_event = tf.expand_dims(y_event, axis=1)
        val_logits = self.model([x, x_clinical], training=False)
        val_loss = self.loss_fn(y_true=[y_event, y_riskset], y_pred=val_logits)
        return val_loss, val_logits

    def evaluate(self, step_counter):
        self.val_cindex_metric.reset_states()
        
        for val_images, val_clinical, y_val in self.val_ds:
            val_loss, val_logits = self.evaluate_one_step(
                val_images, val_clinical, y_val["label_event"], y_val["label_riskset"])

            # Update val metrics
            self.val_loss_metric.update_state(val_loss)
            self.val_cindex_metric.update_state(y_val, val_logits)

        val_loss = self.val_loss_metric.result()
        summary.scalar("loss",
                       val_loss,
                       step=step_counter)
        self.val_loss_metric.reset_states()
        
        val_cindex = self.val_cindex_metric.result()
        for key, value in val_cindex.items():
          summary.scalar(key, value, step=step_counter)

        print(f"Validation: loss = {val_loss:.4f}, cindex = {val_cindex['cindex']:.4f}")

### **Architecture of the Clinical Survival Convolutional Neural Network**

This neural network handle two inputs: **3D CT scans** and **Clinical data**.

- The 3D CT scans (grayscale image, shape is 92x92x92) are normalized by batch before being fed to the network.
- Two convolutional layers using a 3x3x3 kernel size, stride 1, padding='valid', and a ReLu activation function, followed by a batch normalization and a max pooling layer with a pool size of 2, dividing the spatial dimension with a size of 2.
- The output of the second max pooling layer is flattened and concatenated with the clinical data.
- Fully connected network, composed of two hidden dense layers and a dense output layer. A batch normalization is applied on the output of the two hidden dense layers. The final dense layer output the predictions (risk score). A dropout layer with a dropout rate of 10% is added to reduce overfitting.

In [139]:
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Dense, Flatten, Conv3D, BatchNormalization, MaxPooling3D, Concatenate, Dropout

image_input = keras.Input(shape=(92,92,92,1), name="image_input")
x = layers.BatchNormalization()(image_input)
x = layers.Conv3D(filters=16, kernel_size=(3,3,3), activation="relu", use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling3D(pool_size=(2,2,2))(x)
x = layers.Conv3D(filters=32, kernel_size=(3,3,3), activation="relu", use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling3D(pool_size=(2,2,2))(x)
block_1_output = layers.Flatten()(x)

numeric_input = keras.Input(shape=(11), name="numeric_input")
x = Concatenate()([numeric_input, block_1_output])
x = layers.Dense(120, activation='relu', name='dense_1', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.Dense(84, activation='relu', name='dense_2', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(rate=0.10)(x)
output = layers.Dense(1, activation='linear', name='dense_3')(x)

model = keras.Model(inputs=[image_input, numeric_input], outputs=output)

In [140]:
""" Note: The line below has to be removed if data augmentation is applied. """
train_fn = InputFunction_train(x_train, clinical_train, time_train, event_train, batch_size=20, shuffle=True, drop_last=True)

eval_fn = InputFunction_test(x_test, clinical_test, time_test, event_test)

trainer = TrainAndEvaluateModel(
    model=model,
    model_dir=Path("ckpts-mnist-cnn"),
    # Note: The line below has to be removed if data augmentation is applied. 
    train_dataset=train_fn(),
    eval_dataset=eval_fn(),
    learning_rate=0.0001,
    num_epochs=5,
)

In [None]:
trainer.train_and_evaluate()

## **Results**

### **Predicted Risk Score**

In [142]:
from lifelines import CoxPHFitter
cph = CoxPHFitter(alpha=0.05)

class Predictor:

    def __init__(self, model, model_dir):
        self.model = model
        self.model_dir = model_dir

    def predict(self, dataset):
        ckpt = tf.train.Checkpoint(
            step=tf.Variable(0, dtype=tf.int64),
            optimizer=tf.keras.optimizers.Adam(),
            model=self.model)
        ckpt_manager = tf.train.CheckpointManager(
            ckpt, str(self.model_dir), max_to_keep=2)

        if ckpt_manager.latest_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
            print(f"Latest checkpoint restored from {ckpt_manager.latest_checkpoint}.")

        risk_scores = []
        for batch_image, batch_num in dataset:
            pred = self.model([batch_image, batch_num], training=False)
            risk_scores.append(pred.numpy())

        return np.row_stack(risk_scores)

**On Train Data**

In [143]:
train_pred_fn_1 = tf.data.Dataset.from_tensor_slices(x_train[..., np.newaxis]).batch(20)
train_pred_fn_2 = tf.data.Dataset.from_tensor_slices(clinical_train).batch(20)
train_pred_fn = tf.data.Dataset.zip((train_pred_fn_1,train_pred_fn_2))

predictor = Predictor(model, trainer.model_dir)

In [None]:
#Predicted risk score of train data
train_predictions = predictor.predict(train_pred_fn)

risk_score_train=pd.DataFrame(train_predictions)
risk_score_train['time_train']=time_train
risk_score_train['event_train']=event_train
risk_score_train=risk_score_train.rename(columns={0:'risk_score'})
risk_score_train.head()

**On Validation Data**

In [None]:
#Predicted risk score of validation data
sample_pred_ds_1 = tf.data.Dataset.from_tensor_slices(x_test[..., np.newaxis]).batch(20)
sample_pred_ds_2 = tf.data.Dataset.from_tensor_slices(clinical_test).batch(20)
sample_pred_ds = tf.data.Dataset.zip((sample_pred_ds_1,sample_pred_ds_2))

sample_predictions = predictor.predict(sample_pred_ds)

risk_score_val=pd.DataFrame(sample_predictions)
risk_score_val['time_test']=time_test
risk_score_val['event_test']=event_test
risk_score_val=risk_score_val.rename(columns={0:'risk_score'})

### **Cox Model**

**On Train Data**

This is a univariate Cox model with the predicted risk score estimated from the 3D CT scans and the clinical data as explanatory variable.

In [None]:
breslow = cph.fit(risk_score_train, duration_col="time_train", event_col="event_train")
breslow.print_summary()

**On Validation Data**

In [None]:
# Expected Lifetime
risk_score_val['predictions']=breslow.predict_expectation(risk_score_val)

#Concordance Index
print(f'Concordance index (lifelines): {concordance_index(risk_score_val.time_test, risk_score_val.predictions, risk_score_val.event_test)}')

### **Predictions of Expected Lifetime on Test Data**

In [None]:
# Loading Data
test_folder = pathlib.Path("/content/drive/My Drive/x_test/images")
all_image_paths_test = [str(img_path) for img_path in list(test_folder.glob("*"))]
all_image_paths_test = sorted(all_image_paths_test)
normal_scans_test = np.array([process_scan(path) for path in all_image_paths_test])

clinical_test=pd.read_csv("/content/drive/My Drive/x_test/features/clinical_data_test.csv", index_col='PatientID')

#Preprocessing
clinical_test['Nstage']=clinical_test.Nstage.apply(lambda x: 3 if x==4 else x)
clinical_test['Tstage']=clinical_test.Tstage.apply(lambda x: 4 if x==5 else x)
clinical_test['Histology_cat']=clinical_test.Histology.apply(lambda x: 0 if x in ('Adenocarcinoma','adenocarcinoma') 
                                                               else 1 if x=='large cell' 
                                                               else 2 if x in('squamous cell carcinoma', 'Squamous cell carcinoma') 
                                                               else 3)
categories = ['Nstage','Tstage','Histology_cat']
clinical_test = pd.get_dummies(clinical_test, columns=categories, drop_first=True)
clinical_test=clinical_test.drop(columns=['Histology','Mstage'])
clinical_test=clinical_test.sort_index()
clinical_test.SourceDataset=clinical_test.SourceDataset.apply(lambda x: 0 if x=='l1' else 1)

clinical_test_= imp_mean.transform(clinical_test)
clinical_test=pd.DataFrame(clinical_test_, index=clinical_test.index, columns=clinical_test.columns)
clinical = clinical_test.to_numpy(dtype="float32")

#Predicted risk score of test data
sample_pred_ds_1 = tf.data.Dataset.from_tensor_slices(normal_scans_test[..., np.newaxis]).batch(20)
sample_pred_ds_2 = tf.data.Dataset.from_tensor_slices(clinical).batch(20)
sample_pred_ds = tf.data.Dataset.zip((sample_pred_ds_1,sample_pred_ds_2))

sample_predictions = predictor.predict(sample_pred_ds)

In [None]:
# Expected Lifetime
risk_score_test=pd.DataFrame(data=sample_predictions)
risk_score_test=risk_score_test.rename(columns={0:'risk_score'})

pred=breslow.predict_expectation(risk_score_test)

In [None]:
output=pd.DataFrame(data=pred)
output=output.rename(columns={0:'SurvivalTime'})
output['Event']='nan'
output.head(15)