In [None]:
''' Imports and loss function definitions  '''
!pip install h5py
!pip install lifelines
import h5py
import pickle
import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import math
from lifelines.utils import concordance_index
import collections
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

def nll_loss(hazards, S, Y, c, alpha=0.5, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1) 
    c = c.view(batch_size, 1).float() 
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1)
    S_padded = torch.cat([torch.ones_like(c), S], 1) 
    uncensored_loss = -(c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
    censored_loss = - (1-c) * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
    neg_l = censored_loss + uncensored_loss
    loss = (1-alpha) * neg_l + alpha * uncensored_loss
    loss = loss.mean()
    return loss

def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1) 
    c = c.view(batch_size, 1).float() 
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1) 
    S_padded = torch.cat([torch.ones_like(c), S], 1)
    reg = -(c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
    ce_l = - (1-c) * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
    loss = (1-alpha) * ce_l + alpha * reg
    loss = loss.mean()
    return loss

class CrossEntropySurvLoss(object):
    def __init__(self, alpha=0.15):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None): 
        if alpha is None:
            return ce_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return ce_loss(hazards, S, Y, c, alpha=alpha)

class NLLSurvLoss(object):
    def __init__(self, alpha=0.5):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None):
        if alpha is None:
            return nll_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return nll_loss(hazards, S, Y, c, alpha=alpha)

class CoxSurvLoss(object):
    def __call__(self,hazards, S, c, **kwargs):
        current_batch_len = len(S)
        R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
        for i in range(current_batch_len):
            for j in range(current_batch_len):
                R_mat[i,j] = S[0][j] >= S[0][i]
    
        R_mat = torch.FloatTensor(R_mat).to(device)
        theta = hazards.reshape(-1)
        exp_theta = torch.exp(theta)
        loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (c))
        return loss_cox



Collecting lifelines
  Downloading lifelines-0.26.4-py3-none-any.whl (348 kB)
