In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%matplotlib inline
!pip install MulticoreTSNE

%matplotlib inline
from MulticoreTSNE import MulticoreTSNE as TSNE
from matplotlib import pyplot as plt
import torch
import copy
from torchvision import datasets, transforms
from torch import nn
import torch.nn.functional as F
import numpy as np
import pandas as pd 
from sklearn.preprocessing import Normalizer
from tqdm import tqdm_notebook
torch.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



In [None]:
UNLABELED_BS = 256
TRAIN_BS = 32
TEST_BS = 1024

total_samples = 60000
label_unlabel_ratio = 0.1
num_train_samples = (int)(label_unlabel_ratio * total_samples)

x = pd.read_csv('/content/drive/MyDrive/mnistDataset/mnist_train.csv')
y = x['label']
x.drop(['label'], inplace = True, axis = 1)

x_test = pd.read_csv('/content/drive/MyDrive/mnistDataset/mnist_test.csv')
y_test = x_test['label']
x_test.drop(['label'], inplace = True, axis = 1)

normalizer = Normalizer()

In [None]:
# Architecture from : https://github.com/peimengsui/semi_supervised_mnist
class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
            self.conv2 = nn.Conv2d(20, 40, kernel_size=5)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(640, 150)
            self.fc2 = nn.Linear(150, 10)
            self.log_softmax = nn.LogSoftmax(dim = 1)

        def forward(self, x):
            x = x.view(-1,1,28,28)
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 640)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            x = F.relu(self.fc2(x))
            x = self.log_softmax(x)
            return x
        
net = Net().cuda()

In [None]:
def getClientDataset(num_clients, x, y):

    
    
    x_lab = x.iloc[0:num_train_samples]
    y_lab = y.iloc[0:num_train_samples]
    
    x_unlab = x.iloc[num_train_samples:]

    samples_per_client_lab = (int)(x_lab.shape[0]/num_clients)
    samples_per_client_unlab = (int)(x_unlab.shape[0]/num_clients)

    client_train_lab = []
    client_train_unlab = []

    for i in range(num_clients):

        xlab = x_lab.iloc[i*samples_per_client_lab:(i+1)*samples_per_client_lab, :].values
        ylab = y_lab.iloc[i*samples_per_client_lab:(i+1)*samples_per_client_lab].values
        xunlab = x_unlab.iloc[i*samples_per_client_unlab:(i+1)*samples_per_client_unlab, :].values
        print(xlab.shape)
        print(ylab.shape)
        print(xunlab.shape)
        xlab = normalizer.fit_transform(xlab)
        xunlab = normalizer.transform(xunlab)

        xlab = torch.from_numpy(xlab).type(torch.FloatTensor)
        ylab = torch.from_numpy(ylab).type(torch.LongTensor) 

        train = torch.utils.data.TensorDataset(xlab, ylab)

        train_loader = torch.utils.data.DataLoader(train, batch_size = TRAIN_BS, shuffle = True, num_workers = 8)

        unlabeled_train = torch.from_numpy(xunlab).type(torch.FloatTensor)

        unlabeled = torch.utils.data.TensorDataset(unlabeled_train)
        unlabeled_loader = torch.utils.data.DataLoader(unlabeled, batch_size = UNLABELED_BS, shuffle = True, num_workers = 8)

        client_train_lab.append(train_loader)
        client_train_unlab.append(unlabeled_loader)

    return client_train_lab, client_train_unlab


def getTestData(x_test, y_test):

    x_test = normalizer.transform(x_test.values)

    x_test = torch.from_numpy(x_test).type(torch.FloatTensor)
    y_test = torch.from_numpy(y_test.values).type(torch.LongTensor)

    test = torch.utils.data.TensorDataset(x_test, y_test)

    test_loader = torch.utils.data.DataLoader(test, batch_size = TEST_BS, shuffle = True, num_workers = 8)

    return test_loader


def getClientModels(num_clients):

    client_models = []

    for i in range(num_clients):
        client_models.append(Net().cuda())
      
    return client_models


def getGlobalModel():

    return Net().cuda()


