In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from util import *
from model import SimpleCNN, Net
from train_fed import train_fed
from load_data import load_iid, load_noniid

In [None]:
# Set the random seed
seed = 42
set_seed(seed)

# Hyperparameters
batch_size = 100
learning_rate = 0.001
num_models = 5

In [None]:
train_loaders, test_loader = load_iid(batch_size, train_size=[5000,5000,5000,5000,5000]) # iid case
# train_loaders, test_loader = load_noniid(batch_size, train_size=[5000,5000,5000,5000,5000], min_label=[0,0,0,5,5], max_label=[4,4,4,9,9]) # non-iid case

In [None]:
def build_model():
    return Net()       # Net() is the large CNN, SimpleCNN() is the small CNN.
net = build_model()
print('total parameters = ', count_parameters(net))
print(count_detailed_parameters(net))

In [None]:
# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def run_fed(num_epochs, top_percent, local_buffer=True, global_buffer=True):
    local_models = [build_model().to(device) for _ in range(num_models)]
    global_model = build_model().to(device)
    optimizers = [optim.Adam(local_models[i].parameters(), lr=learning_rate) for i in range(num_models)]
    criterions = [nn.CrossEntropyLoss() for _ in range(num_models)]
    test_acc = train_fed(device, num_epochs, top_percent, train_loaders, test_loader, optimizers, criterions, local_models, global_model, local_buffer, global_buffer)
    return test_acc

In [None]:
test_acc1 = run_fed(num_epochs = 5, top_percent = 1, local_buffer=False, global_buffer=False)  # top_percent = 1 means no filter

In [None]:
test_acc2 = run_fed(num_epochs = 5, top_percent = 1, local_buffer=True, global_buffer=True)

In [None]:
test_acc3 = run_fed(num_epochs = 50, top_percent = 0.1, local_buffer=False, global_buffer=False)

In [None]:
test_acc4 = run_fed(num_epochs = 50, top_percent = 0.1, local_buffer=True, global_buffer=False)

In [None]:
test_acc5 = run_fed(num_epochs = 50, top_percent = 0.1, local_buffer=False, global_buffer=True)

In [None]:
test_acc6 = run_fed(num_epochs = 50, top_percent = 0.1, local_buffer=True, global_buffer=True)

In [None]:
np.savez('LargeModel_iid.npz', test_acc1=test_acc1, test_acc2=test_acc2, test_acc3=test_acc3, test_acc4=test_acc4, test_acc5=test_acc5, test_acc6=test_acc6)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

loaded_arrays = np.load('LargeModel_iid.npz')
test_acc1=loaded_arrays['test_acc1']
test_acc2=loaded_arrays['test_acc2']
test_acc3=loaded_arrays['test_acc3']
test_acc4=loaded_arrays['test_acc4']
test_acc5=loaded_arrays['test_acc5']
test_acc6=loaded_arrays['test_acc6']

x0 = np.linspace(0, 10, 5)
x = np.linspace(0, 10, 50)

# Create a figure and axis
fig, ax = plt.subplots(figsize=(8,6),dpi=150)

# Plot the curves
ax.plot(x0, test_acc1, label='fed')
ax.plot(x0, test_acc2, label='fed+buffer')
ax.plot(x, test_acc3, label='fed+filter')
#ax.plot(x, test_acc4, label='fed+filter+localbuffer')
#ax.plot(x, test_acc5, label='fed+filter+globalbuffer')
ax.plot(x, test_acc6, label='fed+filter+buffer')

# Set axis labels
ax.set_xlabel('communication cost')
ax.set_ylabel('accuracy %')

# Set the plot title
ax.set_title('IID, 5 Local Models, Large CNN')

# Add a legend
ax.legend()

# Show the plot
plt.show()