In [15]:
import os
import time
import socket
import struct
import pickle
from tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models

import torch.nn as nn
import torch.nn.functional as F


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

cuda:0


In [16]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [17]:
def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )

def conv_dw(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.ReLU(inplace=True),

        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True),
    )

In [18]:
class ClientMobileNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(ClientMobileNet, self).__init__()
        self.layer1 = conv_bn(3, 32, 2)
        self.layer2 = conv_dw(32, 64, 1)
        self.layer3 = conv_dw(64, 128, 2)
        self.layer4 = conv_dw(128, 128, 1)
        self.layer5 = conv_dw(128, 256, 2)
        self.layer6 = conv_dw(256, 256, 1)
        self.layer7 = conv_dw(256, 512, 2)
        self.layer8 = conv_dw(512, 512, 1)
        self.layer9 = conv_dw(512, 512, 1)
        self.layer10 = conv_dw(512, 512, 1)
        self.layer11 = conv_dw(512, 512, 1)
        self.layer12 = conv_dw(512, 512, 1)
        self.layer13 = conv_dw(512, 1024, 2)
        self.layer14 = conv_dw(1024, 1024, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        #x = self.layer4(x)
        #x = self.layer5(x)
        #x = self.layer6(x)
        #x = self.layer7(x)
        #x = self.layer8(x)
        #x = self.layer9(x)
        #x = self.layer10(x)
        #x = self.layer11(x)
        #x = self.layer12(x)
        #x = self.layer13(x)
        #x = self.layer14(x)
        #x = self.avg_pool(x)
        #x = x.view(-1, 1024)
        #x = self.fc(x)
        return x

client_model = ClientMobileNet(num_classes=10)
client_model = client_model.to(device)
print(client_model)

ClientMobileNet(
  (layer1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (layer3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): B

In [19]:
class ServerMobileNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(ServerMobileNet, self).__init__()
        self.layer1 = conv_bn(3, 32, 2)
        self.layer2 = conv_dw(32, 64, 1)
        self.layer3 = conv_dw(64, 128, 2)
        self.layer4 = conv_dw(128, 128, 1)
        self.layer5 = conv_dw(128, 256, 2)
        self.layer6 = conv_dw(256, 256, 1)
        self.layer7 = conv_dw(256, 512, 2)
        self.layer8 = conv_dw(512, 512, 1)
        self.layer9 = conv_dw(512, 512, 1)
        self.layer10 = conv_dw(512, 512, 1)
        self.layer11 = conv_dw(512, 512, 1)
        self.layer12 = conv_dw(512, 512, 1)
        self.layer13 = conv_dw(512, 1024, 2)
        self.layer14 = conv_dw(1024, 1024, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        # x = self.layer1(x)
        # x = self.layer2(x)
        # x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)
        x = self.layer9(x)
        x = self.layer10(x)
        x = self.layer11(x)
        x = self.layer12(x)
        x = self.layer13(x)
        x = self.layer14(x)
        x = self.avg_pool(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

server_model = ServerMobileNet(num_classes=10)
server_model = server_model.to(device)
print(server_model)

ServerMobileNet(
  (layer1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (layer3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): B

In [20]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(server_model.parameters(), lr=0.001, momentum=0.9)

In [21]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [22]:
host = 'localhost'
port = 12303

s = socket.socket()
s.bind((host, port))
s.listen(5)

In [23]:
users = 1
epochs = 1

clientsoclist = []
train_total_batch = []
val_acc = []

total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

for i in range(users):
    conn, addr = s.accept()
    print("conn: ", conn)
    print('Conntected with', addr)
    clientsoclist.append(conn)    # append client socket on list

    datasize = send_msg(conn, epochs)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[i].append(datasize)

    total_batch, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[i].append(datasize)

    train_total_batch.append(total_batch)    # append on list

conn:  <socket.socket fd=1940, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('127.0.0.1', 12303), raddr=('127.0.0.1', 55665)>
Conntected with ('127.0.0.1', 55665)


In [24]:
import copy

client_weights = copy.deepcopy(client_model.state_dict())

start_time = time.time()    # store start time
print("timmer start!")

for e in range(epochs):

    # train client 0

    for user in range(users):

        datasize = send_msg(clientsoclist[user], client_weights)
        total_sendsize_list.append(datasize)
        client_sendsize_list[user].append(datasize)
        train_sendsize_list.append(datasize)

        for i in tqdm(range(train_total_batch[user]), ncols=100, desc='Epoch {} Client{} '.format(e+1, user)):
            optimizer.zero_grad()  # initialize all gradients to zero

            msg, datasize = recv_msg(clientsoclist[user])  # receive client message from socket
            total_receivesize_list.append(datasize)
            client_receivesize_list[user].append(datasize)
            train_receivesize_list.append(datasize)

            client_output_cpu = msg['client_output']  # client output tensor
            label = msg['label']  # label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().long().to(device)

            output = server_model(client_output)  # forward propagation
            loss = criterion(output, label)  # calculates cross-entropy loss
            loss.backward()  # backward propagation
            msg = client_output_cpu.grad.clone().detach()

            datasize = send_msg(clientsoclist[user], msg)
            total_sendsize_list.append(datasize)
            client_sendsize_list[user].append(datasize)
            train_sendsize_list.append(datasize)
            
            optimizer.step()
            
        client_weights, datasize = recv_msg(clientsoclist[user])
        total_receivesize_list.append(datasize)
        client_receivesize_list[user].append(datasize)
        train_receivesize_list.append(datasize)
        
print('train is done')

end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

# Let's quickly save our trained model:
PATH = './cifar10_sp_server_model.pth'
torch.save(server_model.state_dict(), PATH)

timmer start!


Epoch 1 Client0 : 100%|███████████████████████████████████████| 12500/12500 [03:25<00:00, 60.92it/s]


train is done
TrainingTime: 205.7849817276001 sec


In [25]:
# train acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    train_loss = 0.0
    for j, trn in enumerate(trainloader):
        trn_x, trn_label = trn
        trn_x = trn_x.to(device)
        trn_label = trn_label.to(device)


        trn_output = client_model(trn_x)
        trn_label = trn_label.long()
        
        trn_output = server_model(trn_output)
        
        loss = criterion(trn_output, trn_label)
        train_loss += loss.item()
        model_label = trn_output.argmax(dim=1)
        corr = trn_label[trn_label == model_label].size(0)
        corr_num += corr
        total_num += trn_label.size(0)
    print("train_acc: {:.2f}%, train_loss: {:.4f}".format(corr_num / total_num * 100, train_loss / len(trainloader)))

train_acc: 8.22%, train_loss: 2.9990


In [26]:
with torch.no_grad():
    corr_num = 0
    total_num = 0
    val_loss = 0.0
    for j, val in enumerate(testloader):
        val_x, val_label = val
        val_x = val_x.to(device)
        val_label = val_label.to(device)

        val_output = client_model(val_x)
        val_label = val_label.long()
        
        val_output = server_model(val_output)
        
        loss = criterion(val_output, val_label)
        val_loss += loss.item()
        model_label = val_output.argmax(dim=1)
        corr = val_label[val_label == model_label].size(0)
        corr_num += corr
        total_num += val_label.size(0)
    print("test_acc: {:.2f}%, test_loss: {:.4f}".format(corr_num / total_num * 100, val_loss / len(testloader)))

test_acc: 8.06%, test_loss: 3.0048


In [27]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for data in testloader:
        x, labels = data
        x = x.to(device)
        labels = labels.to(device)

        outputs = client_model(x)
        outputs = server_model(outputs)
        labels = labels.long()
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane :  7 %
Accuracy of   car : 24 %
Accuracy of  bird :  7 %
Accuracy of   cat :  0 %
Accuracy of  deer :  0 %
Accuracy of   dog : 15 %
Accuracy of  frog : 11 %
Accuracy of horse : 12 %
Accuracy of  ship :  1 %
Accuracy of truck :  0 %
