In [7]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


class DummyDataset(Dataset):
	def __init__(self, n_bins=4, nb_of_cases=10):
		np.random.seed(7)
		self.num_intervals = n_bins
		self.path_features =  [torch.randn((i, 1024)) for i in np.random.randint(1, 100, nb_of_cases)]
		self.tabular_data = [torch.randn(5)] * nb_of_cases
		self.patients_df = pd.DataFrame({
			"case_id": np.arange(1, nb_of_cases+1),
			"survival_months": np.random.randint(1, 100, nb_of_cases),
			"event": np.random.randint(0, 2, nb_of_cases)
		})
		survival_time_list = self.patients_df["survival_months"]
		
		_, time_breaks = pd.qcut(survival_time_list, q=self.num_intervals, retbins=True, labels=False)
		time_breaks[0] = 0
		time_breaks[-1] += 1
		self.time_breaks = time_breaks
		print("Time intervals: ", self.time_breaks)
		disc_labels, _ = pd.cut(self.patients_df["survival_months"], bins=self.time_breaks, retbins=True, labels=False, right=False, include_lowest=True)
		self.patients_df.insert(2, 'label', disc_labels.values.astype(int))
		
		self.label_dict = {}
		key_count = 0
		for i in range(len(self.time_breaks)-1):
			for c in [0, 1]:
				self.label_dict.update({(i, c):key_count})
				key_count+=1

		self.patients_df.reset_index(drop=True, inplace=True)
		
		for i in self.patients_df.index:
			key = self.patients_df.loc[i, 'label']
			self.patients_df.at[i, 'disc_label'] = key
			event = self.patients_df.loc[i, 'event']
			key = (key, int(event))
			self.patients_df.at[i, 'label'] = self.label_dict[key]

		self.num_classes=len(self.label_dict)
		self.summarize()

	def summarize(self):
		print("label column: {}".format("survival_months"))
		print("number of classes: {}".format(self.num_classes))
		for i in range(self.num_classes):
			cases = self.patients_df["case_id"][self.patients_df["label"]==i].values
			nb_cases = len(cases)
			print('Number of samples registered in class %d: %d' % (i, nb_cases))
			
	def __len__(self):
		return len(self.path_features)

	def __getitem__(self, idx):
		case_id = self.patients_df['case_id'].iloc[idx]
		
		t = self.patients_df["survival_months"].iloc[idx]
		e = self.patients_df['event'].iloc[idx]
		label = torch.Tensor([self.patients_df['disc_label'][idx]])
		
		return self.path_features[idx], label, t, e, self.tabular_data[idx], case_id

In [8]:
n_bins = 4
dataset = DummyDataset(n_bins=n_bins, nb_of_cases=100)
data_WSI, y_disc, event_time, event, data_tab, case_id = next(iter(dataset))
data_WSI.shape, y_disc, event_time, event, data_tab.shape, case_id

Time intervals:  [  0.    32.5   52.5   77.25 100.  ]
label column: survival_months
number of classes: 8
Number of samples registered in class 0: 10
Number of samples registered in class 1: 15
Number of samples registered in class 2: 14
Number of samples registered in class 3: 11
Number of samples registered in class 4: 14
Number of samples registered in class 5: 11
Number of samples registered in class 6: 13
Number of samples registered in class 7: 12


(torch.Size([48, 1024]), tensor([2.]), 60, 1, torch.Size([5]), 1)

In [28]:
def collate_MIL_survival(batch):
	img = torch.cat([item[0] for item in batch], dim = 0)	
	label = torch.cat([item[1] for item in batch], dim = 0).type(torch.LongTensor)
	event_time = torch.FloatTensor([item[2] for item in batch])
	c = torch.FloatTensor([item[3] for item in batch])
	tabular = torch.cat([item[4] for item in batch], dim = 0).type(torch.FloatTensor)
	case_id = np.array([item[5] for item in batch])
	
	return [img, label, event_time, c, tabular, case_id]

dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_MIL_survival)
data_WSI, y_disc, event_time, event, data_tab, case_id = next(iter(dataloader))
data_WSI.shape, y_disc, event_time, event, data_tab.shape, case_id

(torch.Size([62, 1024]),
 tensor([0]),
 tensor([7.]),
 tensor([1.]),
 torch.Size([5]),
 array([60]))

In [29]:
from torch import nn
from torch import optim
from utils.loss_func import NLLSurvLoss
from models.model_clam import CLAM_SB

# SVM loss!
instance_loss_fn = nn.CrossEntropyLoss()
surv_loss_fn = NLLSurvLoss()
# subtyping?
model = CLAM_SB(n_classes=n_bins, instance_loss_fn=instance_loss_fn, subtyping=True)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-3)

In [31]:
logits, Y_prob, Y_hat, _, instance_dict = model(data_WSI, label=y_disc, instance_eval=True, return_features=True)

In [33]:
instance_dict.keys()

dict_keys(['instance_loss', 'inst_labels', 'inst_preds', 'features'])

In [34]:
hazards = torch.sigmoid(logits)
S = torch.cumprod(1 - hazards, dim=1)
print(hazards.shape, S.shape, event_time.shape, event.shape)
risk = -torch.mean(S, dim=1).detach().cpu().numpy()

torch.Size([1, 4]) torch.Size([1, 4]) torch.Size([1]) torch.Size([1])


In [35]:
pt_dict = {
    'case_id': case_id, 
    'risk': risk, 
    'time': event_time.detach().cpu().numpy(), 
    'event': event.detach().cpu().numpy(),
    "hazards": np.squeeze(hazards.detach().cpu().numpy()),
}



In [24]:
import torch.nn.functional as F
inst_labels = F.one_hot(torch.tensor([0, 1]), num_classes=2).squeeze() #binarize label

In [27]:
inst_labels[0].item()

RuntimeError: a Tensor with 2 elements cannot be converted to Scalar