In [None]:
import torch
import numpy as np
from torch.utils.data import SubsetRandomSampler, WeightedRandomSampler, Sampler, Dataset

class ImbalancedDatasetSampler_multilabel(torch.utils.data.sampler.Sampler):

	def __init__(self, dataset, indices=None, num_samples=None):

		self.indices = list(range(len(dataset)))             if indices is None else indices

		self.num_samples = len(self.indices)             if num_samples is None else num_samples

		
		label_to_count = {}
		for idx in self.indices:
			label = self._get_label(dataset, idx)
			for l in label:
				if l in label_to_count:
					label_to_count[l] += 1
				else:
					label_to_count[l] = 1

		weights = []

		for idx in self.indices:
			c = 0
			for j, l in enumerate(self._get_label(dataset, idx)):
				c = c+(1/label_to_count[l])
				
			weights.append(c/(j+1))
		self.weights = torch.DoubleTensor(weights)
		
	def _get_label(self, dataset, idx):
		labels = np.where(dataset[idx,1:]==1)[0]
		#print(labels)
		#labels = dataset[idx,2]
		return labels

	def __iter__(self):
		return (self.indices[i] for i in torch.multinomial(
			self.weights, self.num_samples, replacement=True))

	def __len__(self):
		return self.num_samples

class Balanced_Multimodal(torch.utils.data.sampler.Sampler):

	def __init__(self, dataset, indices=None, num_samples=None, alpha = 0.5):

		self.indices = list(range(len(dataset)))             if indices is None else indices

		self.num_samples = len(self.indices)             if num_samples is None else num_samples

		class_sample_count = [0,0,0,0,0]


		class_sample_count = np.sum(train_dataset[:,1:],axis=0)

		min_class = np.argmin(class_sample_count)
		class_sample_count = np.array(class_sample_count)
		weights = []
		for c in class_sample_count:
			weights.append((c/class_sample_count[min_class]))

		ratio = np.array(weights).astype(np.float)

		label_to_count = {}
		for idx in self.indices:
			label = self._get_label(dataset, idx)
			for l in label:
				if l in label_to_count:
					label_to_count[l] += 1
				else:
					label_to_count[l] = 1

		weights = []

		for idx in self.indices:
			c = 0
			for j, l in enumerate(self._get_label(dataset, idx)):
				c = c+(1/label_to_count[l])#*ratio[l]

			weights.append(c/(j+1))
			#weights.append(c)
			
		self.weights_original = torch.DoubleTensor(weights)

		self.weights_uniform = np.repeat(1/self.num_samples, self.num_samples)

		#print(self.weights_a, self.weights_b)

		beta = 1 - alpha
		self.weights = (alpha * self.weights_original) + (beta * self.weights_uniform)


	def _get_label(self, dataset, idx):
		labels = np.where(dataset[idx,1:]==1)[0]
		#print(labels)
		#labels = dataset[idx,2]
		return labels

	def __iter__(self):
		return (self.indices[i] for i in torch.multinomial(
			self.weights, self.num_samples, replacement=True))

	def __len__(self):
		return self.num_samples


class Dataset_instance(Dataset):

	def __init__(self, list_IDs, mode):
		self.list_IDs = list_IDs
		#self.list_IDs = list_IDs[:,0]
		#self.list_hes = list_IDs[:,1:]
		self.mode = mode

	def __len__(self):
		return len(self.list_IDs)

	def __getitem__(self, index):
		# Select sample
		ID = self.list_IDs[index]
		# Load data and get label
		with open(ID, 'rb') as fin:
			X = pyspng.load(fin.read())
		#img.close()

		#k = pipeline_transform_soft(image=k)['image']
		#k = pipeline_transform(image=q)['image']

		k = X				

		#k = pipeline_transform_soft(image=k)['image']
		q = pipeline_transform(image=k)['image']

		q = preprocess(q).type(torch.FloatTensor)
		k = preprocess(k).type(torch.FloatTensor)
		#return input_tensor
		return k, q

class Dataset_bag(Dataset):

	def __init__(self, list_IDs, labels):

		self.labels = labels
		self.list_IDs = list_IDs

	def __len__(self):

		return len(self.list_IDs)

	def __getitem__(self, index):
		# Select sample
		WSI = self.list_IDs[index]

		return WSI