In [1]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import os 
import pandas as pd
import numpy as np

import torch
import pickle

from torch_geometric.data import Dataset, download_url, DataLoader
from torch_geometric.nn import SAGEConv, global_mean_pool, SAGPooling, GATConv, JumpingKnowledge, ASAPooling, GlobalAttention
from torch.optim import Adam
import copy

import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts

from sksurv.metrics import concordance_index_censored

import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig

#### Use/Check GPU availability

In [2]:
!nvidia-smi

Fri Jan 17 12:05:09 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-32GB           Off |   00000000:18:00.0 Off |                    0 |
| N/A   37C    P0             41W /  300W |       4MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB          

In [3]:
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

# Load Clinical Data

## TCGA

In [4]:
WSI_TCGA_info = pd.read_csv("/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/TCGA_model/WSI/wsi_metadata_with_patches_and_embeddings_and_graphs.csv",index_col=0)
WSI_TCGA_info

Unnamed: 0,bcr_patient_barcode,age_at_diagnosis,cigarettes_per_day,primary_diagnosis,tissue_or_organ_of_origin,ajcc_pathologic_stage,race,gender,prior_malignancy,vital_status,ajcc_pathologic_t,ajcc_pathologic_n,ajcc_pathologic_m,survival_time,patches_npy,patches_pkl,embeddings,graphs
0,TCGA-2F-A9KO,63.898630,3.780822,Transitional cell carcinoma,Posterior wall of bladder,Stage IV,white,male,no,1,T3,N1,M0,734.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
1,TCGA-2F-A9KP,66.926027,3.397260,Transitional cell carcinoma,Lateral wall of bladder,Stage IV,white,male,no,1,T3a,N2,MX,364.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
2,TCGA-2F-A9KP,66.926027,3.397260,Transitional cell carcinoma,Lateral wall of bladder,Stage IV,white,male,no,1,T3a,N2,MX,364.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
3,TCGA-2F-A9KQ,69.202740,0.000000,Transitional cell carcinoma,"Bladder, NOS",Stage III,white,male,no,0,T3a,N0,M0,2886.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
4,TCGA-2F-A9KR,59.857534,1.232877,Papillary transitional cell carcinoma,"Bladder, NOS",Stage III,not reported,female,no,1,T3a,N0,M0,3183.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
451,TCGA-ZF-AA54,71.583562,0.000000,Transitional cell carcinoma,Lateral wall of bladder,Stage III,white,male,no,1,T3,NX,MX,590.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
452,TCGA-ZF-AA58,61.778082,2.136986,Transitional cell carcinoma,"Bladder, NOS",Stage IV,white,female,no,0,T3a,N2,MX,1649.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
453,TCGA-ZF-AA5H,60.608219,0.438356,Transitional cell carcinoma,"Bladder, NOS",Stage IV,white,female,no,0,T3b,N2,M0,897.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...
454,TCGA-ZF-AA5N,62.304110,0.547945,Transitional cell carcinoma,"Bladder, NOS",Stage IV,white,female,no,1,T2,NX,M1,168.0,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...


#### Get Demogrphics and Tumor stage of TCGA, then, Convert categorical variables into dummy variables

In [5]:
TCGA_Pheno = pd.read_csv("/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/TCGA_model/BLCA_TCGA_meta_data.csv").sort_values(by='bcr_patient_barcode', ascending=True).reset_index(drop=True)
TCGA_Pheno = TCGA_Pheno[['bcr_patient_barcode', 'age_at_diagnosis', 'gender', 'ajcc_pathologic_stage']].copy()
TCGA_Pheno.rename({'age_at_diagnosis': 'Age',
                   'gender': 'Gender',
                   'ajcc_pathologic_stage': 'Tumor_Stage'}, axis=1, inplace=True)
TCGA_Pheno.replace({'Tumor_Stage': {"0": 'cis_Stage_0',
                                    'Stage I':'Stage_I', 'Stage II':'Stage_II',
                                    'Stage III':'Stage_III', 'Stage IV':'Stage_IV'},
                    'Gender':{'female':0, 'male':1}}, inplace = True)

TCGA_Pheno = TCGA_Pheno.drop_duplicates()

stand = StandardScaler()
TCGA_Pheno['Age_stand'] = stand.fit_transform(TCGA_Pheno['Age'].values.reshape(-1,1))

TCGA_Pheno_Dummy = pd.get_dummies(TCGA_Pheno, columns=['Tumor_Stage'],
                                  drop_first=False, prefix=['Stage'])

TCGA_Pheno_Dummy

