In [None]:
from helper import *

In [None]:
dataset_name = "MNIST"
model_cfg = models.reducedLeNet5
title = model_cfg.__name__ + "_" + dataset_name + "_federated"
pretrained_init = False
pretrained_clients = False
n_clients = 10
n_classes = 10

In [None]:
seed = 42
torch.backends.cudnn.benchmark = True
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

In [None]:
dataset = getattr(torchvision.datasets, dataset_name)

batch_size = 256
train_dataset = dataset(root='./data', train=True, download=True, transform=ToTensor())
test_dataset = dataset(root='./data', train=False, download=True, transform=ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

In [None]:
pretrained_init = False

In [None]:
general_model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
title_pretrained_init = "pretrained_init" + title

if not pretrained_init:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(general_model.parameters(), lr=1e-1)
    wd = 0.0
    lr_init = 1e-1
    train(general_model, train_loader, test_loader, optimizer, criterion, lr_init, title=title_pretrained_init, epochs=20)
else:
    general_model.load_state_dict(torch.load("ckpts/" + title_pretrained_init + ".pt"))

In [None]:
client_loaders = noniid_datasets(train_loader, n_clients=10, n_classes=10)
visualizing_client_loader(client_loaders, n_clients, n_classes, path_figures="./figures/"+title)

In [None]:
pretrained_clients = False

In [None]:
if not pretrained_clients:
    s = sum(1 for _ in general_model.parameters())
    for i in range(n_clients):
        loader_i = client_loaders[i]
        model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
        for j, (param, param_client) in enumerate(zip(model.parameters(), general_model.parameters())):
            if j <= s - 2: #we do not wish to copy/clone the weights of the last layer (logistic regression layer in the report)
                param.data = param_client.data.clone()
        title_i = "parametrized_preclient_" + title + str(i)
        wd = 0.0
        lr_init = 1e-1
        train(model, loader_i, test_loader, optimizer, criterion, lr_init, title=title_i, epochs=10)

In [None]:
clients = [SWAG(model_cfg.base, subspace_type="pca", *model_cfg.args, **model_cfg.kwargs, 
                  subspace_kwargs={"max_rank": 2, "pca_rank": 2}) for i in range(n_clients)]
probs = []

In [None]:
for i in range(n_clients):
    lr_init = 1e-2
    wd = 0.0
    epochs=20
    title_i = "parametrized_preclient_" + title + str(i)
    new_title_i = "swag_" + title_i
    
    swag_model_i = clients[i]
    loader_i = client_loaders[i]
    model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)

    model.load_state_dict(torch.load("ckpts/" + title_i + ".pt"))
    optimizer = torch.optim.SGD(model.parameters(), lr=lr_init, weight_decay=wd)
    
    #test_loader is only included to display accuracy
    train(model, loader_i, test_loader, optimizer, criterion, lr_init, epochs, title=new_title_i, print_freq=5, 
          swag=True, swag_model=swag_model_i, swag_start=5, swag_freq=5, swag_lr=1e-2)
    all_probs = model_averaging(swag_model_i, model=model_cfg.base(*model_cfg.args, **model_cfg.kwargs), loader=test_loader)
    probs.append(all_probs)

In [None]:
Mu_s, Sigma_s = np.vstack([np.array(swag_model._get_mean_and_variance()[0]) for swag_model in clients]), np.vstack([np.array(swag_model._get_mean_and_variance()[1]) for swag_model in clients])

In [None]:
Sigma_server = np.reciprocal(np.sum(np.reciprocal(Sigma_s), axis=0))

In [None]:
Mu_server = np.multiply(Sigma_server, np.sum(np.multiply(Mu_s, np.reciprocal(Sigma_s)), axis=0))

In [None]:
model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
set_weights(model, torch.tensor(Mu_server))
accuracy_model(model, test_loader, 'cpu')

In [None]:
swag_model = SWAG(model_cfg.base, subspace_type="pca", *model_cfg.args, **model_cfg.kwargs, 
                  subspace_kwargs={"max_rank": 2, "pca_rank": 2})
swag_model.mean = torch.tensor(Mu_server, dtype=torch.float32)
swag_model.sq_mean = torch.tensor(Sigma_server, dtype=torch.float32) + swag_model.mean ** 2
swag_model.cov_factor = torch.eye(swag_model.mean.shape[0], dtype=torch.float32) * swag_model.sq_mean


In [None]:
accuracies = {}