def updateGlobalModel(num_client, client_models, global_model):

    with torch.no_grad():

        conv1_wts = client_models[0].conv1.weight.data    
        conv2_wts = client_models[0].conv2.weight.data
        fc1_wts = client_models[0].fc1.weight.data
        fc2_wts = client_models[0].fc2.weight.data

        conv1_b = client_models[0].conv1.bias.data    
        conv2_b = client_models[0].conv2.bias.data
        fc1_b = client_models[0].fc1.bias.data
        fc2_b = client_models[0].fc2.bias.data


        for i in range(1,num_clients):

            conv1_wts += client_models[i].conv1.weight.data    
            conv2_wts += client_models[i].conv2.weight.data
            fc1_wts += client_models[i].fc1.weight.data
            fc2_wts += client_models[i].fc2.weight.data

            conv1_b += client_models[i].conv1.bias.data    
            conv2_b += client_models[i].conv2.bias.data
            fc1_b += client_models[i].fc1.bias.data
            fc2_b += client_models[i].fc2.bias.data

        global_model.conv1.weight.set_(conv1_wts/num_clients)
        global_model.conv2.weight.set_(conv2_wts/num_clients)
        global_model.fc1.weight.set_(fc1_wts/num_clients)
        global_model.fc2.weight.set_(fc2_wts/num_clients)

        global_model.conv1.bias.set_(conv1_b/num_clients)
        global_model.conv2.bias.set_(conv2_b/num_clients)
        global_model.fc1.bias.set_(fc1_b/num_clients)
        global_model.fc2.bias.set_(fc2_b/num_clients)


    return global_model





In [None]:
def evaluate(model, test_loader):
    model.eval()
    correct = 0 
    loss = 0
    i=0
    with torch.no_grad():
        for data, labels in test_loader:
            
            data = data.cuda()
            output = model(data)
            predicted = torch.max(output,1)[1]
            correct += (predicted == labels.cuda()).sum()
            loss += F.nll_loss(output, labels.cuda()).item()
            i+=predicted.shape[0]
    return (float(correct)/i) *100, (loss/len(test_loader))

In [None]:
def train_supervised(model, train_loader, test_loader, EPOCHS):
    optimizer = torch.optim.SGD( model.parameters(), lr = 0.1)
    model.train()

    for epoch in tqdm_notebook(range(EPOCHS)):
        correct = 0
        running_loss = 0
        i=0
        for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
            X_batch, y_batch = X_batch.cuda(), y_batch.cuda()
            i+=y_batch.shape[0]
            output = model(X_batch)
            labeled_loss = F.nll_loss(output, y_batch)
                       
            optimizer.zero_grad()
            labeled_loss.backward()
            optimizer.step()
            running_loss += labeled_loss.item()
        
        if epoch %1 == 0:
            test_acc, test_loss = evaluate(model, test_loader)
            print('Epoch: {} : Train Loss : {:.5f} | Test Acc : {:.5f} | Test Loss : {:.3f} '.format(epoch, running_loss/(1 * i), test_acc, test_loss))
            model.train()

In [None]:
T1 = 100
T2 = 700
af = 3

def alpha_weight(epoch):
    if epoch < T1:
        return 0.0
    elif epoch > T2:
        return af
    else:
         return ((epoch-T1) / (T2-T1))*af

In [None]:
from tqdm import tqdm_notebook

acc_scores = []
unlabel = []
pseudo_label = []

