# Deep Learning reproducability project

## Deep learning with multimodal representation for pancancer prognosis prediction

## Authors: Anika Cheerla, Olivier Gevaert

## Authors of the reproduction: Luke Prananta, Joris Feijen, Favian Stelmach, Zsombor Csuvar (Group 37)

### Short description of the article 
The goal of the article was to create a deep learning algorithm that is to give accurate predictions of the future course of patients with cancer. It would be able to do this by using large amounts of multimodal data which previously was not possible to be done by physicians.  Out of all different data types (clinical, genomic profiling, histoloy slide images and radiographic images) this article took into consideration the gene expression data, miRNA data, clinical data and the whole slide images.
In preceding articles it was shown that these data types individually can be used to predict the prognosis with high accuracy. Using a combination of said features was mode possible by the creation of databases like "The Cancer Genome Atlas (TCGA)" which was the main source of information for this article.
Using said information the authors of the article created an unsupervised deep learning model that aggregates the above described data into a single feature vector and makes an accurate prognosis (overall C-index of 0.78).

### Composition of the original article

The first problem that the article had to solve was how to use the data that is both heterogen and high dimensional in nature. The architecture has to be able to  Additionally the model also had to be able to cope with missing data as the information available per patient and per cancer type varied greatly.

To solve these problems the author created four different models one for each of the different data types. These four deep neural network models would each output a feature vector that would later be combined together using similarity loss to a single feature vector. To predict the survival data the concordance score was maximized using the Cox loss function. 

### Description of the model architectures

As mentioned earlier the deep learning model of the article consists of four different models that are then combined together using similarity loss and then the prognosis prediction is done using the Cox loss function.
The four model architectures are based on CNN-s. For the clinical data fully connected data layers were used with sigmoid activation. The models for the gene expression and the miRNA maintained the basic structure of the clinical data model but a highway gate and dropout was added. 
Due to the size of the whole slide images (WSI) stochastic sampling was used to decrease their size, random patches of the image was chosen than it was fed to a SqueezeNet model. Similarly to the other models the output of this was a feature vector consisting of 512 elements.
 
### Reproduction of the article

The reproduction of the article presented several issues and difficulties. Although the authors provided the code for the experiments and data collection we were unable to use said code without modification or rather complete rewriting. Furthermore the code was lacking comments which made its understanding and reuse difficult.

#### Steps of reproduction

The reproduction consisted of the following steps:
 1. Download the necessary clinical data (all four types)
 2. Replication of the individual models for evaluating the gene expression datam miRNA data, clinical data and whole slide images
 3. Reproduction of the results of the paper

#### Download the necessary clinical data

