In [1]:
import argparse, random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets, transforms, utils, models
from skimage import io, transform, img_as_float
from torch.autograd import Variable
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
import math
import random
from imblearn.over_sampling import RandomOverSampler
import pandas as pd
from lifelines.statistics import logrank_test
from lifelines.utils import concordance_index
import tables
import csv
import numpy as np
import json
from tqdm import tqdm
import gc
import copy
import os
from PIL import *
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
class Prognostic_Dataset(Dataset):

    def __init__(self, csv_file, transform=None):
     
        self.files_list = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.files_list)

    def __getitem__(self, idx):
        
        if type(idx) == torch.Tensor:
                idx = idx.item()
        
        he_name = os.path.join(self.files_list.iloc[idx, 0])
        he_name = he_name.replace('Teresa_TMA_patch', 'Patches')
        index_sub = he_name.find('Patches')
        he_name = os.path.join('/deepdata/adib/prognostic_study', he_name[index_sub:])
        he_name = he_name.replace('\\', '/') 
       

        
        status = files_list.iloc[idx, 11]
        age = files_list.iloc[idx, 2]
        sex = files_list.iloc[idx, 3]
        os_month = files_list.iloc[idx, 13]

        event=0
        if status=='deceased':
            event=1
        
        he_image = Image.open( he_name )

        he_image = img_as_float(he_image)
        
        sample = {'input': he_image, 'event': event, 'OS':os_month}

        if self.transform:
            sample = self.transform(sample)

        return sample
    
class ToTensor(object):

    def __call__(self, sample):
        input, event, os_month = sample['input'], sample['event'], sample['OS']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        input = input.transpose((2, 0, 1))
        sample = {'input': input, 'event': event, 'OS':os_month}

        return sample

In [3]:
csv_file='/deepdata/adib/prognostic_study/pytorch_code/Patch_clinical_clean_v2.csv'
files_list = pd.read_csv(csv_file)
prognostic_dataset= Prognostic_Dataset( csv_file, transform=ToTensor())
#import nonechucks as nc
#prognostic_dataset = nc.SafeDataset(prognostic_dataset)
dataloader = DataLoader(prognostic_dataset, batch_size=280,shuffle=True, num_workers=0)

  interactivity=interactivity, compiler=compiler, result=result)
  if (yield from self.run_code(code, result)):


In [4]:
class survresnet(nn.Module):
    def __init__(self):
        super(survresnet, self).__init__()
        cox_in_dim=4
        label_dim = 1
        use_pretrained=True 
        feature_extract = False

        PATH="/deepdata/adib/prognostic_study/pytorch_code/trained_model/trainedResnet.pth"
        model_ft = models.resnet101(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 4)
        input_size = 224
        model_ft.load_state_dict(torch.load(PATH))
    
        self.resnet=model_ft
        self.coxnet = nn.Sequential(nn.Linear(cox_in_dim, label_dim),nn.Softplus())
        
    def forward(self, x):
        x_d = None
        
        code = self.resnet(x)
        lbl_pred = self.coxnet(code)
        
        return x_d, code, lbl_pred


In [5]:
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def accuracy_cox(hazards, labels):
    # This accuracy is based on estimated survival events against true survival events
    hazardsdata = hazards.cpu().numpy().reshape(-1)
    median = np.median(hazardsdata)
    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
    hazards_dichotomize[hazardsdata > median] = 1
    labels = labels.data.cpu().numpy()
    correct = np.sum(hazards_dichotomize == labels)
    return correct / len(labels)

def cox_log_rank(hazards, labels, survtime_all):
    hazardsdata = hazards.cpu().numpy().reshape(-1)
    median = np.median(hazardsdata)
    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
    hazards_dichotomize[hazardsdata > median] = 1
    survtime_all = survtime_all.data.cpu().numpy().reshape(-1)
    idx = hazards_dichotomize == 0
    labels = labels.data.cpu().numpy()
    T1 = survtime_all[idx]
    T2 = survtime_all[~idx]
    E1 = labels[idx]
    E2 = labels[~idx]
    results = logrank_test(T1, T2, event_observed_A=E1, event_observed_B=E2)
    pvalue_pred = results.p_value
    return(pvalue_pred)
    
