---
## Preliminaries
---

### Import libraries

In [1]:
import matplotlib.pyplot as plt
plt.rcParams.update({
        'text.usetex' : True,
        'font.family' : 'Times',
        'text.latex.preamble' : r'''
    \usepackage{amsmath}
    \usepackage{amssymb}
    \usepackage{times}
    \usepackage{amsfonts}
''',
        'legend.shadow' : False,
        'legend.framealpha' : 1,
        'legend.fancybox' : True,
        'legend.edgecolor' : 'gray',
    }) 
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

### Utility Functions

In [2]:
from utils.data import *
from utils.model import *
from src.client import Client
from src.server import Server
from src.net import MNISTNet, CIFARNet

### Data and CNN

In [3]:
# Load datasets
cifar10_trainset, cifar10_testset = load_CIFAR10()
fmnist_trainset, fmnist_testset = load_Fashion_MNIST()
mnist_trainset, mnist_testset = load_MNIST()

# Organize datasetss
trainsets = dict(
  CIFAR10=cifar10_trainset,
  FashionMNIST=fmnist_trainset,
  MNIST=mnist_trainset,
)
testsets = dict(
  CIFAR10=cifar10_testset,
  FashionMNIST=fmnist_testset,
  MNIST=mnist_testset,
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
net_MNIST = MNISTNet()
net_CIFAR = CIFARNet()

---
## Baseline Models
---

### IID Scenario

In [None]:
server = Server(
    net_MNIST,
    trainsets["MNIST"],
    testsets["MNIST"],
    num_clients=20
)
server.train(rounds=50)

In [None]:
server = Server(
    net_CIFAR,
    trainsets["CIFAR10"],
    testsets["CIFAR10"],
    num_clients=20
)
server.train(rounds=50)
server.plot("testCIFAR10")

Round 50/50 (Testing)			: 100%|██████████| 50/50 [4:49:36<00:00, 347.53s/it]              


Avg:
tensor([1., 1.])
