In [1]:
# experiment

import math
import time
import torch
from torch import nn as nn
from torch.nn import functional as F

from batchbald_redux import (
    active_learning,
    repeated_mnist,
)
#torch.manual_seed(0)
from torchvision.datasets import MNIST,FashionMNIST,CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import wandb
import random

from tqdm.auto import tqdm

In [2]:
from scipy.spatial import distance_matrix
import numpy as np
def greedy_k_center(labeled, unlabeled, amount):

        greedy_indices = []

        # get the minimum distances between the labeled and unlabeled examples (iteratively, to avoid memory issues):
        min_dist = np.min(distance_matrix(labeled[0, :].reshape((1, labeled.shape[1])), unlabeled), axis=0)
        min_dist = min_dist.reshape((1, min_dist.shape[0]))
        for j in range(1, labeled.shape[0], 100):
            if j + 100 < labeled.shape[0]:
                dist = distance_matrix(labeled[j:j+100, :], unlabeled)
            else:
                dist = distance_matrix(labeled[j:, :], unlabeled)
            min_dist = np.vstack((min_dist, np.min(dist, axis=0).reshape((1, min_dist.shape[1]))))
            min_dist = np.min(min_dist, axis=0)
            min_dist = min_dist.reshape((1, min_dist.shape[0]))

        # iteratively insert the farthest index and recalculate the minimum distances:
        farthest = np.argmax(min_dist)
        greedy_indices.append(farthest)
        for i in range(amount-1):
            dist = distance_matrix(unlabeled[greedy_indices[-1], :].reshape((1,unlabeled.shape[1])), unlabeled)
            min_dist = np.vstack((min_dist, dist.reshape((1, min_dist.shape[1]))))
            min_dist = np.min(min_dist, axis=0)
            min_dist = min_dist.reshape((1, min_dist.shape[0]))
            farthest = np.argmax(min_dist)
            greedy_indices.append(farthest)

        return np.array(greedy_indices)



In [3]:
class CNN(nn.Module):
    def __init__(self,):
        super(CNN, self).__init__()
        self.convs = nn.Sequential(
                                nn.Conv2d(1,32,4),
                                nn.ReLU(),
                                nn.Conv2d(32,32,4),
                                nn.ReLU(),
                                nn.MaxPool2d(2),
                                nn.Dropout(0.25),
                                nn.Flatten(),
                                nn.Linear(11*11*32,128)
                                
        )
        self.fcs = nn.Sequential(nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(128,10)
        )

    def forward(self, x):
        out = x
        out = self.convs(out)
        out = self.fcs(out)
        return out

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = MNIST("data", train=True, transform=transform)
test_dataset = MNIST("data", train=False,transform=transform)

subset_size=1000
train_dataset,monitor_set = random_split(
    train_dataset, [59900,100]
)

num_initial_samples = 20
num_classes = 10

initial_samples = active_learning.get_balanced_sample_indices(
    repeated_mnist.get_targets(train_dataset), num_classes=num_classes, n_per_digit=num_initial_samples / num_classes
)

# experiment
max_training_samples = 310
acquisition_batch_size = 5

test_batch_size = 512
batch_size = 64
scoring_batch_size = 128
training_iterations = 1 
use_cuda = torch.cuda.is_available()

print(f"use_cuda: {use_cuda}")

device = "cuda" if use_cuda else "cpu"

kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **kwargs)

active_learning_data = active_learning.ActiveLearningData(train_dataset)
active_learning_data.acquire(initial_samples)
active_learning_data.extract_dataset_from_pool(59900-subset_size-num_initial_samples)
train_loader = torch.utils.data.DataLoader(
    active_learning_data.training_dataset,
    sampler=active_learning.RandomFixedLengthSampler(active_learning_data.training_dataset, 300),
    batch_size=batch_size,
    **kwargs,
    )

pool_loader = torch.utils.data.DataLoader(
    active_learning_data.pool_dataset, batch_size=scoring_batch_size, shuffle=False, **kwargs
)

use_cuda: True


In [None]:
pbar = tqdm(initial=len(active_learning_data.training_dataset), total=max_training_samples, desc="Training Set Size")
epoch=0
percentage_correct_list=[]
from sklearn.cluster import KMeans
import numpy as np
wandb.init(project="CORESET")
wandb.run.name = f'Core_{subset_size}'
wandb.run.save()
model = CNN().to(device=device)
optimizer = torch.optim.Adam(model.parameters())
added_indices = []
while epoch !=60:

    epoch+=1
    model.train()
    train_loss=0
    # Train
    for data, target in tqdm(train_loader, desc="Training", leave=False):
        data = data.to(device=device)
        target = target.to(device=device)

        optimizer.zero_grad()

        prediction = model(data)
        loss = nn.CrossEntropyLoss()(prediction, target)
        train_loss+=loss.item()
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    # Test
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Testing", leave=False):
            data = data.to(device=device)
            target = target.to(device=device)
        
            prediction = model(data)
            loss += nn.CrossEntropyLoss()(prediction, target)

            prediction = prediction.max(1)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()

    loss /= len(test_loader.dataset)

    percentage_correct = 100.0 * correct / len(test_loader.dataset)

    if len(active_learning_data.training_dataset) > max_training_samples or len(active_learning_data.pool_dataset)<100:
        break
    

    acquir_start=time.time()
    
#     N = len(active_learning_data.pool_dataset)
#     logits_N_K_C = torch.empty((N, 128), dtype=torch.double, pin_memory=use_cuda)

#     N_t = len(active_learning_data.training_dataset)
#     logits_N_K_C_t = torch.empty((N_t, 128), dtype=torch.double, pin_memory=use_cuda)
#     with torch.no_grad():
#         model.eval()
#         for i, (data, _) in enumerate(tqdm(train_loader, desc="Evaluating Acquisition Set", leave=False)):
#             data = data.to(device=device)