Unnamed: 0,bcr_patient_barcode,Age,Gender,Age_stand,Stage_Stage_I,Stage_Stage_II,Stage_Stage_III,Stage_Stage_IV,Stage_cis_Stage_0
0,TCGA-2F-A9KO,63.898630,1,-0.408341,0,0,0,1,0
1,TCGA-2F-A9KP,66.926027,1,-0.135424,0,0,0,1,0
2,TCGA-2F-A9KQ,69.202740,1,0.069820,0,0,1,0,0
3,TCGA-2F-A9KR,59.857534,0,-0.772643,0,0,1,0,0
4,TCGA-2F-A9KT,83.616438,1,1.369205,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...
407,TCGA-ZF-AA56,79.279452,0,0.978229,0,0,1,0,0
408,TCGA-ZF-AA58,61.778082,0,-0.599507,0,0,0,1,0
409,TCGA-ZF-AA5H,60.608219,0,-0.704970,0,0,0,1,0
410,TCGA-ZF-AA5N,62.304110,0,-0.552086,0,0,0,1,0


#### Merge WSI and Phenotype

In [6]:
WSI_TCGA_info = WSI_TCGA_info.merge(TCGA_Pheno_Dummy, left_on="bcr_patient_barcode", right_on="bcr_patient_barcode", how="left")
WSI_TCGA_info

Unnamed: 0,bcr_patient_barcode,age_at_diagnosis,cigarettes_per_day,primary_diagnosis,tissue_or_organ_of_origin,ajcc_pathologic_stage,race,gender,prior_malignancy,vital_status,...,embeddings,graphs,Age,Gender,Age_stand,Stage_Stage_I,Stage_Stage_II,Stage_Stage_III,Stage_Stage_IV,Stage_cis_Stage_0
0,TCGA-2F-A9KO,63.898630,3.780822,Transitional cell carcinoma,Posterior wall of bladder,Stage IV,white,male,no,1,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,63.898630,1,-0.408341,0,0,0,1,0
1,TCGA-2F-A9KP,66.926027,3.397260,Transitional cell carcinoma,Lateral wall of bladder,Stage IV,white,male,no,1,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,66.926027,1,-0.135424,0,0,0,1,0
2,TCGA-2F-A9KP,66.926027,3.397260,Transitional cell carcinoma,Lateral wall of bladder,Stage IV,white,male,no,1,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,66.926027,1,-0.135424,0,0,0,1,0
3,TCGA-2F-A9KQ,69.202740,0.000000,Transitional cell carcinoma,"Bladder, NOS",Stage III,white,male,no,0,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,69.202740,1,0.069820,0,0,1,0,0
4,TCGA-2F-A9KR,59.857534,1.232877,Papillary transitional cell carcinoma,"Bladder, NOS",Stage III,not reported,female,no,1,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,59.857534,0,-0.772643,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
451,TCGA-ZF-AA54,71.583562,0.000000,Transitional cell carcinoma,Lateral wall of bladder,Stage III,white,male,no,1,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,71.583562,1,0.284450,0,0,1,0,0
452,TCGA-ZF-AA58,61.778082,2.136986,Transitional cell carcinoma,"Bladder, NOS",Stage IV,white,female,no,0,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,61.778082,0,-0.599507,0,0,0,1,0
453,TCGA-ZF-AA5H,60.608219,0.438356,Transitional cell carcinoma,"Bladder, NOS",Stage IV,white,female,no,0,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,60.608219,0,-0.704970,0,0,0,1,0
454,TCGA-ZF-AA5N,62.304110,0.547945,Transitional cell carcinoma,"Bladder, NOS",Stage IV,white,female,no,1,...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Inte...,62.304110,0,-0.552086,0,0,0,1,0


#### Split Dataset into training and testing set

In [7]:
y = WSI_TCGA_info['gender'].copy()
WSI_TCGA_info_train, WSI_TCGA_info_val, y_train, y_val = train_test_split(WSI_TCGA_info, y, 
                                                                          random_state=2,
                                                                          test_size=0.2)
# training: testing = 4:1 (N = 364 vs N = 92)

## DH

In [8]:
clinical_DH_df = pd.read_csv("/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/Bladder_DNAm_HiTIMED/Total_88_subjects_WSI_Pheno_HiTIMED.csv")
clinical_DH_df = clinical_DH_df.dropna(subset=['file_name'])
clinical_DH_df = clinical_DH_df[clinical_DH_df['file_name'] != '205943_.svs'] # "205943_.pkl" < 810 byte
clinical_DH_df = clinical_DH_df.drop('Age', axis=1)
clinical_DH_df