In [None]:
swag_all_probs = model_averaging(swag_model, model, test_loader, S=10)
ytest = np.array(test_loader.dataset.targets)
acc_swag = accuracy_all_probs(swag_all_probs, ytest)
accuracies['swag'] = acc_swag
acc_swag

In [None]:
model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
for i in range(10):
    set_weights(model, torch.tensor(Mu_s[i]))
    print(f"Mu {i}:" , accuracy_model(model, test_loader, 'cpu'))

In [None]:
new_mu_server = np.average(Mu_s, weights=[len(client_loader.dataset) for client_loader in client_loaders.values()], axis=0)
model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
set_weights(model, torch.tensor(new_mu_server))
accuracy_model(model, test_loader, 'cpu')

In [None]:
new_mu_server = np.mean(Mu_s, axis=0)
model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
set_weights(model, torch.tensor(new_mu_server))
accuracy_model(model, test_loader, 'cpu')

In [None]:
all_probs = np.average(probs, weights=[len(client_loader.dataset) for client_loader in client_loaders.values()], axis=0)
ytest = np.array(test_loader.dataset.targets)
accuracy_all_probs(all_probs, ytest)

In [None]:
all_probs = np.mean(probs, axis=0)
ytest = np.array(test_loader.dataset.targets)
accuracy_all_probs(all_probs, ytest)

In [None]:
accuracy_model(general_model, test_loader, 'cpu')

In [None]:
sgld = Sgld(model_cfg.base(*model_cfg.args, **model_cfg.kwargs))
sgld_path = "./ckpts/sgld_" + title + ".pt"
state_dict, save_dict = sgld.run(train_loader, test_loader, 20, params_optimizer={'lr' : 1e-2}, weight_decay=0.0, t_burn_in=5, path_save_samples=sgld_path)
sgld_all_probs = np.array(sgld_tools.predictions(test_loader, model, path=sgld_path, device='cpu'))

In [None]:
psgld = Sgld(model_cfg.base(*model_cfg.args, **model_cfg.kwargs))
psgld_path = "./ckpts/psgld_" + title + ".pt"
pstate_dict, psave_dict = sgld.run(train_loader, test_loader, T20, params_optimizer={'lr' : 1e-2, 'precondition_decay_rate' : 0.95}, weight_decay=0.0, t_burn_in=5, path_save_samples=psgld_path)
psgld_all_probs = np.array(sgld_tools.predictions(test_loader, model, path=psgld_path, device='cpu'))

In [None]:
save_calibration_scores(swag_all_probs, ytest, title="swag")
save_calibration_scores(sgld_all_probs, ytest, title="SGLD")
save_calibration_scores(psgld_all_probs, ytest, title="pSGLD")

In [None]:
def compute_nll(all_probs, ytest):
    log_it = - np.log(np.take_along_axis(all_probs, np.expand_dims(ytest, axis=1), axis=1)).squeeze()
    nll = log_it.mean()
    return nll

In [None]:
compute_nll(swag_all_probs, ytest)

In [None]:
compute_nll(sgld_all_probs, ytest)

In [None]:
compute_nll(psgld_all_probs, ytest)

In [None]:
swag_all_probs

In [None]:
tau_list = np.linspace(0, 1, num=100)

In [None]:
sns.set(rc={"figure.dpi":600, 'savefig.dpi':600})
sns.set_style("darkgrid")
path_figures = path + "/figures"
tau_list = np.linspace(0, 1, num=100)
for name, all_probs in [('swag', swag_all_probs), ('sgld', sgld_all_probs), ('psgld', psgld_all_probs)]:
  acc_conf = accuracy_confidence(all_probs, ytest, tau_list, num_bins = 20)
  plt.plot(tau_list, acc_conf, label=name)
plt.xlabel(r"$\tau$", fontsize=18)
plt.ylabel(r"accuracy - confidence | confidence $\geq \tau$", fontsize=12)
plt.legend()
plt.savefig(path_figures + '/acc_conf-' + title + '.pdf', bbox_inches='tight')
plt.show()

for name, all_probs in [('swag', swag_all_probs), ('sgld', sgld_all_probs), ('psgld', psgld_all_probs)]:
  cal_curve = calibration_curve(all_probs, ytest, num_bins = 20)
  plt.plot(cal_curve[1], cal_curve[0] - cal_curve[1], label=name)
plt.xlabel("confidence", fontsize=16)
plt.ylabel("accuracy - confidence", fontsize=12)
plt.legend()
plt.savefig(path_figures + '/cal_curve-' + title + '.pdf', bbox_inches='tight')
plt.show()