# Training individual models

## Imports

In [1]:
import os
import time
import numpy as np
from tqdm import tqdm
import rdflib as rl
import torch
import torchtuples as tt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.preprocessing import MinMaxScaler
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv
from torch_geometric.data import Data, DataLoader, Batch
from torch_geometric.nn import GCNConv, SAGEConv, GraphConv, SAGPooling
from torch_geometric.nn import global_max_pool as gmp
import click as ck
import gzip
import pickle
import sys
import matplotlib.pyplot as plt
import statistics
import pandas as pd

## Constant variables

In [2]:
# Manually categorized cancer subtypes
CANCER_SUBTYPES = [
    [0,12,7,14,4,1,6,2,3],
    [4],
    [5,4,14,6],
    [6,4,12,7],
    [4],
    [6,4,12,7],
    [8],
    [6,4,12],
    [9],
    [6],
    [4],
    [4],
    [4],
    [10],
    [9],
    [4],
    [4,11,12],
    [6],
    [13],
    [12],
    [0,4,12,14],
    [15],
    [4,0,12],
    [4,12],
    [16,17,18,19,20],
    [20],
    [4,12],
    [22],
    [4,14],
    [23],
    [4,12,14],
    [24],
    [21]
]

CELL_TYPES = [
    0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 3, 0, 0, 4, 2, 0,
    0, 0, 5, 0, 0, 6, 0, 0, 7, 8, 0, 9, 0, 0, 0, 0,
    8]

cancer_types = [
    "TCGA-ACC", "TCGA-BLCA", "TCGA-BRCA", "TCGA-CESC",
    "TCGA-CHOL", "TCGA-COAD", "TCGA-DLBC", "TCGA-ESCA",
    "TCGA-GBM", "TCGA-HNSC", "TCGA-KICH", "TCGA-KIRC",
    "TCGA-KIRP", "TCGA-LAML","TCGA-LGG","TCGA-LIHC",
    "TCGA-LUAD","TCGA-LUSC","TCGA-MESO","TCGA-OV",
    "TCGA-PAAD","TCGA-PCPG","TCGA-PRAD","TCGA-READ",
    "TCGA-SARC","TCGA-SKCM","TCGA-STAD","TCGA-TGCT",
    "TCGA-THCA","TCGA-THYM","TCGA-UCEC","TCGA-UCS","TCGA-UVM"]


## Load proteins and interactions

In [3]:
device = 'cuda:0'

proteins_df = pd.read_pickle('data/proteins.pkl')
interactions_df = pd.read_pickle('data/interactions.pkl')
proteins = {row.proteins: row.ids for row in proteins_df.itertuples()}
edge_index = [interactions_df['protein1'].values, interactions_df['protein2'].values]
edge_index = torch.LongTensor(edge_index).to(device)

proteins, edge_index


