In [9]:
import os
import sys
import argparse
import torch
import torch.multiprocessing
import torch.nn as nn
from torch.nn.modules.module import Module
from scipy.spatial.distance import cdist
from sklearn.metrics import pairwise_distances, adjusted_rand_score, normalized_mutual_info_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import snf
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.cluster import spectral_clustering, KMeans
from sklearn.metrics import v_measure_score
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
from MIND import MIND

# TCGA datasets
We train and test the 17 multiomics TCGA dataests using the proposed model. Training/testing of the CCMA dataset is identical.
## Training

In [2]:
# specify hyperparameters
emb_dim = 64
lr = 1e-4
epoch = 5000

In [3]:
np.random.seed(31415)
torch.manual_seed(31415)

cancer_types = ['GBM', 'HNSC', 'LUSC', 'LIHC', 'LUAD', 'KIRC', 'BRCA', 'LGG', 'OV', 'PRAD', 'SKCM', 'THCA', 'BLCA', 'STAD', 'UCEC', 'COADREAD', 'COAD', 'CESC']  # do COAD seperately

for cancer_type in cancer_types:
    clinic_data = pd.read_csv('./TCGA_preprocessed/{}/clinic_data.csv'.format(cancer_type), header=0, index_col=0)
    RNA_data = pd.read_csv('./TCGA_preprocessed/{}/RNA_data.csv'.format(cancer_type), header=0, index_col=0)
    methyl_data = pd.read_csv('./TCGA_preprocessed/{}/meth_data.csv'.format(cancer_type), header=0, index_col=0)
    rppa_data = pd.read_csv('./TCGA_preprocessed/{}/rppa_data_imp.csv'.format(cancer_type), header=0, index_col=0)
    cna_data = pd.read_csv('./TCGA_preprocessed/{}/cna_data.csv'.format(cancer_type), header=0, index_col=0)
    if os.path.isfile('./TCGA_preprocessed/{}/miRNA_data_imp.csv'.format(cancer_type)):
        miRNA_data = pd.read_csv('./TCGA_preprocessed/{}/miRNA_data_imp.csv'.format(cancer_type), header=0, index_col=0)

    N = clinic_data.shape[0]
    print('Cancer type = {}'.format(cancer_type))
    print('total number of patients = {}'.format(N))
    if os.path.isfile('./TCGA_preprocessed/{}/miRNA_data_imp.csv'.format(cancer_type)):
        data_dict = {'RNA': RNA_data, 'methyl': methyl_data, 'CNA': cna_data, 'miRNA': miRNA_data, 'RPPA': rppa_data}
    else:
        data_dict = {'RNA': RNA_data, 'methyl': methyl_data, 'CNA': cna_data, 'RPPA': rppa_data}

    for i, j in data_dict.items():
        print('{} data missing {}/{}, missing proportion = {}'.format(i, j.iloc[:, 0].isna().sum(), j.shape[0], np.round(j.iloc[:, 0].isna().mean(), 3)))

    test = MIND(data_dict=data_dict, device=device, emb_dim=emb_dim).to(device)
    test.my_train(epoch, lr=lr)
    with torch.no_grad():
        z_emb = test.get_embedding()[0].cpu().numpy()

    if not os.path.isdir('./TCGA_results/{}'.format(cancer_type)):
        os.makedirs('./TCGA_results/{}'.format(cancer_type))

    pd.DataFrame(z_emb, index=RNA_data.index).to_csv('./TCGA_results/{}/embeddings.csv'.format(cancer_type))


Cancer type = GBM
total number of patients = 590
RNA data missing 431/590, missing proportion = 0.731
methyl data missing 307/590, missing proportion = 0.52
CNA data missing 19/590, missing proportion = 0.032
RPPA data missing 353/590, missing proportion = 0.598
Epoch=0
Epoch=1000
Epoch=2000
Epoch=3000
Epoch=4000
Cancer type = HNSC
total number of patients = 528
RNA data missing 7/528, missing proportion = 0.013
methyl data missing 0/528, missing proportion = 0.0
CNA data missing 11/528, missing proportion = 0.021
miRNA data missing 4/528, missing proportion = 0.008
RPPA data missing 316/528, missing proportion = 0.598
Epoch=0
Epoch=1000
Epoch=2000
Epoch=3000
Epoch=4000
Cancer type = LUSC
total number of patients = 504
RNA data missing 3/504, missing proportion = 0.006
methyl data missing 132/504, missing proportion = 0.262
CNA data missing 17/504, missing proportion = 0.034
miRNA data missing 26/504, missing proportion = 0.052
RPPA data missing 176/504, missing proportion = 0.349
Epoc