alpha_log = []
test_acc_log = []
test_loss_log = []
def semisup_train(model, train_loader, unlabeled_loader, test_loader, EPOCHS):
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)
    
    # Instead of using current epoch we use a "step" variable to calculate alpha_weight
    # This helps the model converge faster
    step = 100 
    
    model.train()
    for epoch in tqdm_notebook(range(EPOCHS)):
        for batch_idx, x_unlabeled in enumerate(unlabeled_loader):
            
            
            # Forward Pass to get the pseudo labels
            x_unlabeled = x_unlabeled[0].cuda()
            model.eval()
            output_unlabeled = model(x_unlabeled)
            _, pseudo_labeled = torch.max(output_unlabeled, 1)
            model.train()

                        
            
            """ ONLY FOR VISUALIZATION"""
            if (batch_idx < 3) and (epoch % 10 == 0):
                unlabel.append(x_unlabeled.cpu())
                pseudo_label.append(pseudo_labeled.cpu())
            """ ********************** """
            
            # Now calculate the unlabeled loss using the pseudo label
            output = model(x_unlabeled)
            unlabeled_loss = alpha_weight(step) * F.nll_loss(output, pseudo_labeled)   
            
            # Backpropogate
            optimizer.zero_grad()
            unlabeled_loss.backward()
            optimizer.step()
            
            
            # For every 50 batches train one epoch on labeled data 
            if batch_idx % 50 == 0:
                
                # Normal training procedure
                for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
                    X_batch = X_batch.cuda()
                    y_batch = y_batch.cuda()
                    output = model(X_batch)
                    labeled_loss = F.nll_loss(output, y_batch)

                    optimizer.zero_grad()
                    labeled_loss.backward()
                    optimizer.step()

                                    
                # Now we increment step by 1
                step += 1
                

        test_acc, test_loss =evaluate(model, test_loader)
        print('Epoch: {} : Alpha Weight : {:.5f} | Test Acc : {:.5f} | Test Loss : {:.3f} '.format(epoch, alpha_weight(step), test_acc, test_loss))
        
        """ LOGGING VALUES """
        alpha_log.append(alpha_weight(step))
        test_acc_log.append(test_acc/100)
        test_loss_log.append(test_loss)
        """ ************** """
        model.train()

In [None]:
num_global_iter = 5
num_clients = 4
num_local_epochs = 5

print("======================================== !!! Starting Federated Learning !!! ==================================================")

client_models = getClientModels(num_clients)
global_model = getGlobalModel()

client_train_lab, client_train_unlab = getClientDataset(num_clients, x, y)
test_loader = getTestData(x_test, y_test)

for iter in range(num_global_iter):

    print("============================================ Global Training Iteration : " + str(iter+1) + " ========================================\n\n")

    for i in range(num_clients):
        
        print("==================================== Supervised Training on Client : " + str(i+1) + "============================================\n")

        train_supervised(client_models[i], client_train_lab[i], test_loader, num_local_epochs)

        print("====================================== Semi Supervised Alpha Weighted Training =========================================\n")

        semisup_train(client_models[i], client_train_lab[i], client_train_unlab[i], test_loader, num_local_epochs)

    global_model = updateGlobalModel(num_clients, client_models, global_model)

    test_acc, test_loss = evaluate(global_model, test_loader)

    print("Global Model Accuracy : ", test_acc)
    print("Global Model Loss : ", test_loss)

    client_models = []

    for i in range(num_clients):
        client_models.append(copy.deepcopy(global_model))


(1500, 784)
(1500,)
(13500, 784)
(1500, 784)
(1500,)
(13500, 784)
(1500, 784)
(1500,)
(13500, 784)


  cpuset_checked))
  cpuset_checked))
  cpuset_checked))


(1500, 784)
(1500,)
(13500, 784)





  cpuset_checked))
  cpuset_checked))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


  0%|          | 0/5 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Traceback (most recent call last):
AssertionError: can only test a child process
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
    

