# CIFAR10 Federated Mobilenet Server Side
This code is the server part of CIFAR10 federated mobilenet for **multi** client and a server.

## Setting variables

In [16]:
rounds = 50
local_epoch = 1
users = 8 # number of clients


In [17]:
import os
import h5py

import socket
import struct
import pickle
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from threading import Thread
from threading import Lock


import time

from tqdm import tqdm

## Cuda

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Pytorch layer modules for *Conv1D* Network



### `Conv1d` layer
- `torch.nn.Conv1d(in_channels, out_channels, kernel_size)`

### `MaxPool1d` layer
- `torch.nn.MaxPool1d(kernel_size, stride=None)`
- Parameter `stride` follows `kernel_size`.

### `ReLU` layer
- `torch.nn.ReLU()`

### `Linear` layer
- `torch.nn.Linear(in_features, out_features, bias=True)`

### `Softmax` layer
- `torch.nn.Softmax(dim=None)`
- Parameter `dim` is usually set to `1`.

In [19]:
# -*- coding: utf-8 -*-
"""
Created on Thu Nov  1 14:23:31 2018
@author: tshzzz
"""



# def conv_dw(inplane,outplane,stride=1):
#     return nn.Sequential(
#         nn.Conv2d(inplane,inplane,kernel_size = 3,groups = inplane,stride=stride,padding=1),
#         nn.BatchNorm2d(inplane),
#         nn.ReLU(),
#         nn.Conv2d(inplane,outplane,kernel_size = 1,groups = 1,stride=1),
#         nn.BatchNorm2d(outplane),
#         nn.ReLU()    
#         )

# def conv_bw(inplane,outplane,kernel_size = 3,stride=1):
#     return nn.Sequential(
#         nn.Conv2d(inplane,outplane,kernel_size = kernel_size,groups = 1,stride=stride,padding=1),
#         nn.BatchNorm2d(outplane),
#         nn.ReLU() 
#         )


# class MobileNet(nn.Module):
    
#     def __init__(self,num_class=10):
#         super(MobileNet,self).__init__()
        
#         layers = []
#         layers.append(conv_bw(3,32,3,1))
#         layers.append(conv_dw(32,64,1))
#         layers.append(conv_dw(64,128,2))
#         layers.append(conv_dw(128,128,1))
#         layers.append(conv_dw(128,256,2))
#         layers.append(conv_dw(256,256,1))
#         layers.append(conv_dw(256,512,2))

#         for i in range(5):
#             layers.append(conv_dw(512,512,1))
#         layers.append(conv_dw(512,1024,2))
#         layers.append(conv_dw(1024,1024,1))

#         self.classifer = nn.Sequential(
#                 nn.Dropout(0.5),
#                 nn.Linear(1024,num_class)
#                 )
#         self.feature = nn.Sequential(*layers)
        
        

#     def forward(self,x):
#         out = self.feature(x)
#         out = out.mean(3).mean(2)
#         out = out.view(-1,1024)
#         out = self.classifer(out)
#         return out

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        x = self.model(x)
        return x


In [20]:
mobile_net = Model()
mobile_net.to('cpu')

Model(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)

## variables

In [21]:
import copy

clientsoclist = [0]*users

start_time = 0
weight_count = 0

global_weights = copy.deepcopy(mobile_net.state_dict())

datasetsize = [0]*users
weights_list = [0]*users

lock = Lock()

## Comunication overhead

In [22]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

## Socket initialization
### Set host address and port number

### Required socket functions

In [23]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [24]:
import copy

