In [None]:
import torch
import numpy as np
from torch.utils.data import SubsetRandomSampler, WeightedRandomSampler, Sampler, Dataset
import pandas as pd

class ImbalancedDatasetSampler_multilabel(Sampler):

	def __init__(self, dataset, indices=None, num_samples=None):

		self.indices = list(range(len(dataset))) if indices is None else indices

		self.num_samples = len(self.indices) if num_samples is None else num_samples

		label_to_count = {}
		for idx in self.indices:
			label = self._get_label(dataset, idx)
			for l in label:
				if l in label_to_count:
					label_to_count[l] += 1
				else:
					label_to_count[l] = 1

		weights = []

		for idx in self.indices:
			c = 0
			for j, l in enumerate(self._get_label(dataset, idx)):
				c = c+(1/label_to_count[l])
				
			weights.append(c/(j+1))
		self.weights = torch.DoubleTensor(weights)
		
	def _get_label(self, dataset, idx):
		labels = np.where(dataset[idx,1:]==1)[0]
		#print(labels)
		#labels = dataset[idx,2]
		return labels

	def __iter__(self):
		return (self.indices[i] for i in torch.multinomial(
			self.weights, self.num_samples, replacement=True))

	def __len__(self):
		return self.num_samples

class Balanced_Multimodal(Sampler):

	def __init__(self, dataset, indices=None, num_samples=None, alpha = 0.5):

		self.indices = list(range(len(dataset)))             if indices is None else indices

		self.num_samples = len(self.indices)             if num_samples is None else num_samples

		class_sample_count = [0,0,0,0,0]


		class_sample_count = np.sum(train_dataset[:,1:],axis=0)

		min_class = np.argmin(class_sample_count)
		class_sample_count = np.array(class_sample_count)
		weights = []
		for c in class_sample_count:
			weights.append((c/class_sample_count[min_class]))

		ratio = np.array(weights).astype(np.float)

		label_to_count = {}
		for idx in self.indices:
			label = self._get_label(dataset, idx)
			for l in label:
				if l in label_to_count:
					label_to_count[l] += 1
				else:
					label_to_count[l] = 1

		weights = []

		for idx in self.indices:
			c = 0
			for j, l in enumerate(self._get_label(dataset, idx)):
				c = c+(1/label_to_count[l])#*ratio[l]

			weights.append(c/(j+1))
			#weights.append(c)
			
		self.weights_original = torch.DoubleTensor(weights)

		self.weights_uniform = np.repeat(1/self.num_samples, self.num_samples)

		#print(self.weights_a, self.weights_b)

		beta = 1 - alpha
		self.weights = (alpha * self.weights_original) + (beta * self.weights_uniform)


	def _get_label(self, dataset, idx):
		labels = np.where(dataset[idx,1:]==1)[0]
		#print(labels)
		#labels = dataset[idx,2]
		return labels

	def __iter__(self):
		return (self.indices[i] for i in torch.multinomial(
			self.weights, self.num_samples, replacement=True))

	def __len__(self):
		return self.num_samples


class Dataset_instance(Dataset):

	def __init__(self, list_IDs, mode):
		self.list_IDs = list_IDs
		#self.list_IDs = list_IDs[:,0]
		#self.list_hes = list_IDs[:,1:]
		self.mode = mode

	def __len__(self):
		return len(self.list_IDs)

	def __getitem__(self, index):
		# Select sample
		ID = self.list_IDs[index]
		# Load data and get label
		with open(ID, 'rb') as fin:
			X = pyspng.load(fin.read())
		#img.close()

		#k = pipeline_transform_soft(image=k)['image']
		#k = pipeline_transform(image=q)['image']

		k = X				

		#k = pipeline_transform_soft(image=k)['image']
		q = pipeline_transform(image=k)['image']

		q = preprocess(q).type(torch.FloatTensor)
		k = preprocess(k).type(torch.FloatTensor)
		#return input_tensor
		return k, q

class Dataset_bag(Dataset):

	def __init__(self, list_IDs, labels):

		self.labels = labels
		self.list_IDs = list_IDs

	def __len__(self):

		return len(self.list_IDs)

	def __getitem__(self, index):
		# Select sample
		WSI = self.list_IDs[index]

		return WSI

In [None]:
def get_splits(data, n):
	
	train_dataset = []
	valid_dataset = []
	
	for sample in data:
		
		fname = sample[0]
		cancer = sample[1]
		hgd = sample[2]
		lgd = sample[3]
		hyper = sample[4]
		normal = sample[5]
		f = sample[6]
		
		row = [fname, cancer, hgd, lgd, hyper, normal]
		
		if (f==n):
			
			valid_dataset.append(row)
		
		else:
			
			train_dataset.append(row)
			
	train_dataset = np.array(train_dataset, dtype=object)#[:30]
	valid_dataset = np.array(valid_dataset, dtype=object)#[:30]
	
	#train_dataset = np.append(train_dataset, cad_data, axis=0)

	return train_dataset, valid_dataset

In [None]:
#parameters bag
batch_size_bag = 16

"""
sampler = ImbalancedDatasetSampler_multilabel
params_train_bag = {'batch_size': batch_size_bag,
		'sampler': sampler(train_dataset)}
		#'shuffle': True}
"""
#"""
sampler = Balanced_Multimodal
params_train_bag = {'batch_size': batch_size_bag,
		#'sampler': sampler(train_dataset,alpha=0.25)}
		'shuffle': True}
#"""
"""
sampler = Balanced_Multimodal
params_bag_train = {'batch_size': batch_size_bag,
		'sampler': sampler(train_dataset,alpha=0.5)}
		#'shuffle': True}
"""


params_bag_test = {'batch_size': batch_size_bag,
		#'sampler': sampler(train_dataset)
	  'shuffle': True}

params_bag_train_queue = {'batch_size': int(batch_size_bag*2),
		'sampler': sampler(train_dataset,alpha=0.25)}
	  #'shuffle': True}

params_bag_test_queue = {'batch_size': int(batch_size_bag*2),
		#'sampler': sampler(train_dataset)
	  'shuffle': True}

training_set_bag = Dataset_bag(train_dataset[:,0], train_dataset[:,1:])
training_generator_bag = data.DataLoader(training_set_bag, **params_train_bag)

#validation_set_bag = Dataset_bag(valid_dataset[:,0], valid_dataset[:,1:])
#validation_generator_bag = data.DataLoader(validation_set_bag, **params_bag_test)

training_set_bag = Dataset_bag(train_dataset[:,0], train_dataset[:,1:])
training_generator_bag_queue = data.DataLoader(training_set_bag, **params_bag_train_queue)

#validation_set_bag = Dataset_bag(valid_dataset[:,0], valid_dataset[:,1:])
#validation_generator_bag_queue = data.DataLoader(validation_set_bag, **params_bag_test_queue)

#params patches generated

# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in encoder.parameters())
print(f'{total_params:,} total parameters.')

total_trainable_params = sum(
	p.numel() for p in encoder.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

