In [1]:
''' Imports and loss function definitions  '''
!pip install h5py
!pip install lifelines
import h5py
import pickle
import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import math
from lifelines.utils import concordance_index
import collections
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

def nll_loss(hazards, S, Y, c, alpha=0.5, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1) 
    c = c.view(batch_size, 1).float() 
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1)
    S_padded = torch.cat([torch.ones_like(c), S], 1) 
    uncensored_loss = -(c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
    censored_loss = - (1-c) * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
    neg_l = censored_loss + uncensored_loss
    loss = (1-alpha) * neg_l + alpha * uncensored_loss
    loss = loss.mean()
    return loss

def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1) 
    c = c.view(batch_size, 1).float() 
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1) 
    S_padded = torch.cat([torch.ones_like(c), S], 1)
    reg = -(c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
    ce_l = - (1-c) * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
    loss = (1-alpha) * ce_l + alpha * reg
    loss = loss.mean()
    return loss

class CrossEntropySurvLoss(object):
    def __init__(self, alpha=0.15):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None): 
        if alpha is None:
            return ce_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return ce_loss(hazards, S, Y, c, alpha=alpha)

class NLLSurvLoss(object):
    def __init__(self, alpha=0.5):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None):
        if alpha is None:
            return nll_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return nll_loss(hazards, S, Y, c, alpha=alpha)

class CoxSurvLoss(object):
    def __call__(self,hazards, S, c, **kwargs):
        current_batch_len = len(S)
        R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
        for i in range(current_batch_len):
            for j in range(current_batch_len):
                R_mat[i,j] = S[0][j] >= S[0][i]
    
        R_mat = torch.FloatTensor(R_mat).to(device)
        theta = hazards.reshape(-1)
        exp_theta = torch.exp(theta)
        loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (c))
        return loss_cox



Collecting lifelines
  Downloading lifelines-0.26.4-py3-none-any.whl (348 kB)