def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """
        
    for i, data in enumerate(datasize):
        for key in w[i].keys():
            w[i][key] *= (data)
    
    w_avg = copy.deepcopy(w[0])
    
    

# when client use only one kinds of device

    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], (sum(datasize)))

# when client use various devices (cpu, gpu) you need to use it instead
#
#     for key, val in w_avg.items():
#         common_device = val.device
#         break
#     for key in w_avg.keys():
#         for i in range(1, len(w)):
#             if common_device == 'cpu':
#                 w_avg[key] += w[i][key].cpu()
#             else:
#                 w_avg[key] += w[i][key].cuda()
#         w_avg[key] = torch.div(w_avg[key], float(sum(datasize)))

    return w_avg

## Thread define

## Receive users before training

In [25]:
def run_thread(func, num_user):
    global clientsoclist
    global start_time
    
    thrs = []
    for i in range(num_user):
        conn, addr = s.accept()
        print('Conntected with', addr)
        # append client socket on list
        clientsoclist[i] = conn
        args = (i, num_user, conn)
        thread = Thread(target=func, args=args)
        thrs.append(thread)
        thread.start()
    print("timmer start!")
    start_time = time.time()    # store start time
    for thread in thrs:
        thread.join()
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))

In [26]:
def receive(userid, num_users, conn): #thread for receive clients
    global weight_count
    
    global datasetsize


    msg = {
        'rounds': rounds,
        'client_id': userid,
        'local_epoch': local_epoch
    }

    datasize = send_msg(conn, msg)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[userid].append(datasize)

    train_dataset_size, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[userid].append(datasize)
    
    
    with lock:
        datasetsize[userid] = train_dataset_size
        weight_count += 1
    
    train(userid, train_dataset_size, num_users, conn)

## Train

In [27]:
def train(userid, train_dataset_size, num_users, client_conn):
    global weights_list
    global global_weights
    global weight_count
    global mobile_net
    global val_acc
    
    for r in range(rounds):
        with lock:
            if weight_count == num_users:
                for i, conn in enumerate(clientsoclist):
                    datasize = send_msg(conn, global_weights)
                    total_sendsize_list.append(datasize)
                    client_sendsize_list[i].append(datasize)
                    train_sendsize_list.append(datasize)
                    weight_count = 0

        client_weights, datasize = recv_msg(client_conn)
        total_receivesize_list.append(datasize)
        client_receivesize_list[userid].append(datasize)
        train_receivesize_list.append(datasize)

        weights_list[userid] = client_weights
        print("User" + str(userid) + "'s Round " + str(r + 1) +  " is done")
        with lock:
            weight_count += 1
            if weight_count == num_users:
                #average
                global_weights = average_weights(weights_list, datasetsize)
                
        
    

In [28]:
host = '192.168.1.104'
port = 2000
print(host,port)

192.168.1.104 2000


In [29]:
s = socket.socket()
s.bind((host, port))
s.listen()

In [30]:
# s.close()

### Open the server socket

In [31]:
run_thread(receive, users)

Conntected with ('192.168.1.111', 33538)
Conntected with ('192.168.1.112', 56492)
Conntected with ('192.168.1.113', 42892)
Conntected with ('192.168.1.114', 38052)
Conntected with ('192.168.1.115', 44030)
Conntected with ('192.168.1.116', 35006)
Conntected with ('192.168.1.117', 55482)
Conntected with ('192.168.1.118', 41914)
timmer start!
User4's Round 1 is done
User5's Round 1 is done
User6's Round 1 is done
User7's Round 1 is done
User3's Round 1 is done
User1's Round 1 is done
User2's Round 1 is done
User0's Round 1 is done
User4's Round 2 is done
User3's Round 2 is done
User5's Round 2 is done
User7's Round 2 is done
User6's Round 2 is done
User2's Round 2 is done
User1's Round 2 is done
User0's Round 2 is done
User3's Round 3 is done
User4's Round 3 is done
User5's Round 3 is done
User7's Round 3 is done
User6's Round 3 is done
User1's Round 3 is done
User2's Round 3 is done
User0's Round 3 is done
User7's Round 4 is done
User3's Round 4 is done
User6's Round 4 is done
User4's Ro

In [32]:
end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

TrainingTime: 2994.0973851680756 sec


## Print all of communication overhead

In [33]:
# def commmunication_overhead():  
print('\n')
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print("number of total_send: ", len(total_sendsize_list))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print("number of total receive: ", len(total_receivesize_list) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_sendsize_list[i]))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_receivesize_list[i]))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print("number of train_send: ", len(train_sendsize_list) )
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print("number of train_receive: ", len(train_receivesize_list) )
print('\n')




---total_sendsize_list---
total_sendsize size: 234305992 bytes
number of total_send:  408


---total_receivesize_list---
total receive sizes: 234257920 bytes
number of total receive:  408


---client_sendsize_list(user0)---
total client_sendsizes(user0): 29288249 bytes
number of client_send(user0):  51


---client_receivesize_list(user0)---
total client_receive sizes(user0): 29282365 bytes
number of client_send(user0):  51


---client_sendsize_list(user1)---
total client_sendsizes(user1): 29288249 bytes
number of client_send(user1):  51


---client_receivesize_list(user1)---
total client_receive sizes(user1): 29282365 bytes
number of client_send(user1):  51


---client_sendsize_list(user2)---
total client_sendsizes(user2): 29288249 bytes
number of client_send(user2):  51


---client_receivesize_list(user2)---
total client_receive sizes(user2): 29282365 bytes
number of client_send(user2):  51


---client_sendsize_list(user3)---
total client_sendsizes(user3): 29288249 bytes
number of c

In [34]:
root_path = '../../models/cifar10_data'

In [35]:
from torch.utils.data import Dataset, DataLoader

In [36]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

## Making Batch Generator

In [37]:
batch_size=32

In [38]:
trainset = torchvision.datasets.CIFAR10 (root=root_path, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10 (root=root_path, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [39]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

### `DataLoader` for batch generating
`torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)`

### Number of total batches

In [40]:
train_total_batch = len(trainloader)
print(train_total_batch)
test_batch = len(testloader)
print(test_batch)

1563
313


In [41]:
mobile_net.load_state_dict(global_weights)
mobile_net.eval()
mobile_net = mobile_net.to(device)

lr = 0.001
optimizer = optim.SGD(mobile_net.parameters(), lr=lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()

## Accuracy of train and each of classes

In [42]:
# train acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    train_loss = 0.0
    for j, trn in enumerate(trainloader):
        trn_x, trn_label = trn
        trn_x = trn_x.to(device)
        trn_label = trn_label.clone().detach().long().to(device)

        trn_output = mobile_net(trn_x)
        loss = criterion(trn_output, trn_label)
        train_loss += loss.item()
        model_label = trn_output.argmax(dim=1)
        corr = trn_label[trn_label == model_label].size(0)
        corr_num += corr
        total_num += trn_label.size(0)
    print("train_acc: {:.2f}%, train_loss: {:.4f}".format(corr_num / total_num * 100, train_loss / len(trainloader)))


# test acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    val_loss = 0.0
    for j, val in enumerate(testloader):
        val_x, val_label = val
        val_x = val_x.to(device)
        val_label = val_label.clone().detach().long().to(device)

        val_output = mobile_net(val_x)
        loss = criterion(val_output, val_label)
        val_loss += loss.item()
        model_label = val_output.argmax(dim=1)
        corr = val_label[val_label == model_label].size(0)
        corr_num += corr
        total_num += val_label.size(0)
        accuracy = corr_num / total_num * 100
        test_loss = val_loss / len(testloader)
    print("test_acc: {:.2f}%, test_loss: {:.4f}".format( accuracy, test_loss))

# acc of each acc    
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for data in testloader:
        x, labels = data
        x = x.to(device)
        labels = labels.to(device)

        outputs = mobile_net(x)
        labels = labels.long()
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

# Let's quickly save our trained model:
PATH = './cifar10_fd_mobile.pth'
torch.save(mobile_net.state_dict(), PATH)

end_time = time.time()  # store end time
print("WorkingTime: {} sec".format(end_time - start_time))

train_acc: 77.45%, train_loss: 0.6731
test_acc: 69.77%, test_loss: 0.8998
Accuracy of plane : 72 %
Accuracy of   car : 81 %
Accuracy of  bird : 55 %
Accuracy of   cat : 51 %
Accuracy of  deer : 62 %
Accuracy of   dog : 59 %
Accuracy of  frog : 80 %
Accuracy of horse : 76 %
Accuracy of  ship : 79 %
Accuracy of truck : 79 %
WorkingTime: 3036.774134159088 sec