({'ENSP00000000233': 0,
  'ENSP00000432568': 1,
  'ENSP00000427900': 2,
  'ENSP00000350199': 3,
  'ENSP00000354878': 4,
  'ENSP00000405926': 5,
  'ENSP00000314615': 6,
  'ENSP00000349588': 7,
  'ENSP00000414982': 8,
  'ENSP00000480707': 9,
  'ENSP00000324020': 10,
  'ENSP00000300087': 11,
  'ENSP00000268919': 12,
  'ENSP00000384164': 13,
  'ENSP00000388878': 14,
  'ENSP00000258739': 15,
  'ENSP00000310226': 16,
  'ENSP00000303145': 17,
  'ENSP00000364864': 18,
  'ENSP00000404190': 19,
  'ENSP00000263373': 20,
  'ENSP00000359000': 21,
  'ENSP00000380308': 22,
  'ENSP00000380432': 23,
  'ENSP00000297044': 24,
  'ENSP00000401010': 25,
  'ENSP00000262305': 26,
  'ENSP00000324287': 27,
  'ENSP00000264712': 28,
  'ENSP00000273130': 29,
  'ENSP00000354560': 30,
  'ENSP00000357048': 31,
  'ENSP00000223369': 32,
  'ENSP00000249923': 33,
  'ENSP00000480301': 34,
  'ENSP00000378356': 35,
  'ENSP00000361057': 36,
  'ENSP00000341344': 37,
  'ENSP00000360532': 38,
  'ENSP00000320130': 39,
  'ENSP000

In [4]:
# device = torch.device('cpu')

# cancer_type_vector = np.zeros((33,), dtype=np.float32)
# cancer_type_vector[cancer_type] = 1

# cancer_subtype_vector = np.zeros((25,), dtype=np.float32)
# for i in CANCER_SUBTYPES[cancer_type]:
#     cancer_subtype_vector[i] = 1

# anatomical_location_vector = np.zeros((52,), dtype=np.float32)
# anatomical_location_vector[0] = 1
# cell_type_vector = np.zeros((10,), dtype=np.float32)
# cell_type_vector[CELL_TYPES[cancer_type]] = 1

# pt_tensor_cancer_type = torch.FloatTensor(cancer_type_vector).to(device)
# pt_tensor_cancer_subtype = torch.FloatTensor(cancer_subtype_vector).to(device)
# pt_tensor_anatomical_location = torch.FloatTensor(anatomical_location_vector).to(device)
# pt_tensor_cell_type = torch.FloatTensor(cell_type_vector).to(device)
# edge_index = torch.LongTensor(edge_index).to(device)


## Model class

In [5]:
class MyNet(nn.Module):
    def __init__(self, num_nodes, edge_index):
        super(MyNet, self).__init__()
        self.num_nodes = num_nodes
        self.edge_index = edge_index
        self.conv1 = GCNConv(6, 64)
        self.pool1 = SAGPooling(64, ratio=0.70, GNN=GCNConv)
        self.conv2 = GCNConv(64, 1)
        self.fc1 = nn.Linear(12030, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, data):
        batch_size = data.shape[0]
        x = data[:, :self.num_nodes * 6]
        x = x.reshape(batch_size, self.num_nodes, 6)
        l = []
        for i in range(batch_size):
            l.append(Data(x=x[i], edge_index=self.edge_index))
        batch = Batch.from_data_list(l)
        x = F.relu(self.conv1(x=batch.x, edge_index=batch.edge_index))
        x, edge_index, _, batch, perm, score = self.pool1(x, batch.edge_index, None, batch.batch)
        x = F.relu(self.conv2(x, edge_index=edge_index))
        # x = gmp(x, batch)
        x = x.view(batch_size, -1)
        x = self.sigmoid(self.fc1(x))
        return x
    
net = MyNet(len(proteins), edge_index).to(device)
net

MyNet(
  (conv1): GCNConv(6, 64)
  (pool1): SAGPooling(GCNConv, 64, ratio=0.7, multiplier=1.0)
  (conv2): GCNConv(64, 1)
  (fc1): Linear(in_features=12030, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

## Train a model for each cancer type

In [6]:
for i, cancer_type in enumerate(cancer_types[:1]):
    # Create a new model for each cancer type
    
    df = pd.read_pickle(f'data/{cancer_type}.pkl')
    print(df)
    
    dataset = np.stack(df['features'].values).reshape(len(df), -1)
    # dataset = torch.FloatTensor(dataset)
    print(dataset.shape)
    in_features = dataset.shape[1]
    labels_days = df['duration'].values
    labels_surv = df['survival'].values
    
    censored_index = []
    uncensored_index = []
    for i in range(len(dataset)):
        if labels_surv[i] == 1:
            censored_index.append(i)
        else:
            uncensored_index.append(i)
    print('Censored', len(censored_index))
    print('Uncensored', len(uncensored_index))
    
    censored_index = np.array(censored_index)
    uncensored_index = np.array(uncensored_index)

    ev_ = []
    splits = 5
    best_cindex = 0
    for fold in range(splits):
        del net
        torch.manual_seed(0)
        # net = MyNet(len(proteins), edge_index).to(device)
        net = tt.practical.MLPVanilla(in_features, [1024, 512], 1, True,
                             0.1, output_bias=False)
        model = CoxPH(net, tt.optim.Adam(0.01))
        # Censored split
        n = len(censored_index)
        index = np.arange(n)
        i = n // 5
        np.random.seed(seed=0)
        np.random.shuffle(index)
        if fold < 4:
            ctest_idx = index[fold * i: fold * i + i]
            ctrain_idx = np.concatenate((index[:fold * i],index[fold * i + i:]))
        else:
            ctest_idx = index[fold * i:]
            ctrain_idx = index[:fold * i]
        ctrain_n = len(ctrain_idx)
        cvalid_n = ctrain_n // 10
        cvalid_idx = ctrain_idx[:cvalid_n]
        ctrain_idx = ctrain_idx[cvalid_n:]

        # Uncensored split
        n = len(uncensored_index)
        index = np.arange(n)
        i = n // 5
        np.random.seed(seed=0)
        np.random.shuffle(index)
        if fold < 4:
            utest_idx = index[fold * i: fold * i + i]
            utrain_idx = np.concatenate((index[:fold * i],index[fold * i + i:]))
        else:
            utest_idx = index[fold * i:]
            utrain_idx = index[:fold * i]
        utrain_n = len(utrain_idx)
        uvalid_n = utrain_n // 10
        uvalid_idx = utrain_idx[:uvalid_n]
        utrain_idx = utrain_idx[uvalid_n:]


        train_idx = np.concatenate((
            censored_index[ctrain_idx], uncensored_index[utrain_idx]))
        np.random.seed(seed=0)
        np.random.shuffle(train_idx)
        valid_idx = np.concatenate((
            censored_index[cvalid_idx], uncensored_index[uvalid_idx]))
        np.random.seed(seed=0)
        np.random.shuffle(valid_idx)
        test_idx = np.concatenate((
            censored_index[ctest_idx], uncensored_index[utest_idx]))
        np.random.seed(seed=0)
        np.random.shuffle(test_idx)

        min_max_scaler = MinMaxScaler()
        train_data = dataset[train_idx]
        train_data = min_max_scaler.fit_transform(train_data)
        train_labels_days = labels_days[train_idx]
        train_labels_surv = labels_surv[train_idx]
        train_labels = (train_labels_days, train_labels_surv)

        val_data = dataset[valid_idx]
        val_data = min_max_scaler.transform(val_data)
        val_labels_days = labels_days[valid_idx]
        val_labels_surv = labels_surv[valid_idx]
        test_data = dataset[test_idx]
        test_data = min_max_scaler.transform(test_data)
        test_labels_days = labels_days[test_idx]
        test_labels_surv = labels_surv[test_idx]
        val_labels = (val_labels_days, val_labels_surv)

        print(val_labels)
        print('Training data', train_data.shape)
        print('Validation data', val_data.shape)
        print('Testing data', test_data.shape)
        callbacks = [tt.callbacks.EarlyStopping(patience=10)]
        batch_size = 32
        epochs = 100
        val = (val_data, val_labels)
        log = model.fit(
            train_data, train_labels, batch_size, epochs, callbacks, verbose=True,
            val_data=val,
            val_batch_size=batch_size)
        train = train_data, train_labels
        # Compute the evaluation measurements
        _ = model.compute_baseline_hazards(*train)
        surv = model.predict_surv_df(test_data)
        ev = EvalSurv(surv, test_labels_days, test_labels_surv)
        result = ev.concordance_td()
        print('Concordance', result)
        ev_.append(result)

        if result > best_cindex:
            best_cindex = result
            
            np.savetxt('test'+str(fold+1)+'.csv', test_data, delimiter="\t")
                
            np.savetxt('test_labels_days'+str(fold+1)+'.csv', test_labels_days, delimiter="\t")
            
            np.savetxt('test_labels_surv'+str(fold+1)+'.csv', test_labels_surv, delimiter="\t")

    print(cancer_type)            
    print(str(statistics.mean(ev_))+"["+str(min(ev_))+"-"+str(max(ev_))+"]")

    survival  duration                                           features
0          0      2703  [[0.1338738, 0.0, 0.0022847964, 0.15337788, 0....
1          1       822  [[0.3156823, 0.0, 0.0018584887, 0.042907022, 0...
2          0       967  [[0.12269625, 0.0, 0.002814955, 0.4678744, 0.0...
3          1      1029  [[0.07559731, 0.0, 0.0061716787, 0.2649523, 0....
4          1      1201  [[0.40974522, 0.0, 0.0008181625, 0.2413451, 0....
..       ...       ...                                                ...
75         1      1096  [[0.12181443, 0.0, 0.0015971365, 0.077812135, ...
76         1      4673  [[0.08773822, 0.0, 0.002590872, 0.26549563, 0....
77         1      2740  [[0.3526518, 0.0, 0.0014291713, 0.74829865, 0....
78         0       253  [[0.06243348, 0.0, 0.006840582, 0.1760158, -1....
79         1      1194  [[0.21361445, 0.0, 0.00591789, 0.29969403, 0.0...

[80 rows x 3 columns]
(80, 103110)
Censored 52
Uncensored 28
(array([2342, 1364, 1317,  871, 1858,  253]), arra

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


Concordance 0.6666666666666666
(array([ 309, 1857,  861,  950,  616, 2091]), array([0, 1, 1, 1, 1, 0]))
Training data (59, 103110)
Validation data (6, 103110)
Testing data (15, 103110)
0:	[1s / 1s],		train_loss: 2.3477,	val_loss: 27.0933
1:	[1s / 2s],		train_loss: 5.2627,	val_loss: 0.0464
2:	[0s / 2s],		train_loss: 3.4598,	val_loss: 5.2792
3:	[0s / 2s],		train_loss: 2.3506,	val_loss: 0.5922
4:	[0s / 2s],		train_loss: 1.7068,	val_loss: 1.2903
5:	[0s / 2s],		train_loss: 1.5925,	val_loss: 0.6653
6:	[0s / 3s],		train_loss: 1.3487,	val_loss: 0.4809
7:	[0s / 3s],		train_loss: 1.3536,	val_loss: 0.5237
8:	[0s / 3s],		train_loss: 1.2418,	val_loss: 0.5649
9:	[0s / 3s],		train_loss: 1.2776,	val_loss: 0.4393
10:	[0s / 3s],		train_loss: 1.0327,	val_loss: 0.4136
11:	[0s / 3s],		train_loss: 0.9457,	val_loss: 0.4473
Concordance 0.7101449275362319
(array([ 309, 1857,  861,  950,  616, 2091]), array([0, 1, 1, 1, 1, 0]))
Training data (59, 103110)
Validation data (6, 103110)
Testing data (15, 103110)
0:	