In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import medmnist
from medmnist import INFO
import numpy as np
from tqdm import tqdm

In [29]:
# GLOBAL VARIABLES

# Dataset
DATASET = 'bloodmnist'
NUM_CLASSES = 8
NUM_CHANNELS = 3

#Environment
CUDA_SEED = 0

# Training
BATCH_SIZE = 128
LABELS_PER_CLASS = 10
LR=.0001
NUM_EPOCHS = 3

# Active Learning
NUM_ACTIVE_LEARNING_ITERATIONS = 3
QUERY_SIZE = NUM_CLASSES * LABELS_PER_CLASS

In [35]:
# DATASET HANDLING

class DataSubset(Dataset):
  def __init__(self, base_dataset, inds=None, size=-1):
    self.base_dataset = base_dataset
    if inds is None:
        inds = np.random.choice(
            list(range(len(base_dataset))), size, replace=False)
    self.inds = inds

  def __getitem__(self, ind):
    self.base_ind = self.inds[ind]
    return self.base_dataset[self.base_ind]

  def __len__(self):
    return len(self.inds)

def get_data(train_inds=None, train_labeled_inds=None, train_unlabeled_inds=None, inds_to_query=None, active_learning_iter=False):
  def cycle(loader):
    while True:
      for data in loader:
        yield data
            
  def MedMNIST(train, transforms):
    info = INFO[DATASET]
    DataClass = getattr(medmnist, info['python_class'])
    return DataClass(split='train' if train else 'val', transform=transforms, download=True)
  
  transform_train = transforms.Compose([
    transforms.Pad(4, padding_mode="reflect"),
    transforms.RandomCrop(28),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    lambda x: x + 3e-2 * torch.randn_like(x)
  ])

  transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
    lambda x: x + 3e-2 * torch.randn_like(x)
  ])

  if not active_learning_iter: # if start iteration
    train_class = MedMNIST(train=True, transforms=transform_train) # (img_vector, label)

    train_inds = list(range(len(train_class)))
    
    # np.random.seed(0)
    # np.random.shuffle(train_inds)

    train_inds = np.array(train_inds)
    train_labels = np.array([np.squeeze(train_class[ind][1]) for ind in train_inds])

    if NUM_CLASSES > 0:
      train_labeled_inds, train_unlabeled_inds = [], []
      for i in range(NUM_CLASSES):
        train_labeled_inds.extend(train_inds[train_labels == i][:LABELS_PER_CLASS])
        train_unlabeled_inds.extend(train_inds[train_labels == i][LABELS_PER_CLASS:])
    else:
      train_labeled_inds = train_inds

    
  else: # Every active learning iteration
    train_labeled_inds = np.append(train_labeled_inds, inds_to_query)
    relative_inds_in_unlabeled_set = np.argwhere(np.isin(train_unlabeled_inds, inds_to_query))
    train_unlabeled_inds = np.delete(train_unlabeled_inds, relative_inds_in_unlabeled_set)

  dset_train = DataSubset(MedMNIST(train=True, transforms=transform_train), inds=train_inds)
  dset_train_labeled = DataSubset(MedMNIST(train=True, transforms=transform_train), inds=train_labeled_inds)
  dset_train_unlabeled = DataSubset(MedMNIST(train=True, transforms=transform_train), inds=train_unlabeled_inds)
  dset_validation = MedMNIST(train=False, transforms=transform_val)

  BATCH_SIZE_LABELED = BATCH_SIZE_UNLABELED = BATCH_SIZE if LABELS_PER_CLASS * NUM_CLASSES > BATCH_SIZE else LABELS_PER_CLASS * NUM_CLASSES

  dload_train = DataLoader(dset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)
  dload_train_labeled = DataLoader(dset_train_labeled, batch_size=BATCH_SIZE_LABELED, shuffle=True, num_workers=2, drop_last=True)
  # dload_train_labeled = cycle(dload_train_labeled)
  dload_train_unlabeled = DataLoader(dset_train_unlabeled, batch_size=BATCH_SIZE_UNLABELED, shuffle=True, num_workers=2, drop_last=True)
  dload_val = DataLoader(dset_validation, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, drop_last=False)

  return dload_train, dload_train_labeled, dload_train_unlabeled, dload_val, train_inds, train_labeled_inds, train_unlabeled_inds

