In [1]:
import time
import math
import tables
import matplotlib.pyplot as plt
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

from torch.utils.data import DataLoader

In [2]:
class Dataset(object):
	'''
	Assumptions made about the dataset:
	
	Pytable contains associated hdf5 earrays'''
	def __init__(self, fname, img_transform):
		#nothing special here, just internalizing the constructor parameters
		self.fname=fname

		self.img_transform=img_transform
			
		with tables.open_file(self.fname,'r') as db:
			self.classsizes=db.root.classsizes[:]
			self.nitems=db.root.imgs.shape[0]
			
		self.imgs = None
		self.labels = None
		self.slide_ids = None
		self.mode = 1
		self.idxs = None

	def get_slide_ids(self):
		with tables.open_file(self.fname,'r') as db:
			return torch.tensor(np.array(db.root.slide_ids))
		
	def get_labels(self):
		with tables.open_file(self.fname,'r') as db:
			return torch.tensor(np.array(db.root.labels))

	def setmode(self, mode):
		self.mode = mode

	def make_train_data(self, idxs):
		"""creates a training set based on the topk highest probability patches
		
			Args:
				idxs (int[]): array of patch ids. 
		
			Returns: 
				None: 
		"""
		self.idxs = idxs

	def shuffle_training_data(self):
		self.idxs = random.sample(self.idxs, len(self.idxs))
			
	def __getitem__(self, index):
		if self.mode==1:	# without topk grouping: the full dataset
			#opening should be done in __init__ but seems to be
			#an issue with multithreading so doing here. need to do it everytime, otherwise hdf5 crashes
			with tables.open_file(self.fname,'r') as db:
				self.imgs=db.root.imgs
				self.labels=db.root.labels
				self.slide_ids=db.root.slide_ids

				#get the requested image
				img = self.imgs[index,::]
				label = self.labels[index]
				slide_id = self.slide_ids[index]
			
			if self.img_transform is not None:
				img_new = self.img_transform(img)

			return img_new, label, img, slide_id

		elif self.mode==2:	# topk sampling: only the top k instances from each slide, compiled.
			with tables.open_file(self.fname,'r') as db:
				self.imgs=db.root.imgs
				self.labels=db.root.labels
				self.slide_ids=db.root.slide_ids

				high_prob_tile_index = self.idxs[index]	# converts input index to the index within interesting patches.

				img = self.imgs[high_prob_tile_index,::]
				label = self.labels[high_prob_tile_index]
				slide_id = self.slide_ids[high_prob_tile_index]

			if self.img_transform is not None:
				img_new = self.img_transform(img)

			return img_new, label, img, slide_id
		else:
			raise ValueError('Mode has not been set to either 1 or 2')

	def __len__(self):
		if self.mode == 1:
			return self.nitems
		elif self.mode == 2:
			return len(self.idxs)
		else:
			return None


In [3]:
#helper function for pretty printing of current time and remaining time
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent+.00001)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [4]:
def prepare_dataset(dataname, batch_size, phases):
	"""Prepares a dataset based on a pre-created pytable set.
	
		Args:
			dataname (str): name of the data
			phase (str): target phase
			batch_size: size of each batch
		
		Returns: 
			dataset: dataset object
			dataloader: used to load data directly into the model
	"""
	
	img_transform = transforms.Compose([
    transforms.ToPILImage(),
    # transforms.RandomVerticalFlip(),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomCrop(size=(patch_size,patch_size),pad_if_needed=True), #these need to be in a reproducible order, first affine transforms and then color
    transforms.ToTensor()
    ])

	dataset={}
	dataLoader={}
	for phase in phases: #now for each of the phases, we're creating the dataloader
						#interestingly, given the batch size, i've not seen any improvements from using a num_workers>0
		
		dataset[phase]=Dataset(f"./data/{dataname}_{phase}.pytable", img_transform=img_transform)
		dataLoader[phase]=DataLoader(dataset[phase], batch_size=batch_size, 
									shuffle=False, num_workers=0,pin_memory=True, drop_last=False) 
		print(f"{phase} dataset size:\t{len(dataset[phase])}")

	return dataset, dataLoader

In [5]:
def define_model(dataname, dataset, device, load_weights=False):
	"""creates the model object and associated tools
	
		Args:
			dataname (str): name of the data
			dataset (Dataset): dataset object
			device (): processing unit to use
			load_weights (bool): whether or not to load pretrained
								weights from a .pth file
	
		Returns: 
			model: resnet18 model
			criterion: a CrossEntropyLoss object
			optimizer: an Adam optimizer
	"""
	
	model = models.resnet18()
	model.fc = nn.Linear(model.fc.in_features, 2)
	model.to(device)

	class_weight=dataset["train"].classsizes
	class_weight=torch.from_numpy(1-class_weight/class_weight.sum()).type('torch.FloatTensor').to(device)
	criterion = nn.CrossEntropyLoss(weight=class_weight)
	optimizer = torch.optim.Adam(model.parameters())

	if load_weights:
		checkpoint = torch.load(f"{dataname}_unet_best_model.pth")
		model.load_state_dict(checkpoint['model_dict'])

	return model, criterion, optimizer

