<a id='Top'></a>

# MultiSurv results<a class='tocSkip'></a>

Evaluation metric results for MultiSurv.

In [27]:
%load_ext autoreload
%autoreload 2

%load_ext watermark

import sys
import os

import ipywidgets as widgets
import numpy as np
import pandas as pd
import torch

# Make modules in "src" dir visible
project_dir = os.path.split(os.getcwd())[0]
if project_dir not in sys.path:
    sys.path.append(os.path.join(project_dir, 'src'))

import dataset
from model import Model
import utils

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark


<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#DataLoader" data-toc-modified-id="DataLoader-1"><span class="toc-item-num">1&nbsp;&nbsp;</span><code>DataLoader</code></a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model</a></span><ul class="toc-item"><li><span><a href="#Load-weights" data-toc-modified-id="Load-weights-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Load weights</a></span></li></ul></li><li><span><a href="#Evaluate" data-toc-modified-id="Evaluate-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Evaluate</a></span><ul class="toc-item"><li><span><a href="#Write-to-results-table" data-toc-modified-id="Write-to-results-table-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Write to results table</a></span></li><li><span><a href="#Check-results-on-all-datasets" data-toc-modified-id="Check-results-on-all-datasets-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Check results on all datasets</a></span></li></ul></li></ul></div>

In [14]:
DATA = '/mnt/data/d.kornilov/TCGA/processed_GBM_LGG'
MODELS = '/home/d.kornilov/work/multisurv/outputs/models_gbm_lgg'
LABELS_FILE = '/home/d.kornilov/work/multisurv/data/labels_gbm_lgg.tsv'
MODEL = 'clinical_lr0.005_epoch71_concord0.83.pth'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# `DataLoader`

In [15]:
modalities = widgets.SelectMultiple(
    options=['clinical', 'mRNA', 'DNAm', 'miRNA', 'CNV', 'wsi'],
    index=[0],
    rows=6,
    description='Input data',
    disabled=False
)
display(modalities)

