In [None]:
# CUDA_VISIBLE_DEVICES=0,1
# CUDA_LAUNCH_BLOCKING=1
import os
import time
import numpy as np
from tqdm import tqdm
import rdflib as rl
import torch
import torchtuples as tt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.preprocessing import MinMaxScaler
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv
from torch_geometric.data import Data, DataLoader, Batch
from torch_geometric.nn import GCNConv, SAGEConv, GraphConv, SAGPooling, GENConv, GATConv
from torch_geometric.nn import global_max_pool as gmp
from torch.nn.parallel import DistributedDataParallel as DDP
import click as ck
import gzip
import pickle
import sys
import matplotlib.pyplot as plt
import statistics
import pandas as pd
import random
# from torch.utils.tensorboard import SummaryWriter

In [None]:
# Manually categorized cancer subtypes
CANCER_SUBTYPES = [
    [0,12,7,14,4,1,6,2,3],
    [4],
    [5,4,14,6],
    [6,4,12,7],
    [4],
    [6,4,12,7],
    [8],
    [6,4,12],
    [9],
    [6],
    [4],
    [4],
    [4],
    [10],
    [9],
    [4],
    [4,11,12],
    [6],
    [13],
    [12],
    [0,4,12,14],
    [15],
    [4,0,12],
    [4,12],
    [16,17,18,19,20],
    [20],
    [4,12],
    [22],
    [4,14],
    [23],
    [4,12,14],
    [24],
    [21]
]

CELL_TYPES = [
    0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 3, 0, 0, 4, 2, 0,
    0, 0, 5, 0, 0, 6, 0, 0, 7, 8, 0, 9, 0, 0, 0, 0,
    8]

cancer_types = [
    "TCGA-ACC", "TCGA-BLCA", "TCGA-BRCA", "TCGA-CESC",
    "TCGA-CHOL", "TCGA-COAD", "TCGA-DLBC", "TCGA-ESCA",
    "TCGA-GBM", "TCGA-HNSC", "TCGA-KICH", "TCGA-KIRC",
    "TCGA-KIRP", "TCGA-LAML","TCGA-LGG","TCGA-LIHC",
    "TCGA-LUAD","TCGA-LUSC","TCGA-MESO","TCGA-OV",
    "TCGA-PAAD","TCGA-PCPG","TCGA-PRAD","TCGA-READ",
    "TCGA-SARC","TCGA-SKCM","TCGA-STAD","TCGA-TGCT",
    "TCGA-THCA","TCGA-THYM","TCGA-UCEC","TCGA-UCS","TCGA-UVM"]


# cancer = str(sys.argv[1])
cancer = 'TCGA-ACC'

In [None]:
device = 'cuda:0'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# def setup(rank, world_size):
#     os.environ['MASTER_ADDR'] = 'localhost'
#     os.environ['MASTER_PORT'] = '12355'

#     # initialize the process group
#     dist.init_process_group("gloo", rank=rank, world_size=world_size)

# def cleanup():
#     dist.destroy_process_group()

# torch.cuda.set_device(device)

proteins_df = pd.read_pickle('data/proteins.pkl')
interactions_df = pd.read_pickle('data/interactions.pkl')
proteins = {row.proteins: row.ids for row in proteins_df.itertuples()}
edge_index = [interactions_df['protein1'].values, interactions_df['protein2'].values]
edge_index = torch.LongTensor(edge_index).to(device)

# proteins, edge_index

In [None]:
class MyNet(nn.Module):
    def __init__(self, num_nodes, edge_index):
        super(MyNet, self).__init__()
        self.num_nodes = num_nodes
        self.edge_index = edge_index
        self.conv1 = GraphConv(6, 6)
        self.pool1 = SAGPooling(6, ratio=0.1, GNN=GraphConv)
        self.fc1 = nn.Linear(10314, 1024, bias=False)
        self.bn1 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.1)
        self.fc2 = nn.Linear(1024, 512, bias=False)
        self.bn2 = nn.BatchNorm1d(512)
        self.dropout2 = nn.Dropout(0.1)
        self.fc3 = nn.Linear(512, 1, bias=False)
        self.bn3 = nn.BatchNorm1d(1)
        self.dropout3 = nn.Dropout(0.1)
        self.sigmoid = nn.Sigmoid()
        self.batches = {}

    def forward(self, data):
#         print(data.shape)
        batch_size = data.shape[0]
        x = data[:, :self.num_nodes * 6]
        x = x.reshape(batch_size, self.num_nodes, 6)
        l = []
        for i in range(batch_size):
            l.append(Data(x=x[i], edge_index=self.edge_index))
        batch = Batch.from_data_list(l).to(device)
        x = x.reshape(-1, 6)
        x = F.relu(self.conv1(x=x, edge_index=batch.edge_index))
        x, edge_index, _, batch, perm, score = self.pool1(
            x, batch.edge_index, None, batch.batch)
        x = x.view(batch_size, -1)
        # print(x.shape)
        x = self.dropout1(self.bn1(torch.relu(self.fc1(x))))
        x = self.dropout2(self.bn2(torch.relu(self.fc2(x))))
        x = self.dropout3(self.bn3(self.fc3(x)))
        return x
    
