In [None]:
from accuracies import *

def generate_accuracies(select_fn):
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
    train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    val_size = int(0.1 * len(train_set))
    indices = torch.randperm(len(train_set))
    val_set = torch.utils.data.Subset(train_set, indices[:val_size])
    train_set = torch.utils.data.Subset(train_set, indices[val_size:])
    
    subset_sizes = [256 * i for i in range(1, 21)]
    
    accuracies = []

    for size in tqdm(subset_sizes):
        model = create_model().to(device)

        subset = select_fn(device, train_set, size)
        
        train(model, subset, device)
        accuracy = test(model,  test_set, device)

        accuracies.append(accuracy)

    return torch.tensor(accuracies)

def plot_accuracies(accuracies_from_models):
    plt.ylabel("Accuracy")
    plt.xlabel("Number of labelled points")

    names = ["uniform_random", "cluster_margin", "committee_soft", "committee_hard"]
    labels = ["Uniform random", "Cluster-Margin", "Committee (Soft)", "Committee (Hard)"]
    
    subset_sizes = [256 * i for i in range(1, 21)]

    for i in range(len(names)):

        accuracies = accuracies_from_models[i]

        if accuracies.size(0) == 1:
            plt.plot(subset_sizes, accuracies[0,:], std, label=labels[i])

        if accuracies.size(0) > 1:
            mean = accuracies.mean(dim=0)
            std = accuracies.std(dim=0)

            (_, caps, _) = plt.errorbar(subset_sizes, mean, std, capsize=3, elinewidth=0.5, label=labels[i])

            for cap in caps:
                cap.set_markeredgewidth(0.5)


    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()

torch.manual_seed(1234)
num_runs = 3
accuracies_uniform_random = []
accuracies_cluster_margin = []
accuracies_committee_soft = []
accuracies_committee_hard = []

for i in range(num_runs):
    print(f"run {i}")
    accuracies_uniform_random.append(generate_accuracies(select_uniform_random))
    accuracies_cluster_margin.append(generate_accuracies(select_cluster_margin))
    accuracies_committee_soft.append(generate_accuracies(select_committee_soft))
    accuracies_committee_hard.append(generate_accuracies(select_committee_hard))

accuracies_uniform_random = torch.stack(accuracies_uniform_random, dim=0)
accuracies_cluster_margin = torch.stack(accuracies_cluster_margin, dim=0)
accuracies_committee_soft = torch.stack(accuracies_committee_soft, dim=0)
accuracies_committee_hard = torch.stack(accuracies_committee_hard, dim=0)

accuracies_from_models = [accuracies_uniform_random,accuracies_cluster_margin,accuracies_committee_soft,accuracies_committee_hard]
plot_accuracies(accuracies_from_models)