In [60]:
import os
import time
import socket
import struct
import pickle
from tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models

import torch.nn as nn
import torch.nn.functional as F


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

cuda:0


In [61]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


In [62]:
class ClientResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ClientResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.layer1(out)
        #out = self.layer2(out)
        #out = self.layer3(out)
        #out = self.layer4(out)
        #out = self.avgpool(out)
        #out = torch.flatten(out, 1)
        #out = self.fc(out)
        return out

# Create the model
client_model = ClientResNet(BasicBlock, [1, 1, 1, 1])
client_model = client_model.to(device)
print(client_model)

ClientResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=

In [63]:
class ServerResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ServerResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        #out = self.relu(self.bn1(self.conv1(x)))
        #out = self.maxpool(out)
        #out = self.layer1(out)
        out = self.layer2(x)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

# Create the model
server_model = ServerResNet(BasicBlock, [1, 1, 1, 1])
server_model = server_model.to(device)
print(server_model)

ServerResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=

In [64]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(server_model.parameters(), lr=0.001, momentum=0.9)

In [65]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [66]:
host = 'localhost'
port = 12310

s = socket.socket()
s.bind((host, port))
s.listen(5)

In [67]:
users = 2
epochs = 5

clientsoclist = []
train_total_batch = []
val_acc = []

total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

for i in range(users):
    conn, addr = s.accept()
    print("conn: ", conn)
    print('Conntected with', addr)
    clientsoclist.append(conn)    # append client socket on list

    datasize = send_msg(conn, epochs)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[i].append(datasize)

    total_batch, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[i].append(datasize)

    train_total_batch.append(total_batch)    # append on list

conn:  <socket.socket fd=2956, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('127.0.0.1', 12310), raddr=('127.0.0.1', 57373)>
Conntected with ('127.0.0.1', 57373)
conn:  <socket.socket fd=3712, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('127.0.0.1', 12310), raddr=('127.0.0.1', 57377)>
Conntected with ('127.0.0.1', 57377)


In [68]:
import copy

client_weights = copy.deepcopy(client_model.state_dict())

start_time = time.time()    # store start time
print("timer start!")

for e in range(epochs):

    # train client 0

    for user in range(users):

        datasize = send_msg(clientsoclist[user], client_weights)
        total_sendsize_list.append(datasize)
        client_sendsize_list[user].append(datasize)
        train_sendsize_list.append(datasize)

        for i in tqdm(range(train_total_batch[user]), ncols=100, desc='Epoch {} Client{} '.format(e+1, user)):
            optimizer.zero_grad()  # initialize all gradients to zero

            msg, datasize = recv_msg(clientsoclist[user])  # receive client message from socket
            total_receivesize_list.append(datasize)
            client_receivesize_list[user].append(datasize)
            train_receivesize_list.append(datasize)

            client_output_cpu = msg['client_output']  # client output tensor
            label = msg['label']  # label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().long().to(device)

            output = server_model(client_output)  # forward propagation
            loss = criterion(output, label)  # calculates cross-entropy loss
            loss.backward()  # backward propagation
            msg = client_output_cpu.grad.clone().detach()

            datasize = send_msg(clientsoclist[user], msg)
            total_sendsize_list.append(datasize)
            client_sendsize_list[user].append(datasize)
            train_sendsize_list.append(datasize)
            
            optimizer.step()
            
        client_weights, datasize = recv_msg(clientsoclist[user])
        print(client_weights)
        total_receivesize_list.append(datasize)
        client_receivesize_list[user].append(datasize)
        train_receivesize_list.append(datasize)
        
print('train is done')

end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

# Let's quickly save our trained model:
PATH = './cifar10_sp_server_model.pth'
torch.save(server_model.state_dict(), PATH)

timer start!


Epoch 1 Client0 : 100%|█████████████████████████████████████████| 1000/1000 [01:19<00:00, 12.62it/s]


