In [None]:
%matplotlib inline

# Run active learning with pre-existing weights

In [None]:
# coding=utf-8
from active_learn import get_mean_std
from active.random import RandomSampling
from active.mc_dropout import MCDropoutSampling
from active.ensemble import EnsembleSampling
from active.coreset import CoreSetSampling
from active.policy_learner import PolicyLearner
from train_helper import train_validate_model, reinit_seed
from models.CNN import CNN
import logging
import numpy as np
from models.model_helpers import weights_init
import properties as prop
from results.results_reader import read_results, set_results
import copy
import json
from tqdm import trange
from torch.utils.data import ConcatDataset
import torch
from data.data_helpers import make_tensordataset, stratified_split_dataset, concat_datasets


## Overwrite properties 

set dataset and policy weights to load

In [None]:
# set dataset for evaluation
# "mnist" "fmnist" "kmnist"
prop.DATASET = "mnist" 

# set weights to load

# fixed expert
prop.POLICY_FILEPATH = "./weights/fixed_experts/policy_99.pth"

# exponential expert
#prop.POLICY_FILEPATH = "./weights/exp_experts/policy_99.pth"

# exp. with additional random expert
#prop.POLICY_FILEPATH = "./weights/exp_random/policy_52.pth"

# these weights require edits in state calculation:

# exp. with additional pool embedding 
#prop.POLICY_FILEPATH = "./weights/exp_pool/policy_99.pth"

# exp with additional pool embedding + Coreset expert
#prop.POLICY_FILEPATH = "./weights/exp_pool_coreset/policy_99.pth"

In [None]:
def active_learn(exp_num, StrategyClass, subsample):
    # all strategies use same initial training data and model weights
    reinit_seed(prop.RANDOM_SEED)
    test_acc_list = []
    model = CNN().apply(weights_init).to(device)
    init_weights = copy.deepcopy(model.state_dict())

    reinit_seed(exp_num*10)
    dataset_pool, valid_dataset, test_dataset = get_data_splits()
    train_dataset, pool_dataset = stratified_split_dataset(dataset_pool, 20, 10)

    # initial data
    strategy = StrategyClass(dataset_pool, valid_dataset, test_dataset, device)

    t = trange(1, prop.NUM_ACQS + 1, desc="Aquisitions (size {})".format(prop.ACQ_SIZE), leave=True)
    for acq_num in t:  # range(1, prop.NUM_ACQS + 1):
        model.load_state_dict(init_weights)  # model.apply(weights_init)

        test_acc = train_validate_model(model, device, train_dataset, valid_dataset, test_dataset)
        test_acc_list.append(test_acc)

        if subsample:
            subset_ind = np.random.choice(a=len(pool_dataset), size=prop.K, replace=False)
            pool_subset = make_tensordataset(pool_dataset, subset_ind)
            sel_ind, remain_ind = strategy.query(prop.ACQ_SIZE, model, train_dataset, pool_subset)
            q_idxs = subset_ind[sel_ind]  # from subset to full pool
            remaining_ind = list(set(np.arange(len(pool_dataset))) - set(q_idxs))
            sel_dataset = make_tensordataset(pool_dataset, q_idxs)
            train_dataset = concat_datasets(train_dataset, sel_dataset)
            pool_dataset = make_tensordataset(pool_dataset, remaining_ind)
        else:
            # all strategies work on k-sized windows in semi-batch setting
            sel_ind, remaining_ind = strategy.query(prop.ACQ_SIZE, model, train_dataset, pool_dataset)
            sel_dataset = make_tensordataset(pool_dataset, sel_ind)
            pool_dataset = make_tensordataset(pool_dataset, remaining_ind)
            train_dataset = concat_datasets(train_dataset, sel_dataset)

        logging.info("Accuracy for {} sampling and {} acquisition is {}".format(strategy.name, acq_num, test_acc))
    return test_acc_list

In [None]:
prop.RESULTS_FILE = "./results/experiments_{}_TMP.json".format(prop.DATASET)

logging.basicConfig(level=logging.INFO)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if prop.DATASET.lower() == "mnist":
    from data.mnist import get_data_splits
    print("Training on MNIST")
elif prop.DATASET.lower() == "fmnist":
    from data.fmnist import get_data_splits
    print("Training on Fashion-MNIST")
elif prop.DATASET.lower() == "kmnist":
    from data.kmnist import get_data_splits
    print("Training on KMNIST")