[K     |████████████████████████████████| 348 kB 5.1 MB/s 
[?25hCollecting autograd-gamma>=0.3
  Downloading autograd-gamma-0.5.0.tar.gz (4.0 kB)
Collecting formulaic<0.3,>=0.2.2
  Downloading formulaic-0.2.4-py3-none-any.whl (55 kB)
[K     |████████████████████████████████| 55 kB 4.0 MB/s 
Collecting interface-meta>=1.2
  Downloading interface_meta-1.2.4-py2.py3-none-any.whl (14 kB)
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (setup.py) ... [?25l[?25hdone
  Created wheel for autograd-gamma: filename=autograd_gamma-0.5.0-py3-none-any.whl size=4049 sha256=8927bac22ba25758bc0eaa8362cfc7f5b09d27db3dfad79300e8745bd1ae88a1
  Stored in directory: /root/.cache/pip/wheels/9f/01/ee/1331593abb5725ff7d8c1333aee93a50a1c29d6ddda9665c9f
Successfully built autograd-gamma
Installing collected packages: interface-meta, formulaic, autograd-gamma, lifelines
Successfully instal

In [None]:
im_test=[]
for i in range(200):
    im_test.append(np.random.randint(200, size=(224,224,3),dtype=np.uint8))
im_test[0].shape

(224, 224, 3)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
''' Attention Model definition suited for extracting survival/hazard rates '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
import numpy as np

class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 1024
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(3, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 53 * 53, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Dropout(0.25),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 4),
            nn.Sigmoid()
        )
        
    def forward(self, **kwargs):
        h = kwargs['x']
        H = self.feature_extractor_part1(h)
        H = H.view(-1, 50 * 53 * 53)
        H = self.feature_extractor_part2(H)
        #h = h.squeeze(0)
        A= self.attention(H)  
        A = torch.transpose(A, 1, 0) 
        A_raw = A 
        A = F.softmax(A, dim=1)
        M = torch.mm(A, H) 
        logits  = self.classifier(M) 
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        hazards = torch.sigmoid(logits)
        S = torch.cumprod(1 - hazards, dim=1)

        
        return hazards, S, Y_hat, A_raw


       

In [None]:
''' Input to model taken from filtered clinical dataset to include WSI patch, event observation and event duration. 
    It is to be noted that although the variable for event observation is defined as "censor" it contains value of event observation.
    This particular definition is for the training phase '''
class InputData(torch.utils.data.Dataset):
    def __init__(self,image,censor='/content/drive/MyDrive/propro/train_censor_1.pt',event_time='/content/drive/MyDrive/propro/train_et_1.pt' ,mode = 'train', transform = None):
        self.censor=censor
        self.event_time=event_time
        self.mode = mode
        self.final_censor = torch.load(self.censor)
        self.final_event_time = torch.load(self.event_time)
        self.dset = image
        self.transform = transforms.Compose([transforms.ToTensor()]) 
        
      
    def __len__(self):
        return len(self.dset)
    
    def __getitem__(self, index):
        preprocess = transforms.Compose([
        transforms.ToTensor()
        ])
        img_test = transforms.ToPILImage()(self.dset[index]).convert("RGB")
        input_batch = preprocess(img_test)
        censor = self.final_censor[index]
        event_time = self.final_event_time[index]
        
        
        
        return (input_batch, censor,event_time)


In [None]:
''' Input to model taken from filtered clinical dataset to include WSI patch, event observation and event duration. 
    It is to be noted that although the variable for event observation is defined as "censor" it contains value of event observation.
    This particular definition is for the testing/validation phase '''
class InputTest(torch.utils.data.Dataset):
    def __init__(self,image,censor='/content/drive/MyDrive/propro/test_censor_1.pt',event_time='/content/drive/MyDrive/propro/test_et_1.pt' ,mode = 'test', transform = None):
        self.censor=censor
        self.event_time=event_time
        self.mode = mode
        self.final_censor = torch.load(self.censor)
        self.final_event_time = torch.load(self.event_time)
        self.dset = image
        self.transform = transforms.Compose([transforms.ToTensor()]) 

    def __len__(self):
        return len(self.dset)
    
    def __getitem__(self, index):
        preprocess = transforms.Compose([
        transforms.ToTensor()
        ])
        img_test = transforms.ToPILImage()(self.dset[index]).convert("RGB")
        input_batch = preprocess(img_test)
        censor = self.final_censor[index]
        event_time = self.final_event_time[index]
        
        
        
        return (input_batch, censor,event_time)


In [None]:
''' Training is guided by a loss function that takes in survival information and optimizer updates after every 16 slides  '''
def train_loop_survival(epoch, model, loader, optimizer, loss_fn=NLLSurvLoss(),  gc=16):   
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    model.train()
    train_loss_surv = 0.

    print('\n')
    risk_scores = np.zeros((len(loader)))
    observations = np.zeros((len(loader)))
    event_times = np.zeros((len(loader)))

    for batch_idx, (data_WSI,censor,event_time) in enumerate(loader):
        if censor.item() == -1.0:
            continue
        model.train()
        data_WSI= data_WSI.to(device)

        hazards, S, Y_hat, _ = model(x=data_WSI)
        c=censor.to(device)
        loss = loss_fn(hazards,S,Y_hat,c) 
        loss_value = loss.item()


        risk = -torch.sum(S, dim=1).detach().cpu().numpy()
        risk_scores[batch_idx] = risk
        observations[batch_idx] = c.item()
        event_times[batch_idx] = event_time
        

        train_loss_surv += loss_value
       
        if (batch_idx + 1) % 100 == 0:
            print('batch {}, loss: {:.4f}, risk: {:.4f}, bag_size: {}'.format(batch_idx, loss_value,float(risk), data_WSI.size(0)))
        loss = loss/gc 
        loss.backward()

        if (batch_idx + 1) % gc == 0: 
            optimizer.step()
            optimizer.zero_grad()
    train_loss_surv /= len(loader)
 
    
    

    c_index = concordance_index(event_times, risk_scores, event_observed=observations) 
    
    print('Epoch: {}, train_loss_surv: {:.4f}, train_c_index: {:.4f}'.format(epoch, train_loss_surv, c_index))

  





In [None]:
''' Used for validation/testing '''
def validate_survival(epoch, model, loader, loss_fn=NLLSurvLoss()):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    val_loss_surv= 0.
    risk_scores = np.zeros((len(loader)))
    observations = np.zeros((len(loader)))
    event_times = np.zeros((len(loader)))

    for batch_idx, (data_WSI,censor,event_time) in enumerate(loader):
        data_WSI = data_WSI.to(device)
        c = censor.to(device)

        with torch.no_grad():
            hazards, S, Y_hat, _ = model(x=data_WSI) 

        loss = loss_fn(hazards,S,Y_hat,c)
        loss_value = loss.item()

    
        risk = -torch.sum(S, dim=1).cpu().numpy()
        risk_scores[batch_idx] = risk
        observations[batch_idx] = c.cpu().numpy()
        event_times[batch_idx] = event_time

        val_loss_surv += loss_value
  
    val_loss_surv /= len(loader)
  
    c_index = concordance_index(event_times, risk_scores, event_observed=observations)
   
    print('Epoch: {}, val_loss_surv: {:.4f}, val_c_index: {:.4f}'.format(epoch, val_loss_surv, c_index))

  

In [None]:
''' Training and testing block '''
filename = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_train.h5"
test = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_test.h5"
hdf5_file = h5py.File(filename, "r")
hdf5_test = h5py.File(test, "r")
dset = hdf5_file['train_img']
dtest = hdf5_test['test_img']
label = hdf5_file['train_labels']
train_data = InputData(image=dset)
test_data = InputTest(image=dtest)
print('\nInit Loaders...')
train_loader = torch.utils.data.DataLoader(train_data, shuffle = False, num_workers = 0, batch_size = 1)
test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)
print('Init Model')
model = Attention()
model.cuda()
print('\nInit optimizer ...', end=' ')
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, weight_decay=1e-5)


epoch = 1

for i in range(epoch):
    print("...Train...")
    train_loop_survival(i,model,train_loader,optimizer)

validate_survival(1,model,test_loader)



Init Loaders...
Init Model

Init optimizer ... ...Train...


batch 99, loss: 1.3498, risk: -0.8387, bag_size: 1
batch 199, loss: 1.3498, risk: -0.8366, bag_size: 1
batch 299, loss: 1.6996, risk: -0.8359, bag_size: 1
batch 399, loss: 1.3498, risk: -0.8358, bag_size: 1
batch 499, loss: 1.6996, risk: -0.8357, bag_size: 1
batch 599, loss: 1.3498, risk: -0.8362, bag_size: 1
batch 699, loss: 1.3498, risk: -0.8358, bag_size: 1
batch 799, loss: 1.3498, risk: -0.8362, bag_size: 1
batch 899, loss: 1.6996, risk: -0.8359, bag_size: 1
batch 999, loss: 1.3498, risk: -0.8365, bag_size: 1
batch 1099, loss: 1.6996, risk: -0.8360, bag_size: 1
batch 1199, loss: 1.6996, risk: -0.8362, bag_size: 1
batch 1299, loss: 1.3498, risk: -0.8363, bag_size: 1
batch 1399, loss: 1.6996, risk: -0.8365, bag_size: 1
batch 1499, loss: 1.6996, risk: -0.8367, bag_size: 1
batch 1599, loss: 1.6996, risk: -0.8369, bag_size: 1
batch 1699, loss: 1.6996, risk: -0.8371, bag_size: 1
batch 1799, loss: 1.6996, risk: -0.8375, bag_siz

In [None]:
torch.save(model.state_dict(), 'trial.pt')

In [None]:
''' Checkpoint '''
model1=Attention()
model1.cuda()
model1.load_state_dict(torch.load('/content/drive/MyDrive/propro/random2_nov30.pt'))
model1.eval()

Attention(
  (feature_extractor_part1): Sequential(
    (0): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (feature_extractor_part2): Sequential(
    (0): Linear(in_features=140450, out_features=1024, bias=True)
    (1): ReLU()
  )
  (attention): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=4, bias=True)
    (1): Sigmoid()
  )
)

In [None]:
''' Testing the model loaded from checkpoint '''
filename = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_train.h5"
test = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_test.h5"
hdf5_file = h5py.File(filename, "r")
hdf5_test = h5py.File(test, "r")
dset = hdf5_file['train_img']
dtest = hdf5_test['test_img']
label = hdf5_file['train_labels']
test_data = InputTest(image=dtest)
#traindata_0 = FeatureData(root = 'traindata_0.pt')
#testdata_0 = FeatureData(root = 'testdata_0.pt')
#testdata_1 = FeatureData(root = 'testdata_1.pt')
#train_dataset = torch.utils.data.ConcatDataset([traindata_1,traindata_0])
#test_dataset = torch.utils.data.ConcatDataset([testdata_1,testdata_0])
print('\nInit Loaders...')
test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)





validate_survival(1,model1,test_loader)



Init Loaders...
Epoch: 1, val_loss_surv: 1.8782, val_c_index: 0.5000