SelectMultiple(description='Input data', index=(0,), options=('clinical', 'mRNA', 'DNAm', 'miRNA', 'CNV', 'wsi…

In [16]:
dataloaders = utils.get_dataloaders(
    data_location=DATA,
    labels_file=LABELS_FILE,
    modalities=modalities.value,
    wsi_patch_size=299,
    n_wsi_patches=5,
#     exclude_patients=exclude_cancers,
    num_workers=8,
    drop_last=False 
)

for split, dataloader in dataloaders.items():
    print(f"{split} dataloader: {len(dataloader)}")

Data modalities:
   clinical

Dataset sizes (# patients):
   train: 8880
   val: 1109
   test: 1092

Batch size: 128
train dataloader: 70
val dataloader: 9
test dataloader: 9


# Model

In [17]:
multisurv = Model(dataloaders=dataloaders,
#                   output_intervals=prediction_intervals,
                  device=device)

Instantiating MultiSurv model...




In [18]:
print('Output intervals (in years):')
print(multisurv.output_intervals / 365)

Output intervals (in years):
tensor([0.0000e+00, 3.0815e-33, 1.0000e+00, 2.0000e+00, 3.0000e+00, 4.0000e+00,
        5.0000e+00, 6.0000e+00, 7.0000e+00, 8.0000e+00, 9.0000e+00, 1.0000e+01,
        1.1000e+01, 1.2000e+01, 1.3000e+01, 1.4000e+01, 1.5000e+01, 1.6000e+01,
        1.7000e+01, 1.8000e+01, 1.9000e+01, 2.0000e+01, 2.1000e+01, 2.2000e+01,
        2.3000e+01, 2.4000e+01, 2.5000e+01, 2.6000e+01, 2.7000e+01, 2.8000e+01,
        2.9000e+01])


## Load weights

In [19]:
!ls -1 /home/d.kornilov/work/multisurv/outputs/models

clinical_lr0.005_epoch45_concord0.79.pth


In [8]:
!ls -1 /mnt/dataA/multisurv_models/wsi*

/mnt/dataA/multisurv_models/wsi5patches299px_lr0.001_epoch44_concord0.55.pth


In [20]:
# Best model
multisurv.load_weights(os.path.join(MODELS, MODEL))

Load model weights:
/home/d.kornilov/work/multisurv/outputs/models/clinical_lr0.005_epoch45_concord0.79.pth


In [21]:
for modality in modalities.value:
    print(modality)

clinical


# Evaluate

In [22]:
%%time

# Using MultiSurv's default output intervals
performance = utils.Evaluation(model=multisurv, dataset=dataloaders['test'].dataset,
                               device=device)
performance.run_bootstrap()
print()

Collect patient predictions: 1092/1092

Bootstrap
---------


100%|██████████| 1000/1000 [11:04<00:00,  1.51it/s]



CPU times: user 18min 14s, sys: 41min 45s, total: 59min 59s
Wall time: 11min 36s





In [25]:
data_modalities = ' + '.join(modalities.value) if len(modalities.value) > 1 else modalities.value[0]
print(f'>> {data_modalities} <<')
print()
performance.show_results()

>> clinical <<

          Value (95% CI)
-----------------------------
C-index   0.797 (0.781-0.815)
Ctd       0.804 (0.788-0.822)
IBS       0.144 (0.133-0.158)
INBLL     0.444 (0.412-0.479)


In [26]:
data_modalities = ' + '.join(modalities.value) if len(modalities.value) > 1 else modalities.value[0]
print(f'>> {data_modalities} <<')
print()
performance.show_results(method='empirical')

>> clinical <<

          Value (95% CI)
-----------------------------
C-index   0.797 (0.781-0.815)
Ctd       0.804 (0.788-0.822)
IBS       0.144 (0.133-0.158)
INBLL     0.444 (0.412-0.479)


## Write to results table

In [16]:
results = utils.ResultTable()

In [17]:
data_modalities = ' + '.join(modalities.value) if len(modalities.value) > 1 else modalities.value[0]

results.write_result_dict(result_dict=performance.format_results(),
                          algorithm='MultiSurv',
                          data_modality=data_modalities)
results.table

Unnamed: 0_level_0,Unnamed: 1_level_0,clinical,mRNA,DNAm,miRNA,CNV,wsi,clinical + mRNA,clinical + DNAm,clinical + mRNA + DNAm,clinical + mRNA + DNAm + miRNA + CNV,clinical + mRNA + DNAm + miRNA,clinical + miRNA,clinical + CNV,mRNA + DNAm,clinical + mRNA + DNAm + miRNA + CNV + wsi,clinical + wsi
Algorithm,Metric,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
CPH,C-index,0.796 (0.779-0.813),0.733 (0.712-0.756),0.739 (0.719-0.76),0.676 (0.651-0.7),0.57 (0.543-0.599),,,,,,,,,,,
CPH,Ctd,0.796 (0.779-0.813),0.733 (0.712-0.755),0.739 (0.719-0.76),0.676 (0.651-0.7),0.57 (0.543-0.599),,,,,,,,,,,
CPH,IBS,0.143 (0.135-0.154),0.177 (0.165-0.19),0.179 (0.165-0.192),0.186 (0.171-0.202),0.214 (0.207-0.224),,,,,,,,,,,
CPH,INBLL,0.438 (0.414-0.465),0.528 (0.497-0.558),0.532 (0.499-0.563),0.547 (0.511-0.585),0.617 (0.601-0.64),,,,,,,,,,,
RSF,C-index,0.764 (0.744-0.782),0.718 (0.695-0.741),0.728 (0.707-0.751),0.663 (0.638-0.688),0.604 (0.579-0.63),,,,,,,,,,,
RSF,Ctd,0.77 (0.751-0.789),0.719 (0.695-0.741),0.729 (0.709-0.752),0.664 (0.639-0.689),0.604 (0.579-0.63),,,,,,,,,,,
RSF,IBS,0.184 (0.179-0.191),0.191 (0.181-0.2),0.186 (0.176-0.192),0.193 (0.183-0.201),0.217 (0.208-0.225),,,,,,,,,,,
RSF,INBLL,0.546 (0.533-0.56),0.563 (0.537-0.581),0.55 (0.527-0.564),0.567 (0.543-0.583),0.621 (0.602-0.64),,,,,,,,,,,
DeepSurv,C-index,0.792 (0.773-0.811),0.746 (0.722-0.768),0.76 (0.74-0.78),0.685 (0.661-0.711),0.596 (0.571-0.621),,,,,,,,,,,
DeepSurv,Ctd,0.792 (0.773-0.81),0.746 (0.722-0.768),0.759 (0.739-0.78),0.685 (0.661-0.711),0.596 (0.571-0.621),,,,,,,,,,,


## Check results on all datasets

In [12]:
%%time

print('~' * 23)
print('     RESULT CHECK')
print('~' * 23)
check_results = {'train': None, 'val': None, 'test': None}

for group in check_results.keys():
    print(f'~ {group} ~')
    performance = utils.Evaluation(model=multisurv, dataset=dataloaders[group].dataset, device=device)
    performance.compute_metrics()
    performance.show_results()
    print()

~~~~~~~~~~~~~~~~~~~~~~~
     RESULT CHECK
~~~~~~~~~~~~~~~~~~~~~~~
~ train ~
Collect patient predictions: 8880/8880

C-index   0.843
Ctd       0.85
IBS       0.112
INBLL     0.349

~ val ~
Collect patient predictions: 1109/1109

C-index   0.805
Ctd       0.808
IBS       0.137
INBLL     0.423

~ test ~
Collect patient predictions: 1092/1092

C-index   0.818
Ctd       0.822
IBS       0.138
INBLL     0.425

CPU times: user 5min 36s, sys: 1min 18s, total: 6min 55s
Wall time: 13min 54s


# Watermark<a class='tocSkip'></a>

In [13]:
%watermark --iversions
%watermark -v
print()
%watermark -u -n

pandas     1.0.1
torch      1.4.0
numpy      1.18.1
ipywidgets 7.5.1

CPython 3.6.7
IPython 7.11.1

last updated: Tue Jul 28 2020


[Top of the page](#Top)