def CIndex(hazards, labels, survtime_all):
    labels = labels.data.cpu().numpy()
    concord = 0.
    total = 0.
    N_test = labels.shape[0]
    labels = np.asarray(labels, dtype=bool)
    for i in range(N_test):
        if labels[i] == 1:
            for j in range(N_test):
                if survtime_all[j] > survtime_all[i]:
                    total = total + 1
                    if hazards[j] < hazards[i]: concord = concord + 1
                    elif hazards[j] < hazards[i]: concord = concord + 0.5

    return(concord/total)
    
def CIndex_lifeline(hazards, labels, survtime_all):
    labels = labels.data.cpu().numpy()
    hazards = hazards.cpu().numpy().reshape(-1)
    return(concordance_index(survtime_all, -hazards, labels))
        
def frobenius_norm_loss(a, b):
    loss = torch.sqrt(torch.sum(torch.abs(a-b)**2))
    return loss

In [6]:
def train(dataloader, num_epochs, batch_size, learning_rate, dropout_rate,
                        lambda_1, measure, verbose):
    

   # dataloader = DataLoader(prognostic_dataset, batch_size=batch_size,shuffle=False, num_workers=0)
    
    cudnn.deterministic = True
    torch.cuda.manual_seed_all(666)
    torch.manual_seed(666)
    random.seed(666)
    
#     model = survresnet()
#     device=torch.device('cuda:0')
#     model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
    
    lr = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=LambdaLR(num_epochs, 0, 10).step
    )
    
    c_index_list = {}
    c_index_list = []
    loss_nn_all = []
    pvalue_all = []
    c_index_all = []
    acc_train_all = []
    c_index_best = 0
    code_output = None
    
    for epoch in tqdm(range(num_epochs)):
        lr.step()
        for param_group in optimizer.param_groups:
            print("learning rate:" + str(param_group['lr']))
        model.train()
        lbl_pred_all = None
        lbl_all = None
        survtime_all = None
        code_final = None
        loss_nn_sum = 0
        iter = 0
        gc.collect()
        for iteration, batch in enumerate(dataloader):
            data, lbl, survtime = batch['input'], batch['event'], batch['OS']
            optimizer.zero_grad() # zero the gradient buffer
            graph = data
            
            
            graph = graph.to(device, dtype=torch.float)
            lbl = lbl.to(device)
            # ===================forward=====================
            output, code, lbl_pred = model(graph)
            
            if iter == 0:
                lbl_pred_all = lbl_pred
                survtime_all = survtime
                lbl_all = lbl
                code_final = code
            else:
                lbl_pred_all = torch.cat([lbl_pred_all, lbl_pred])
                lbl_all = torch.cat([lbl_all, lbl])
                survtime_all = torch.cat([survtime_all, survtime])
                code_final = torch.cat([code_final, code])
            current_batch_len = len(survtime)
            R_matrix_train = np.zeros([current_batch_len, current_batch_len], dtype=int)
            for i in range(current_batch_len):
                for j in range(current_batch_len):
                    R_matrix_train[i,j] = survtime[j] >= survtime[i]
        
            train_R = torch.FloatTensor(R_matrix_train)
            
            train_R = train_R.to(device, dtype=torch.float)
            train_ystatus = lbl
            
            theta = lbl_pred.reshape(-1)
            exp_theta = torch.exp(theta)
            
            loss_nn = -torch.mean( (theta - torch.log(torch.sum( exp_theta*train_R ,dim=1))) * train_ystatus.float() )

            l1_reg = None
            for W in model.parameters():
                if l1_reg is None:
                    l1_reg = torch.abs(W).sum()
                else:
                    l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
            
            loss = loss_nn + lambda_1 * l1_reg
            if verbose > 2:
                print("\nloss_nn: %.4f, L1: %.4f" % (loss_nn, lambda_1 * l1_reg))
            loss_nn_sum = loss_nn_sum + loss_nn.data.item()
            # ===================backward====================
            loss.backward()
            optimizer.step()
            
            iter += 1