net = MyNet(len(proteins), edge_index).to(device)
net

In [None]:
def normalize(data, minx=None, maxx=None):
    if minx is None:
        minx = np.min(data)
        maxx = np.max(data)
    if minx == maxx:
        return data
    return (data - minx) / (maxx - minx)
        
def normalize_by_row(data):
    for i in range(data.shape[0]):
        data[i, :] = normalize(data[i, :])
    return data

def normalize_by_column(data):
    for i in range(data.shape[1]):
        data[:, i] = normalize(data[:, i])
    return data

In [None]:
lr_ = 0.01
pat_ = 10
cancer_combined = pd.DataFrame()
for i, cancer__type in enumerate(cancer_types):
    if cancer__type != cancer:
        df = pd.read_pickle(f'preprocessing_codes/{cancer__type}.pkl')
#         print(df.shape)
        cancer_combined = pd.concat([cancer_combined, df], ignore_index=True)
# print(cancer_combined.shape)
# print("Done 1")
val_dataset = cancer_combined.sample(frac =.2)

val_labels_days = val_dataset['duration'].values
val_labels_surv = val_dataset['survival'].values

cancer_combined = cancer_combined.drop(val_dataset.index)

dataset = np.stack(cancer_combined['features'].values).reshape(len(cancer_combined), -1)

val_dataset = np.stack(val_dataset['features'].values).reshape(len(val_dataset), -1)

# print('Validation shape = '+val_dataset.shape)

in_features = dataset.shape[1]
labels_days = cancer_combined['duration'].values
labels_surv = cancer_combined['survival'].values

df_test = pd.read_pickle(f'preprocessing_codes/{cancer}.pkl')

test_labels_days = df_test['duration'].values
test_labels_surv = df_test['survival'].values

df_test = np.stack(df_test['features'].values).reshape(len(df_test), -1)

# print("Done 2")
num_features = 6
num_nodes = 17185

del net
torch.manual_seed(0)
net = MyNet(len(proteins), edge_index).to(device)

model = CoxPH(net, tt.optim.Adam(lr_))

train_data = dataset

train_data = train_data.reshape(-1, num_nodes, num_features)

for i in range(num_features):
    if i <= 5:
        train_data[:, :, i] = normalize_by_row(train_data[:, :, i])
    elif i > 5 and i <= 11:
        train_data[:, :, i] = normalize_by_column(train_data[:, :, i])
    else:
        train_data[:, :, i] = normalize(train_data[:, :, i])

train_data = train_data.reshape(-1, num_nodes * num_features)

train_labels_days = labels_days
train_labels_surv = labels_surv
train_labels = (train_labels_days, train_labels_surv)
# print("Done 3")
val_data = val_dataset
val_data = val_data.reshape(-1, num_nodes, num_features)
for i in range(num_features):
    if i <= 5:
        val_data[:, :, i] = normalize_by_row(val_data[:, :, i])
    elif i > 5 and i <= 11:
        val_data[:, :, i] = normalize_by_column(val_data[:, :, i])
    else:
        val_data[:, :, i] = normalize(val_data[:, :, i])

val_data = val_data.reshape(-1, num_nodes * num_features)
# print("Done 4")
test_data = df_test
test_data = test_data.reshape(-1, num_nodes, num_features)
for i in range(num_features):
    if i <= 5:
        test_data[:, :, i] = normalize_by_row(test_data[:, :, i])
    elif i > 5 and i <= 11:
        test_data[:, :, i] = normalize_by_column(test_data[:, :, i])
    else:
        test_data[:, :, i] = normalize(test_data[:, :, i])

test_data = test_data.reshape(-1, num_nodes * num_features)
# print("Done 5")
val_labels = (val_labels_days, val_labels_surv)

# print(val_labels)
# print('Training data', train_data.shape)
# print('Validation data', val_data.shape)
# print('Testing data', test_data.shape)
callbacks = [tt.callbacks.EarlyStopping(patience=pat_)]
batch_size = 32
epochs = 100
val = (val_data, val_labels)
log = model.fit(
    train_data, train_labels, batch_size, epochs, callbacks, verbose=True,
    val_data=val,
    val_batch_size=batch_size, shuffle=True)
train = train_data, train_labels
# Compute the evaluation measurements
_ = model.compute_baseline_hazards(*train)
surv = model.predict_surv_df(test_data)
ev = EvalSurv(surv, test_labels_days, test_labels_surv)
result = ev.concordance_td()

print(cancer)
print('lr = '+str(lr_))
print('patience = '+str(pat_))
print('normalization = row')
print('Concordance', result)