# Centralized MNIST Baseline

A centralized MNIST baseline experiment using a simple network for comparison
purposes with FLoES functionality.

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

In [None]:
class MNISTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(32 * 26 * 26, 128)
        self.fc2 = torch.nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.fc1(x)
        return self.fc2(x)

In [None]:
train_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

In [None]:
model = MNISTModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

model = model.to(device)
model.train()
accs = []
test_losses = []

print(device)

# number of epochs
with tqdm(range(200), unit='epochs') as tepochs:
    for i in tepochs:
        
        # epoch
        model.train()
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = criterion(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # testing
        model.eval()
        size = len(test_dataloader.dataset)
        num_batches = len(test_dataloader)
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += criterion(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size

        # update
        accs.append(correct * 100.)
        test_losses.append(test_loss)
        tepochs.set_postfix_str(f'Last accuracy: {correct:.3f}')

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(accs)
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Centralized MNIST: Epoch vs. Accuracy")
plt.show()
print(f'Final Accuracy: {accs[-1]:.2f}')
import numpy
print(f'Best Accuracy: {max(accs):.2f} at Epoch: {numpy.argmax(accs)}.')