#             torch.cuda.empty_cache()
       # code_final_4_original_data = code_final.data.cpu().numpy()
        
        if measure or epoch == (num_epochs - 1):
            acc_train = accuracy_cox(lbl_pred_all.data, lbl_all)
            pvalue_pred = cox_log_rank(lbl_pred_all.data, lbl_all, survtime_all)
            c_index = CIndex_lifeline(lbl_pred_all.data, lbl_all, survtime_all)
            
            c_index_list.append(c_index)
            if c_index > c_index_best:
                c_index_best = c_index
            #    code_output = code_final_4_original_data
            if verbose > 0:
                print('\n[Training]\t loss (nn):{:.4f}'.format(loss_nn_sum),
                      'c_index: {:.4f}, p-value: {:.3e}'.format(c_index, pvalue_pred))
                torch.save(model.state_dict(), Save_model_path)
            pvalue_all.append(pvalue_pred)
            c_index_all.append(c_index)
            loss_nn_all.append(loss_nn_sum)
            acc_train_all.append(acc_train)
    return(model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all)

In [7]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [8]:
cwd = os.getcwd();
save_path = os.path.join(cwd, 'Saved model')
Save_model_path=os.path.join(save_path,'survivalnet.pth')
Tensor = torch.cuda.LongTensor

In [9]:
cwd = os.getcwd()
path = os.path.join(cwd, 'Saved model', 'survivalnet.pth')
state_dict = torch.load(path, map_location=torch.device('cuda:1'))
device = torch.device('cuda:1')
gpu_ids = [1, 2, 3, 4, 5, 6, 7]
torch.cuda.set_device(gpu_ids[0])
model = survresnet()
model = torch.nn.DataParallel(model, device_ids=gpu_ids)
model.load_state_dict(torch.load(path))
model.eval()
model.to(device)