Out of the different types of medical data the handling of the whole slide images was the most difficult. This was due to their size. A single slide can be as big as several gigabytes and the GDC database [GDC database](https://portal.gdc.cancer.gov/repository?filters=%7B%22op%22%3A%22and%22%2C%22content%22%3A%5B%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22files.data_type%22%2C%22value%22%3A%5B%22Slide%20Image%22%5D%7D%7D%5D%7D) currently holds more than 30 000 whole slide image files which is 16.98 TB in total. This amount of data was impossible download or store using our personal computers. Therefore based on the guidance from out supervisor (Soufiane Mouragui) we only used the slides where the primary site of the cancer was the skin. This meant that only 950 files (707 GB) of data needed to be downloaded. To this goal the original code needed to be modified (the code can be found in slide_download_convert.py). For this to correctly run a new manifest file is needed that only includes the skin cancer files.
The other three data types could be downloaded using the instructions 

### Graph reproduction 

As written in the description of the assignment the figure 1 also had to be reproduce. Figure 1 represenets the Kaplan-Meier survival curves for all cancer sites. This shows that based on different cancer sites we can expect different survival rates. In the reproduction of said graph the difficulty was that by downloading the clinical data from the GDC database we do not get the primary site of the cancer, but the specific type of cancer. Also in this case it the data was incomplete in some cases and no indication was given of how missing data was handled. 
The resulting figure can be obtained by running the fig_1_reproduce.py
Some differences can be seen on the reproduced figure and the original in the article. The are multiple possible explanations for this. The article was written in the end of 2018 (based on the Github commit) but it is unknown that when was the data downloaded from the GDC database, and because that is continously updated it is possible that at the moment we are working with patient records that were not present when the article was written. Another possible explanation might be the handling of the missing information. In the reproduction when information was missing then the record was disregarded but in the original work it also could have been handled differently, eg.: substituting it with an average/ maximum or minimum values.

In [None]:
%matplotlib inline  
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using Device: {device}")

In [None]:
layer_size = 512
np_type = 'float32'

# Highway Network

For the implementation of the highway network it was necessary to read the reference literature. The referenced articled detailed the structure of the highway network and using that it was possible to construct it. Furthermore it was also indicated the value of the bias term.


In [None]:
# Implemented as p
class Highway(nn.Module):
    def __init__(self):
        super(Highway, self).__init__()
        
        # Transformation gate exactly as defined by the paper
        self.transform_gate = nn.Linear(layer_size, layer_size)
        # Authors of the paper recommend filling the bias term with negative values
        self.transform_gate.bias.data.fill_(-2)
        
        # The authors say that:
        # "H is usually an affine transform followed by a non-linear activation function, but in general it may take other forms"
        # https://arxiv.org/pdf/1505.00387.pdf
        self.affine = nn.Linear(layer_size, layer_size)

    def forward(self, x):
        # Sigmoid as recommended by authors
        T = torch.sigmoid(self.transform_gate(x))
        
        # Any non-linear activation can be used here
        H = F.relu(self.affine(x))
        
        return H * T + x * (1 - T)

# Gene Expression Sub-Network

When downloading the gene expression data from the GDC database using the provided code we faced difficulties the article stated that it had it had 10198 cases meanwhile in our case we only managed to download 5898 cases. The same issue was present with the miRNA data where instead of 10125 cases we just got 5149.
Besides only accessing part of the data described in the article and further research on the highway networks the implementation of the gene expression and the miRNA sub-network was straightforward.

In [None]:
class GenExp(nn.Module):
    def __init__(self, n, dropout):
        super(GenExp, self).__init__()
        cycle_layers = []
        for i in range(n):
            l = 60483 if i == 0 else layer_size
            cycle_layers.append(nn.Linear(l, layer_size))
            cycle_layers.append(Highway())
        self.cycle_layers = nn.Sequential(*cycle_layers)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.cycle_layers(x)
        x = self.dropout(x)
        x = torch.sigmoid(x)
        return x


# miRNA Sub-Network

In [None]:
class miRNA(nn.Module):
    def __init__(self, n, dropout):
        super(miRNA, self).__init__()
        cycle_layers = []
        for i in range(n):
            l = 1881 if i == 0 else layer_size
            cycle_layers.append(nn.Linear(l, layer_size))
            cycle_layers.append(Highway())
        self.cycle_layers = nn.Sequential(*cycle_layers)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.cycle_layers(x)
        x = self.dropout(x)
        x = torch.sigmoid(x)
        return x

# Clinical Data Sub-Network

Clinical data sub-network was the simplest one out of the four network its structure was well represented in figure 2 of the article.

In [None]:
class Clinical(nn.Module):
    def __init__(self):
        super(Clinical, self).__init__()
        self.fc = nn.Linear(4, layer_size)

    def forward(self, x):
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

# Whole Slide Images Sub-Network
The reproduction of the algorithm and all the steps necessary to implement this model was successful however due to the memory requirements posed by the fire modules used here our resources (desktop computer with 32GB of ram) were not sufficient. Therefore this module was not used in the analysis. 
Furthermore there were several issues when designing the network. While reproducing the network to encode the WSI into feature vectors, it was necessary to look into the paper used to make the SqueezeNet. Furthermore, a look into the code in the Git repository was necessary to find the exact parameters. Also there is an unmentioned difference between the original SqueezeNet and the used SqueezeNet: the last pooling layer in the original paper uses average pooling, while in the paper in question they use may pooling. Furthermore, because the network processes a batch of 40 samples for each image, it outputs a feature vector for each image. This should be one feature vector since it originates from one image, but the paper does not mention how to achieve this.

In [None]:
class fire(nn.Module):
    def __init__(self, n_channels, s1x1, e1x1, e3x3):
        super(fire,self).__init__()
        self.squeeze = nn.Conv2d(n_channels,s1x1,1)
        self.expand1 = nn.Conv2d(s1x1,e1x1,1)
        self.expand3 = nn.Conv2d(s1x1,e3x3,3,padding=1)
    
    def forward(self, x):
        out1 = F.relu(self.squeeze(x))
        out2 = F.relu(self.expand1(out1))
        out3 = F.relu(self.expand3(out1))
        return torch.cat([out2,out3],1)

In [None]:
class WholeSlide(nn.Module):
    def __init__(self,n_channels):
        super(WholeSlide,self).__init__()
        self.conv1 = nn.Conv2d(n_channels,96,3,stride=2)
        self.pool = nn.MaxPool2d(3,stride=2)
        self.fire2 = fire(96,16,64,64)
        self.fire3 = fire(128,16,64,64)
        self.fire4 = fire(128,32,128,128)
        self.fire5 = fire(256,32,128,128)
        self.fire6 = fire(256,48,192,192)
        self.fire7 = fire(384,48,192,192)
        self.fire8 = fire(384,64,256,256)
        self.fire9 = fire(512,64,256,256)
        self.dropout = nn.Dropout2d(0.5)
        self.conv10 = nn.Conv2d(512,1000,1)
        self.fc = nn.Linear(1000*6*6*40,layer_size)
    
    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.fire2(x)
        x = self.fire3(x)
        x = self.fire4(x)
        x = self.pool(x)
        x = self.fire5(x)
        x = self.fire6(x)
        x = self.fire7(x)
        x = self.fire8(x)
        x = self.pool(x)
        x = self.fire9(x)
        x = self.dropout(x)
        x = self.pool(F.relu(self.conv10(x)))
        x = x.view(1000*6*6*40)
        x = self.fc(x)
        return x

# Main Network

After excluding the WSI network from the reproduction the model only included three elements.

In [None]:
from torch.distributions.bernoulli import Bernoulli

class Network(nn.Module):
    def __init__(self, p=0.25):
        super(Network, self).__init__()
        self.gen_exp = GenExp(10, 0.3)
        self.mirna = miRNA(10, 0.3)
        self.clinical = Clinical()
        self.fc = nn.Linear(layer_size * 3, 1)
        self.d = nn.Dropout2d(p=p)

    def forward(self, x):
        x_gen_exp = self.gen_exp(x['gen_exp'])
        x_mirna = self.mirna(x['mirna'])
        x_clinical = self.clinical(x['clinical'])

        # Perform multimodal dropout if p is not 0
        x_gen_exp = self.d(x_gen_exp.unsqueeze(0)).squeeze()
        x_mirna = self.d(x_mirna.unsqueeze(0)).squeeze()
        x_clinical = self.d(x_clinical.unsqueeze(0)).squeeze()

        # Here, the modalities are merged into a single tensor 
        x_merged = torch.cat([x_gen_exp, x_mirna, x_clinical], -1)
        
        prognosis = self.fc(x_merged)
        
        # Output is prognosis (COX LOSS)
        # As well as the outputs of the sub-networks (SIMILARITY LOSS)
        return {"prognosis": prognosis, "gen_exp": x_gen_exp, "mirna": x_mirna, "clinical": x_clinical}


# Data Manipulation Helpers
Helper functions for data manipulation

In [None]:
# Given a list of datapoints, create a batch
def datapoints_to_batch(xs):
    sample = xs[0]
    out = {}
    for key in sample.keys():
        if key == "type":
            continue
        lst = []
        for x in xs:
            lst.append(x[key])
        out[key] = torch.stack(lst)
    return out

# Send datapoint(s) to cuda, if available
def datapoints_to_device(x):
    y = {}
    for key in x.keys():
        y[key] = x[key].to(device)
    return y

# Load Data

In [None]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from numpy import errstate, isneginf

mirna_path = "./data/miRNA/"
rnaseq_path = "./data/rnaseq/"
clinical_data_path = "./data/original_clinical_dataset.json"
# If you want to load a specific dataset only, specify below
# For example: 'READ' or 'LUSC'
# Empty string: '' for the whole dataset
dataset_to_load = ''

# Use for miRNA and gene expression only
def load_data(path, data_type, data):
    for x in os.listdir(path):
        if "data" not in x:
            continue
        if dataset_to_load not in x:
            continue
        data_path = os.path.join(path, x)
        cancer_type = x.split("_")[1]
        df = pd.read_pickle(data_path, compression="gzip")
        # normalize and mean center the data
#         with errstate(divide = 'ignore'):
#             df = df.apply(lambda x :np.log(x+1))
#             df[isneginf(df)] = 0  # solve the negative infinity problem
#         # mean center it
#         df = df.apply(lambda x: x-x.mean())
        for _, row in df.iterrows():
            
            splitname = row.name.split("-")
            patient = "-".join([splitname[0], splitname[1], splitname[2]])
            if patient not in data:
                data[patient] = {}
                data[patient]['type'] = None
            # WARNING!! TODO
            # What to do if a patient has two mirna data?
            if data_type in data[patient]:
                continue
#                 raise Exception(f"The patient {patient} has two {data_type} specified")

            data[patient][data_type] = torch.tensor(row.values.astype(np_type))
            data[patient]['type'] = cancer_type
            if data[patient]['type'] == 'and':
                data[patient]['type'] = 'HNSC'

def load_clinical_data(path, data_type, data):
    import json
    with open(path, "r") as file:
        clinical_data = json.load(file)
    
    races = []
    genders = []
    for case in clinical_data:
        patient = case['demographic']['submitter_id'].split("_")[0]
        
        # Assumption: clinical data is loaded LAST
        # If we only have clinical data, skip the patient
        if patient not in data:
            continue
        
        race = case['demographic']['race']
        if race not in races:
            races.append(race)
        race = races.index(race)
        
        if 'days_to_birth' in case['demographic']:
            age = case['demographic']['days_to_birth']
        else:
            age = -1
        
        gender = case['demographic']['gender']
        if gender not in genders:
            genders.append(gender)
        gender = genders.index(gender)
        
        # TODO: HISTOLOGICAL GRADE
        
        vital_status = case['demographic']['vital_status']
        # TODO: what is time to death when someone is alive?
        # in the figure reproduction for those patients I used the days to last follow up value
        # time_to_death = case['diagnoses']['days_to_last_follow_up']
        # but based on cox loss desc they could be just skipped "neural network model trained to predict survival times. The loss iscomputed  over  all  patients  whose  lack  of  survival  was  observed."
        if vital_status.lower() == "alive":
            time_to_death = 1e5
        else:
            if 'days_to_death' in case['demographic']:
                time_to_death = case['demographic']['days_to_death']
            else:
                # TODO: what is time of death when not specified?
                time_to_death = 1e5
        
        datapoint = np.array([race, age, gender, 0]).astype(np_type)
        data[patient]['clinical'] = torch.Tensor(datapoint)
        
        time_to_death = np.array([time_to_death]).astype(np_type)
        data[patient]['prognosis'] = torch.Tensor(time_to_death)

def move_to(obj, device):
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        res = {}
        for k, v in obj.items():
            res[k] = move_to(v, device)
            if k == 'type':
                res[k] = v
        return res

dataset = {}
load_data(mirna_path, 'mirna', dataset)
load_data(rnaseq_path, 'gen_exp', dataset)
load_clinical_data(clinical_data_path, 'clinical', dataset)
types = []
for key in dataset.keys():
    typ = dataset[key]['type']
    if typ not in types:
        types.append(typ)
print(f"Loaded patient data with {len(dataset)} patients")


# Handle Missing Data

To handle the missing data the missing fields are filled up with zero values. 

In [None]:
empty_mirna = np.zeros(1881).astype(np_type)
empty_mirna = torch.tensor(empty_mirna)

empty_gen_exp = np.zeros(60483).astype(np_type)
empty_gen_exp = torch.tensor(empty_gen_exp)

empty_clinical = np.zeros(4).astype(np_type)
empty_clinical = torch.tensor(empty_clinical)

# empty_slides = np.zeros([40,3,224,224]).astype(np_type)
# empty_slides = torch.tensor(empty_slides)

empty_prognosis = np.array([0]).astype(np_type)
empty_prognosis = torch.tensor(empty_prognosis)

def fill_empty(data, key, value):
    if key not in data:
        data[key] = value

# Given a datapoint, fill it with 0s if one of the modalities is missing
def fill_dataset(dataset):
    for datapoint in dataset:
        fill_empty(dataset[datapoint], 'mirna', empty_mirna)
        fill_empty(dataset[datapoint], 'gen_exp', empty_gen_exp)
        fill_empty(dataset[datapoint], 'clinical', empty_clinical)
#         fill_empty(dataset[datapoint], 'slides', empty_slides)
        fill_empty(dataset[datapoint], 'prognosis', empty_prognosis)

fill_dataset(dataset)

## Separation of the data to training set and test set

To prevent overfitting the dataset is separated to training and test sets it sizes can be adjusted by parameter test_size.

In [None]:
def create_train_test(dataset,test_size):
    data = pd.DataFrame(dataset).T
    y = data.pop('prognosis').to_frame()
    X = data
    traindf, testdf, ytrain, ytest = train_test_split(X,y,test_size=test_size,shuffle=True,stratify=X.loc[:,'type'])
    traindf['prognosis'] = ytrain
    testdf['prognosis'] = ytest
    trainset = traindf.T.to_dict()
    testset = testdf.T.to_dict()
    return move_to(trainset,device), move_to(testset,device)

import time

t1 = time.time()

test_size = 0.25
trainset, testset = create_train_test(dataset,test_size)

t2 = time.time()

del dataset

torch.cuda.empty_cache()

print(f"Trainset size: {len(trainset)}")
print(f"Testset size: {len(testset)}")
print(f"Data loaded in {t2-t1} seconds")

# Data normalization
Two different data manipulation techniques are used for the miRNA and gene expression data.

1. The data is normalized using the logarithm function: $f(x) = log(x+1)$
2. The data is mean-shifted

All the manipulations occur on feature level

In [None]:
def normalize(dataset, train_means=None):
    yy = ['mirna', 'gen_exp']
    means = {}
    # First, log(x + 1)
    for x in dataset:
        for y in yy:
            if y not in means:
                means[y] = []
            dataset[x][y] = torch.log(dataset[x][y] + 1)
            means[y].append(dataset[x][y])
            
    # Now, Calculate means
    for y in means:
        means[y] = torch.stack(means[y])
        means[y] = means[y].mean(dim=0)

    # Finally, mean shift
    for x in dataset:
        for y in yy:
            if train_means is None:
                dataset[x][y] = dataset[x][y] - means[y]
            else:
                dataset[x][y] = dataset[x][y] - train_means[y]
    
    return means

means = normalize(trainset)
normalize(testset, means)

# Similarity Loss

The purpose of the similarity loss is to create meaningful feature vectors from the different data modalities, by forcing the networks to create similar feature vectors from different data modalities of the same data. The similarity between two feature vectors x and y is defined as follows:

\begin{equation}
si m_{\theta} (x,y) = \sum _{i,\in modalities} \frac{\hat{h}_{\theta , i }(x_i)\cdot \hat{h}_{\theta , j} (y_j)}{\|\hat{h}_{\theta , i }(x_i) \| \| \hat{h}_{\theta , j} (y_j) \| }
\end{equation}

Here $ \hat{h}_{\theta , i}(x_j)$ signifies the feature vector produced by network $h$ with parameters $\theta$ from input data from modility $i$. The formula can be interpreted as the "angle" between the two feature vectors which is small when they are very similar and large when they are dissimilar. Moreover, this "angle" is summed for all pairs of modalities.
The loss for two data points is then: 

\begin{equation}
L_{\theta} (x,y) = max(0, M- sim_{\theta} (x,y)+ sim_{\theta} (x,x))
\end{equation}

This formulation makes sure the loss becomes larger when the vectors are similar between the same patient and dissimilar for different patients. The variable $M$ tunes the realtive importance of the vectors being similar. When $M$ is high, the similarity measure is relatively small so the vectors are allowed to be more dissimilar, and vice versa. The may makes sure the loss cannot become negative.
Finally the loss is summed over all patient data:

\begin{equation}
l_{sim} (\theta ) = \sum _{x,y} L_{\theta} (x,y)
\end{equation}
    

In [None]:
# Calculate all cosine similarities 
# Fast because vectorized
def sim_matrix(data):
    # sim[k][l] is the cosine similarity between patients k and l
    sim = None
    
    # Based on: https://stackoverflow.com/a/50426321
    for i in range(len(data)):
        for j in range(len(data)):
            # Cos similarity
            # = U * V / |U||V|
            # = (U / |V|) * (V / |V|)

            # Add small epsilon for missing modalities
            eps = 1e-5

            # First, calculate the normalized values (U / |V|) and (V / |V|)
            i_norm = data[i] / (data[i].norm(dim=1)[:, None] + eps)
            j_norm = data[j] / (data[j].norm(dim=1)[:, None] + eps)

            # Now calculate for each patient, the dot product using matrix multiplication
            # res[k][l] is cosine similarity of modalities [i][j] for patients [k][l]
            res = torch.mm(i_norm, j_norm.transpose(0, 1))

            # Initialize the tensor
            if (sim is None):
                sim = res
            # Sum up similarities over different modalities
            else:
                sim = sim.add(res)
    return sim


def similarity(data):
    M = 0.1

    # This is sim(x, y)
    sim = sim_matrix(data)
    
    # The diagonal of the matrix is the self-similarity
    # The following holds: simxx(x) = sim(x, x)
    simxx = torch.diag(sim)

    sim = M - sim + simxx
    
    sim.clamp(min=0)

    x = sim.sum()
        
    return x


def similarity_loss(outputs):
    mirna = outputs['mirna']
    genexp = outputs['gen_exp']
    clinical = outputs['clinical']
    
    data = [mirna, genexp, clinical]
    
    return similarity(data) 

# Cox loss


In [None]:
def cox_loss(data, labels):
    '''
    Computes the Cox loss for a batch of patient data for which survival time is available.
    pred: survival predictions from the model (torch.tensor.shape = [1,batchsize])
    labels: actual survival times (labels) of the data (torch.tensor.shape = [1,batchsize] or [batchsize, 1])
    '''
    pred = data['prognosis']
    dead = labels < (1e5 - 1) 
    labels = labels[dead]
    pred = pred[dead]
    labels.squeeze_()
    pred.squeeze_()
    _, ix = labels.sort()
    pred = pred[ix]
    pred = pred.float()
    loss = torch.tensor([0.]).to(device)
    for i in range(len(pred)):
        if i < len(pred)-1:
            loss -= pred[i] - torch.log( torch.exp(pred[i+1:]).sum())
        else:
            loss -= pred[i]
    return loss

# Validation Loop
C-index is calculated for the a given dataset.

In [None]:
def validate(network, loader, dataset):
    # Validation
    actual = None
    predicted = None
    for data_keys in loader:
        data = list(map(lambda x: dataset[x], data_keys))
        data = datapoints_to_batch(data)
        with torch.no_grad():
            output = network(data)
            if actual is None:
                actual = data['prognosis']
                predicted = output['prognosis']
            else:
                actual = torch.cat((actual, data['prognosis']))
                predicted = torch.cat((predicted, output['prognosis']))


    actual = np.squeeze(np.array(actual.cpu()))
    predicted = np.squeeze(np.array(predicted.cpu()))
    
    # Mask is to exclude patients that did not die yet
    mask = actual < (1e5 - 1)
    c = lifelines.utils.concordance_index(actual, -predicted, mask)
    return c

# Training Loop
Completely different than the paper right now, but shows that the network defined above can be trained. Note that the network wont learn anything since training is on random data. 

In [None]:
import lifelines
import torch.optim as optim
import copy

batch_size = 512
epochs = 25
multimodal_dropout = False

if (multimodal_dropout):
    p = 0.25
else:
    p = 0

network = Network(p=p).to(device)
optimizer = optim.AdamW(network.parameters(), lr=0.00001, amsgrad=True)

data_loader = torch.utils.data.DataLoader(list(trainset.keys()), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(list(testset.keys()), batch_size=batch_size, shuffle=False)

losses = [[] for epoch in range(epochs)]
losses_sim = [[] for epoch in range(epochs)]
losses_cox = [[] for epoch in range(epochs)]

C_test = []
C_data = []

counter = 0

best_network_state_dict = None
best_network_c_test = 0.0

for epoch in range(epochs): # Epochs
    for data_keys in data_loader: # Patients
        optimizer.zero_grad()
        counter += 1

        data = list(map(lambda x: trainset[x], data_keys))
        data = datapoints_to_batch(data)
        
        labels = data['prognosis']

        outputs = network(data)

        loss_sim = similarity_loss(outputs)
        loss_cox = cox_loss(outputs, labels)
        
        ratio = loss_cox / loss_sim
        ratio = ratio.detach()
        
        loss_sim = loss_sim * ratio * 0.1
        
        loss_total = loss_sim + loss_cox
        loss_total.backward()
        
        optimizer.step()
        
        losses_sim[epoch].append(loss_sim.item())
        losses_cox[epoch].append(loss_cox.item())
        losses[epoch].append(loss_total.item())

    c_test = validate(network, test_loader, testset)
    c_data = validate(network, data_loader, trainset)
    C_test.append(c_test)
    C_data.append(c_data)
    counter = 0
    print(f"Epoch: {epoch} C-score test: {c_test}, C-score data: {c_data}")
    
    if c_test > best_network_c_test:
        best_network_state_dict = copy.deepcopy(network.state_dict())
        best_network_c_test = c_test
        print(f"Found new best C-score of: {c_test}")

# Various plots

In [None]:
import os

output_dir = "./output"

try:
    os.mkdir(output_dir)
except:
    pass

### Loss curve

In [None]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

plt.figure()
total = np.mean(np.array(losses), axis=1)
plt.plot(total)
plt.xlabel("Epoch")
plt.ylabel("Total loss")
plt.savefig(f"{output_dir}/loss.svg", dpi=1000)


### C-index curve

In [None]:
plt.figure()

plt.plot(C_test)
plt.plot(C_data)
plt.ylim(0.65, 0.85)
plt.legend(["C-index test", "C-index train"])
plt.xlabel("Epoch")
plt.ylabel("C-index")
plt.savefig(f"{output_dir}/default.svg", dpi=1000)

print(max(C_test))

### C-index per cancer site

In [None]:
def sort_per_type(dataset):
    types = []
    for key in dataset.keys():
        typ = dataset[key]['type']
        if typ not in types:
            types.append(typ)
    data_per_type = {typ:{} for typ in types}
    for key in dataset.keys():
        data_per_type[dataset[key]['type']][key] = dataset[key]
    return data_per_type

def validate_per_type(network, dataset):
    sorted_data = sort_per_type(dataset)
    loaders = {}
    for typ, data in sorted_data.items():
        loaders[typ] = torch.utils.data.DataLoader(list(data.keys()), batch_size=batch_size, shuffle=True)
    # Validation
    C = {}
    for typ, loader in loaders.items():
        actual = None
        predicted = None
        for data_keys in loader:
            data = list(map(lambda x: dataset[x], data_keys))
            data = datapoints_to_batch(data)
            with torch.no_grad():
                output = network(data)
                if actual is None:
                    actual = data['prognosis']
                    predicted = output['prognosis']
                else:
                    actual = torch.cat((actual, data['prognosis']))
                    predicted = torch.cat((predicted, output['prognosis']))

        actual = np.squeeze(np.array(actual.cpu()))
        predicted = np.squeeze(np.array(predicted.cpu()))
        
        # Mask is to exclude patients that did not die yet
        mask = actual < (1e5 - 1)
        try:
            c = lifelines.utils.concordance_index(actual, -predicted, mask)
            C[typ] = c
        except:
            print(typ)
    return C

network.load_state_dict(best_network_state_dict)
validate_per_type(network, testset)

In [None]:
def c_per_type(network, dataset):
    allowed_types = ["BLCA", "BRCA", "CESC", "COAD", "READ", "HNSC", "KICH", "KIRC", "KIRP", "LAML", "LGG", "LIHC", "LUAD", "LUSC", "OV", "PAAD", "PRAD", "SKCM", "STAD", "THCA", "UCEC"]
    types = []
    C = []
    for key in dataset.keys():
        typ = dataset[key]['type']
        if typ not in types:
            types.append(typ)
    C = {}
    with torch.no_grad():
        data = datapoints_to_batch(list(dataset.values()))
        output = network(data)

        actual = data['prognosis']
        predicted = output['prognosis']

        actual = np.squeeze(np.array(actual.cpu()))
        predicted = np.squeeze(np.array(predicted.cpu()))
        
        for typ in types:
#             if typ not in allowed_types:
#                 continue

            mask_type = list(map(lambda x: x['type'] == typ, dataset.values()))
            mask = actual < (1e5 - 1)
            try:
                c = lifelines.utils.concordance_index(actual[mask_type], -predicted[mask_type], mask[mask_type])
                C[typ] = c
            except:
                pass
    return C

network.load_state_dict(best_network_state_dict)

Cs = c_per_type(network, testset)
for x in Cs:
    print(f"{x} {round(Cs[x], 3)}")
mn = np.mean(np.array(list(Cs.values())))
print(mn)