## Downstreaming task 1: Cancer stage classification
We fit XGBoost classfiers to predict cancer stage of patients from the output embeddings.

In [25]:
from sklearn.model_selection import KFold
import xgboost as xgb
from sklearn.preprocessing import LabelEncoder
# test cancer type classification
cancer_type_subtypes = ['LUAD', 'KIRC', 'BRCA', 'LGG', 'OV', 'SKCM', 'THCA', 'BLCA', 'STAD', 'UCEC', 'COADREAD', 'COAD', 'GBM', 'HNSC', 'LUSC', 'LIHC', 'CESC']
res_ACC = pd.DataFrame(np.zeros((len(cancer_type_subtypes), 1)), index=cancer_type_subtypes, columns=['Accuracy'])

for c_type in cancer_type_subtypes:
    print(c_type)
    clinic_data = pd.read_csv('./TCGA_preprocessed/{}/clinic_data.csv'.format(c_type), header=0, index_col=0)
    if 'Stage' in clinic_data.columns:
        non_missing_subtypes_names = clinic_data.index.to_numpy()[~clinic_data['Stage'].isna()]
        stage_vec = clinic_data['Stage'].loc[non_missing_subtypes_names]
        encoder = LabelEncoder()
        labels = encoder.fit_transform(stage_vec)
        # if len(stage_vec) > 0.1*clinic_data.shape[0]:
        #     print(c_type)
        #     print(len(stage_vec)/clinic_data.shape[0])
        emb = pd.read_csv('./TCGA_results/{}/embeddings.csv'.format(c_type), index_col=0).loc[non_missing_subtypes_names]
        if emb.shape[0] > 10:
            kf = KFold(n_splits=10, shuffle=True, random_state=314159)
            acc_temp = 0.
            for train_idx, test_idx in kf.split(emb):
                X_train, X_test = emb.to_numpy()[train_idx], emb.to_numpy()[test_idx]
                y_train, y_test = labels[train_idx], labels[test_idx]
    
                model_gbm = GradientBoostingClassifier(n_estimators=150)
    
                model = xgb.XGBClassifier(
                    objective='multi:softmax',  
                    num_class=len(np.unique(labels)),
                    eval_metric='mlogloss'
                )
                model.fit(X_train, y_train)
                y_pred = model.predict(X_test)
                acc_temp += np.mean(y_pred == y_test)
            res_ACC.loc[c_type] = acc_temp/10.
        else:
            res_ACC.loc[c_type] = np.nan

    else:
        res_ACC.loc[c_type] = np.nan
print(res_ACC)


LUAD
KIRC
BRCA
LGG
OV
SKCM
THCA
BLCA
STAD
UCEC
COADREAD
COAD
GBM
HNSC
LUSC
LIHC
CESC
          Accuracy
LUAD      0.523115
KIRC      0.517855
BRCA      0.508270
LGG            NaN
OV             NaN
SKCM      0.457491
THCA      0.542902
BLCA      0.460976
STAD      0.407898
UCEC           NaN
COADREAD  0.346913
COAD      0.344293
GBM            NaN
HNSC      0.519614
LUSC      0.472000
LIHC      0.481349
CESC      0.473333


## Downstreaming task 2: Survival prediction
Here we fit Coxnet survival models to predict the survival status of patients. In addition to the output embeddings, we also add sex and age of patients to the input features. Performance is measured using C-index.

In [32]:
from sklearn.model_selection import train_test_split
from sksurv.linear_model import CoxnetSurvivalAnalysis, CoxPHSurvivalAnalysis
from sksurv.util import Surv
from sksurv.metrics import concordance_index_censored, concordance_index_ipcw, integrated_brier_score
from sklearn.model_selection import StratifiedKFold