torch.cuda.cudnn_enabled = False
reinit_seed(prop.RANDOM_SEED)
logging.info("later dumping to {}".format(prop.RESULTS_FILE))
strategies = [PolicyLearner] 
results = read_results()
for strategy in strategies:
    test_acc = [active_learn(exp_num, strategy, subsample=True) for exp_num in range(prop.NUM_EXPS)]
    mean, std = get_mean_std(test_acc)
    results[strategy.name] = [mean.tolist(), std.tolist()]

"""strategies = [RandomSampling, EnsembleSampling, MCDropoutSampling, CoreSetSampling] 
for strategy in strategies:
    test_acc = [active_learn(exp_num, strategy, subsample=False) for exp_num in range(prop.NUM_EXPS)]
    mean, std = get_mean_std(test_acc)
    results[strategy.name] = [mean.tolist(), std.tolist()]"""

logging.info("dumping results to {}".format(prop.RESULTS_FILE))
set_results(results, results_file=prop.RESULTS_FILE)

# Plot results

In [None]:
import matplotlib.pyplot as plt
from results.results_reader import read_results
import numpy as np
import seaborn as sns
import pandas as pd
from matplotlib.pyplot import figure

INIT_SIZE = 20
ACQ_SIZE = 10

current_palette = sns.color_palette()
sns.set_style('whitegrid')
figure(num=None, figsize=(12, 8), dpi=80, facecolor='w', edgecolor='k')

linewidth = 3
markersize = 5
with_stddev = True

def plot(keys, results, strategies, acq_size=ACQ_SIZE, ylim=(0,1), lim=1000, loc='lower right'):
    
    for i in keys:
        means = np.array(results[strategies[i]][0])
        std = np.array(results[strategies[i]][1])
        x = [x * acq_size + 20 for x in range(len(means))]
        if labels[i].endswith('Random'):
            color = 'black'
            plt.plot(x, means, marker='.', label=labels[i], linestyle='--', linewidth=5, markersize=markersize,
                     color=color)
            if with_stddev:
                plt.fill_between(x, means + std, means - std, alpha=0.1, color=color)
        else:
            plt.plot(x, means, marker='.', label=labels[i], linewidth=linewidth, markersize=markersize)
            if with_stddev:
                plt.fill_between(x, means + std, means - std, alpha=0.1)
    plt.legend(loc=loc, prop={'size': 10})
    plt.ylabel("Test data accuracy score")
    plt.xlabel("Labeling effort")
    plt.ylim(ylim)
    plt.xlim(0, lim)
    plt.show()

In [None]:
if prop.DATASET == "mnist":
    strategies = ['random', 'mc-dropout', 'ensemble', 'coreset_all', 'policy-learner_state_True']
    labels = ['Random', 'MC Dropout', 'Ensemble', 'Corest', 'Our method']
    keys = [0, 1, 2, 3, 4]
    results = read_results(results_file="./results/experiments_mnist_1000_10_baselines.json")
    results.update(read_results(results_file="./results/experiments_mnist_1000_10_coresets.json"))
    results.update(read_results(results_file="./results/experiments_mnist_TMP.json"))
    plot(keys, results, strategies, acq_size=10, ylim=(0.6,1), loc='lower right')
elif prop.DATASET == "fmnist":
    strategies = ['random', 'mc-dropout', 'ensemble', 'coreset_all', 'policy-learner_state_True']
    labels = ['Random', 'MC Dropout', 'Ensemble', 'Corest', 'Our method']
    keys = [0, 1, 2, 3, 4]
    results = read_results(results_file="./results/experiments_fmnist_1000_10_baselines.json")
    results.update(read_results(results_file="./results/experiments_fmnist_1000_10_coresets.json"))
    results.update(read_results(results_file="./results/experiments_fmnist_TMP.json"))
    plot(keys, results, strategies, acq_size=10, ylim=(0.6,1), loc='lower right')
elif prop.DATASET == "kmnist":
    strategies = ['random', 'mc-dropout', 'ensemble', 'coreset_all', 'policy-learner_state_True']
    labels = ['Random', 'MC Dropout', 'Ensemble', 'Corest', 'Our method']
    keys = [0, 1, 2, 3, 4]
    results = read_results(results_file="./results/experiments_kmnist_1000_10_baselines.json")
    results.update(read_results(results_file="./results/experiments_kmnist_1000_10_coresets.json"))
    results.update(read_results(results_file="./results/experiments_kmnist_TMP.json"))
    plot(keys, results, strategies, acq_size=10, ylim=(0.5,.85), loc='lower right')