Unnamed: 0,FFPE.DNA.ID,Blood.Sample.ID,ChIP_ID_Blood,Batch,Sex,Grade,Grade2,Muscle_Invasive,BCG,ImToBlood,...,CD4mem,Treg,CD8nv,CD8mem,Mono,DC,NK,Bas,Eos,Neu
0,BLD050,A00000FBK,203723190040_R08C01,Batch_3,male,grade 3,Grade 3 + 4,no,Without Immuno,,...,0.000000,5.580140,0.0,3.388986,0.000000,0.000000,2.652840,0.000000,0.000000,0.0
1,BLD043,A00000FC3,203752100070_R04C01,Batch_2,female,grade 3,Grade 3 + 4,no,Without Immuno,,...,0.185204,1.613221,0.0,3.676642,1.383615,5.232480,1.701208,0.000000,0.000000,0.0
2,BLD043,A00000FC3,203752100070_R04C01,Batch_2,female,grade 3,Grade 3 + 4,no,Without Immuno,,...,0.185204,1.613221,0.0,3.676642,1.383615,5.232480,1.701208,0.000000,0.000000,0.0
3,BLD061,A00000EH7,203789410029_R03C01,Batch_3,female,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.997968,1.589810,0.0,3.141940,0.000000,4.534836,2.254315,0.073649,0.205418,0.0
4,BLD061,A00000EH7,203789410029_R03C01,Batch_3,female,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.997968,1.589810,0.0,3.141940,0.000000,4.534836,2.254315,0.073649,0.205418,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
136,BLD048,BDB1P03023,202163530099_R07C01,Batch_1,male,grade 2,Grade 1 + 2,no,Without Immuno,,...,2.428655,1.116786,0.0,2.827982,0.000000,0.000000,1.453588,0.618560,0.974700,0.0
137,BLD075,BDB1P04022,202172220150_R06C01,Batch_1,male,grade 1,Grade 1 + 2,no,Without Immuno,,...,1.010792,2.579993,0.0,0.000000,0.000000,0.000000,0.000000,2.200406,0.000000,0.0
138,BLD024,BDB1P02094,202163530080_R06C01,Batch_1,male,grade 1,Grade 1 + 2,no,Without Immuno,,...,1.592609,2.059939,0.0,1.649354,3.207488,0.000000,0.704491,0.724507,0.487317,0.0
139,BLD094,BDB1P03058,202163550181_R02C01,Batch_1,male,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.620439,0.556929,0.0,1.327215,0.388502,0.000000,0.000000,0.287847,0.132852,0.0


#### Load DH phenotype

In [9]:
DH_pheno = pd.read_csv("/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/Bladder_DNAm_HiTIMED/Total_88_subjects_WSI_Pheno_HiTIMED.csv")
DH_pheno = DH_pheno.loc[:,['FFPE.DNA.ID', 'Age', 'Sex', 'stage2']]
DH_pheno = DH_pheno.drop_duplicates()
DH_pheno.rename({'Sex': 'Gender',
                 'stage2': 'Tumor_Stage'}, axis=1, inplace=True)
DH_pheno.replace({'Tumor_Stage': {"cis + stage 0a": 'cis_Stage_0',
                                  'stage I':'Stage_I', 'stage II':'Stage_II',
                                  'stage III':'Stage_III', 'stage IV':'Stage_IV'},
                  'Gender':{'female':0, 'male':1}}, inplace = True)

DH_pheno = DH_pheno[['FFPE.DNA.ID', 'Age', 'Gender', 'Tumor_Stage']].copy()
DH_pheno['Age_stand'] = stand.transform(DH_pheno['Age'].values.reshape(-1,1))
DH_pheno = pd.get_dummies(DH_pheno, columns=['Tumor_Stage'],
                          drop_first=False, prefix=['Stage'])
DH_pheno

Unnamed: 0,FFPE.DNA.ID,Age,Gender,Age_stand,Stage_Stage_I,Stage_Stage_II,Stage_Stage_III,Stage_Stage_IV,Stage_cis_Stage_0
0,BLD050,77,1,0.772738,0,0,0,0,1
1,BLD043,77,0,0.772738,1,0,0,0,0
3,BLD061,76,0,0.682589,0,0,0,0,1
5,BLD046,69,1,0.051544,0,0,0,1,0
6,BLD080,66,1,-0.218904,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...
136,BLD048,58,1,-0.940099,0,0,0,0,1
137,BLD075,51,1,-1.571144,0,0,0,0,1
138,BLD024,68,1,-0.038606,0,0,0,0,1
139,BLD094,53,1,-1.390845,0,0,0,0,1


#### Merge WSI and Phenotype

In [10]:
clinical_DH_df = clinical_DH_df.merge(DH_pheno, left_on="FFPE.DNA.ID", right_on="FFPE.DNA.ID", how="left")
clinical_DH_df

