In [1]:
from avalanche.benchmarks.classic import RotatedMNIST
from torch.utils.data import DataLoader
from exp_framework.Ensemble import Ensemble, PretrainedEnsemble
from exp_framework.delegation import (
    DelegationMechanism,
    UCBDelegationMechanism,
    ProbaSlopeDelegationMechanism,
    RestrictedMaxGurusUCBDelegationMechanism,
)
from exp_framework.experiment import (
    Experiment,
    calculate_avg_std_test_accs,
    calculate_avg_std_train_accs,
    calculate_avg_std_test_accs_per_trial,
)
from matplotlib import pyplot as plt
from exp_framework.data_utils import Data
import numpy as np
import matplotlib as mpl
import seaborn as sns
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def run_experiment(batch_size, window_size, num_trials, n_voters):
    data = Data(data_set_name="rotated_mnist")

    NOOP_del_mech = DelegationMechanism(batch_size=batch_size, window_size=window_size)

    proba_slope_del_mech = ProbaSlopeDelegationMechanism(
        batch_size=batch_size, window_size=window_size
    )

    ensembles = [
        Ensemble(
            training_epochs=1,
            n_voters=n_voters,
            delegation_mechanism=NOOP_del_mech,
            name="full_ensemble",
            input_dim=28 * 28,
            output_dim=10,
        ),
        Ensemble(
            training_epochs=1,
            n_voters=n_voters,
            delegation_mechanism=proba_slope_del_mech,
            name="proba_slope_delegations",
            input_dim=28 * 28,
            output_dim=10,
        ),
    ]

    exp = Experiment(n_trials=num_trials, ensembles=ensembles, data=data, seed=653)
    _ = exp.run()

    (
        proba_slope_avg_test_accs_per_trial,
        proba_slope_std_test_accs_per_trial,
    ) = calculate_avg_std_test_accs_per_trial(
        exp, "proba_slope_delegations", num_trials
    )
    (
        full_avg_test_accs_per_trial,
        full_std_test_accs_per_trial,
    ) = calculate_avg_std_test_accs_per_trial(exp, "full_ensemble", num_trials)

    return np.mean(proba_slope_avg_test_accs_per_trial) - np.mean(
        full_avg_test_accs_per_trial
    )

In [5]:
# hyper parameter search
batch_sizes = [128]
window_sizes = [300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400]
num_trials = 2
n_voters = [3]

results = np.zeros((len(batch_sizes), len(window_sizes), len(n_voters)))
for i, batch_size in enumerate(batch_sizes):
    for j, window_size in tqdm(enumerate(window_sizes)):
        for k, n_voter in enumerate(n_voters):
            results[i, j, k] = run_experiment(
                batch_size, window_size, num_trials, n_voter
            )
            print(
                f"batch_size: {batch_size}, window_size: {window_size}, n_voter: {n_voter}, result: {results[i, j, k]}"
            )

100%|██████████| 2/2 [04:09<00:00, 124.82s/it]
1it [04:14, 254.01s/it]

batch_size: 128, window_size: 300, n_voter: 3, result: 0.15186221230670305


100%|██████████| 2/2 [04:05<00:00, 122.93s/it]
2it [08:24, 251.79s/it]

batch_size: 128, window_size: 310, n_voter: 3, result: 0.14728260871089627


100%|██████████| 2/2 [04:09<00:00, 124.69s/it]
3it [12:37, 252.68s/it]

batch_size: 128, window_size: 320, n_voter: 3, result: 0.15521499362138214


100%|██████████| 2/2 [04:11<00:00, 125.89s/it]
4it [16:54, 254.05s/it]

batch_size: 128, window_size: 330, n_voter: 3, result: 0.15311500956030444


100%|██████████| 2/2 [04:11<00:00, 125.81s/it]
5it [21:10, 254.76s/it]

batch_size: 128, window_size: 340, n_voter: 3, result: 0.15023377560593576


100%|██████████| 2/2 [04:12<00:00, 126.33s/it]
6it [25:27, 255.54s/it]

batch_size: 128, window_size: 350, n_voter: 3, result: 0.16707560740163552


100%|██████████| 2/2 [04:10<00:00, 125.21s/it]
7it [29:41, 255.28s/it]

batch_size: 128, window_size: 360, n_voter: 3, result: 0.1657269022348895


100%|██████████| 2/2 [04:03<00:00, 121.83s/it]
8it [33:50, 252.99s/it]

batch_size: 128, window_size: 370, n_voter: 3, result: 0.1853960198362159


100%|██████████| 2/2 [04:10<00:00, 125.25s/it]
9it [38:04, 253.59s/it]

batch_size: 128, window_size: 380, n_voter: 3, result: 0.16509351027591146


100%|██████████| 2/2 [04:10<00:00, 125.33s/it]
10it [42:19, 254.04s/it]

batch_size: 128, window_size: 390, n_voter: 3, result: 0.20055346865483248


100%|██████████| 2/2 [04:11<00:00, 125.84s/it]
11it [46:36, 254.19s/it]

batch_size: 128, window_size: 400, n_voter: 3, result: 0.21515345268542196



