In [1]:
from exp_framework.Ensemble import Ensemble, PretrainedEnsemble, StudentExpertEnsemble
from exp_framework.delegation import (
    DelegationMechanism,
    UCBDelegationMechanism,
    ProbaSlopeDelegationMechanism,
    RestrictedMaxGurusDelegationMechanism,
    StudentExpertDelegationMechanism,
)
from exp_framework.learning import Net
from exp_framework.experiment import (
    Experiment,
    calculate_avg_std_test_accs,
    calculate_avg_std_train_accs,
)
from avalanche.training.supervised import Naive
from matplotlib import pyplot as plt
from exp_framework.data_utils import Data
from avalanche.benchmarks.classic import RotatedMNIST, SplitMNIST
import numpy as np
import matplotlib as mpl
import seaborn as sns
from itertools import product
import pandas as pd
import torch.optim as optim
from torch.nn import CrossEntropyLoss

from avalanche.training.plugins import (
    CWRStarPlugin,
    ReplayPlugin,
    EWCPlugin,
    TrainGeneratorAfterExpPlugin,
    LwFPlugin,
    SynapticIntelligencePlugin,
)
from exp_framework.MinibatchEvalAccuracy import MinibatchEvalAccuracy
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics

  from .autonotebook import tqdm as notebook_tqdm


### learning the mapping $\mathcal{X} \rightarrow \mathcal{G}$ (i.e. $\mathcal{X} \rightarrow \mathcal{Y}\times\mathcal{C}$)

In [2]:
batch_size = 128
window_size = 50
num_trials = 2
n_voters = 5


# # Set up the Class Incremental framework
# data = Data(
#     data_set_name="mnist",
#     # train_digit_groups=[range(5), range(5, 10)],
#     # train_digit_groups=[[0, 1, 2], [3, 4, 5,], [6, 7, 8, 9]],
#     train_digit_groups=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
#     # test_digit_groups=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
#     # test_digit_groups=[range(5), range(5, 10)],
#     test_digit_groups=[range(10)],
#     batch_size=batch_size,
# )

data = SplitMNIST(n_experiences=5, fixed_class_order=list(range(10)))
# if data == "MNIST":
#     benchmark = SplitMNIST(n_experiences=5, fixed_class_order=list(range(10)), seed=self.seed)
# elif data == "RotatedMNIST":
#     benchmark = RotatedMNIST(n_experiences=5)

# Single Active Voter

### Create Delegation Mechanisms and Ensembles

For simplicity, only explore full ensemble and variants of ProbaSlopeDelegationMechanism since they can be created programmatically.

#### Create Delegation Mechanisms

In [3]:
# Create Delegation Mechanisms - single guru

NOOP_del_mech = DelegationMechanism(batch_size=batch_size, window_size=window_size)

probability_functions = [
    "random_better",
    "probabilistic_better",
    "probabilistic_weighted",
]
score_functions = [
    "accuracy_score",
    "balanced_accuracy_score",
    "f1_score",
    "precision_score",
    "recall_score",
    "top_k_accuracy_score",
    "roc_auc_score",
    "log_loss_score",
    "max_diversity",
]
probability_functions = ["max_diversity"]
score_functions = ["accuracy_score"]
max_active_gurus = 1

del_mechs = {"full-ensemble": NOOP_del_mech}
for prob_func, score_func in product(probability_functions, score_functions):
    dm = ProbaSlopeDelegationMechanism(
        batch_size=batch_size,
        window_size=window_size,
        max_active=max_active_gurus,
        probability_function=prob_func,
        score_method=score_func,
    )
    del_mechs[f"{prob_func}-{score_func}"] = dm


ensembles_dict = {
    dm_name: Ensemble(
        training_epochs=1,
        n_voters=n_voters,
        delegation_mechanism=dm,
        name=dm_name,
        input_dim=28 * 28,
        output_dim=10,
    )
    for dm_name, dm in del_mechs.items()
}

# restricted_max_gurus_mech = RestrictedMaxGurusDelegationMechanism(
#     batch_size=batch_size,
#     num_voters=n_voters,
#     max_active_voters=max_active_gurus,
#     window_size=window_size,
#     t_between_delegation=3,
# )
# UCB_del_mech = UCBDelegationMechanism(
#     batch_size=batch_size,
#     window_size=window_size,
#     ucb_window_size=None
# )

#### Create Avalanche Strategies to Compare Against

