In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import medmnist
from medmnist import INFO
import numpy as np
from tqdm import tqdm

In [26]:
class DataSubset(Dataset):
  def __init__(self, base_dataset, inds=None, size=-1):
    self.base_dataset = base_dataset
    if inds is None:
        inds = np.random.choice(
            list(range(len(base_dataset))), size, replace=False)
    self.inds = inds

  def __getitem__(self, ind):
    self.base_ind = self.inds[ind]
    return self.base_dataset[self.base_ind]

  def __len__(self):
        return len(self.inds)

def get_data(train_labeled_inds=None, train_unlabeled_inds=None, inds_to_query=None, active_learning_iter=False, labels_per_class=0):
  transform_train = transforms.Compose([
    transforms.Pad(4, padding_mode="reflect"),
    transforms.RandomCrop(28),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    lambda x: x + 3e-2 * torch.randn_like(x)
  ])

  transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
    lambda x: x + 3e-2 * torch.randn_like(x)
  ])

  def MedMNIST(dataset, train, transforms):
    info = INFO[dataset]
    DataClass = getattr(medmnist, info['python_class'])
    return DataClass(split='train' if train else 'val', transform=transforms, download=True), len(info['label'])

  if not active_learning_iter: # if start iteration
    train_class, num_classes = MedMNIST(dataset='bloodmnist', train=True, transforms=transform_train) # (img_vector, label)

    train_inds = list(range(len(train_class)))
    print(f'train_inds[:10] = {train_inds[:10]}')
    
    # np.random.seed(0)
    # np.random.shuffle(train_inds)

    train_inds = np.array(train_inds)
    train_labels = np.array([np.squeeze(train_class[ind][1]) for ind in train_inds])

    if num_classes > 0:
      train_labeled_inds, train_unlabeled_inds = [], []
      for i in range(num_classes):
        train_labeled_inds.extend(train_inds[train_labels == i][:labels_per_class])
        train_unlabeled_inds.extend(train_inds[train_labels == i][labels_per_class:])
    else:
      train_labeled_inds = train_inds

    print(f'train_labeled_inds[:10] = {train_labeled_inds[:10]}')
    print(f'train_unlabeled_inds[:10] = {train_unlabeled_inds[:10]}')
  else:
    train_labeled_inds = np.append(train_labeled_inds, inds_to_query)
    relative_inds_in_unlabeled_set = np.argwhere(np.isin(train_unlabeled_inds, inds_to_query))
    train_unlabeled_inds = np.delete(train_unlabeled_inds, relative_inds_in_unlabeled_set)

get_data()

Using downloaded and verified file: /Users/carlosgil/.medmnist/bloodmnist.npz
train_inds[:10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
train_labeled_inds[:10] = [26, 38, 42, 10, 11, 19, 6, 7, 9, 1]
train_unlabeled_inds[:10] = [47, 79, 81, 86, 92, 112, 159, 165, 176, 196]