def run_test_survival(K=5):
    cancer_types = ['GBM', 'HNSC', 'LUSC', 'LIHC', 'CESC', 'LUAD', 'KIRC', 'BRCA', 'LGG', 'OV', 'SKCM', 'THCA', 'BLCA', 'STAD', 'UCEC', 'COADREAD', 'COAD']  
    res_C_idx = pd.DataFrame(np.zeros((len(cancer_types), 1)), index=cancer_types, columns=['C-index'])
    for c_type in cancer_types:
        print(c_type)
        c_idx = []

        clinic_data = pd.read_csv('./TCGA_preprocessed/{}/clinic_data.csv'.format(c_type), header=0, index_col=0)
        clinic_data['merged_days'] = clinic_data["days_to_death"].fillna(clinic_data["days_to_last_followup"])
        indicator = (~clinic_data[['years_to_birth', 'gender']].isna().any(axis=1)) & (clinic_data['merged_days'] >= 0)
        samples = clinic_data.index[indicator]
        y = Surv.from_dataframe("vital_status", "merged_days", clinic_data.loc[samples])
        

        emb = pd.read_csv('./TCGA_results/{}/embeddings.csv'.format(c_type), index_col=0).loc[samples]
        features = clinic_data.loc[samples][['years_to_birth', 'gender']].join(emb)

        features = pd.get_dummies(features, drop_first=True).to_numpy() * 1.0
        kf = KFold(n_splits=K, shuffle=True)
        
        c_temp = 0.

        for train_idx, test_idx in kf.split(features):
            X_train, X_test = features[train_idx], features[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]
            model = CoxnetSurvivalAnalysis(l1_ratio=0.1, fit_baseline_model=True)

            model.fit(X_train, y_train)
            risk_scores = model.predict(X_test)
            c_temp += concordance_index_censored(y_test["vital_status"], y_test["merged_days"], risk_scores)[0]

        res_C_idx.loc[c_type] = c_temp/K
    return res_C_idx
    
res = run_test_survival(K=5)
print(res)

GBM
HNSC
LUSC
LIHC
CESC
LUAD
KIRC
BRCA
LGG
OV
SKCM
THCA
BLCA
STAD
UCEC
COADREAD
COAD
           C-index
GBM       0.640125
HNSC      0.621491
LUSC      0.536826
LIHC      0.615794
CESC      0.676878
LUAD      0.595405
KIRC      0.715703
BRCA      0.673233
LGG       0.841450
OV        0.609655
SKCM      0.650040
THCA      0.892058
BLCA      0.628860
STAD      0.563993
UCEC      0.629874
COADREAD  0.573442
COAD      0.518750


## Downstreaming task 3: Clustering
Here we partition the patients into a range of clusters (3-10) based on the output embeddings. For each cancer type and each number of clusters, we compute $-log_2(p)$, where $p$ is the $p$-value of a pairwise logrank test, using the survival status and estimated cluster membership of patients. Lower $p$ indicates stronger evidence of inter-cluster differences.

In [34]:
from lifelines.statistics import multivariate_logrank_test, pairwise_logrank_test
from lifelines import KaplanMeierFitter

cancer_types = ['GBM', 'HNSC', 'LUSC', 'LIHC', 'CESC', 'LUAD', 'KIRC', 'BRCA', 'LGG', 'OV', 'SKCM', 'THCA', 'BLCA', 'STAD', 'UCEC', 'COADREAD', 'COAD']  # do COAD seperately
my_range = range(3, 11)
res_logp_pair = pd.DataFrame(np.zeros((len(cancer_types), len(my_range))), index=cancer_types, columns=my_range)

for c_type in cancer_types:
    print(c_type)
    clinic_data = pd.read_csv('./TCGA_preprocessed/{}/clinic_data.csv'.format(c_type), header=0, index_col=0)
    clinic_data['merged_days'] = clinic_data["days_to_death"].fillna(clinic_data["days_to_last_followup"])
    indicator = (~clinic_data[['years_to_birth', 'gender']].isna().any(axis=1)) & (clinic_data['merged_days'] >= 0)
    samples = clinic_data.index[indicator]
    y = clinic_data.loc[samples][["vital_status", "merged_days"]]

    emb = pd.read_csv('./TCGA_results/{}/embeddings.csv'.format(c_type), index_col=0).loc[samples]
    for idx2, jjj in enumerate(my_range):
        kmeans = KMeans(n_clusters=jjj, n_init=20, random_state=314159)
        labels_pred = kmeans.fit_predict(emb)
        res_logp_pair.loc[c_type, jjj] = -1.*np.log2(pairwise_logrank_test(y['merged_days'], labels_pred, y['vital_status']).p_value.mean())
print(res_logp_pair)

GBM
HNSC
LUSC
LIHC
CESC
LUAD
KIRC
BRCA
LGG
OV
SKCM
THCA
BLCA
STAD
UCEC
COADREAD
COAD
                3         4         5         6         7         8   \