In [4]:
def initialize_strategies_to_evaluate():

    model = Net(input_dim=28 * 28, output_dim=10)
    # model = SimpleMLP(num_classes=10)
    optimize = optim.Adam(model.parameters(), lr=0.001)

    plugins_to_evaluate = {
        "LwF": LwFPlugin(),
        # "EWC": EWCPlugin(ewc_lambda=0.001),
        # "SynapticIntelligence": SynapticIntelligencePlugin(si_lambda=0.5),
        # "Replay": ReplayPlugin(mem_size=100),
    }

    strategies_to_evaluate = {}
    for name, pte in plugins_to_evaluate.items():
        mb_eval = MinibatchEvalAccuracy()
        evp = EvaluationPlugin(
            accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
            mb_eval
        )
        cl_strategy = Naive(
            model=model,
            optimizer=optimize,
            criterion=CrossEntropyLoss(),
            train_mb_size=batch_size,
            train_epochs=1,
            eval_mb_size=batch_size,
            plugins=[pte, evp, mb_eval],
            # evaluator=evp,
        )
        strategies_to_evaluate[name] = (cl_strategy, evp)
    
    return strategies_to_evaluate

### Train Ensemble

In [5]:
# Train ensembles - single guru

one_active_exp = Experiment(
    n_trials=num_trials,
    ensembles=list(ensembles_dict.values()),
    benchmark=data,
    strategies_to_evaluate=initialize_strategies_to_evaluate,
)
_ = one_active_exp.run()

  0%|          | 0/2 [00:00<?, ?it/s]



