In [None]:
from exp_framework.Ensemble import Ensemble
from exp_framework.delegation import DelegationMechanism, UCBDelegationMechanism
from exp_framework.experiment import Experiment
from matplotlib import pyplot as plt
import numpy as np

In [None]:
batch_size = 128
window_size = 5

UCB_del_mech = UCBDelegationMechanism(
    batch_size=batch_size, window_size=window_size, ucb_window_size=None
)
NOOP_del_mech = DelegationMechanism(batch_size=batch_size, window_size=window_size)

ensembles = [
    Ensemble(
        training_epochs=1,
        n_voters=10,
        delegation_mechanism=UCB_del_mech,
        name="UCB_delegation_ensemble",
        input_dim=28 * 28,
        output_dim=10,
    ),
    Ensemble(
        training_epochs=1,
        n_voters=10,
        delegation_mechanism=NOOP_del_mech,
        name="full_ensemble",
        input_dim=28 * 28,
        output_dim=10,
    ),
]

exp = Experiment(n_trials=5, ensembles=ensembles)
exp.run()

In [None]:
UCB_trial_0_test_accs = exp.batch_metric_values["UCB_delegation_ensemble"][0][
    "batch_test_acc"
]
UCB_trial_1_test_accs = exp.batch_metric_values["UCB_delegation_ensemble"][1][
    "batch_test_acc"
]
UCB_trial_2_test_accs = exp.batch_metric_values["UCB_delegation_ensemble"][2][
    "batch_test_acc"
]
UCB_trial_3_test_accs = exp.batch_metric_values["UCB_delegation_ensemble"][3][
    "batch_test_acc"
]
UCB_trial_4_test_accs = exp.batch_metric_values["UCB_delegation_ensemble"][4][
    "batch_test_acc"
]


full_trial_0_test_accs = exp.batch_metric_values["full_ensemble"][0]["batch_test_acc"]
full_trial_1_test_accs = exp.batch_metric_values["full_ensemble"][1]["batch_test_acc"]
full_trial_2_test_accs = exp.batch_metric_values["full_ensemble"][2]["batch_test_acc"]
full_trial_3_test_accs = exp.batch_metric_values["full_ensemble"][3]["batch_test_acc"]
full_trial_4_test_accs = exp.batch_metric_values["full_ensemble"][4]["batch_test_acc"]

In [None]:
# average across trials
UCB_avg_test_accs = []
UCB_std_test_accs = []

full_avg_test_accs = []
full_std_test_accs = []

for i in range(len(UCB_trial_0_test_accs)):
    UCB_avg_test_accs.append(
        np.mean(
            [
                UCB_trial_0_test_accs[i],
                UCB_trial_1_test_accs[i],
                UCB_trial_2_test_accs[i],
                UCB_trial_3_test_accs[i],
                UCB_trial_4_test_accs[i],
            ]
        )
    )
    UCB_std_test_accs.append(
        np.std(
            [
                UCB_trial_0_test_accs[i],
                UCB_trial_1_test_accs[i],
                UCB_trial_2_test_accs[i],
                UCB_trial_3_test_accs[i],
                UCB_trial_4_test_accs[i],
            ]
        )
    )

    full_avg_test_accs.append(
        np.mean(
            [
                full_trial_0_test_accs[i],
                full_trial_1_test_accs[i],
                full_trial_2_test_accs[i],
                full_trial_3_test_accs[i],
                full_trial_4_test_accs[i],
            ]
        )
    )
    full_std_test_accs.append(
        np.std(
            [
                full_trial_0_test_accs[i],
                full_trial_1_test_accs[i],
                full_trial_2_test_accs[i],
                full_trial_3_test_accs[i],
                full_trial_4_test_accs[i],
            ]
        )
    )

In [None]:
# make a figure with dims 10x5
fig, ax = plt.subplots(figsize=(10, 5))

# plot the average test accuracy for each batch for each ensemble
ax.plot(UCB_avg_test_accs, label="UCB Delegation Ensemble")
ax.fill_between(
    range(len(UCB_avg_test_accs)),
    np.array(UCB_avg_test_accs) - np.array(UCB_std_test_accs),
    np.array(UCB_avg_test_accs) + np.array(UCB_std_test_accs),
    alpha=0.5,
)

ax.plot(full_avg_test_accs, label="Full Ensemble")
ax.fill_between(
    range(len(full_avg_test_accs)),
    np.array(full_avg_test_accs) - np.array(full_std_test_accs),
    np.array(full_avg_test_accs) + np.array(full_std_test_accs),
    alpha=0.5,
)

ax.set_xlabel("Batch Number")
ax.set_ylabel("Test Accuracy")
ax.set_title("UCB Delegation Ensemble vs Full Ensemble")
ax.legend()

plt.show()