GBM       5.468568  2.415752  2.453869  1.652026  1.954711  1.910928   
HNSC      3.159770  2.603318  2.614053  2.543306  2.373723  2.092139   
LUSC      2.954474  1.939171  2.179644  1.680865  1.494637  1.313886   
LIHC      1.232846  1.041398  1.546891  2.419295  1.588979  2.091966   
CESC      2.525748  1.757811  1.523754  1.625143  1.656757  1.417515   
LUAD      2.637015  1.856228  1.625765  1.685225  1.471031  1.728038   
KIRC      2.312045  1.978773  2.154344  2.702073  2.368248  1.861773   
BRCA      1.012320  4.246113  1.971604  2.177411  1.798338  2.024668   
LGG       5.040676  2.538890  2.628751  2.620503  2.342840  1.912463   
OV        0.725235  0.937297  1.229061  1.460342  1.550677  0.716696   
SKCM      5.950288  2.306732  2.324576  2.422261  2.151044  2.310080   
THCA      2.011360  1.252562  0.950964  2.173230  1

## Downstreaming task 4: Reconstruction
We investiate how well does the proposed model reconstruct/predict the missing part of the training data. We frist re-train the model using masked versions of the multiomics datasets. The datasets are masked as follows: For each modality of each cancer type, we first randomly mask $10\%$ of its data subject to the constraint that, for every cancer type, each patient must be present in at least one modality of the resulting masked multiomics dataset. Once the models have been trained on the masked datasets, we predict the masked data using the learned embeddings, and compare the predictions with the observed values.

In [35]:
np.random.seed(31415)
torch.manual_seed(31415)

cancer_types = ['GBM', 'HNSC', 'LUSC', 'LIHC', 'LUAD', 'KIRC', 'BRCA', 'LGG', 'OV', 'PRAD', 'SKCM', 'THCA', 'BLCA', 'STAD', 'UCEC', 'COADREAD', 'COAD', 'CESC']  
for cancer_type in cancer_types:
    clinic_data = pd.read_csv('./TCGA_preprocessed/{}/clinic_data.csv'.format(cancer_type), header=0, index_col=0)
    RNA_data = pd.read_csv('./TCGA_preprocessed/{}/RNA_data_train.csv'.format(cancer_type), header=0, index_col=0)
    methyl_data = pd.read_csv('./TCGA_preprocessed/{}/methyl_data_train.csv'.format(cancer_type), header=0, index_col=0)
    rppa_data = pd.read_csv('./TCGA_preprocessed/{}/RPPA_data_train.csv'.format(cancer_type), header=0, index_col=0)
    cna_data = pd.read_csv('./TCGA_preprocessed/{}/CNA_data_train.csv'.format(cancer_type), header=0, index_col=0)
    if os.path.isfile('./TCGA_preprocessed/{}/miRNA_data_train.csv'.format(cancer_type)):
        miRNA_data = pd.read_csv('./TCGA_preprocessed/{}/miRNA_data_train.csv'.format(cancer_type), header=0, index_col=0)

    N = clinic_data.shape[0]
    print('Cancer type = {}'.format(cancer_type))
    print('total number of patients = {}'.format(N))
    if os.path.isfile('./TCGA_preprocessed/{}/miRNA_data_train.csv'.format(cancer_type)):
        data_dict = {'RNA': RNA_data, 'methyl': methyl_data, 'CNA': cna_data, 'miRNA': miRNA_data, 'RPPA': rppa_data}
    else:
        data_dict = {'RNA': RNA_data, 'methyl': methyl_data, 'CNA': cna_data, 'RPPA': rppa_data}

    for i, j in data_dict.items():
        print('{} data missing {}/{}, missing proportion = {}'.format(i, j.iloc[:, 0].isna().sum(), j.shape[0], np.round(j.iloc[:, 0].isna().mean(), 3)))

    presence_list = [torch.tensor(~_.isna().to_numpy().all(1)).to(device) for _ in list(data_dict.values())]
    data_list = [torch.tensor(_.to_numpy(), dtype=torch.float32).to(device) for _ in list(data_dict.values())]

    test = MIND(data_dict=data_dict, device=device, emb_dim=emb_dim).to(device)
    test.my_train(epoch, lr=lr)
    with torch.no_grad():
        z_emb = test.get_embedding()[0].cpu().numpy()

    # reconstructed version of the full dataset
    pred = test.predict()  
    # select the masked portion of data
    for i, nammme in enumerate(data_dict.keys()):
        test_obs = pd.read_csv('./TCGA_preprocessed/{}/{}_data_test.csv'.format(cancer_type, nammme), header=0, index_col=0)
        test_predicted = pd.DataFrame(pred[i].numpy(), index=RNA_data.index).loc[test_obs.index]
        test_predicted.to_csv('./TCGA_results/{}/{}_data_test_pred.csv'.format(cancer_type, nammme))