Starting trial  0
-- >> Start of training phase << --
100%|██████████| 99/99 [00:00<00:00, 161.73it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.0755
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9878
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 95/95 [00:00<00:00, 150.53it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 1.4892
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8196
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 88/88 [00:00<00:00, 154.79it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 2.0096
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7018
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 96/96 [00:00<00:00, 146.00it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 2.0663
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7007
-- >> End of training phase << -

 50%|█████     | 1/2 [00:12<00:12, 12.02s/it]

Starting trial  1
-- >> Start of training phase << --
100%|██████████| 99/99 [00:00<00:00, 156.36it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.0807
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9904
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 95/95 [00:00<00:00, 145.72it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 1.4354
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8337
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 88/88 [00:00<00:00, 148.70it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 1.9490
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6927
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 96/96 [00:00<00:00, 148.20it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 2.1216
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7384
-- >> End of training phase << -

100%|██████████| 2/2 [00:24<00:00, 12.02s/it]


### Save and Print Results

In [6]:
batch_metrics = one_active_exp.get_aggregate_batch_metrics()
dfs = []
for ens, metric_dict in batch_metrics.items():
    df = pd.DataFrame.from_dict(metric_dict, orient="index")
    df["ensemble_name"] = ens
    dfs.append(df)
df = pd.concat(dfs)
col_order = [len(df.columns) - 1] + list(range(len(df.columns) - 1))
df = df[df.columns[col_order]]
print(df)
file_prefix = f"class_incremental_single_guru-trials={num_trials}-batch_size={batch_size}_window_size={window_size}"
path = "results"

df.to_csv(f"{path}/{file_prefix}.csv")

Batch metric values for full-ensemble and batch_test_acc are: [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9765625, 0.9609375, 0.96875, 0.96875, 0.9765625, 0.953125, 0.96875, 0.96875, 0.984375, 0.9765625, 0.984375, 0.96875, 1.0, 0.9765625, 0.96875, 0.9841269850730896], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9609375, 0.9765625, 0.9921875, 0.9921875, 0.9453125, 0.953125, 0.984375, 0.9453125, 0.984375, 0.9765625, 0.96875, 0.984375, 0.96

In [7]:
# Print results - single guru

print(f"Results for mechanisms with max_active_gurus = {max_active_gurus}:")

# Collect and print train accuracies - aggregate and by batch
train_results_dict = dict()
for ens_name, ensemble in ensembles_dict.items():
    train_acc, train_acc_std = calculate_avg_std_train_accs(
        one_active_exp, ens_name, num_trials
    )
    train_results_dict[ens_name] = (train_acc, train_acc_std)

for strat_name, (strat, eval_plugin) in initialize_strategies_to_evaluate().items():
    train_acc, train_acc_std = calculate_avg_std_train_accs(
        one_active_exp, strat_name, num_trials
    )
    train_results_dict[strat_name] = (train_acc, train_acc_std)

for ens_name, (train_acc, train_acc_std) in train_results_dict.items():
    print(
        f"Mean train acc for {ens_name}: {round(np.mean(train_acc), 3)}+-{round(np.mean(train_acc_std), 3)}"
    )
# for ens_name, (train_acc, train_acc_std) in train_results_dict.items():
#     print(f"All train accs for {ens_name}: {train_acc}")

print("--------------")

# Collect and print test accuracies
results_dict = dict()
for ens_name, ensemble in ensembles_dict.items():
    test_acc, test_acc_std = calculate_avg_std_test_accs(
        one_active_exp, ens_name, num_trials
    )
    results_dict[ens_name] = (test_acc, test_acc_std)

for strat_name, (strat, eval_plugin) in initialize_strategies_to_evaluate().items():
    test_acc, test_acc_std = calculate_avg_std_train_accs(
        one_active_exp, strat_name, num_trials
    )
    results_dict[strat_name] = (test_acc, test_acc_std)

for ens_name, (test_acc, test_acc_std) in results_dict.items():
    print(
        f"Mean test acc for {ens_name}: {round(np.mean(test_acc), 3)}+-{round(np.mean(test_acc_std), 3)}"
    )

Results for mechanisms with max_active_gurus = 1:


Mean train acc for full-ensemble: 0.862+-0.017
Mean train acc for max_diversity-accuracy_score: 0.863+-0.043
Mean train acc for LwF: 0.767+-0.026
--------------
Mean test acc for full-ensemble: 0.194+-0.002
Mean test acc for max_diversity-accuracy_score: 0.252+-0.166
Mean test acc for LwF: 0.767+-0.026




# Many Active Voters

In [8]:
# Create Delegation Mechanisms - many gurus

NOOP_del_mech = DelegationMechanism(batch_size=batch_size, window_size=window_size)

probability_functions = [
    "random_better",
    "probabilistic_better",
    "probabilistic_weighted",
]
score_functions = [
    "accuracy_score",
    "balanced_accuracy_score",
    "f1_score",
    "precision_score",
    "recall_score",
    "top_k_accuracy_score",
    "roc_auc_score",
    "log_loss_score",
]
max_active_gurus = 3

del_mechs = {"full-ensemble": NOOP_del_mech}
for prob_func, score_func in product(probability_functions, score_functions):
    dm = ProbaSlopeDelegationMechanism(
        batch_size=batch_size,
        window_size=window_size,
        max_active=max_active_gurus,
        probability_function=prob_func,
        score_method=score_func,
    )
    del_mechs[f"{prob_func}-{score_func}"] = dm


many_active_ensembles_dict = {
    dm_name: Ensemble(
        training_epochs=1,
        n_voters=n_voters,
        delegation_mechanism=dm,
        name=dm_name,
        input_dim=28 * 28,
        output_dim=10,
    )
    for dm_name, dm in del_mechs.items()
}

In [9]:
# Run experiment - Many gurus

many_active_exp = Experiment(
    n_trials=num_trials,
    ensembles=list(many_active_ensembles_dict.values()),
    data=data,
    seed=4090,
)
_ = many_active_exp.run()

TypeError: Experiment.__init__() got an unexpected keyword argument 'data'

In [None]:
print(f"Results for mechanisms with max_active_gurus = {max_active_gurus}:")

# Collect and print train accuracies - aggregate and by batch
train_results_dict = dict()
for ens_name, ensemble in many_active_ensembles_dict.items():
    train_acc, train_acc_std = calculate_avg_std_train_accs(
        many_active_exp, ens_name, num_trials
    )
    train_results_dict[ens_name] = (train_acc, train_acc_std)

for ens_name, (train_acc, train_acc_std) in train_results_dict.items():
    print(
        f"Mean train acc for {ens_name}: {round(np.mean(train_acc), 3)}+-{round(np.mean(train_acc_std), 3)}"
    )
# for ens_name, (train_acc, train_acc_std) in train_results_dict.items():
#     print(f"All train accs for {ens_name}: {train_acc}")

print("--------------")

# Collect and print test accuracies
results_dict = dict()
for ens_name, ensemble in many_active_ensembles_dict.items():
    test_acc, test_acc_std = calculate_avg_std_test_accs(
        many_active_exp, ens_name, num_trials
    )
    results_dict[ens_name] = (test_acc, test_acc_std)

for ens_name, (test_acc, test_acc_std) in results_dict.items():
    print(
        f"Mean test acc for {ens_name}: {round(np.mean(test_acc), 3)}+-{round(np.mean(test_acc_std), 3)}"
    )

# Basic Comparison with Avalanche

In [None]:
from exp_framework.learning import Net
from torch.optim import SGD, Adam
from torch.nn import CrossEntropyLoss
from avalanche.models import SimpleMLP
from avalanche.training.supervised import (
    Naive,
    CWRStar,
    Replay,
    GDumb,
    Cumulative,
    LwF,
    GEM,
    AGEM,
    EWC,
)  # and many more!
from avalanche.benchmarks.classic import RotatedMNIST, SplitMNIST
from avalanche.training.plugins import ReplayPlugin
import pprint

model = Net(input_dim=28 * 28, output_dim=10)
optimize = Adam(model.parameters(), lr=0.001)
replay = ReplayPlugin(mem_size=100)

cl_strategy = Naive(
    model,
    optimizer=optimize,
    criterion=CrossEntropyLoss(),
    train_mb_size=128,
    train_epochs=1,
    eval_mb_size=128,
    # plugins=[replay]
)
# optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
# model = SimpleMLP(num_classes=10)
# criterion = CrossEntropyLoss()
# cl_strategy = Naive(
#     model, SGD(model.parameters(), lr=0.001, momentum=0.9), criterion,
#     train_mb_size=100, train_epochs=4, eval_mb_size=100
# )


# scenario
# benchmark = RotatedMNIST(n_experiences=5, seed=1)
benchmark = SplitMNIST(n_experiences=5, fixed_class_order=list(range(10)), seed=1)

# TRAINING LOOP
print("Starting experiment...")
results = []
for experience in benchmark.train_stream:
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)

    cl_strategy.train(experience)
    print("Training completed")

    print("Computing accuracy on the whole test set")
    results.append(cl_strategy.eval(benchmark.test_stream))

for r in results:
    pprint.pprint(r)
# print(results)

In [None]:
from avalanche.benchmarks.classic import RotatedMNIST

# scenario
benchmark = RotatedMNIST(n_experiences=5, seed=1)

# TRAINING LOOP
print("Starting experiment...")
results = []
for experience in benchmark.train_stream:
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)

    cl_strategy.train(experience)
    print("Training completed")

    print("Computing accuracy on the whole test set")
    results.append(cl_strategy.eval(benchmark.test_stream))

print(results)

# Explore Results

(leftover copied code from other file, not adapted for the above code)

In [None]:
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=2)

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = "Georgia"

# set colors for each bar. Use pastel
colors = sns.color_palette("pastel")
# assign colors for each bar

fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111)
ax.set_title("Test Accuracies")
ax.set_ylabel("Test Accuracy")
ax.set_xlabel("Delegation Mechanism")
# ax.set_xticks([0, 1, 2])
# ax.set_xticklabels(
#     ["No Delegation", "Proba Slope", "Restricted Max Guru"], rotation=45, ha="right"
# )
ax.set_xticks([0, 1])
ax.set_xticklabels(["No Delegation", "Proba Slope"], rotation=45, ha="right")
# Data for bar plot
means = [
    np.mean(full_avg_test_accs),
    np.mean(proba_slope_avg_test_accs),
    # np.mean(restricted_max_guru_avg_test_accs),
]
stds = [
    np.std(full_avg_test_accs),
    np.std(proba_slope_avg_test_accs),
    # np.std(restricted_max_guru_avg_test_accs),
]

# Create each bar individually to set different colors
for i in range(len(ensembles)):
    ax.bar(i, means[i], color=colors[i], yerr=stds[i], capsize=10)

plt.tight_layout()
plt.show()

In [None]:
proba_slope_avg_train_accs, proba_slope_std_train_accs = calculate_avg_std_train_accs(
    exp, "proba_slope_delegations", num_trials
)
full_avg_train_accs, full_std_train_accs = calculate_avg_std_train_accs(
    exp, "full_ensemble", num_trials
)

# (
#     restricted_max_guru_avg_train_accs,
#     restricted_max_guru_std_train_accs,
# ) = calculate_avg_std_train_accs(exp, "restricted_max_guru_delegations", num_trials)

print(
    "Mean train accs for proba_slope delegation ensemble: ",
    np.mean(proba_slope_avg_train_accs),
)
print("Mean train accs for full ensemble: ", np.mean(full_avg_train_accs))

# print(
#     "Mean train accs for restricted_max_guru delegation ensemble: ",
#     np.mean(restricted_max_guru_avg_train_accs),
# )

In [None]:
train_splits = exp.train_splits

In [None]:
sns.set(style="whitegrid", palette="pastel", context="paper")

# Set the font to Georgia
mpl.rcParams["font.family"] = "Georgia"
mpl.rcParams["font.size"] = 12
mpl.rcParams["axes.labelsize"] = 14
mpl.rcParams["axes.titlesize"] = 16

fig, ax = plt.subplots(figsize=(10, 5))

colors = sns.color_palette("pastel")
proba_slope_color = colors[1]
full_color = colors[0]
restricted_max_guru_color = colors[2]

ax.plot(
    proba_slope_avg_train_accs,
    label="ProbaSlope Delegation Ensemble",
    color=proba_slope_color,
    linewidth=2,
)
ax.fill_between(
    range(len(proba_slope_avg_train_accs)),
    np.array(proba_slope_avg_train_accs) - np.array(proba_slope_std_train_accs),
    np.array(proba_slope_avg_train_accs) + np.array(proba_slope_std_train_accs),
    color=proba_slope_color,
    alpha=0.3,
)

ax.plot(full_avg_train_accs, label="Full Ensemble", color=full_color, linewidth=2)
ax.fill_between(
    range(len(full_avg_train_accs)),
    np.array(full_avg_train_accs) - np.array(full_std_train_accs),
    np.array(full_avg_train_accs) + np.array(full_std_train_accs),
    color=full_color,
    alpha=0.3,
)

# ax.plot(
#     restricted_max_guru_avg_train_accs,
#     label="Restricted Max Guru Delegation Ensemble",
#     color=restricted_max_guru_color,
#     linewidth=2,
# )
# ax.fill_between(
#     range(len(restricted_max_guru_avg_train_accs)),
#     np.array(restricted_max_guru_avg_train_accs)
#     - np.array(restricted_max_guru_std_train_accs),
#     np.array(restricted_max_guru_avg_train_accs)
#     + np.array(restricted_max_guru_std_train_accs),
#     color=colors[2],
#     alpha=0.3,
# )


# plot vertical lines at test splits
for split in train_splits[:-1]:
    ax.axvline(x=split, color="k", linestyle="--", linewidth=1)

# Setting labels, title, and legend
ax.set_xlabel("Batch Number")
ax.set_ylabel("Train Accuracy")
ax.set_title(
    "ProbaSlope Delegation Ensemble vs Full Ensemble vs Restricted Max Guru Delegation Ensemble"
)

ax.legend(loc="upper left")
# set y lim to lower
ax.set_ylim(top=1.3)
# set y ticks to 0-1
ax.set_yticks(np.arange(0, 1.1, 0.1))

plt.tight_layout()

# Show the plot
plt.show()

In [None]:
ps_voters = exp.ensembles[1].voters
print(ensembles[1].name)
batch_accs = []
for v in ps_voters:
    batch_accs.append(v.batch_accuracies)

In [None]:
len_train = len(data.train_data_loader.dataset) / batch_size

In [None]:
def find_active_streaks(voter_id, trial_num):
    """
    Find active streaks for a specified voter.

    :param voter_id: ID of the voter for which to find active streaks.
    :param batch_metric_values: Dictionary containing the batch metric values.
    :param metric_key: Key to access the relevant metric in batch_metric_values.
    :return: List of active streaks for the specified voter.
    """
    active_batches = []
    active_streak = [None, None]
    voter_active = False

    for i, av in enumerate(
        exp.batch_metric_values["proba_slope_delegations"][trial_num][
            "active_voters-train"
        ]
    ):
        # print(av)
        if voter_id in av:
            if not voter_active:
                # Start a new streak
                active_streak[0] = i
                voter_active = True
                # print("streak started")
            active_streak[1] = i
        else:
            if voter_active:
                # End the current streak
                active_batches.append(active_streak.copy())
                active_streak = [None, None]
                voter_active = False
                # print("streak done")

    # Handle case where the streak continues till the end of the list
    if voter_active:
        active_batches.append(active_streak.copy())

    return active_batches

### Look at activity on last trial

In [None]:
for voter_id in range(n_voters):
    active_streaks = find_active_streaks(voter_id, num_trials - 1)
    # print(f"Active Streaks for Voter {voter_id}: {active_streaks}")

    plt.figure(figsize=(10, 5))  # Create a new figure for each voter
    plt.plot(batch_accs[voter_id])
    plt.axvline(x=len_train, color="k", linestyle="--", linewidth=1)

    # Shade the active batches for this voter
    for streak in active_streaks:
        if streak[0] is not None and streak[1] is not None:
            plt.axvspan(streak[0], streak[1], alpha=0.3, color="red")

    # Plot a green vertical line at all train splits
    for split in train_splits[:-1]:
        plt.axvline(x=split, color="g", linestyle="--", linewidth=2)

    plt.title(f"Voter {voter_id} Activity")
    plt.xlabel("Batches")
    plt.ylabel("Accuracy")
    plt.show()  # Display the plot for each voter