In [36]:
# Feature Visualization via t-SNE

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize_dataset(dload_train, sampled_inds=[], show_entire_dataset=False):
  input_tensors = []
  target_labels = []

  for inputs, targets in tqdm(dload_train):
    input_tensors.append(inputs)
    target_labels.append(targets)

  inputs = torch.cat(input_tensors, dim=0)
  target_labels = torch.cat(target_labels, dim=0).squeeze().numpy()

  num_inputs, _, _, _ = inputs.shape

  tsne = TSNE(n_components=2, random_state=0)
  embeddings = tsne.fit_transform(inputs.view(num_inputs, -1))

  _, ax = plt.subplots(figsize=(10, 10))

  if show_entire_dataset:
    for i in range(NUM_CLASSES):
      embeddings_cluster = embeddings[target_labels == i]
      ax.scatter(embeddings_cluster[:, 0], embeddings_cluster[:, 1],label=f"Class {i}", s=3)
  else:
    ax.scatter(embeddings[:, 0], embeddings[:, 1], color='gray', s=3)
    embeddings_samples = embeddings[sampled_inds]
    ax.scatter(embeddings_samples[:, 0], embeddings_samples[:, 1],label="Sampled images", color='red', s=30)
  ax.legend()
  plt.show()

In [37]:
dload_train, dload_train_labeled, dload_train_unlabeled, dload_val, train_inds, train_labeled_inds, train_unlabeled_inds = get_data()

visualize_dataset(dload_train, sampled_inds=[], show_entire_dataset=True)

Using downloaded and verified file: /Users/carlosgil/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/carlosgil/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/carlosgil/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/carlosgil/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/carlosgil/.medmnist/bloodmnist.npz


  0%|          | 0/93 [00:00<?, ?it/s]


AttributeError: Can't pickle local object 'get_data.<locals>.<lambda>'

In [None]:
# DNN

class Net(nn.Module):
  def __init__(self, in_channels, num_classes):
    super(Net, self).__init__()

    self.layer1 = nn.Sequential(
      nn.Conv2d(in_channels, 16, kernel_size=3),
      nn.BatchNorm2d(16),
      nn.ReLU()
    )

    self.layer2 = nn.Sequential(
      nn.Conv2d(16, 16, kernel_size=3),
      nn.BatchNorm2d(16),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2)
    )

    self.layer3 = nn.Sequential(
      nn.Conv2d(16, 64, kernel_size=3),
      nn.BatchNorm2d(64),
      nn.ReLU()
    )

    self.layer4 = nn.Sequential(
      nn.Conv2d(64, 64, kernel_size=3),
      nn.BatchNorm2d(64),
      nn.ReLU()
    )

    self.layer5 = nn.Sequential(
      nn.Conv2d(64, 64, kernel_size=3, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(64 * 4 * 4, 128),
      nn.ReLU(),
      nn.Linear(128, 128),
      nn.ReLU(),
      nn.Linear(128, num_classes)
    )

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.layer5(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

model = Net(in_channels=NUM_CHANNELS, num_classes=NUM_CLASSES)

criterion = nn.CrossEntropyLoss()
    
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)

In [None]:
# MODEL
dload_train, dload_train_labeled, dload_train_unlabeled, dload_val, train_inds, train_labeled_inds, train_unlabeled_inds = get_data()

for al_iter in range(NUM_ACTIVE_LEARNING_ITERATIONS):
  print(f'Active Learning iteration #{al_iter}')
  for epoch in range(NUM_EPOCHS):
    print(f'    Epoch #{epoch}')
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0
    
    model.train()
    
    for inputs, targets in tqdm(dload_train_labeled):
        optimizer.zero_grad()
        outputs = model(inputs)
        
        targets = targets.squeeze().long()
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()

  inds_to_query = np.random.randint(0, len(train_unlabeled_inds), QUERY_SIZE)
  dload_train, dload_train_labeled, dload_train_unlabeled, dload_val, train_inds, train_labeled_inds, train_unlabeled_inds = get_data(train_inds=train_inds,
                                                                                                                                      train_labeled_inds=train_labeled_inds,
                                                                                                                                      train_unlabeled_inds=train_unlabeled_inds,
                                                                                                                                      inds_to_query=inds_to_query,
                                                                                                                                      active_learning_iter=True)