Cancer type = GBM
total number of patients = 590
RNA data missing 445/590, missing proportion = 0.754
methyl data missing 331/590, missing proportion = 0.561
CNA data missing 59/590, missing proportion = 0.1
RPPA data missing 374/590, missing proportion = 0.634
Epoch=0
Epoch=1000
Epoch=2000
Epoch=3000
Epoch=4000
Cancer type = HNSC
total number of patients = 528
RNA data missing 56/528, missing proportion = 0.106
methyl data missing 51/528, missing proportion = 0.097
CNA data missing 61/528, missing proportion = 0.116
miRNA data missing 53/528, missing proportion = 0.1
RPPA data missing 337/528, missing proportion = 0.638
Epoch=0
Epoch=1000
Epoch=2000
Epoch=3000
Epoch=4000
Cancer type = LUSC
total number of patients = 504
RNA data missing 47/504, missing proportion = 0.093
methyl data missing 169/504, missing proportion = 0.335
CNA data missing 64/504, missing proportion = 0.127
miRNA data missing 71/504, missing proportion = 0.141
RPPA data missing 205/504, missing proportion = 0.407
E

In [7]:
cancer_types = ['GBM', 'HNSC', 'LUSC', 'LIHC', 'CESC', 'LUAD', 'KIRC', 'BRCA', 'LGG', 'OV', 'SKCM', 'THCA', 'BLCA', 'STAD', 'UCEC', 'COADREAD', 'COAD']  # do COAD seperately
mods = ['RNA', 'methyl', 'CNA', 'miRNA', 'RPPA']
res_corr = pd.DataFrame(np.zeros((len(cancer_types), 5)) * np.nan, index=cancer_types, columns=mods)
for cancer_type in cancer_types:
    for mod in mods:
        if os.path.isfile('./TCGA_preprocessed/{}/{}_data_test.csv'.format(cancer_type, mod)):
            pred = pd.read_csv('./TCGA_results/{}/{}_data_test_pred.csv'.format(cancer_type, mod), header=0, index_col=0).to_numpy().ravel()
            obs = pd.read_csv('./TCGA_preprocessed/{}/{}_data_test.csv'.format(cancer_type, mod), header=0, index_col=0).to_numpy().ravel()
            res_corr.loc[cancer_type, mod] = np.corrcoef(pred, obs)[0, 1]
            
print(res_corr)
            

               RNA    methyl       CNA     miRNA      RPPA
GBM       0.283667  0.159389  0.304017       NaN  0.066223
HNSC      0.340869  0.362852  0.221831  0.286871  0.077927
LUSC      0.391946  0.347897  0.148169  0.249944  0.187262
LIHC      0.444042  0.617602  0.350798  0.329479  0.345397
CESC      0.301724  0.316541  0.240796  0.346496  0.144993
LUAD      0.298295  0.356028  0.197201  0.290651  0.232931
KIRC      0.498263  0.422525  0.379365  0.402921  0.261167
BRCA      0.432397  0.420671  0.367710  0.421113  0.250354
LGG       0.555975  0.558664  0.575640  0.411837  0.393617
OV        0.210555  0.234810  0.310151  0.176573  0.093143
SKCM      0.298319  0.389299  0.273702  0.228896  0.223381
THCA      0.542357  0.254725  0.188864  0.322165  0.237826
BLCA      0.410563  0.399036  0.182782  0.445897  0.261960
STAD      0.284350  0.497969  0.279424  0.412589  0.286063
UCEC      0.329905  0.413536  0.330754  0.394110  0.160096
COADREAD  0.262425  0.295340  0.435800  0.287395  0.2582

In [8]:
print(res_corr.mean(1))

GBM         0.203324
HNSC        0.258070
LUSC        0.265044
LIHC        0.417464
CESC        0.270110
LUAD        0.275021
KIRC        0.392848
BRCA        0.378449
LGG         0.499147
OV          0.205046
SKCM        0.282719
THCA        0.309187
BLCA        0.340047
STAD        0.352079
UCEC        0.325680
COADREAD    0.307844
COAD        0.264355
dtype: float64