DataParallel(
  (module): survresnet(
    (resnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace

In [9]:
torch.save(model.state_dict(), Save_model_path)

In [10]:
num_epochs = 100
batch_size = 350
lr = 0.001
verbose = 0
measure_while_training = True
dropout_rate = 0
lambda_1 = 1e-5 # L1

In [15]:
# 70 per gpu
model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = train(dataloader, num_epochs, 
                                                                                              batch_size, 
                                                                                              lr, dropout_rate,lambda_1,
                                                                                              measure=True, verbose=1)

  0%|          | 0/100 [00:00<?, ?it/s]

learning rate:0.001

[Training]	 loss (nn):8485.4032 c_index: 0.6095, p-value: 0.000e+00


  1%|          | 1/100 [1:53:11<186:46:14, 6791.66s/it]

learning rate:0.001


KeyboardInterrupt: 

In [12]:
# 50 per gpu
model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = train(dataloader, num_epochs, 
                                                                                              batch_size, 
                                                                                              lr, dropout_rate,lambda_1,
                                                                                              measure=True, verbose=1)

  0%|          | 0/100 [00:00<?, ?it/s]

learning rate:0.001

[Training]	 loss (nn):10852.2202 c_index: 0.6704, p-value: 0.000e+00


  1%|          | 1/100 [1:45:56<174:47:30, 6356.07s/it]

learning rate:0.001


KeyboardInterrupt: 

In [None]:
# 40 per gpu
model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = train(dataloader, num_epochs, 
                                                                                              batch_size, 
                                                                                              lr, dropout_rate,lambda_1,
                                                                                              measure=True, verbose=1)

  0%|          | 0/100 [00:00<?, ?it/s]

learning rate:0.001

[Training]	 loss (nn):12636.8156 c_index: 0.6876, p-value: 0.000e+00


  1%|          | 1/100 [1:45:38<174:19:02, 6338.81s/it]

learning rate:0.001

[Training]	 loss (nn):12227.6670 c_index: 0.7690, p-value: 0.000e+00


  2%|▏         | 2/100 [3:24:27<169:12:32, 6215.84s/it]

learning rate:0.001

[Training]	 loss (nn):11978.9867 c_index: 0.8132, p-value: 0.000e+00


  3%|▎         | 3/100 [5:04:24<165:42:45, 6150.16s/it]

learning rate:0.001

[Training]	 loss (nn):11782.1280 c_index: 0.8265, p-value: 0.000e+00


  4%|▍         | 4/100 [6:43:14<162:14:17, 6083.93s/it]

learning rate:0.001

[Training]	 loss (nn):11639.5710 c_index: 0.8391, p-value: 0.000e+00


  5%|▌         | 5/100 [8:21:51<159:13:59, 6034.10s/it]

learning rate:0.001

[Training]	 loss (nn):11532.9887 c_index: 0.8421, p-value: 0.000e+00


  6%|▌         | 6/100 [10:02:15<157:28:24, 6030.89s/it]

learning rate:0.001

[Training]	 loss (nn):11451.2215 c_index: 0.8503, p-value: 0.000e+00


  7%|▋         | 7/100 [11:41:31<155:12:59, 6008.38s/it]

learning rate:0.001

[Training]	 loss (nn):11361.8173 c_index: 0.8578, p-value: 0.000e+00


  8%|▊         | 8/100 [13:20:39<153:05:06, 5990.29s/it]

learning rate:0.001

[Training]	 loss (nn):11299.8168 c_index: 0.8544, p-value: 0.000e+00


  9%|▉         | 9/100 [15:00:29<151:25:20, 5990.33s/it]

learning rate:0.001

[Training]	 loss (nn):11262.4667 c_index: 0.8629, p-value: 0.000e+00


 10%|█         | 10/100 [16:39:26<149:21:15, 5974.18s/it]

learning rate:0.001

[Training]	 loss (nn):11185.3270 c_index: 0.8662, p-value: 0.000e+00


 11%|█         | 11/100 [18:18:36<147:31:18, 5967.18s/it]

learning rate:0.000988888888888889

[Training]	 loss (nn):11158.0592 c_index: 0.8670, p-value: 0.000e+00


 12%|█▏        | 12/100 [19:58:45<146:09:54, 5979.49s/it]

learning rate:0.0009777777777777777

[Training]	 loss (nn):11085.8110 c_index: 0.8741, p-value: 0.000e+00


 13%|█▎        | 13/100 [21:37:48<144:14:39, 5968.73s/it]

learning rate:0.0009666666666666667


csv_file='D:\\Prognostic_study\z\pytorch_code\\Patch_clinical_new.csv'
files_list = pd.read_csv(csv_file)

In [10]:
he_name = os.path.join(files_list.iloc[0, 0])
print(he_name)
pilImg = Image.open( he_name )
he_image = img_as_float(pilImg)


F:\Prognostic study\Patches\survival\All_A1_14900_rot180_res1.jpeg


In [11]:
from PIL import Image
for i in range(len(files_list)):
    he_name = os.path.join(files_list.iloc[i, 0])
    try:
      img = Image.open(he_name) # open the image file
      img.verify() # verify that it is, in fact an image
    except (IOError, SyntaxError) as e:
      print('Bad file:', he_name) # print out the names of corrupt files
print(i)

Bad file: F:\Prognostic study\Patches\deceased\All_A5_21000_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_A5_21000_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_A9_211400_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_A9_211400_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_A9_211400_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_A9_211400_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_A9_211400_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_A9_211400_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_C17_656100_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_C17_845640_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_C17_845640_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\deceased\All_C9_760320_rot180_res1.jpeg
Bad file: F:\Prognostic study\Patches\survival\All_D3_68134_rot