In [None]:
''' Imports and loss function definitions  '''
!pip install h5py
!pip install lifelines
import h5py
import pickle
import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import math
from lifelines.utils import concordance_index
import collections
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

def nll_loss(hazards, S, Y, c, alpha=0.5, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1) 
    c = c.view(batch_size, 1).float() 
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1)
    S_padded = torch.cat([torch.ones_like(c), S], 1) 
    uncensored_loss = -(c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
    censored_loss = - (1-c) * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
    neg_l = censored_loss + uncensored_loss
    loss = (1-alpha) * neg_l + alpha * uncensored_loss
    loss = loss.mean()
    return loss

def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1) 
    c = c.view(batch_size, 1).float() 
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1) 
    S_padded = torch.cat([torch.ones_like(c), S], 1)
    reg = -(c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
    ce_l = - (1-c) * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
    loss = (1-alpha) * ce_l + alpha * reg
    loss = loss.mean()
    return loss

class CrossEntropySurvLoss(object):
    def __init__(self, alpha=0.15):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None): 
        if alpha is None:
            return ce_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return ce_loss(hazards, S, Y, c, alpha=alpha)

class NLLSurvLoss(object):
    def __init__(self, alpha=0.5):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None):
        if alpha is None:
            return nll_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return nll_loss(hazards, S, Y, c, alpha=alpha)

class CoxSurvLoss(object):
    def __call__(self,hazards, S, c, **kwargs):
        current_batch_len = len(S)
        R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
        for i in range(current_batch_len):
            for j in range(current_batch_len):
                R_mat[i,j] = S[0][j] >= S[0][i]
    
        R_mat = torch.FloatTensor(R_mat).to(device)
        theta = hazards.reshape(-1)
        exp_theta = torch.exp(theta)
        loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (c))
        return loss_cox





In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
''' Attention Model definition suited for extracting survival/hazard rates '''
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 1024
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(3, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 53 * 53, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Dropout(0.25),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 4),
            nn.Sigmoid()
        )
        
    def forward(self, **kwargs):
        H = kwargs['x']
        A= self.attention(H)  
        A = torch.transpose(A, 1, 0) 
        A_raw = A 
        A = F.softmax(A, dim=1)
        M = torch.mm(A, H) 
        logits  = self.classifier(M) 
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        hazards = torch.sigmoid(logits)
        S = torch.cumprod(1 - hazards, dim=1)

        
        return hazards, S, Y_hat, A_raw


       

In [None]:
''' Input to model taken from filtered clinical dataset to include WSI patch, event observation and event duration. 
    It is to be noted that although the variable for event observation is defined as "censor" it contains value of event observation.
    This particular definition is for the training phase '''
class InputData(torch.utils.data.Dataset):
    def __init__(self,train_image, test_image, valid_image,censor='/content/drive/MyDrive/propro/train_valid_censor_1.pt',event_time='/content/drive/MyDrive/propro/train_valid_et_1.pt' ,mode = 'train', transform = None):
        self.censor=censor
        self.event_time=event_time
        self.mode = mode
        self.final_censor = torch.load(self.censor)
        self.final_event_time = torch.load(self.event_time)
        self.train_image = train_image
        self.test_image = test_image
        self.valid_image = valid_image
        self.transform = transforms.Compose([transforms.ToTensor()]) 
        
    
    def __len__(self):
        return 179709
    
    def __getitem__(self, index):
        temp_ind=index
        if temp_ind>=0 and temp_ind<146926:
          self.image=self.train_image
          ind=temp_ind
        elif temp_ind>=146926 and temp_ind<179709:
          self.image=self.valid_image
          ind=temp_ind-146926
        img_test = self.image[ind]
        input_batch = img_test
        censor = self.final_censor[index]
        event_time = self.final_event_time[index]
        
        
        
        return (input_batch, censor,event_time)


In [None]:
''' Input to model taken from filtered clinical dataset to include WSI patch, event observation and event duration. 
    It is to be noted that although the variable for event observation is defined as "censor" it contains value of event observation.
    This particular definition is for the testing/validation phase '''
class InputTest(torch.utils.data.Dataset):
    def __init__(self,image,censor='/content/drive/MyDrive/propro/test_censor_1.pt',event_time='/content/drive/MyDrive/propro/test_et_1.pt' ,mode = 'test', transform = None):
        self.censor=censor
        self.event_time=event_time
        self.mode = mode
        self.final_censor = torch.load(self.censor)
        self.final_event_time = torch.load(self.event_time)
        self.dset = image
           
    def __len__(self):
        return len(self.dset)
    
    def __getitem__(self, index):
        img_test =self.dset[index]
        input_batch = img_test
        censor = self.final_censor[index]
        event_time = self.final_event_time[index]
        
        
        
        return (input_batch, censor,event_time)