[K     |████████████████████████████████| 348 kB 7.8 MB/s 
Collecting formulaic<0.3,>=0.2.2
  Downloading formulaic-0.2.4-py3-none-any.whl (55 kB)
[K     |████████████████████████████████| 55 kB 4.5 MB/s 
[?25hCollecting autograd-gamma>=0.3
  Downloading autograd-gamma-0.5.0.tar.gz (4.0 kB)
Collecting interface-meta>=1.2
  Downloading interface_meta-1.2.4-py2.py3-none-any.whl (14 kB)
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (setup.py) ... [?25l[?25hdone
  Created wheel for autograd-gamma: filename=autograd_gamma-0.5.0-py3-none-any.whl size=4049 sha256=971775150132788627f82c663162544ef82f47dfa2f7223cf3fd95a15822698e
  Stored in directory: /root/.cache/pip/wheels/9f/01/ee/1331593abb5725ff7d8c1333aee93a50a1c29d6ddda9665c9f
Successfully built autograd-gamma
Installing collected packages: interface-meta, formulaic, autograd-gamma, lifelines
Successfully instal

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
''' Attention Model definition suited for extracting survival/hazard rates '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
import numpy as np


class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 1024
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(3, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 53 * 53, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Dropout(0.25),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 4),
            nn.Sigmoid()
        )
        
    def forward(self, **kwargs):
        h = kwargs['x']
        H = self.feature_extractor_part1(h)
        H = H.view(-1, 50 * 53 * 53)
        H = self.feature_extractor_part2(H)
        A= self.attention(H)  
        A = torch.transpose(A, 1, 0) 
        A_raw = A 
        A = F.softmax(A, dim=1)
        M = torch.mm(A, H) 
        logits  = self.classifier(M) 
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        hazards = torch.sigmoid(logits)
        S = torch.cumprod(1 - hazards, dim=1)

        
        return hazards, S, Y_hat, A_raw


       

In [5]:
''' Input to model taken from filtered clinical dataset to include WSI patch, event observation and event duration. 
    It is to be noted that although the variable for event observation is defined as "censor" it contains value of event observation.
    This particular definition is for the training phase '''
class InputData(torch.utils.data.Dataset):
    def __init__(self,train_image,valid_image,pos,censor='/content/drive/MyDrive/propro/train_valid_censor_1.pt',event_time='/content/drive/MyDrive/propro/train_valid_et_1.pt' ,mode = 'train', transform = None):
        self.censor=censor
        self.event_time=event_time
        self.mode = mode
        self.final_censor = torch.load(self.censor)
        self.final_event_time = torch.load(self.event_time)
        self.train_image = train_image
        self.valid_image = valid_image
        self.pos=pos
        self.transform = transforms.Compose([transforms.ToTensor()]) 
        
    def __len__(self):
        return len(self.pos)
    
    def __getitem__(self, index):
        temp_ind=self.pos[index]
        if temp_ind>=0 and temp_ind<146926:
          self.image=self.train_image
          ind=temp_ind
        elif temp_ind>=146926 and temp_ind<179709:
          self.image=self.valid_image
          ind=temp_ind-146926
        preprocess = transforms.Compose([
        transforms.ToTensor()
        ])
        img_test = transforms.ToPILImage()(self.image[ind]).convert("RGB")
        input_batch = preprocess(img_test)
        censor = self.final_censor[index]
        event_time = self.final_event_time[index]
        
        
        
        return (input_batch, censor,event_time)


In [6]:
''' Input to model taken from filtered clinical dataset to include WSI patch, event observation and event duration. 
    It is to be noted that although the variable for event observation is defined as "censor" it contains value of event observation.
    This particular definition is for the testing/validation phase '''
class InputTest(torch.utils.data.Dataset):
    def __init__(self,image,censor='/content/drive/MyDrive/propro/test_censor_1.pt',event_time='/content/drive/MyDrive/propro/test_et_1.pt' ,mode = 'test', transform = None):
        self.censor=censor
        self.event_time=event_time
        self.mode = mode
        self.final_censor = torch.load(self.censor)
        self.final_event_time = torch.load(self.event_time)
        self.dset = image
        self.transform = transforms.Compose([transforms.ToTensor()]) 
            
    def __len__(self):
        return len(self.dset)
    
    def __getitem__(self, index):
        preprocess = transforms.Compose([
        transforms.ToTensor()
        ])
        img_test = transforms.ToPILImage()(self.dset[index]).convert("RGB")
        input_batch = preprocess(img_test)
        censor = self.final_censor[index]
        event_time = self.final_event_time[index]
        
        
        
        return (input_batch, censor,event_time)


In [7]:
''' Training is guided by a loss function that takes in survival information and optimizer updates after every 16 slides  '''
def train_loop_survival(epoch, model, loader, optimizer, loss_fn=NLLSurvLoss(), gc=16):   
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    model.train()
    train_loss_surv, train_loss = 0., 0.

    print('\n')
    risk_scores = np.zeros((len(loader)))
    observations = np.zeros((len(loader)))
    event_times = np.zeros((len(loader)))

    for batch_idx, (data_WSI,censor,event_time) in enumerate(loader):
        if censor.item() == -1.0:
            continue
        model.train()
        data_WSI= data_WSI.to(device)

        hazards, S, Y_hat, _ = model(x=data_WSI) 
        c=censor.to(device)
        loss = loss_fn(hazards,S,Y_hat,c) 
        loss_value = loss.item()


        risk = -torch.sum(S, dim=1).detach().cpu().numpy()
        risk_scores[batch_idx] = risk
        observations[batch_idx] = c.item()
        event_times[batch_idx] = event_time 
        

        train_loss_surv += loss_value
        train_loss += loss_value + loss_reg

        if (batch_idx + 1) % 1000 == 0:
            print('batch {}, loss: {:.4f}, risk: {:.4f}, bag_size: {}'.format(batch_idx, loss_value + loss_reg,float(risk), data_WSI.size(0)))

        loss = loss / gc + loss_reg
        loss.backward()

        if (batch_idx + 1) % gc == 0: 
            optimizer.step()
            optimizer.zero_grad()

    train_loss_surv /= len(loader)
    train_loss /= len(loader)
    
    c_index = concordance_index(event_times, risk_scores, event_observed=observations) 

    print('Epoch: {}, train_loss_surv: {:.4f}, train_loss: {:.4f}, train_c_index: {:.4f}'.format(epoch, train_loss_surv, train_loss, c_index))

  





In [8]:
''' Used for validation/testing '''
def validate_survival(epoch, model, loader, loss_fn=NLLSurvLoss()):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    val_loss_surv, val_loss = 0., 0.
    risk_scores = np.zeros((len(loader)))
    observations = np.zeros((len(loader)))
    event_times = np.zeros((len(loader))) 

    for batch_idx, (data_WSI,censor,event_time) in enumerate(loader):
        if censor.item() == -1.0:
            continue
        data_WSI = data_WSI.to(device)
        c = censor.to(device)

        with torch.no_grad():
            hazards, S, Y_hat, _ = model(x=data_WSI) 

        loss = loss_fn(hazards,S,Y_hat,c)
        loss_value = loss.item()

        


        risk = -torch.sum(S, dim=1).cpu().numpy()
        risk_scores[batch_idx] = risk
        observations[batch_idx] = c.cpu().numpy()
        event_times[batch_idx] = event_time

        val_loss_surv += loss_value
        val_loss += loss_value 

    val_loss_surv /= len(loader)
    val_loss /= len(loader)

    c_index = concordance_index(event_times, risk_scores, event_observed=observations) 
   
    print('Epoch: {}, val_loss_surv: {:.4f}, val_loss: {:.4f}, val_c_index: {:.4f}'.format(epoch, val_loss_surv, val_loss, c_index))

  

In [2]:
'''  Clinical dataset being filtered for the purpose of survival resulting in 942 unique cases. This will be used to get the appropriate patchesin training and testing sets '''
import pandas as pd
import numpy as np
cols=["case_submitter_id","days_to_birth","days_to_death","vital_status","year_of_birth","year_of_death","age_at_diagnosis","days_to_diagnosis","days_to_last_follow_up","year_of_diagnosis","treatment_or_therapy"]
df = pd.read_csv("/content/drive/MyDrive/propro/clinical.tsv",usecols=cols, sep='\t')
temp1=pd.get_dummies(df["vital_status"])
df2 = pd.concat((df,temp1),axis=1)
df2 = df2.drop(['vital_status'],axis=1)
df2 = df2.drop(['Alive'],axis=1)
df2['Dead'] = df2['Dead'].astype(int)
df2 = df2.drop_duplicates()
df2 = df2.replace("'--", np.nan)
df2['event_data'] = df2['days_to_death']
df2.event_data.fillna(df2.days_to_last_follow_up, inplace=True)
df2 = df2[df2['event_data'].notna()]
df2['duration_months'] = np.round(df2['event_data'].astype(int)/31)
df2['duration_months']=df2['duration_months'].astype(int)
df2=df2.sort_values(["duration_months"], ascending=True)
df3 = df2[['case_submitter_id', 'Dead', 'event_data','duration_months']].copy()
df3.reset_index(drop=True, inplace=True)
df3=df3.drop_duplicates()
len(df3)

942

In [None]:
''' Function to generate 5 folds of training and test splits in such a way that the event durations are uniformly distributed '''
def get_folds_bins(frame, n_bins=5, eps=1e-6, num_folds=5):
    def get_folds_event_data(frame, k, event_col):
        df = frame.copy(deep=True)
        df = df.reindex(np.random.permutation(df.index)).sort_values(event_col)
        n, _ = df.shape

        assignments = np.array((n // k + 1) * list(range(1, k + 1)))
        assignments = assignments[:n]

        folds = list()
        for i in range(1, k+1):
            ix = assignments == i
            training_data = df.loc[~ix]
            test_data     = df.loc[ix]
            training_pat  = pd.unique(training_data.case_submitter_id).tolist()
            test_pat  = pd.unique(test_data.case_submitter_id).tolist()
            folds.append((training_pat,test_pat))
        return folds

    frame_working = frame.copy(deep=True)
    uncensored_df = frame_working[frame_working.Dead== 1]

    disc_labels, q_bins = pd.qcut(uncensored_df['duration_months'], q=n_bins, retbins=True, labels=False)
    q_bins[-1] = frame_working['duration_months'].max() + eps
    q_bins[0] = frame_working['duration_months'].min() - eps

    disc_labels, q_bins = pd.cut(frame_working['duration_months'], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True)
    frame_working.insert(2, 'label', disc_labels.values.astype(int))

    total_folds = dict()
    for i in range(num_folds):
        total_folds[i] = dict()
        total_folds[i]['train'] = list()
        total_folds[i]['valid'] = list()
        total_folds[i]['test'] = list()

    for i in range(len(q_bins)-1):
        bin_censored   = frame_working[(frame_working.label==i)&(frame_working.Dead==0)]
        bin_uncensored = frame_working[(frame_working.label==i)&(frame_working.Dead==1)]
        bin_folds_censored   = get_folds_event_data(frame=bin_censored,   k=num_folds, event_col='duration_months')
        bin_folds_uncensored = get_folds_event_data(frame=bin_uncensored, k=num_folds, event_col='duration_months')

        for i in range(num_folds):
            total_folds[i]['train'].extend([pat for pat in bin_folds_censored[i][0]] + [pat for pat in bin_folds_uncensored[i][0]])
            total_folds[i]['test'].extend([pat for pat in bin_folds_censored[i][1]] + [pat for pat in bin_folds_uncensored[i][1]])  
    return total_folds

In [None]:
import numpy as np
import matplotlib.pyplot as plt



In [None]:
''' Input patches and unique ID stored in variables '''
import h5py
train = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_train.h5"
test = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_test.h5"
valid = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_validation.h5"
hdf5_train = h5py.File(train, "r")
hdf5_test = h5py.File(test, "r")
hdf5_valid = h5py.File(valid, "r")
dtrain = hdf5_train['train_img']
dtest = hdf5_test['test_img']
dvalid = hdf5_valid['valid_img']
train_slides = hdf5_train['train_slides']
test_slides = hdf5_test['test_slides']
valid_slides = hdf5_valid['valid_slides']


In [None]:
data=get_folds_bins(df3)
data[0]

{'test': ['TCGA-NK-A7XE',
  'TCGA-37-3789',
  'TCGA-56-A5DR',
  'TCGA-66-2763',
  'TCGA-99-8032',
  'TCGA-56-8503',
  'TCGA-55-8616',
  'TCGA-66-2737',
  'TCGA-56-8622',
  'TCGA-MN-A4N5',
  'TCGA-37-5819',
  'TCGA-73-4670',
  'TCGA-66-2795',
  'TCGA-94-8035',
  'TCGA-56-8083',
  'TCGA-69-7761',
  'TCGA-51-4080',
  'TCGA-94-7557',
  'TCGA-77-A5GB',
  'TCGA-86-8672',
  'TCGA-L9-A5IP',
  'TCGA-MP-A4TC',
  'TCGA-34-2596',
  'TCGA-66-2773',
  'TCGA-86-A4D0',
  'TCGA-34-8455',
  'TCGA-52-7809',
  'TCGA-33-A4WN',
  'TCGA-85-6798',
  'TCGA-55-6978',
  'TCGA-37-4132',
  'TCGA-68-7757',
  'TCGA-71-6725',
  'TCGA-44-A47B',
  'TCGA-49-6761',
  'TCGA-05-4422',
  'TCGA-67-3774',
  'TCGA-69-7764',
  'TCGA-69-8254',
  'TCGA-44-A4SS',
  'TCGA-66-2769',
  'TCGA-55-6979',
  'TCGA-85-8481',
  'TCGA-21-1081',
  'TCGA-22-4609',
  'TCGA-85-8277',
  'TCGA-MP-A4TD',
  'TCGA-46-3768',
  'TCGA-50-6593',
  'TCGA-MP-A4TJ',
  'TCGA-49-4490',
  'TCGA-85-8584',
  'TCGA-60-2697',
  'TCGA-56-8201',
  'TCGA-55-7725',
  

In [None]:
''' Getting the index positions of train patches based on the data available in the filtered dataset represented in each fold generated '''
ix_train=[]
for x in range(5):
  label=data[x]['train']
  train_ix=[]
  for i in range(179709):
      if i < 146926:
          if train_slides[i].decode("utf-8")[0:12] in label:
              train_ix.append(i)
      elif i < 179709:
          ix=i-146926
          if valid_slides[ix].decode("utf-8")[0:12] in label:
              train_ix.append(i)
  len(train_ix)
  ix_train.append(train_ix)
print(len(ix_train[2]))       

141679


In [None]:
''' Getting the index positions of test patches based on the data available in the filtered dataset represented in each fold generated '''
ix_test=[]
for x in range(5):
  label=data[x]['test']
  test_ix=[]
  for i in range(179709):
      if i < 146926:
          if train_slides[i].decode("utf-8")[0:12] in label:
              test_ix.append(i)
      elif i < 179709:
          ix=i-146926
          if valid_slides[ix].decode("utf-8")[0:12] in label:
              test_ix.append(i)
  len(test_ix)
  ix_test.append(test_ix)
print(len(ix_test[2]))

36609


In [None]:
''' Model Training phase '''
print('\nInit Loaders...')
print('Init Model')
model = Attention()
model.cuda()
print('\nInit optimizer ...', end=' ')
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, weight_decay=1e-5)


epoch = 1

for i in range(epoch):
    print("...Train...")
    for x in range(5):
      train_data = InputData(train_image=dtrain,valid_image=dvalid,pos=ix_train[x])   
      test_data = InputData(train_image=dtrain,valid_image=dvalid,pos=ix_test[x])
      train_loader = torch.utils.data.DataLoader(train_data, shuffle = False, num_workers = 0, batch_size = 1)
      test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)
      train_loop_survival(i,model,train_loader,optimizer)
      validate_survival(1,model,test_loader)




Init Loaders...
Init Model

Init optimizer ... ...Train...


batch 999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 1999, loss: 2.3927, risk: -0.9086, bag_size: 1
batch 2999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 3999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 4999, loss: 2.3927, risk: -0.9086, bag_size: 1
batch 5999, loss: 2.3927, risk: -0.9086, bag_size: 1
batch 6999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 7999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 8999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 9999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 10999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 11999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 12999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 13999, loss: 2.3927, risk: -0.9086, bag_size: 1
batch 14999, loss: 2.3927, risk: -0.9086, bag_size: 1
batch 15999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 16999, loss: 2.8838, risk: -0.9086, bag_size: 1
batch 17999, loss: 2.3927, risk

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/propro/model_updc_kfold_epoch8_nov29.pt')

In [None]:
''' Continue training from the loaded model parameters '''
print('Init Model')
model = Attention()
model.cuda()
model.load_state_dict(torch.load('/content/drive/MyDrive/propro/model_updc_kfold_epoch7_nov27.pt'))
print('\nInit optimizer ...', end=' ')
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, weight_decay=1e-5)


epoch = 1

for i in range(epoch):
    print("...Train...")
    for x in range(5):
      train_data = InputData(train_image=dtrain,valid_image=dvalid,pos=ix_train[x])   
      test_data = InputData(train_image=dtrain,valid_image=dvalid,pos=ix_test[x])
      train_loader = torch.utils.data.DataLoader(train_data, shuffle = False, num_workers = 0, batch_size = 1)
      test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)
      train_loop_survival(i,model,train_loader,optimizer)
      validate_survival(1,model,test_loader)

Init Model

Init optimizer ... ...Train...


batch 999, loss: 2.6749, risk: -0.9180, bag_size: 1
batch 1999, loss: 2.3927, risk: -0.9086, bag_size: 1
batch 2999, loss: 1.1784, risk: -0.9375, bag_size: 1
batch 3999, loss: 1.1784, risk: -0.9375, bag_size: 1
batch 4999, loss: 1.3863, risk: -0.9375, bag_size: 1
batch 5999, loss: 1.3863, risk: -0.9375, bag_size: 1
batch 6999, loss: 1.1784, risk: -0.9375, bag_size: 1
batch 7999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 8999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 9999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 10999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 11999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 12999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 13999, loss: 2.0794, risk: -0.9375, bag_size: 1
batch 14999, loss: 2.0794, risk: -0.9375, bag_size: 1
batch 15999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 16999, loss: 1.7675, risk: -0.9375, bag_size: 1
batch 17999, loss: 2.0794, risk: -0.9375, bag_si

In [25]:
''' Load model parameters '''
model1=Attention()
model1.cuda()
model1.load_state_dict(torch.load('/content/drive/MyDrive/propro/model_updc_5fold_epoch6.pt'))
model1.eval()

Attention(
  (feature_extractor_part1): Sequential(
    (0): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (feature_extractor_part2): Sequential(
    (0): Linear(in_features=140450, out_features=1024, bias=True)
    (1): ReLU()
  )
  (attention): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=4, bias=True)
    (1): Sigmoid()
  )
)

In [26]:
''' Testing model on test patches to get final performance of model '''
filename = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_train.h5"
test = "/content/drive/MyDrive/propro/hdf5_TCGAFFPE_LUAD_5x_perP_he_test.h5"
hdf5_file = h5py.File(filename, "r")
hdf5_test = h5py.File(test, "r")
dset = hdf5_file['train_img']
dtest = hdf5_test['test_img']
label = hdf5_file['train_labels']
test_data = InputTest(image=dtest)
print('\nInit Loaders...')
test_loader = torch.utils.data.DataLoader(test_data, shuffle = False, num_workers = 0, batch_size = 1)





validate_survival(1,model1,test_loader)



Init Loaders...
Epoch: 1, val_loss_surv: 1.8842, val_loss: 1.8842, val_c_index: 0.5475