OrderedDict([('conv1.weight', tensor([[[[ 0.0040, -0.0783,  0.0089,  ...,  0.0310,  0.0828, -0.0035],
          [-0.0251, -0.0444, -0.0911,  ...,  0.0273,  0.0691,  0.0843],
          [-0.0119,  0.0514, -0.0814,  ..., -0.0118, -0.0103,  0.0372],
          ...,
          [-0.0291,  0.0745,  0.0631,  ...,  0.0904, -0.0442, -0.0624],
          [-0.0053, -0.0400,  0.0664,  ..., -0.0288, -0.0315, -0.0773],
          [ 0.0316,  0.0172, -0.0232,  ...,  0.0132,  0.0642,  0.0059]],

         [[ 0.0457,  0.0483,  0.0557,  ..., -0.0109,  0.0775, -0.0094],
          [ 0.0006, -0.0784,  0.0620,  ..., -0.0821, -0.0320, -0.0326],
          [-0.0012, -0.0276, -0.0468,  ...,  0.0002, -0.0604, -0.0297],
          ...,
          [ 0.0549, -0.0843,  0.0208,  ..., -0.0163,  0.0038, -0.0266],
          [-0.0166, -0.0270, -0.0550,  ..., -0.0373, -0.0752,  0.0081],
          [-0.0606, -0.0590, -0.0504,  ...,  0.0263, -0.0057, -0.0229]],

         [[ 0.0111, -0.0235,  0.0376,  ..., -0.0043, -0.0092, -0.0910],


Epoch 1 Client1 : 100%|█████████████████████████████████████████| 1000/1000 [00:19<00:00, 50.35it/s]


OrderedDict([('conv1.weight', tensor([[[[-1.1810e-03, -8.5127e-02,  6.6092e-03,  ...,  3.7576e-02,
            8.7964e-02, -1.7954e-03],
          [-2.0777e-02, -4.6067e-02, -9.4931e-02,  ...,  3.5145e-02,
            8.6778e-02,  9.6602e-02],
          [-5.0025e-03,  4.8112e-02, -8.9400e-02,  ..., -1.1455e-02,
            8.7519e-04,  4.1916e-02],
          ...,
          [-3.3295e-02,  7.0357e-02,  6.1320e-02,  ...,  8.9943e-02,
           -4.5269e-02, -5.1427e-02],
          [-1.1900e-02, -4.6268e-02,  6.6575e-02,  ..., -2.2217e-02,
           -2.9819e-02, -7.7578e-02],
          [ 2.0916e-02,  7.3783e-03, -2.5528e-02,  ...,  1.8445e-02,
            7.0988e-02,  6.9551e-03]],

         [[ 5.0899e-02,  5.3417e-02,  6.2520e-02,  ..., -9.0795e-03,
            7.2540e-02, -1.8743e-02],
          [ 1.2013e-02, -6.8494e-02,  6.7181e-02,  ..., -8.1081e-02,
           -2.8759e-02, -3.5136e-02],
          [ 1.1765e-02, -1.9894e-02, -4.6018e-02,  ..., -5.0279e-03,
           -5.9948e-02, -3.6

Epoch 2 Client0 : 100%|█████████████████████████████████████████| 1000/1000 [01:19<00:00, 12.63it/s]


OrderedDict([('conv1.weight', tensor([[[[-0.0049, -0.0906,  0.0026,  ...,  0.0439,  0.0972,  0.0043],
          [-0.0175, -0.0493, -0.0993,  ...,  0.0334,  0.0928,  0.1018],
          [ 0.0030,  0.0458, -0.0959,  ..., -0.0179,  0.0046,  0.0461],
          ...,
          [-0.0275,  0.0745,  0.0610,  ...,  0.0905, -0.0485, -0.0447],
          [-0.0034, -0.0432,  0.0686,  ..., -0.0218, -0.0291, -0.0742],
          [ 0.0301,  0.0080, -0.0233,  ...,  0.0183,  0.0752,  0.0101]],

         [[ 0.0502,  0.0539,  0.0651,  ..., -0.0058,  0.0784, -0.0137],
          [ 0.0147, -0.0653,  0.0720,  ..., -0.0801, -0.0244, -0.0334],
          [ 0.0163, -0.0181, -0.0445,  ..., -0.0099, -0.0596, -0.0372],
          ...,
          [ 0.0485, -0.0902,  0.0111,  ..., -0.0286, -0.0141, -0.0250],
          [-0.0270, -0.0394, -0.0626,  ..., -0.0410, -0.0836, -0.0006],
          [-0.0766, -0.0782, -0.0627,  ...,  0.0207, -0.0064, -0.0326]],

         [[ 0.0184, -0.0102,  0.0487,  ..., -0.0085, -0.0257, -0.1181],


Epoch 2 Client1 : 100%|█████████████████████████████████████████| 1000/1000 [00:19<00:00, 50.04it/s]


OrderedDict([('conv1.weight', tensor([[[[-1.4269e-02, -9.6552e-02, -3.9497e-03,  ...,  5.1021e-02,
            1.0612e-01,  1.0027e-02],
          [-2.3327e-02, -5.4048e-02, -1.1033e-01,  ...,  3.7096e-02,
            1.0381e-01,  1.0619e-01],
          [ 1.4402e-03,  4.0936e-02, -1.0275e-01,  ..., -1.5278e-02,
            1.4789e-02,  5.5058e-02],
          ...,
          [-2.7287e-02,  7.7343e-02,  6.6823e-02,  ...,  9.5720e-02,
           -4.1418e-02, -2.8545e-02],
          [ 2.8181e-03, -3.7042e-02,  8.0390e-02,  ..., -1.5233e-02,
           -2.7730e-02, -7.1522e-02],
          [ 3.3069e-02,  1.0487e-02, -1.8526e-02,  ...,  1.8971e-02,
            7.4896e-02,  7.1489e-03]],

         [[ 4.0328e-02,  4.8657e-02,  6.0550e-02,  ..., -7.3904e-03,
            7.7487e-02, -1.4891e-02],
          [ 8.9693e-03, -6.7513e-02,  6.6014e-02,  ..., -8.2843e-02,
           -2.4750e-02, -3.8163e-02],
          [ 1.3882e-02, -2.1751e-02, -4.8513e-02,  ..., -1.1972e-02,
           -5.9377e-02, -3.8

Epoch 3 Client0 : 100%|█████████████████████████████████████████| 1000/1000 [01:17<00:00, 12.97it/s]


OrderedDict([('conv1.weight', tensor([[[[-1.2162e-02, -1.0131e-01, -1.0591e-02,  ...,  6.3421e-02,
            1.1533e-01,  1.4948e-02],
          [-1.8563e-02, -5.5676e-02, -1.1681e-01,  ...,  4.2215e-02,
            1.1382e-01,  1.1348e-01],
          [ 2.5212e-03,  3.7295e-02, -1.0792e-01,  ..., -1.8121e-02,
            2.1259e-02,  6.3917e-02],
          ...,
          [-2.8179e-02,  7.4827e-02,  6.4536e-02,  ...,  9.2228e-02,
           -4.6457e-02, -2.5706e-02],
          [ 5.9563e-03, -3.7703e-02,  8.0004e-02,  ..., -1.9272e-02,
           -2.8540e-02, -6.7772e-02],
          [ 3.4179e-02,  1.0540e-02, -1.7825e-02,  ...,  1.5336e-02,
            7.5042e-02,  7.8581e-03]],

         [[ 4.2762e-02,  4.6625e-02,  5.5992e-02,  ..., -2.8227e-03,
            7.6564e-02, -1.7942e-02],
          [ 1.1719e-02, -6.7889e-02,  6.1493e-02,  ..., -8.3451e-02,
           -2.6755e-02, -4.0685e-02],
          [ 1.2069e-02, -2.4850e-02, -5.1505e-02,  ..., -1.8921e-02,
           -6.1768e-02, -3.7

Epoch 3 Client1 : 100%|█████████████████████████████████████████| 1000/1000 [00:20<00:00, 49.45it/s]


OrderedDict([('conv1.weight', tensor([[[[-0.0166, -0.1054, -0.0149,  ...,  0.0617,  0.1183,  0.0130],
          [-0.0249, -0.0647, -0.1231,  ...,  0.0445,  0.1188,  0.1111],
          [-0.0005,  0.0329, -0.1113,  ..., -0.0197,  0.0214,  0.0610],
          ...,
          [-0.0200,  0.0858,  0.0703,  ...,  0.0972, -0.0429, -0.0207],
          [ 0.0147, -0.0288,  0.0870,  ..., -0.0130, -0.0310, -0.0702],
          [ 0.0394,  0.0148, -0.0173,  ...,  0.0151,  0.0681, -0.0011]],

         [[ 0.0431,  0.0497,  0.0574,  ..., -0.0021,  0.0810, -0.0182],
          [ 0.0091, -0.0711,  0.0606,  ..., -0.0799, -0.0225, -0.0447],
          [ 0.0107, -0.0242, -0.0489,  ..., -0.0163, -0.0594, -0.0411],
          ...,
          [ 0.0335, -0.1012,  0.0030,  ..., -0.0319, -0.0184, -0.0105],
          [-0.0328, -0.0506, -0.0679,  ..., -0.0438, -0.0926, -0.0026],
          [-0.0884, -0.0947, -0.0815,  ...,  0.0089, -0.0160, -0.0465]],

         [[ 0.0128, -0.0036,  0.0474,  ..., -0.0162, -0.0454, -0.1413],


Epoch 4 Client0 : 100%|█████████████████████████████████████████| 1000/1000 [01:20<00:00, 12.39it/s]


OrderedDict([('conv1.weight', tensor([[[[-2.1481e-02, -1.0511e-01, -1.1214e-02,  ...,  6.9607e-02,
            1.2527e-01,  1.8622e-02],
          [-2.2744e-02, -6.3811e-02, -1.2280e-01,  ...,  4.9627e-02,
            1.2597e-01,  1.1950e-01],
          [ 5.6234e-03,  3.6006e-02, -1.1396e-01,  ..., -2.1264e-02,
            2.4106e-02,  7.0120e-02],
          ...,
          [-1.3773e-02,  9.0362e-02,  6.7310e-02,  ...,  9.3430e-02,
           -4.8202e-02, -1.9255e-02],
          [ 1.3885e-02, -2.8183e-02,  8.3910e-02,  ..., -1.9288e-02,
           -3.6644e-02, -6.9614e-02],
          [ 3.3656e-02,  1.2624e-02, -2.1144e-02,  ...,  1.9624e-02,
            6.8364e-02,  1.4674e-03]],

         [[ 3.9211e-02,  4.8739e-02,  5.8767e-02,  ...,  5.5077e-04,
            8.2347e-02, -1.7664e-02],
          [ 1.0418e-02, -7.1996e-02,  5.9489e-02,  ..., -7.8759e-02,
           -2.0510e-02, -4.1826e-02],
          [ 1.5254e-02, -2.3713e-02, -5.2225e-02,  ..., -2.1004e-02,
           -5.9834e-02, -3.6

Epoch 4 Client1 : 100%|█████████████████████████████████████████| 1000/1000 [00:20<00:00, 49.09it/s]


OrderedDict([('conv1.weight', tensor([[[[-2.4681e-02, -1.0569e-01, -1.2344e-02,  ...,  6.9840e-02,
            1.2425e-01,  1.8727e-02],
          [-2.6905e-02, -6.9592e-02, -1.2998e-01,  ...,  4.6274e-02,
            1.2297e-01,  1.1940e-01],
          [ 8.7171e-05,  2.7606e-02, -1.2178e-01,  ..., -2.6459e-02,
            2.5007e-02,  7.6547e-02],
          ...,
          [-8.1168e-03,  9.7963e-02,  7.4360e-02,  ...,  1.0007e-01,
           -4.1338e-02, -1.2457e-02],
          [ 1.4622e-02, -2.5131e-02,  9.0194e-02,  ..., -1.2817e-02,
           -3.0344e-02, -6.3539e-02],
          [ 2.9839e-02,  9.5500e-03, -2.2615e-02,  ...,  2.2916e-02,
            7.3030e-02,  4.8836e-03]],

         [[ 4.0487e-02,  5.2501e-02,  6.3108e-02,  ..., -1.0812e-03,
            7.7442e-02, -1.9015e-02],
          [ 8.2508e-03, -7.5172e-02,  5.6835e-02,  ..., -8.2367e-02,
           -2.7017e-02, -4.4891e-02],
          [ 1.0881e-02, -2.9199e-02, -5.6727e-02,  ..., -2.6301e-02,
           -6.1441e-02, -3.3

Epoch 5 Client0 : 100%|█████████████████████████████████████████| 1000/1000 [01:19<00:00, 12.55it/s]


OrderedDict([('conv1.weight', tensor([[[[-2.2999e-02, -1.0626e-01, -1.2924e-02,  ...,  7.6807e-02,
            1.2873e-01,  2.0032e-02],
          [-1.9949e-02, -6.7369e-02, -1.2879e-01,  ...,  4.9849e-02,
            1.2691e-01,  1.2186e-01],
          [ 5.5918e-03,  3.1211e-02, -1.1781e-01,  ..., -2.5510e-02,
            2.6631e-02,  7.8120e-02],
          ...,
          [-1.2641e-02,  9.4471e-02,  7.4609e-02,  ...,  9.5088e-02,
           -4.4127e-02, -1.0151e-02],
          [ 1.6008e-02, -2.5905e-02,  9.3523e-02,  ..., -1.2689e-02,
           -3.2613e-02, -5.9929e-02],
          [ 3.3950e-02,  1.2884e-02, -2.0586e-02,  ...,  2.6916e-02,
            7.7172e-02,  1.2531e-02]],

         [[ 4.1128e-02,  5.0773e-02,  5.9800e-02,  ..., -1.6790e-04,
            7.7424e-02, -1.9494e-02],
          [ 1.3565e-02, -7.3465e-02,  5.5885e-02,  ..., -8.0277e-02,
           -2.3319e-02, -4.0658e-02],
          [ 1.5748e-02, -2.4347e-02, -5.1726e-02,  ..., -2.3938e-02,
           -5.8237e-02, -3.1

Epoch 5 Client1 : 100%|█████████████████████████████████████████| 1000/1000 [00:22<00:00, 43.64it/s]


OrderedDict([('conv1.weight', tensor([[[[-2.5063e-02, -1.0628e-01, -1.2407e-02,  ...,  8.1355e-02,
            1.3416e-01,  1.9519e-02],
          [-2.1756e-02, -6.8190e-02, -1.3377e-01,  ...,  5.0814e-02,
            1.2645e-01,  1.1999e-01],
          [ 8.2482e-03,  3.1776e-02, -1.2343e-01,  ..., -2.5373e-02,
            2.7292e-02,  7.5826e-02],
          ...,
          [-9.5263e-03,  9.7304e-02,  7.4897e-02,  ...,  9.5666e-02,
           -4.0942e-02, -6.2378e-03],
          [ 1.4411e-02, -2.4151e-02,  9.4994e-02,  ..., -1.7252e-02,
           -3.6921e-02, -6.3127e-02],
          [ 3.5446e-02,  1.9323e-02, -1.7429e-02,  ...,  2.3149e-02,
            7.5422e-02,  7.2052e-03]],

         [[ 4.1143e-02,  5.4288e-02,  6.2320e-02,  ...,  2.8345e-03,
            8.2342e-02, -1.7790e-02],
          [ 1.3365e-02, -7.2136e-02,  5.2644e-02,  ..., -8.1976e-02,
           -2.5284e-02, -4.1936e-02],
          [ 2.0096e-02, -2.2933e-02, -5.5389e-02,  ..., -2.5725e-02,
           -5.9410e-02, -3.3

In [69]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 50

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [70]:
# train acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    train_loss = 0.0
    for j, trn in enumerate(trainloader):
        trn_x, trn_label = trn
        trn_x = trn_x.to(device)
        trn_label = trn_label.to(device)


        trn_output = client_model(trn_x)
        trn_label = trn_label.long()
        
        trn_output = server_model(trn_output)
        
        loss = criterion(trn_output, trn_label)
        train_loss += loss.item()
        model_label = trn_output.argmax(dim=1)
        corr = trn_label[trn_label == model_label].size(0)
        corr_num += corr
        total_num += trn_label.size(0)
    print("train_acc: {:.2f}%, train_loss: {:.4f}".format(corr_num / total_num * 100, train_loss / len(trainloader)))

train_acc: 41.39%, train_loss: 2.7979


In [71]:
with torch.no_grad():
    corr_num = 0
    total_num = 0
    val_loss = 0.0
    for j, val in enumerate(testloader):
        val_x, val_label = val
        val_x = val_x.to(device)
        val_label = val_label.to(device)

        val_output = client_model(val_x)
        val_label = val_label.long()
        
        val_output = server_model(val_output)
        
        loss = criterion(val_output, val_label)
        val_loss += loss.item()
        model_label = val_output.argmax(dim=1)
        corr = val_label[val_label == model_label].size(0)
        corr_num += corr
        total_num += val_label.size(0)
    print("test_acc: {:.2f}%, test_loss: {:.4f}".format(corr_num / total_num * 100, val_loss / len(testloader)))

test_acc: 36.84%, test_loss: 3.2665


In [72]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for data in testloader:
        x, labels = data
        x = x.to(device)
        labels = labels.to(device)

        outputs = client_model(x)
        outputs = server_model(outputs)
        labels = labels.long()
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane : 40 %
Accuracy of   car : 39 %
Accuracy of  bird : 32 %
Accuracy of   cat : 28 %
Accuracy of  deer : 38 %
Accuracy of   dog : 29 %
Accuracy of  frog : 31 %
Accuracy of horse : 39 %
Accuracy of  ship : 35 %
Accuracy of truck : 55 %