Epoch: 0 : Train Loss : 0.07212 | Test Acc : 10.28000 | Test Loss : 2.301 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Traceback (most recent call last):
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Traceback (most recent call last):
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Traceback (most recent cal

Epoch: 1 : Train Loss : 0.07209 | Test Acc : 10.28000 | Test Loss : 2.301 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
    self._shutdown_workers()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
    if w.is_alive():
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Traceback (most recent call last):
  File "/usr/lib/python3.7/multi

Epoch: 2 : Train Loss : 0.07204 | Test Acc : 20.05000 | Test Loss : 2.301 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc59a039050>
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
AssertionError: can only test a child process
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

Epoch: 3 : Train Loss : 0.07201 | Test Acc : 10.28000 | Test Loss : 2.301 
Epoch: 4 : Train Loss : 0.07197 | Test Acc : 10.28000 | Test Loss : 2.301 



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 18.10000 | Test Loss : 2.296 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 19.14000 | Test Loss : 2.278 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 25.31000 | Test Loss : 2.087 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 51.64000 | Test Loss : 1.515 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 71.16000 | Test Loss : 0.901 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.07217 | Test Acc : 11.35000 | Test Loss : 2.302 
Epoch: 1 : Train Loss : 0.07210 | Test Acc : 11.35000 | Test Loss : 2.301 
Epoch: 2 : Train Loss : 0.07208 | Test Acc : 11.35000 | Test Loss : 2.301 
Epoch: 3 : Train Loss : 0.07205 | Test Acc : 19.14000 | Test Loss : 2.301 
Epoch: 4 : Train Loss : 0.07201 | Test Acc : 26.34000 | Test Loss : 2.300 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 21.98000 | Test Loss : 2.292 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 23.25000 | Test Loss : 2.226 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 50.66000 | Test Loss : 1.726 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 76.71000 | Test Loss : 0.781 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 81.35000 | Test Loss : 0.564 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.07213 | Test Acc : 10.28000 | Test Loss : 2.303 
Epoch: 1 : Train Loss : 0.07210 | Test Acc : 10.28000 | Test Loss : 2.302 
Epoch: 2 : Train Loss : 0.07209 | Test Acc : 10.28000 | Test Loss : 2.301 
Epoch: 3 : Train Loss : 0.07207 | Test Acc : 20.31000 | Test Loss : 2.300 
Epoch: 4 : Train Loss : 0.07200 | Test Acc : 20.29000 | Test Loss : 2.298 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 20.60000 | Test Loss : 2.286 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 25.96000 | Test Loss : 2.171 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 47.63000 | Test Loss : 1.434 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 79.51000 | Test Loss : 0.688 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 85.92000 | Test Loss : 0.471 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.07216 | Test Acc : 9.80000 | Test Loss : 2.303 
Epoch: 1 : Train Loss : 0.07214 | Test Acc : 9.80000 | Test Loss : 2.303 
Epoch: 2 : Train Loss : 0.07210 | Test Acc : 9.80000 | Test Loss : 2.303 
Epoch: 3 : Train Loss : 0.07210 | Test Acc : 18.70000 | Test Loss : 2.302 
Epoch: 4 : Train Loss : 0.07206 | Test Acc : 10.20000 | Test Loss : 2.301 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 19.51000 | Test Loss : 2.298 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 19.10000 | Test Loss : 2.287 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 29.34000 | Test Loss : 2.218 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 45.04000 | Test Loss : 1.755 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 66.23000 | Test Loss : 0.993 
Global Model Accuracy :  24.169999999999998
Global Model Loss :  2.232388663291931





  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.06162 | Test Acc : 55.66000 | Test Loss : 1.470 
Epoch: 1 : Train Loss : 0.04019 | Test Acc : 72.10000 | Test Loss : 0.872 
Epoch: 2 : Train Loss : 0.03009 | Test Acc : 77.46000 | Test Loss : 0.677 
Epoch: 3 : Train Loss : 0.02602 | Test Acc : 81.78000 | Test Loss : 0.585 
Epoch: 4 : Train Loss : 0.02275 | Test Acc : 84.74000 | Test Loss : 0.489 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 87.57000 | Test Loss : 0.396 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 89.59000 | Test Loss : 0.327 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 89.94000 | Test Loss : 0.317 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 91.90000 | Test Loss : 0.261 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 92.45000 | Test Loss : 0.243 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.05962 | Test Acc : 58.54000 | Test Loss : 1.359 
Epoch: 1 : Train Loss : 0.03513 | Test Acc : 69.83000 | Test Loss : 0.850 
Epoch: 2 : Train Loss : 0.02512 | Test Acc : 79.62000 | Test Loss : 0.635 
Epoch: 3 : Train Loss : 0.02100 | Test Acc : 76.99000 | Test Loss : 0.651 
Epoch: 4 : Train Loss : 0.01788 | Test Acc : 85.32000 | Test Loss : 0.471 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 87.24000 | Test Loss : 0.387 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 89.57000 | Test Loss : 0.329 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 90.95000 | Test Loss : 0.289 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 90.76000 | Test Loss : 0.285 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 92.32000 | Test Loss : 0.246 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.06065 | Test Acc : 57.20000 | Test Loss : 1.391 
Epoch: 1 : Train Loss : 0.03679 | Test Acc : 72.86000 | Test Loss : 0.839 
Epoch: 2 : Train Loss : 0.02730 | Test Acc : 79.88000 | Test Loss : 0.648 
Epoch: 3 : Train Loss : 0.02329 | Test Acc : 83.65000 | Test Loss : 0.524 
Epoch: 4 : Train Loss : 0.02005 | Test Acc : 84.62000 | Test Loss : 0.456 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 88.53000 | Test Loss : 0.383 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 89.23000 | Test Loss : 0.344 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 90.44000 | Test Loss : 0.303 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 91.52000 | Test Loss : 0.275 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 92.44000 | Test Loss : 0.243 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.06114 | Test Acc : 60.45000 | Test Loss : 1.406 
Epoch: 1 : Train Loss : 0.03715 | Test Acc : 76.86000 | Test Loss : 0.787 
Epoch: 2 : Train Loss : 0.02867 | Test Acc : 78.72000 | Test Loss : 0.635 
Epoch: 3 : Train Loss : 0.02341 | Test Acc : 83.65000 | Test Loss : 0.532 
Epoch: 4 : Train Loss : 0.02055 | Test Acc : 84.58000 | Test Loss : 0.493 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 86.79000 | Test Loss : 0.417 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 88.95000 | Test Loss : 0.345 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 90.65000 | Test Loss : 0.300 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 91.75000 | Test Loss : 0.269 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 92.23000 | Test Loss : 0.248 
Global Model Accuracy :  92.86999999999999
Global Model Loss :  0.236835315823555





  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.01102 | Test Acc : 92.88000 | Test Loss : 0.230 
Epoch: 1 : Train Loss : 0.01088 | Test Acc : 92.54000 | Test Loss : 0.234 
Epoch: 2 : Train Loss : 0.00960 | Test Acc : 92.78000 | Test Loss : 0.233 
Epoch: 3 : Train Loss : 0.00917 | Test Acc : 93.31000 | Test Loss : 0.211 
Epoch: 4 : Train Loss : 0.00922 | Test Acc : 93.57000 | Test Loss : 0.202 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 93.49000 | Test Loss : 0.204 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 93.87000 | Test Loss : 0.199 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 94.05000 | Test Loss : 0.186 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 94.81000 | Test Loss : 0.171 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 94.41000 | Test Loss : 0.178 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00877 | Test Acc : 92.37000 | Test Loss : 0.236 
Epoch: 1 : Train Loss : 0.00834 | Test Acc : 92.65000 | Test Loss : 0.229 
Epoch: 2 : Train Loss : 0.00793 | Test Acc : 92.49000 | Test Loss : 0.230 
Epoch: 3 : Train Loss : 0.00741 | Test Acc : 93.12000 | Test Loss : 0.216 
Epoch: 4 : Train Loss : 0.00722 | Test Acc : 92.86000 | Test Loss : 0.218 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 92.92000 | Test Loss : 0.218 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 93.37000 | Test Loss : 0.200 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 93.38000 | Test Loss : 0.194 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 93.98000 | Test Loss : 0.186 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 94.23000 | Test Loss : 0.185 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00917 | Test Acc : 92.70000 | Test Loss : 0.235 
Epoch: 1 : Train Loss : 0.00926 | Test Acc : 93.08000 | Test Loss : 0.222 
Epoch: 2 : Train Loss : 0.00855 | Test Acc : 93.25000 | Test Loss : 0.215 
Epoch: 3 : Train Loss : 0.00840 | Test Acc : 93.57000 | Test Loss : 0.211 
Epoch: 4 : Train Loss : 0.00765 | Test Acc : 92.86000 | Test Loss : 0.235 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 93.57000 | Test Loss : 0.213 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 93.96000 | Test Loss : 0.202 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 94.10000 | Test Loss : 0.196 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 94.25000 | Test Loss : 0.190 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 94.48000 | Test Loss : 0.182 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.01021 | Test Acc : 91.92000 | Test Loss : 0.259 
Epoch: 1 : Train Loss : 0.00938 | Test Acc : 93.13000 | Test Loss : 0.223 
Epoch: 2 : Train Loss : 0.00911 | Test Acc : 93.00000 | Test Loss : 0.226 
Epoch: 3 : Train Loss : 0.00891 | Test Acc : 93.02000 | Test Loss : 0.225 
Epoch: 4 : Train Loss : 0.00769 | Test Acc : 93.67000 | Test Loss : 0.203 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 93.43000 | Test Loss : 0.210 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 93.62000 | Test Loss : 0.201 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 94.00000 | Test Loss : 0.197 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 94.03000 | Test Loss : 0.200 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 94.14000 | Test Loss : 0.191 
Global Model Accuracy :  95.04
Global Model Loss :  0.1554717466235161





  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00751 | Test Acc : 94.87000 | Test Loss : 0.164 
Epoch: 1 : Train Loss : 0.00650 | Test Acc : 94.93000 | Test Loss : 0.158 
Epoch: 2 : Train Loss : 0.00677 | Test Acc : 95.25000 | Test Loss : 0.154 
Epoch: 3 : Train Loss : 0.00659 | Test Acc : 94.81000 | Test Loss : 0.167 
Epoch: 4 : Train Loss : 0.00647 | Test Acc : 94.92000 | Test Loss : 0.163 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 95.08000 | Test Loss : 0.155 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 94.72000 | Test Loss : 0.166 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 95.27000 | Test Loss : 0.153 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 94.89000 | Test Loss : 0.160 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 95.49000 | Test Loss : 0.148 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00605 | Test Acc : 95.27000 | Test Loss : 0.145 
Epoch: 1 : Train Loss : 0.00524 | Test Acc : 95.13000 | Test Loss : 0.154 
Epoch: 2 : Train Loss : 0.00514 | Test Acc : 95.04000 | Test Loss : 0.157 
Epoch: 3 : Train Loss : 0.00488 | Test Acc : 95.28000 | Test Loss : 0.150 
Epoch: 4 : Train Loss : 0.00462 | Test Acc : 94.93000 | Test Loss : 0.158 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 95.19000 | Test Loss : 0.148 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 95.48000 | Test Loss : 0.143 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 95.15000 | Test Loss : 0.156 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 95.20000 | Test Loss : 0.156 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 94.93000 | Test Loss : 0.164 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00615 | Test Acc : 95.00000 | Test Loss : 0.160 
Epoch: 1 : Train Loss : 0.00572 | Test Acc : 94.93000 | Test Loss : 0.161 
Epoch: 2 : Train Loss : 0.00581 | Test Acc : 95.09000 | Test Loss : 0.158 
Epoch: 3 : Train Loss : 0.00486 | Test Acc : 94.91000 | Test Loss : 0.159 
Epoch: 4 : Train Loss : 0.00469 | Test Acc : 94.81000 | Test Loss : 0.162 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 95.19000 | Test Loss : 0.157 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 95.25000 | Test Loss : 0.153 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 94.74000 | Test Loss : 0.178 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 95.35000 | Test Loss : 0.162 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 95.22000 | Test Loss : 0.166 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00676 | Test Acc : 94.82000 | Test Loss : 0.161 
Epoch: 1 : Train Loss : 0.00676 | Test Acc : 94.98000 | Test Loss : 0.158 
Epoch: 2 : Train Loss : 0.00593 | Test Acc : 94.86000 | Test Loss : 0.164 
Epoch: 3 : Train Loss : 0.00589 | Test Acc : 95.20000 | Test Loss : 0.153 
Epoch: 4 : Train Loss : 0.00620 | Test Acc : 94.90000 | Test Loss : 0.164 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 94.78000 | Test Loss : 0.171 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 95.36000 | Test Loss : 0.149 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 95.48000 | Test Loss : 0.148 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 95.17000 | Test Loss : 0.152 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 95.53000 | Test Loss : 0.150 
Global Model Accuracy :  96.07
Global Model Loss :  0.1284249447286129





  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00598 | Test Acc : 95.98000 | Test Loss : 0.124 
Epoch: 1 : Train Loss : 0.00544 | Test Acc : 95.47000 | Test Loss : 0.143 
Epoch: 2 : Train Loss : 0.00475 | Test Acc : 95.77000 | Test Loss : 0.131 
Epoch: 3 : Train Loss : 0.00498 | Test Acc : 95.61000 | Test Loss : 0.135 
Epoch: 4 : Train Loss : 0.00458 | Test Acc : 95.96000 | Test Loss : 0.128 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 95.62000 | Test Loss : 0.132 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 95.97000 | Test Loss : 0.129 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 95.35000 | Test Loss : 0.148 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 96.18000 | Test Loss : 0.122 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 95.91000 | Test Loss : 0.131 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00503 | Test Acc : 95.53000 | Test Loss : 0.145 
Epoch: 1 : Train Loss : 0.00479 | Test Acc : 95.80000 | Test Loss : 0.133 
Epoch: 2 : Train Loss : 0.00397 | Test Acc : 95.96000 | Test Loss : 0.131 
Epoch: 3 : Train Loss : 0.00399 | Test Acc : 95.87000 | Test Loss : 0.133 
Epoch: 4 : Train Loss : 0.00381 | Test Acc : 95.94000 | Test Loss : 0.127 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 95.98000 | Test Loss : 0.128 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 95.81000 | Test Loss : 0.142 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 95.78000 | Test Loss : 0.142 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 95.84000 | Test Loss : 0.140 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 96.03000 | Test Loss : 0.129 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00505 | Test Acc : 95.66000 | Test Loss : 0.137 
Epoch: 1 : Train Loss : 0.00443 | Test Acc : 95.90000 | Test Loss : 0.127 
Epoch: 2 : Train Loss : 0.00443 | Test Acc : 95.59000 | Test Loss : 0.143 
Epoch: 3 : Train Loss : 0.00382 | Test Acc : 95.63000 | Test Loss : 0.141 
Epoch: 4 : Train Loss : 0.00328 | Test Acc : 95.84000 | Test Loss : 0.135 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 95.86000 | Test Loss : 0.135 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 95.49000 | Test Loss : 0.150 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 95.76000 | Test Loss : 0.144 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 95.81000 | Test Loss : 0.140 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 95.68000 | Test Loss : 0.153 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Train Loss : 0.00493 | Test Acc : 96.13000 | Test Loss : 0.127 
Epoch: 1 : Train Loss : 0.00539 | Test Acc : 95.91000 | Test Loss : 0.131 
Epoch: 2 : Train Loss : 0.00502 | Test Acc : 95.70000 | Test Loss : 0.142 
Epoch: 3 : Train Loss : 0.00466 | Test Acc : 95.95000 | Test Loss : 0.122 
Epoch: 4 : Train Loss : 0.00452 | Test Acc : 95.90000 | Test Loss : 0.131 



  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0 : Alpha Weight : 0.01000 | Test Acc : 95.93000 | Test Loss : 0.139 
Epoch: 1 : Alpha Weight : 0.02000 | Test Acc : 95.76000 | Test Loss : 0.141 
Epoch: 2 : Alpha Weight : 0.03000 | Test Acc : 95.85000 | Test Loss : 0.140 
Epoch: 3 : Alpha Weight : 0.04000 | Test Acc : 95.73000 | Test Loss : 0.141 
Epoch: 4 : Alpha Weight : 0.05000 | Test Acc : 95.85000 | Test Loss : 0.143 
Global Model Accuracy :  96.46000000000001
Global Model Loss :  0.11153065487742424


In [None]:
test_acc, test_loss = evaluate(global_model, test_loader)
print('Test Acc : {:.5f} | Test Loss : {:.3f} '.format(test_acc, test_loss))

  cpuset_checked))


Test Acc : 96.46000 | Test Loss : 0.112 