In [None]:
''' Training is guided by a loss function that takes in survival information and optimizer updates after every 16 slides  '''
def train_loop_survival(epoch, model, loader, optimizer, loss_fn=NLLSurvLoss(),gc=16):   
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    model.train()
    train_loss_surv = 0.

    print('\n')
    risk_scores = np.zeros((len(loader)))
    observations = np.zeros((len(loader)))
    event_times = np.zeros((len(loader)))

    for batch_idx, (data_WSI,censor,event_time) in enumerate(loader):
        if censor.item() == -1.0:
            continue
        model.train()
        data_WSI= data_WSI.to(device)    
        hazards, S, Y_hat, _ = model(x=data_WSI)       
        c=censor.to(device)
        loss = loss_fn(hazards,S,Y_hat,c)
        loss_value = loss.item()
        risk = -torch.sum(S, dim=1).detach().cpu().numpy()
        risk_scores[batch_idx] = risk
        observations[batch_idx] = c.item()
        event_times[batch_idx] = event_time       

        train_loss_surv += loss_value
  
        if (batch_idx + 1) % 10000 == 0:
            print('batch {}, loss: {:.4f}, risk: {:.4f}, bag_size: {}'.format(batch_idx, loss_value,float(risk), data_WSI.size(0)))
        loss = loss / gc 
        loss.backward()

        if (batch_idx + 1) % gc == 0: 
            optimizer.step()
            optimizer.zero_grad()
    train_loss_surv /= len(loader)
    c_index = concordance_index(event_times, risk_scores, event_observed=observations)
    print('Epoch: {}, train_loss_surv: {:.4f}, train_c_index: {:.4f}'.format(epoch, train_loss_surv, c_index))

  





In [None]:
''' Used for validation/testing '''
def validate_survival(epoch, model, loader, loss_fn=NLLSurvLoss()):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    val_loss_surv= 0.
    risk_scores = np.zeros((len(loader)))
    observations = np.zeros((len(loader)))
    event_times = np.zeros((len(loader)))   
    for batch_idx, (data_WSI,censor,event_time) in enumerate(loader):
        if censor.item() == -1.0:
            continue
        data_WSI = data_WSI.to(device)
        c = censor.to(device)
        with torch.no_grad():
            hazards, S, Y_hat, _ = model(x=data_WSI)

        loss = loss_fn(hazards,S,Y_hat,c)
        loss_value = loss.item()

        risk = -torch.sum(S, dim=1).cpu().numpy()
        risk_scores[batch_idx] = risk
        observations[batch_idx] = c.cpu().numpy()
        event_times[batch_idx] = event_time

        val_loss_surv += loss_value
        
    val_loss_surv /= len(loader)
    
    c_index = concordance_index(event_times, risk_scores, event_observed=observations)
    
    print('Epoch: {}, val_loss_surv: {:.4f}, val_c_index: {:.4f}'.format(epoch, val_loss_surv, c_index))

  

In [None]:
''' Input patches and unique ID stored in variables '''
import h5py
train = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_train.h5"
test = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_test.h5"
valid = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_validation.h5"
hdf5_train = h5py.File(train, "r")
hdf5_test = h5py.File(test, "r")
hdf5_valid = h5py.File(valid, "r")
#dtrain = hdf5_train['train_img']
#dtest = hdf5_test['test_img']
#dvalid = hdf5_valid['valid_img']
train_slides = hdf5_train['train_slides']
test_slides = hdf5_test['test_slides']
valid_slides = hdf5_valid['valid_slides']
dtrain=torch.load('/content/drive/MyDrive/propro/train__resnet50_pretrained_image_features_dec2.pt')
dtest=torch.load('/content/drive/MyDrive/propro/test__resnet50_pretrained_image_features_dec2.pt')
dvalid=torch.load('/content/drive/MyDrive/propro/valid__resnet50_pretrained_image_features_dec2.pt')


In [None]:
''' Model Training phase '''
print('Initializing the Model')
model = Attention()
model.cuda()
print('\nInitializing the optimizer ')
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, weight_decay=1e-5)


epoch = 10

for i in range(epoch):
    print("...Train...")
    for x in range(1):
      train_data = InputData(train_image=dtrain,test_image=dtest,valid_image=dvalid)   
      test_data = InputTest(image=dtest)
      train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, num_workers = 0, batch_size = 1)
      test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)
      train_loop_survival(i,model,train_loader,optimizer)
      validate_survival(1,model,test_loader)



Initializing the Model

Initializing the optimizer 
...Train...