Unnamed: 0,FFPE.DNA.ID,Blood.Sample.ID,ChIP_ID_Blood,Batch,Sex,Grade,Grade2,Muscle_Invasive,BCG,ImToBlood,...,Eos,Neu,Age,Gender,Age_stand,Stage_Stage_I,Stage_Stage_II,Stage_Stage_III,Stage_Stage_IV,Stage_cis_Stage_0
0,BLD050,A00000FBK,203723190040_R08C01,Batch_3,male,grade 3,Grade 3 + 4,no,Without Immuno,,...,0.000000,0.0,77,1,0.772738,0,0,0,0,1
1,BLD043,A00000FC3,203752100070_R04C01,Batch_2,female,grade 3,Grade 3 + 4,no,Without Immuno,,...,0.000000,0.0,77,0,0.772738,1,0,0,0,0
2,BLD043,A00000FC3,203752100070_R04C01,Batch_2,female,grade 3,Grade 3 + 4,no,Without Immuno,,...,0.000000,0.0,77,0,0.772738,1,0,0,0,0
3,BLD061,A00000EH7,203789410029_R03C01,Batch_3,female,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.205418,0.0,76,0,0.682589,0,0,0,0,1
4,BLD061,A00000EH7,203789410029_R03C01,Batch_3,female,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.205418,0.0,76,0,0.682589,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
131,BLD048,BDB1P03023,202163530099_R07C01,Batch_1,male,grade 2,Grade 1 + 2,no,Without Immuno,,...,0.974700,0.0,58,1,-0.940099,0,0,0,0,1
132,BLD075,BDB1P04022,202172220150_R06C01,Batch_1,male,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.000000,0.0,51,1,-1.571144,0,0,0,0,1
133,BLD024,BDB1P02094,202163530080_R06C01,Batch_1,male,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.487317,0.0,68,1,-0.038606,0,0,0,0,1
134,BLD094,BDB1P03058,202163550181_R02C01,Batch_1,male,grade 1,Grade 1 + 2,no,Without Immuno,,...,0.132852,0.0,53,1,-1.390845,0,0,0,0,1


### Generate Dataset Objects for Each Unimodality

- For TCGA

In [11]:
class WSISurvivalDataset_TCGA(Dataset):
    def __init__(self, samples):
        super().__init__()
        self.samples = samples.copy()['graphs'].tolist()
        self.demographic = samples.loc[:, 'Gender':'Stage_cis_Stage_0']
        
    def len(self):
        return len(self.samples)

    def get(self, idx):
        sample = self.samples[idx]
        #print(sample)
        graph_data_object = pickle.load(open(sample,'rb'))
        demographics = torch.tensor(self.demographic.iloc[idx])
        return graph_data_object, demographics

In [12]:
train_dataset = WSISurvivalDataset_TCGA(WSI_TCGA_info_train)
val_dataset = WSISurvivalDataset_TCGA(WSI_TCGA_info_val)

whole_dataset = WSISurvivalDataset_TCGA(WSI_TCGA_info)

In [17]:
train_dataset.get(0)

