In [1]:
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
import torch
import numpy as np
from simulator import Simulator
import yaml

setting = yaml.safe_load(open('../configurations/MNIST/MLP/ClusterAgg/average/no_straggle.yaml'))

sim = Simulator(**setting)
n_epoch = 100

for epoch in tqdm(
    range(n_epoch), 
    desc=f'Running {str(sim.output_dir)}', 
    leave=True,
):
    picked_clients, to_update_global = sim.step()
    avg_losses = [u.avg_loss for u in to_update_global]
    train_acc_scores = [u.train_acc_score for u in to_update_global]
    test_acc_scores = [u.test_acc_score for u in to_update_global]

    print(f'Simulator -- got {len(to_update_global)} updates to incorporate')

    nan_in_model = torch.stack([w.isnan().any() for w in sim.global_model.get_state().values()]).any().item()
    if nan_in_model:
        print('nan is bleeding in!!!!')
    
    if len(to_update_global):
        # only update the global model if we have any updates
        avg_train_acc = sum(train_acc_scores) / len(train_acc_scores)
        avg_test_acc = sum(test_acc_scores) / len(test_acc_scores)
        print(f'Simulator -- avg loss: {sum(avg_losses) / len(to_update_global)}')
        print(f'Simulator -- client avg train acc: {sum(train_acc_scores) / len(train_acc_scores)}')
        print(f'Simulator -- client avg test acc: {sum(test_acc_scores) / len(test_acc_scores)}')

        # Step 4. Update the global model with the finished local updates
        new_state = sim.aggregator(sim.global_model, to_update_global)

        if new_state is not None:
            sim.global_model.set_state(new_state)
    else:
        avg_train_acc = 0
        avg_test_acc = 0
    
    pred = sim.global_model.predict(sim.x_test)
    print(f'Simulator -- Global model test acc: {accuracy_score(sim.y_test, pred)}')

Running output/MNIST/MLP/ClusterAgg/average/no_straggle:   0%|          | 0/100 [00:00<?, ?it/s]

Simulator -- got 20 updates to incorporate
Simulator -- avg loss: 2.320893573760986
Simulator -- client avg train acc: 0.7166666666666667
Simulator -- client avg test acc: 0.6640898345153664


OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


Simulator -- Global model test acc: 0.46335697399527187
Simulator -- got 20 updates to incorporate
Simulator -- avg loss: 1.5979796648025513
Simulator -- client avg train acc: 0.4499999999999999
Simulator -- client avg test acc: 0.4633569739952719
Simulator -- Global model test acc: 0.46335697399527187
Simulator -- got 20 updates to incorporate
Simulator -- avg loss: 1.5979796648025513
Simulator -- client avg train acc: 0.5125
Simulator -- client avg test acc: 0.4633569739952719




Simulator -- Global model test acc: 0.46335697399527187
Simulator -- got 20 updates to incorporate
Simulator -- avg loss: 1.5979796648025513
Simulator -- client avg train acc: 0.45
Simulator -- client avg test acc: 0.4633569739952719
Simulator -- Global model test acc: 0.46335697399527187
Simulator -- got 20 updates to incorporate
Simulator -- avg loss: 1.5979796648025513
Simulator -- client avg train acc: 0.4833333333333333
Simulator -- client avg test acc: 0.4633569739952719




Simulator -- Global model test acc: 0.5049645390070922
Simulator -- got 20 updates to incorporate
Simulator -- avg loss: nan
Simulator -- client avg train acc: 0.4541666666666666
Simulator -- client avg test acc: 0.4633569739952719


ValueError: Input contains NaN.

In [28]:
[
    torch.stack([w.isnan().any() for w in u.new_state.values()]).any().item()
    for u in to_update_global
]

[True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True]

In [25]:
client.x_train.isnan().any()

tensor(False)

In [26]:
client.y_train.isnan().any()

tensor(False)

In [8]:


points, update_weights, update_delays = [], [], []
for u in to_update_global:
    points.append(list(u.new_state.values()))
    update_weights.append(u.train_size)
    update_delays.append(u.counter)

update_weights = torch.tensor(update_weights).to(points[0][0].device)
final_agg = []

for component in map(torch.stack,zip(*points)): 
    reduced = sim.aggregator.dim_reducer.fit_transform(component.flatten(1).cpu())
    component_clusters = sim.aggregator.cluster_detector.fit_predict(reduced)
    cluster_medians = torch.stack([
        sim.aggregator.combine_points(
            component[component_clusters == c], 
            update_weights[component_clusters == c],
        )
        for c in np.unique(component_clusters)
    ])
    cluster_weights = torch.tensor(np.unique(component_clusters, return_counts = True)[1]).to(cluster_medians.device)
    component_agg = sim.aggregator.combine_points(cluster_medians, cluster_weights)         
    final_agg.append(component_agg)

for i,k in enumerate(new_global_state.keys()):
    new_global_state[k] = final_agg[i]

ValueError: Input contains NaN.

False