In [1]:
from federated_learning import Configuration

In [29]:
import torch
import torch.optim as optim
import torch.nn as nn

class Client(): 
    def __init__(self, configs: Configuration, train_dataloader, test_dataloader):
        self.configs = configs
        self.net = self.configs.NETWORK()
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.optimizer = optim.SGD(self.net.parameters(), lr=self.configs.LEARNING_RATE, momentum=self.configs.MOMENTUM)
        self.criterion = nn.CrossEntropyLoss()
        self.train_losses = []
        self.train_counter = []
        self.test_losses = []
        self.test_counter = [i*len(self.train_dataloader.dataset) for i in range(self.configs.N_EPOCHS)]
        self.class_test_accuracy = []
        
    def train(self, epoch):
        self.net.train()
        for batch_idx, (data, target) in enumerate(self.train_dataloader):
            self.optimizer.zero_grad()
            output = self.net(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            if batch_idx % self.configs.LOG_INTERVAL == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(self.train_dataloader.dataset),100. * batch_idx / len(self.train_dataloader), loss.item()))
                self.train_losses.append(loss.item())
                self.train_counter.append((batch_idx*64) + ((epoch-1)*len(self.train_dataloader.dataset)))
                #self.torch.save(self.net.state_dict(), './results/model.pth')
                #torch.save(optimizer.state_dict(), './results/optimizer.pth')
    
    def test(self):
        self.net.eval()
        test_loss = 0
        correct = 0
        total = 0
        # Configuration number of classes
        confusion_matrix = torch.zeros(10, 10)
        with torch.no_grad():
            for data, target in self.test_dataloader:
                output = self.net(data)
                loss = self.criterion(output, target)
                test_loss += loss.item()
                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).sum()
                total += 1
                for t, p in zip(target.view(-1), pred.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1
        test_loss /= total
        self.test_losses.append(test_loss)
        print(confusion_matrix)
        print(confusion_matrix.diag()/confusion_matrix.sum(1))
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(self.test_dataloader.dataset), 100. * correct / len(self.test_dataloader.dataset)))

# Network and Dataloader

In [30]:
config = Configuration()
mnist = config.DATALOADER(config)
client = Client(config, mnist.train_dataloader, mnist.test_dataloader)

MNIST training loader loaded.
MNIST test loader loaded.


In [31]:
client.test()
for epoch in range(1, config.N_EPOCHS + 1):
    client.train(epoch)
    client.test()

tensor([[471.,   0.,   0.,   0.,   0.,   0., 509.,   0.,   0.,   0.],
        [304.,   0.,   0.,  14.,   0.,   0., 817.,   0.,   0.,   0.],
        [298.,   0.,   0.,  10.,   0.,   0., 724.,   0.,   0.,   0.],
        [218.,   0.,   0.,   3.,   0.,   0., 789.,   0.,   0.,   0.],
        [ 61.,   0.,   0.,   0.,   0.,   0., 921.,   0.,   0.,   0.],
        [199.,   0.,   0.,   1.,   0.,   0., 692.,   0.,   0.,   0.],
        [ 76.,   0.,   0.,   7.,   0.,   0., 875.,   0.,   0.,   0.],
        [196.,   0.,   0.,   2.,   0.,   0., 830.,   0.,   0.,   0.],
        [ 28.,   0.,   0.,   0.,   0.,   0., 946.,   0.,   0.,   0.],
        [ 36.,   0.,   0.,   0.,   0.,   0., 973.,   0.,   0.,   0.]])
tensor([0.4806, 0.0000, 0.0000, 0.0030, 0.0000, 0.0000, 0.9134, 0.0000, 0.0000,
        0.0000])

Test set: Avg. loss: 2.3025, Accuracy: 1349/10000 (13%)

tensor([[6.0600e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         3.7400e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
     