(Data(x=[25143, 2048], edge_index=[2, 226287], y=[25143, 2], pos=[25143, 2], censor=0, duration=2423.0),
 tensor([0.0000, 0.4304, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
        dtype=torch.float64))

In [13]:
whole_dataset.get(0)

(Data(x=[29163, 2048], edge_index=[2, 262467], y=[29163, 2], pos=[29163, 2], censor=1, duration=734.0),
 tensor([ 1.0000, -0.4083,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000],
        dtype=torch.float64))

- For DH

In [13]:
class WSISurvivalDataset_DH(Dataset):
    def __init__(self, graph_dataset_dict_path, dataframe):
        super().__init__()
        self.datasets_path = graph_dataset_dict_path # path to subtype pkl, they contain the Data objects
        self.list_svs = dataframe.copy()['file_name'].tolist()
        self.samples = [a.replace('.svs','Graph.pkl') for a in self.list_svs] # list of samples, ex; ['214671_Graph.pkl','214810_Graph.pkl']
        self.demographic = dataframe.loc[:, 'Gender':'Stage_cis_Stage_0']
        
    def len(self):
        return len(self.samples)

    def get(self, idx):
        sample = self.samples[idx]
        #print(sample)
        graph_data_object = pickle.load(open(self.datasets_path+'/'+sample,'rb'))
        demographic = torch.tensor(self.demographic.iloc[idx])
        return graph_data_object, demographic

In [14]:
graph_dataset_path = '/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/Bladder_Graphs'
test_dataset_wsi = WSISurvivalDataset_DH(graph_dataset_path, clinical_DH_df)

In [20]:
test_dataset_wsi.get(0)

(Data(x=[7603, 2048], edge_index=[2, 68427], y=[7603, 2], pos=[7603, 2], censor=1, duration=3816.34497),
 tensor([1.0000, 0.7727, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
        dtype=torch.float64))

In [15]:
train_loader = DataLoader(train_dataset, batch_size = 5, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size = 5)
test_loader = DataLoader(test_dataset_wsi, batch_size = 5, shuffle=False)

whole_TCGA_loader = DataLoader(whole_dataset, batch_size = 5, shuffle=False)



# Model Structure

#### Define Interaction Function

In [16]:
def self_outer(x):
    return torch.einsum('bi,bj->bij',x,x).reshape(x.shape[0],-1)

def outer(x,y):
    return torch.einsum('bi,bj->bij',x,y).reshape(x.shape[0],-1)

#### Create model for training

In [17]:
class NodeNorm(nn.Module): 
    # https://github.com/miafei/NodeNorm/blob/master/layers.py
    def __init__(self, unbiased=False, eps=1e-5):
        super(NodeNorm, self).__init__()
        self.unbiased = unbiased
        self.eps = eps

    def forward(self, x):
        mean = torch.mean(x, dim=1, keepdim=True)
        std = (torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps).sqrt()
        x = (x - mean) / std
        return x

class Network3(nn.Module):
    def __init__(self):
        super(Network3,self).__init__()
        self.conv1 = SAGEConv(2048,256) # try lower, ex; 64 - then 0.7 pooling, then both
        self.conv2 = SAGEConv(256,256)
        self.conv3 = SAGEConv(256,256)
        
        self.norm = NodeNorm()
        
        self.pool = SAGPooling(256,0.3)
        self.pool2 = SAGPooling(256,0.3)
        self.pool3 = SAGPooling(256,0.3)
        
        self.fc = nn.Linear(1792,256) # 1792 = 256 x 7 (from demographic) -> for outer product
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256,32)
        self.out = nn.Linear(32,1)
        self.dropout = nn.Dropout(p=0.4)
        self.jk = JumpingKnowledge(mode="max")# max pooling jumping knowledge
        
    def forward(self, x, edge_index, batch, demogr):
        
        x = self.conv1(x,edge_index)
        x = F.relu(x)
        x = self.norm(x)
        x,edge_index,edge_attr,batch,_, _ = self.pool(x,edge_index,batch=batch)
        x1 = global_mean_pool(x,batch)
        
        x = self.conv2(x,edge_index)
        x = F.relu(x)
        x = self.norm(x)
        x,edge_index,edge_attr,batch,_,_ = self.pool2(x,edge_index,batch=batch)
        x2 = global_mean_pool(x,batch)

        x = self.conv3(x,edge_index)
        x = F.relu(x)
        x = self.norm(x)
        x,edge_index,edge_attr,batch,_, _ = self.pool3(x,edge_index,batch=batch)
        x3 = global_mean_pool(x,batch)
        
        x = self.jk([x1, x2, x3])
        x = outer(x, demogr)
        
        z = F.relu(self.fc(x))
        #z = self.dropout(z)
        
        z = F.relu(self.fc2(z))
        #z = self.dropout(z)
        z = F.relu(self.fc3(z))
        hazard = self.out(z)
        return hazard

#### Create NN for extract GNN Embedding

In [18]:
class GNN_Embedding(nn.Module):
    def __init__(self):
        super(GNN_Embedding, self).__init__()
        self.conv1 = SAGEConv(2048,256) # try lower, ex; 64 - then 0.7 pooling, then both
        self.conv2 = SAGEConv(256,256)
        self.conv3 = SAGEConv(256,256)
        
        self.norm = NodeNorm()
        
        self.pool = SAGPooling(256,0.3)
        self.pool2 = SAGPooling(256,0.3)
        self.pool3 = SAGPooling(256,0.3)
        self.jk = JumpingKnowledge(mode="max")# max pooling jumping knowledge
        
    def forward(self, x, edge_index, batch, demogr):
        
        x = self.conv1(x,edge_index)
        x = F.relu(x)
        x = self.norm(x)
        x,edge_index,edge_attr,batch,_, _ = self.pool(x,edge_index,batch=batch)
        x1 = global_mean_pool(x,batch)
        
        x = self.conv2(x,edge_index)
        x = F.relu(x)
        x = self.norm(x)
        x,edge_index,edge_attr,batch,_,_ = self.pool2(x,edge_index,batch=batch)
        x2 = global_mean_pool(x,batch)

        x = self.conv3(x,edge_index)
        x = F.relu(x)
        x = self.norm(x)
        x,edge_index,edge_attr,batch,_, _ = self.pool3(x,edge_index,batch=batch)
        x3 = global_mean_pool(x,batch)
        
        y = self.jk([x1, x2, x3])
        return y

#### Modified NN for Using GNN Embedding as Input

In [19]:
class Network_Embedding(nn.Module):
    def __init__(self):
        super(Network_Embedding,self).__init__()
        
        self.fc = nn.Linear(1792,256) # 1792 = 256 x 7 (from demographic) -> for outer product
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256,32)
        self.out = nn.Linear(32,1)
        
    def forward(self, x, demogr):
        
        x = outer(x, demogr)
        
        z = F.relu(self.fc(x))
        
        z = F.relu(self.fc2(z))
        z = F.relu(self.fc3(z))
        hazard = self.out(z)
        return hazard

#### Load Pretrained Model

In [20]:
graph_state_dict = torch.load('/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/TCGA_model/WSI/Model_Saving/No_Dropout_with_AgeStandSexStage_WSIPhenoInteraction/model_epoch_39.pth', map_location='cuda:0')

#### Load Pretrained Model to Encoders

In [64]:
Original_GNN = Network3().to(device)
Original_GNN.load_state_dict(graph_state_dict, strict=True)

<All keys matched successfully>

In [21]:
GNN_Embedding_encoder = GNN_Embedding().to(device) 
# please do not use .double() after .to(device) or model() 
# it will cause RuntimeError: mat1 and mat2 must have the same dtype
# Detail: https://discuss.pytorch.org/t/runtimeerror-mat1-and-mat2-must-have-the-same-dtype/166759/7