In [6]:
def define_device(gpuid):
	"""specify if we should use a GPU (cuda) or only the CPU
	
		Args:
			gpuid (int): which gpu is being used 
	
		Returns: 
			device: the gpu (cuda) or the cpu
	"""
	
	if(torch.cuda.is_available()):
		print(torch.cuda.get_device_properties(gpuid))
		torch.cuda.set_device(gpuid)
		device = torch.device(f'cuda:{gpuid}')
	else:
		device = torch.device(f'cpu')

	return device

In [7]:
def infer(run, dataLoader, dataset, model, phase, device, batchsize):
	
	model.eval()
	probs = torch.FloatTensor(len(dataset[phase]))
	for ii, (X, label, img_orig, slide_id) in enumerate(dataLoader[phase]):
		X = X.to(device)  # [Nbatch, 3, H, W]
		label = label.type('torch.LongTensor').to(device)
		output = F.softmax(model(X), dim=1)
		probs[ii*batchsize : ii*batchsize+X.shape[0]] = output.detach()[:,1].clone()
	
	return probs.cpu().numpy()

In [8]:
def train(run, dataLoader, model, criterion, optimizer, device):
	model.train()
	running_loss = 0.
	for ii, (X, label, img_orig, slide_id) in enumerate(dataLoader['train']): # where ii is the batch number
		# print(f"batch={ii} epoch={run}")
		# print(f"X.shape={X.shape}")
		X = X.to(device)
		label = label.type('torch.LongTensor').to(device)

		output = model(X)
		loss = criterion(output, label)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
	
	return loss



In [9]:
def calc_err(pred,real):
	"""error analysis during validation
	
		Args:
			pred (int[]): the array of 0/1 predictions of the model
			real (int[]): the array of ground-truth labels 
	
		Returns: 
			err (float): the err generated from pred and real
			fpr (float): false positive rate
			fnr (float): false negative rate
	"""

	pred = np.array(pred)
	real = np.array(real)
	
	neq = np.not_equal(pred, real)
	err = float(neq.sum())/pred.shape[0]
	fpr = float(np.logical_and(pred==1,neq).sum())/(real==0).sum()
	fnr = float(np.logical_and(pred==0,neq).sum())/(real==1).sum()
	return err, fpr, fnr

In [10]:
def group_argtopk(groups, data,k=1):
	"""Outputs the index of the instance with the highest probability for each spot. 
	
		Args:
			groups (int[]): list of slide ids for each instance, ex. [1,1,5,4,7,4,7,7, ...]
			data (float[]): list of probabilities for each instance, ex. [.7,.8,.1,.4,.5, ...]
		Returns: 
			int[]: k indices for each slide, which include the highest inferred probabilities for that slide.
				   ex. index 0 refers to the highest prob for the 0th slide,
					   index 1 refers to the highest prob for the 1st slide,
					   ...
	"""
	
	order = np.lexsort((data, groups))
	groups = groups[order]
	data = data[order]
	index = np.empty(len(groups), 'bool')
	index[-k:] = True
	index[:-k] = groups[k:] != groups[:-k]
	return list(order[index])


In [11]:
def group_max(groups, data, nmax):
	"""description

		Args:
			paramname (paramtype): describe_the_param 

		Returns: 
			type: description
	"""
	out = np.empty(nmax)
	out[:] = np.nan
	order = np.lexsort((data, groups))
	groups = groups[order]
	data = data[order]
	index = np.empty(len(groups), 'bool')
	index[-1] = True
	index[:-1] = groups[1:] != groups[:-1]
	out[groups[index]] = data[index]
	return out

In [12]:
def save_to_tensorboard(writer, phase, all_loss, all_acc, epoch, cmatrix):
	all_acc[phase]=(cmatrix[phase]/cmatrix[phase].sum()).trace()
	all_loss[phase] = all_loss[phase].cpu().numpy().mean()

	writer.add_scalar(f'{phase}/loss', all_loss[phase], epoch)

	if phase == 'val':
		writer.add_scalar(f'{phase}/acc', all_acc[phase], epoch)
		for r in range(2):
			for c in range(2): #essentially write out confusion matrix
				writer.add_scalar(f'{phase}/{r}{c}', cmatrix[phase][r][c],epoch)
	