#             lower = i * train_loader.batch_size
#             upper = min(lower + train_loader.batch_size, N_t)
#             logits_N_K_C_t[lower:upper].copy_(model.convs(data)[lower:upper], non_blocking=True)

#         for i, (data, _) in enumerate(tqdm(pool_loader, desc="Evaluating Acquisition Set", leave=False)):
#             data = data.to(device=device)

#             lower = i * pool_loader.batch_size
#             upper = min(lower + pool_loader.batch_size, N)
#             logits_N_K_C[lower:upper].copy_(model.convs(data), non_blocking=True)
#     indices=greedy_k_center(logits_N_K_C_t,logits_N_K_C,5)
    N = len(active_learning_data.pool_dataset)
    indices=random.sample(range(N-1), 5)
    N = len(active_learning_data.pool_dataset)
    logits_N_K_C = torch.empty((N, 128), dtype=torch.double, pin_memory=use_cuda)

    with torch.no_grad():
        model.eval()
        for i, (data, _) in enumerate(tqdm(pool_loader, desc="Evaluating Acquisition Set", leave=False)):
            data = data.to(device=device)

            lower = i * pool_loader.batch_size
            upper = min(lower + pool_loader.batch_size, N)
            logits_N_K_C[lower:upper].copy_(model.convs(data), non_blocking=True)
    
    x=KMeans(n_clusters=5, random_state=0).fit_transform(logits_N_K_C.numpy())
    indices=np.argmin(x,axis=0)
    
    

    acquire_end=acquir_start-time.time()
    wandb.log({"Epoch":epoch,
               "Num_Pool": len(active_learning_data.pool_dataset),
               "Num_Train":len(active_learning_data.training_dataset),
               "Test Loss": loss,
               'Train Loss': train_loss,
              "Test Accuracy":percentage_correct,
               "Query Time":-acquire_end})


    targets = repeated_mnist.get_targets(active_learning_data.pool_dataset)
    dataset_indices = active_learning_data.get_dataset_indices(indices)
#
    print("Labels: ", targets[indices])
# #     active_learning_data.acquire_remove(candidate_batch.indices,remove_indices)
    active_learning_data.acquire(indices)
    added_indices.append(targets[indices])
    pbar.update(len(dataset_indices))
print(added_indices)

Training Set Size:   6%|6         | 20/310 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33mhslrock[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([3, 2, 8, 4, 6])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([9, 6, 0, 3, 2])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([8, 4, 1, 6, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([9, 0, 1, 8, 2])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([3, 9, 0, 6, 1])


Training:   0%|          | 0/5 [01:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:06<?, ?it/s]

Labels:  tensor([3, 1, 8, 0, 2])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:06<?, ?it/s]

Labels:  tensor([2, 9, 0, 3, 1])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([1, 3, 2, 4, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([9, 3, 6, 0, 1])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:06<?, ?it/s]

Labels:  tensor([0, 1, 9, 3, 6])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([9, 1, 6, 3, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([6, 1, 8, 0, 8])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:30<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([6, 3, 9, 0, 8])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([1, 6, 7, 0, 3])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([1, 0, 9, 6, 3])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([1, 3, 6, 7, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([0, 6, 2, 3, 7])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:56<?, ?it/s]

Labels:  tensor([9, 3, 0, 4, 1])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([1, 0, 4, 9, 2])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([7, 1, 2, 4, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/8 [00:01<?, ?it/s]

Labels:  tensor([7, 2, 6, 0, 8])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([3, 2, 7, 0, 1])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([2, 7, 1, 3, 4])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([7, 3, 8, 0, 1])


Training:   0%|          | 0/5 [01:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:06<?, ?it/s]

Labels:  tensor([3, 9, 0, 6, 2])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:06<?, ?it/s]

Labels:  tensor([3, 7, 9, 2, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([4, 8, 7, 0, 2])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:06<?, ?it/s]

Labels:  tensor([7, 4, 5, 1, 2])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([0, 4, 3, 2, 7])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([5, 4, 1, 7, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([0, 1, 7, 9, 3])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([7, 3, 1, 3, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([4, 8, 7, 0, 3])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([4, 3, 2, 7, 2])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([4, 6, 3, 7, 2])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([7, 9, 6, 5, 2])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([6, 5, 7, 2, 4])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:06<?, ?it/s]

Labels:  tensor([2, 3, 7, 4, 6])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:06<?, ?it/s]

Labels:  tensor([3, 9, 7, 0, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([6, 9, 3, 7, 0])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:06<?, ?it/s]

Labels:  tensor([8, 9, 1, 0, 3])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([2, 0, 9, 2, 8])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:05<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:03<?, ?it/s]

Training:   0%|          | 0/5 [00:01<?, ?it/s]

Labels:  tensor([4, 8, 3, 0, 2])


Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([7, 4, 3, 6, 8])


Training:   0%|          | 0/5 [00:16<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([8, 3, 6, 7, 4])


Training:   0%|          | 0/5 [00:06<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:06<?, ?it/s]

Labels:  tensor([7, 4, 2, 8, 6])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/7 [00:01<?, ?it/s]

Labels:  tensor([4, 3, 6, 7, 5])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/6 [00:06<?, ?it/s]

Labels:  tensor([6, 4, 2, 3, 7])


Training:   0%|          | 0/5 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/6 [00:08<?, ?it/s]

Labels:  tensor([6, 0, 9, 2, 8])


Training:   0%|          | 0/5 [00:03<?, ?it/s]

In [53]:
greedy_k_center(logits_N_K_C_t,logits_N_K_C,5)

array([295, 392, 962, 626, 423], dtype=int64)