GNN_Embedding_encoder.load_state_dict(graph_state_dict, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias', 'out.weight', 'out.bias'])

In [22]:
Embedding_NN = Network_Embedding().to(device)
Embedding_NN.load_state_dict(graph_state_dict, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['conv1.lin_l.weight', 'conv1.lin_l.bias', 'conv1.lin_r.weight', 'conv2.lin_l.weight', 'conv2.lin_l.bias', 'conv2.lin_r.weight', 'conv3.lin_l.weight', 'conv3.lin_l.bias', 'conv3.lin_r.weight', 'pool.gnn.lin_rel.weight', 'pool.gnn.lin_rel.bias', 'pool.gnn.lin_root.weight', 'pool2.gnn.lin_rel.weight', 'pool2.gnn.lin_rel.bias', 'pool2.gnn.lin_root.weight', 'pool3.gnn.lin_rel.weight', 'pool3.gnn.lin_rel.bias', 'pool3.gnn.lin_root.weight'])

- Freeze the encoders

In [65]:
encoders = [Original_GNN, GNN_Embedding_encoder, Embedding_NN]
for m in encoders:
    for param in m.parameters():
        param.requires_grad = False

# Generate GNN Embedding

## Training Set (TCGA)

In [66]:
GNN_Grph_Embedings_Train = []

In [71]:
for batch_idx, data in enumerate(whole_TCGA_loader):
    
    censors = torch.Tensor(data[0].censor)
    times = data[0].duration
    edge_index = data[0].edge_index.to(device)
    batch = data[0].batch.to(device)
    data_x = data[0].x.float().to(device)
    demographics = data[1].float().to(device)
    
    if torch.any(times.isnan()):
        times = torch.nan_to_num(times)
    if torch.any(censors.isnan()):
        censors = torch.nan_to_num(censors)
                
    graph_embeddings = GNN_Embedding_encoder(data_x, edge_index, batch, demographics)
    for i in graph_embeddings.cpu().numpy():
        GNN_Grph_Embedings_Train.append(i.tolist())
            
    data_x = data_x.cpu().detach()
    edge_index = edge_index.cpu().detach()
    batch = batch.cpu().detach()
            
    del data_x, edge_index, batch

- Merge with Demographic Data

In [75]:
Target_Features = ['bcr_patient_barcode', 'survival_time', 'vital_status', 'Gender', 
                   'Age_stand', 'Stage_Stage_I', 'Stage_Stage_II', 'Stage_Stage_III', 
                   'Stage_Stage_IV', 'Stage_cis_Stage_0']

In [76]:
clinical_df_With_Embedding_TCGA = WSI_TCGA_info[Target_Features].copy()
clinical_df_With_Embedding_TCGA['GNN_Embedding'] = GNN_Grph_Embedings_Train
clinical_df_With_Embedding_TCGA

Unnamed: 0,bcr_patient_barcode,survival_time,vital_status,Gender,Age_stand,Stage_Stage_I,Stage_Stage_II,Stage_Stage_III,Stage_Stage_IV,Stage_cis_Stage_0,GNN_Embedding
0,TCGA-2F-A9KO,734.0,1,1,-0.408341,0,0,0,1,0,"[0.142581507563591, 0.7256076335906982, 1.8738..."
1,TCGA-2F-A9KP,364.0,1,1,-0.135424,0,0,0,1,0,"[0.3259793221950531, 0.725544273853302, 1.8715..."
2,TCGA-2F-A9KP,364.0,1,1,-0.135424,0,0,0,1,0,"[0.20245903730392456, 0.7255260348320007, 1.86..."
3,TCGA-2F-A9KQ,2886.0,0,1,0.069820,0,0,1,0,0,"[0.22644075751304626, 0.7256844639778137, 1.89..."
4,TCGA-2F-A9KR,3183.0,1,0,-0.772643,0,0,1,0,0,"[0.07130429148674011, 0.7255820035934448, 1.86..."
...,...,...,...,...,...,...,...,...,...,...,...
451,TCGA-ZF-AA54,590.0,1,1,0.284450,0,0,1,0,0,"[0.30232250690460205, 0.725551426410675, 1.877..."
452,TCGA-ZF-AA58,1649.0,0,0,-0.599507,0,0,0,1,0,"[0.09882383793592453, 0.725610613822937, 1.879..."
453,TCGA-ZF-AA5H,897.0,0,0,-0.704970,0,0,0,1,0,"[0.3739091455936432, 0.7255582213401794, 1.890..."
454,TCGA-ZF-AA5N,168.0,1,0,-0.552086,0,0,0,1,0,"[0.30932381749153137, 0.7255484461784363, 1.87..."


- Save

In [78]:
with open('/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/TCGA_model/WSI/TCGA_Demographics_and_GNN_Embedding.pkl', 'wb') as file:
    pickle.dump(clinical_df_With_Embedding_TCGA, file)

- Load

In [24]:
with open('/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/TCGA_model/WSI/TCGA_Demographics_and_GNN_Embedding.pkl', 'rb') as file:
    clinical_df_With_Embedding_TCGA = pickle.load(file)

## Testing Set (DH)

- Generate Embedding

In [56]:
GNN_Grph_Embedings_Test = []

In [57]:
for batch_idx, data in enumerate(test_loader):
    
    censors = torch.Tensor(data[0].censor)
    times = data[0].duration
    edge_index = data[0].edge_index.to(device)
    batch = data[0].batch.to(device)
    data_x = data[0].x.float().to(device)
    demographics = data[1].float().to(device)
    
    if torch.any(times.isnan()):
        times = torch.nan_to_num(times)
    if torch.any(censors.isnan()):
        censors = torch.nan_to_num(censors)
                
    graph_embeddings = GNN_Embedding_encoder(data_x, edge_index, batch, demographics)
    for i in graph_embeddings.cpu().numpy():
        GNN_Grph_Embedings_Test.append(i.tolist())
            
    data_x = data_x.cpu().detach()
    edge_index = edge_index.cpu().detach()
    batch = batch.cpu().detach()
            
    del data_x, edge_index, batch

- Merge with Demographic Data

In [46]:
Target_Features = ['FFPE.DNA.ID', 'death_censor_time', 'death_stat', 'Gender', 
                   'Age_stand', 'Stage_Stage_I', 'Stage_Stage_II', 'Stage_Stage_III', 
                   'Stage_Stage_IV', 'Stage_cis_Stage_0']

In [61]:
clinical_df_With_Embedding_DH = clinical_DH_df[Target_Features].copy()
clinical_df_With_Embedding_DH['GNN_Embedding'] = GNN_Grph_Embedings_Test
clinical_df_With_Embedding_DH

Unnamed: 0,FFPE.DNA.ID,death_censor_time,death_stat,Gender,Age_stand,Stage_Stage_I,Stage_Stage_II,Stage_Stage_III,Stage_Stage_IV,Stage_cis_Stage_0,GNN_Embedding
0,BLD050,3816.344970,1,1,0.772738,0,0,0,0,1,"[0.3499777317047119, 0.7255560159683228, 1.877..."
1,BLD043,4869.979467,1,0,0.772738,1,0,0,0,0,"[0.16093844175338745, 0.7255772352218628, 1.86..."
2,BLD043,4869.979467,1,0,0.772738,1,0,0,0,0,"[0.18417532742023468, 0.7256166338920593, 1.87..."
3,BLD061,3780.000000,0,0,0.682589,0,0,0,0,1,"[0.17486831545829773, 0.7256158590316772, 1.87..."
4,BLD061,3780.000000,0,0,0.682589,0,0,0,0,1,"[0.2684324085712433, 0.725586473941803, 1.8730..."
...,...,...,...,...,...,...,...,...,...,...,...
131,BLD048,3609.363450,1,1,-0.940099,0,0,0,0,1,"[0.21637272834777832, 0.7256174087524414, 1.88..."
132,BLD075,4230.000000,0,1,-1.571144,0,0,0,0,1,"[0.29817965626716614, 0.725570797920227, 1.870..."
133,BLD024,4500.000000,0,1,-0.038606,0,0,0,0,1,"[0.28502145409584045, 0.7255651354789734, 1.87..."
134,BLD094,4410.000000,0,1,-1.390845,0,0,0,0,1,"[0.18989501893520355, 0.7256122827529907, 1.87..."


- Save

In [65]:
with open('/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/TCGA_model/WSI/DH_Demographics_and_GNN_Embedding.pkl', 'wb') as file:
    pickle.dump(clinical_df_With_Embedding_DH, file)

- Load

In [26]:
with open('/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Interns_2023/user/JiQing/TCGA_model/WSI/DH_Demographics_and_GNN_Embedding.pkl', 'rb') as file:
    clinical_df_With_Embedding_DH = pickle.load(file)

# Check New NN Using Embeddings as Input

## Generate Dataset Tensor Object for Embedding for Training and Testing

In [36]:
class TCGA_GNNEmbedding_Dataset(Dataset):
    def __init__(self, dataframe):
        super().__init__()
        self.df = dataframe
        self.embeddings = dataframe['GNN_Embedding']
        self.demographic = dataframe.loc[:, 'Gender':'Stage_cis_Stage_0']
        self.death = dataframe['survival_time'] # ouput
        self.status = dataframe['vital_status'] # output
        
    def len(self):
        return len(self.df)
        
    def get(self, idx):
        embeddings = torch.tensor(self.embeddings.iloc[idx])
        demographic = torch.tensor(self.demographic.iloc[idx])
        death = torch.tensor(int(float(self.death.iloc[idx])))
        status = torch.tensor(self.status.iloc[idx])
        return embeddings, demographic, death, status

In [37]:
whole_Embedding_dataset = TCGA_GNNEmbedding_Dataset(clinical_df_With_Embedding_TCGA)

In [41]:
whole_TCGA_Embedding_loader = DataLoader(whole_Embedding_dataset, batch_size = 5, shuffle=False)



In [62]:
hazard_predict_New = []

for idx,data in enumerate(whole_TCGA_Embedding_loader):
    censors = data[3]
    times = data[2]
    data1 = data[0].float().to(device)
    data2 = data[1].float().to(device)
    if torch.any(censors.isnan()):
        censors = torch.nan_to_num(censors)
    if torch.any(times.isnan()):
        times = torch.nan_to_num(times)
    if torch.any(data1.isnan()):
        data1 = torch.nan_to_num(data1)
    if torch.any(data2.isnan()):
        data2 = torch.nan_to_num(data2)
        
    hazard_pred = Embedding_NN(data1,data2)
    for i in hazard_pred.cpu().numpy():
        hazard_predict_New.append(i[0])

In [63]:
hazard_predict_New

[1.7863293,
 2.8467963,
 2.6066833,
 -0.04490111,
 -0.19755195,
 1.2734665,
 0.5189002,
 0.79393804,
 0.8492887,
 -0.4206177,
 -0.5024761,
 1.3423917,
 2.2409,
 -0.33212912,
 -1.3919519,
 -1.4994115,
 3.5418425,
 1.526831,
 0.59163266,
 -1.3305845,
 1.5190862,
 2.9914398,
 0.10799269,
 0.42050084,
 0.15770347,
 0.37229943,
 -18.458055,
 -18.44275,
 0.06921818,
 0.20128025,
 3.2269735,
 3.535084,
 -0.5469323,
 0.4959322,
 0.41439557,
 1.0671813,
 0.9045541,
 3.8538022,
 0.0052917097,
 0.8579634,
 1.4836627,
 1.7515732,
 4.491112,
 2.6614697,
 0.90907705,
 2.6607554,
 -1.1785678,
 0.5984566,
 3.938976,
 5.0177307,
 0.1542082,
 4.623092,
 1.3139383,
 -0.28145477,
 0.41474473,
 2.409527,
 2.7643309,
 -1.4820132,
 0.3754227,
 0.7792846,
 4.785987,
 -0.71539086,
 0.16715284,
 -1.217979,
 -3.3621666,
 -1.5448742,
 2.5859873,
 -1.3347144,
 -2.2074296,
 -0.7410207,
 -1.3077711,
 -2.1458359,
 -1.469124,
 -1.3969615,
 -1.4125707,
 1.0422345,
 -1.3342668,
 -0.96387774,
 -1.4323953,
 -1.2775863,
 4

## Compare With Original NN

In [68]:
hazard_predict_Original = []

for idx,data in enumerate(whole_TCGA_loader):
    censors = torch.Tensor(data[0].censor)
    times = data[0].duration
    edge_index = data[0].edge_index.to(device)
    batch = data[0].batch.to(device)
    data_x = data[0].x.float().to(device)
    demographics = data[1].float().to(device)
    
    if torch.any(times.isnan()):
        times = torch.nan_to_num(times)
    if torch.any(censors.isnan()):
        censors = torch.nan_to_num(censors)
        
    hazard_pred = Original_GNN(data_x, edge_index, batch, demographics)
    
    for i in hazard_pred.cpu().numpy():
        hazard_predict_Original.append(i[0])

In [69]:
hazard_predict_Original

[1.7863299,
 2.846796,
 2.6066823,
 -0.044900738,
 -0.19755183,
 1.2734671,
 0.5189004,
 0.79393756,
 0.8492888,
 -0.42061776,
 -0.502476,
 1.3423927,
 2.2408977,
 -0.33212924,
 -1.3919522,
 -1.4994115,
 3.541843,
 1.5268301,
 0.59163237,
 -1.3305843,
 1.5190852,
 2.9914405,
 0.10799182,
 0.42050052,
 0.1577036,
 0.37229943,
 -18.458055,
 -18.442753,
 0.069218434,
 0.20128028,
 3.2269735,
 3.535084,
 -0.54693234,
 0.4959316,
 0.4143955,
 1.0671811,
 0.9045541,
 3.8538024,
 0.0052921567,
 0.85796344,
 1.483663,
 1.7515724,
 4.4911113,
 2.66147,
 0.9090773,
 2.6607554,
 -1.1785678,
 0.59845644,
 3.9389756,
 5.0177307,
 0.15420811,
 4.6230927,
 1.3139385,
 -0.28145537,
 0.41474503,
 2.4095287,
 2.7643304,
 -1.4820132,
 0.37542242,
 0.77928424,
 4.785987,
 -0.7153914,
 0.1671532,
 -1.217979,
 -3.3621666,
 -1.5448742,
 2.585987,
 -1.3347142,
 -2.2074296,
 -0.74102104,
 -1.307771,
 -2.1458356,
 -1.469124,
 -1.3969619,
 -1.4125708,
 1.0422351,
 -1.3342669,
 -0.96387774,
 -1.4323955,
 -1.27758

- Exacly the same! Great!