In [13]:
def generate_single_output(dataname, phase, imgid, model):
	'''
	WORK IN PROGRESS'''
	db = tables.open_file(f"./{dataname}_{phase}.pytable")
	img = db.root.imgs[imgid, ::]
	label = torch.tensor(np.array(db.root.labels[imgid]))
	fig, ax = plt.subplots(1, 2)
	ax = ax.flatten()

	output = model(img)

	ax[0].set_title()
	ax[0]

In [14]:
# --- general parameters
dataname = 'MIL_32x32_100pos'
gpuid=0
k=50

# --- resnet params
n_classes = 2
in_channels = 3

# --- evaluation params
batch_size=64
patch_size=32 #based on resnet architecture. Changed from 224
num_epochs=10
phases = ["train","val"] #how many phases did we create databases for?

# --- initializing model and dataset
writer = SummaryWriter()
device = define_device(gpuid)
dataset, dataLoader = prepare_dataset(dataname, batch_size, phases)
model, criterion, optimizer = define_model(dataname, dataset, device)
best_loss_on_test = np.Infinity

for epoch in range(num_epochs):
	all_acc = {key: 0 for key in phases} 
	all_loss = {key: torch.zeros(0).to(device) for key in phases} #keep this on GPU for greatly improved performance
	cmatrix = {key: np.zeros((n_classes,n_classes)) for key in phases}
	
	for phase in phases:
		'''
		If phase is train, model will:
		1. infer probabilities for each instance in the training set
		2. group the highest probability instances for each WSI
		3. train the model based on these instances
		4. log the loss.
		'''
		labels = dataset[phase].get_labels()
		slide_ids = dataset[phase].get_slide_ids()

		if phase == 'train':
			dataset[phase].setmode(1)
			probs = infer(epoch, dataLoader, dataset, model, phase, device, batch_size)
			topk = group_argtopk(slide_ids, probs, k)
			dataset[phase].make_train_data(topk)								# this will need to be tested to see if it works
			dataset[phase].shuffle_training_data()

			dataset[phase].setmode(2)
			loss = train(epoch, dataLoader, model, criterion, optimizer, device)
			all_loss[phase]=torch.cat((all_loss[phase],loss.detach().view(1,-1)))
			print('Training\tEpoch: [{}/{}]\tLoss: {}'.format(epoch+1, num_epochs, loss))

		'''
		If phase is val, model will:
		1. infer probabilities
		2. identify the maximum probabilities'''
		if phase == 'val':
			# do error analysis
			dataset[phase].setmode(1)
			probs = infer(epoch, dataLoader, dataset, model, phase, device, batch_size)
			maxs = group_max(slide_ids, probs, )
			print(*maxs)
			# print(maxs)	produced an array with four numbered values followed by many NaN values
			pred = [1 if x >= 0.5 else 0 for x in maxs]
			print(*pred)
			err,fpr,fnr = calc_err(pred, labels)	# pred is of type int[] and labels is a tensor-list
			#cmatrix[phase]=cmatrix[phase]+confusion_matrix(yflat,cpredflat, labels=range(nclasses))
			print(f'error: {err}, false positive rate: {fpr}, false negative rate: {fnr}')

		save_to_tensorboard(writer, phase, all_loss, all_acc, epoch, cmatrix)
	
	# save state of model if it's the best one so far.
	if all_loss["val"] < best_loss_on_test:
		best_loss_on_test = all_loss["val"]
		print("  **")
		state = {'epoch': epoch + 1,
		'model_dict': model.state_dict(),
		'optim_dict': optimizer.state_dict(),
		'best_loss_on_test': all_loss,
		'n_classes': n_classes}
		torch.save(state, f"{dataname}_resnet_best_model.pth")
	else:
		print("")

train dataset size:	10000
val dataset size:	3000


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Training	Epoch: [1/10]	Loss: 0.14530184864997864


  all_acc[phase]=(cmatrix[phase]/cmatrix[phase].sum()).trace()


0.0003320824180264026 0.0002983476151712239 0.00029037389322184026 0.00028079934418201447 0.0003890176594723016 0.00038670049980282784 0.0003481248568277806 0.0003535803116392344 0.000364798674127087 0.00040211627492681146 0.0003573053400032222 0.00039709749398753047 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan 

  all_loss[phase] = all_loss[phase].cpu().numpy().mean()
  ret = ret.dtype.type(ret / rcount)


KeyboardInterrupt: 

In [None]:
#for ii, (X, label, img_orig, slide_id) in enumerate(dataLoader['train']):
#	print(ii)