batch 9999, loss: 1.8847, risk: -0.8521, bag_size: 1
batch 19999, loss: 1.1931, risk: -0.8647, bag_size: 1
batch 29999, loss: 1.8546, risk: -0.8726, bag_size: 1
batch 39999, loss: 1.8690, risk: -0.8807, bag_size: 1
batch 49999, loss: 1.8836, risk: -0.8847, bag_size: 1
batch 59999, loss: 1.1527, risk: -0.8895, bag_size: 1
batch 69999, loss: 1.5419, risk: -0.9179, bag_size: 1
batch 79999, loss: 2.5534, risk: -0.9201, bag_size: 1
batch 89999, loss: 2.5663, risk: -0.9218, bag_size: 1
batch 99999, loss: 2.5130, risk: -0.9185, bag_size: 1
batch 109999, loss: 1.5367, risk: -0.9207, bag_size: 1
batch 119999, loss: 1.4962, risk: -0.9248, bag_size: 1
batch 129999, loss: 2.5894, risk: -0.9244, bag_size: 1
batch 139999, loss: 1.5423, risk: -0.9205, bag_size: 1
batch 149999, loss: 1.5018, risk: -0.9244, bag_size: 1
batch 159999, loss: 1.5257, risk: -0.9222, bag_size: 1
batch 169999, loss: 2.5972, risk: -0.9254, bag_size: 1
Epoch: 0, 

In [None]:
''' Continue model training from saved parameters '''
model.load_state_dict(torch.load('/content/drive/MyDrive/propro/resnet50_nllloss_epoch40_05536.pt'))
print('\nInit optimizer ...', end=' ')
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, weight_decay=1e-5)
model.train()
epoch = 10 #4

for i in range(epoch):
    print("...Train...")
    for x in range(1):
      train_data = InputData(train_image=dtrain,test_image=dtest,valid_image=dvalid)   
      test_data = InputTest(image=dtest)
      train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, num_workers = 0, batch_size = 1)
      test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)
      train_loop_survival(i,model,train_loader,optimizer)
      validate_survival(1,model,test_loader)


Init optimizer ... ...Train...


batch 9999, loss: 1.4148, risk: -0.9340, bag_size: 1
batch 19999, loss: 1.3997, risk: -0.9358, bag_size: 1
batch 29999, loss: 1.5918, risk: -0.9164, bag_size: 1
batch 39999, loss: 2.6130, risk: -0.9266, bag_size: 1
batch 49999, loss: 1.5166, risk: -0.9231, bag_size: 1
batch 59999, loss: 2.4746, risk: -0.9157, bag_size: 1
batch 69999, loss: 1.4091, risk: -0.9347, bag_size: 1
batch 79999, loss: 1.6050, risk: -0.9152, bag_size: 1
batch 89999, loss: 1.4370, risk: -0.9314, bag_size: 1
batch 99999, loss: 2.6518, risk: -0.9294, bag_size: 1
batch 109999, loss: 1.4615, risk: -0.9287, bag_size: 1
batch 119999, loss: 2.5827, risk: -0.9244, bag_size: 1
batch 129999, loss: 1.4689, risk: -0.9279, bag_size: 1
batch 139999, loss: 1.5402, risk: -0.9209, bag_size: 1
batch 149999, loss: 1.6654, risk: -0.9107, bag_size: 1
batch 159999, loss: 2.5706, risk: -0.9234, bag_size: 1
batch 169999, loss: 2.5810, risk: -0.9242, bag_size: 1
Epoch: 0, train_loss_surv: 1.9553, train_l

In [None]:
 model.eval()

Attention(
  (feature_extractor_part1): Sequential(
    (0): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (feature_extractor_part2): Sequential(
    (0): Linear(in_features=140450, out_features=1024, bias=True)
    (1): ReLU()
  )
  (attention): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=4, bias=True)
    (1): Sigmoid()
  )
)

In [None]:
''' Testing model on test patches to get final performance of model '''
test_data = InputTest(image=dtest)
test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)
validate_survival(1,model,test_loader)

Epoch: 1, val_loss_surv: 1.8758, val_loss: 1.8758, val_c_index: 0.5536


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/propro/resnet50_nllloss_epoch50.pt')

In [None]:
''' Checkpoint '''
model = Attention()
model.cuda()
model.load_state_dict(torch.load('/content/drive/MyDrive/propro/resnet50_nllloss_epoch50.pt'))
model.eval()

Attention(
  (feature_extractor_part1): Sequential(
    (0): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (feature_extractor_part2): Sequential(
    (0): Linear(in_features=140450, out_features=1024, bias=True)
    (1): ReLU()
  )
  (attention): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=4, bias=True)
    (1): Sigmoid()
  )
)

In [None]:
''' Testing model on test patches to get final performance of model '''
test_data = InputTest(image=dtest)
test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)
validate_survival(1,model,test_loader)

Epoch: 1, val_loss_surv: 1.8737, val_c_index: 0.5568
