In [1]:
server_name = 'SERVER_001'

import os
import h5py

import socket
import struct
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import time
import sys



from tqdm import tqdm


from torch.utils.data import Dataset, DataLoader

from torch.autograd import Variable
import torch.nn.init as init

import copy

In [2]:
# Setup CUDA
seed_num = 777
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# torch.manual_seed(seed_num)
# if device == "cuda:0":
#     torch.cuda.manual_seed_all(seed_num)
device = "cpu"

In [3]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msg_len = recv_all(sock, 4)
    if not raw_msg_len:
        return None
    msg_len = struct.unpack('>I', raw_msg_len)[0]
    # read the message data
    msg =  recv_all(sock, msg_len)
    msg = pickle.loads(msg)
    return msg, msg_len

def recv_all(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

The model definition of server

In [4]:
def _weights_init(m):
    classname = m.__class__.__name__

    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
    
class BasicBlock_server(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock_server, self).__init__()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = self.bn2(self.conv2(x))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, block_server, num_blocks, num_classes=10, ):
        super(ResNet, self).__init__()
        self.in_planes = 16
        self.layer1 = self._make_layer_server(block, block_server, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer_server(self, block, block_server, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        layers.append(block_server(self.in_planes, planes, 1))
        self.in_planes = planes * block.expansion
        layers.append(block(self.in_planes, planes, 1))
        self.in_planes = planes * block.expansion
        layers.append(block(self.in_planes, planes, 1))
        self.in_planes = planes * block.expansion
        
        return nn.Sequential(*layers)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20():
    return ResNet(BasicBlock, BasicBlock_server, [3, 3, 3])

In [5]:
resnet20_server =  resnet20().to(device)

criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.SGD(resnet20_server.parameters(), lr=lr, momentum=0.9)

epochs = 1

In [6]:
# host = '10.2.144.188'
host = '10.9.240.14'
port = 10081

s = socket.socket()
s.bind((host, port))
s.listen(5)

conn, addr = s.accept()
print("Connected to: ", addr)

# read epoch
rmsg, data_size = recv_msg(conn) # receive total bach number and epoch from client.

epoch = rmsg['epoch']
batch_size = rmsg['total_batch']

print("received epoch: ", rmsg['epoch'], rmsg['total_batch'])

send_msg(conn, server_name) # send server meta information.

# Start training
start_time = time.time()
print("Start training @ ", time.asctime())

for epc in range(epoch):
    init = 0
    for i in tqdm(range(batch_size), ncols = 100, desc='Training with {}'.format(server_name)):
        optimizer.zero_grad()
        
        msg, data_size = recv_msg(conn) # receives label and feature from client.
        
        # label
        label = msg['label']
        label = label.clone().detach().long().to(device) # conversion between gpu and cpu.
        
        # feature
        client_output_cpu = msg['client_output']
        client_output = client_output_cpu.to(device)

        # forward propagation
        output = resnet20_server(client_output)
        loss = criterion(output, label) # compute cross-entropy loss
        loss.backward() # backward propagation
        
        # send gradient to client
        msg = client_output_cpu.grad.clone().detach()
        data_size = send_msg(conn, msg)
        
        optimizer.step()
        


print('Contribution from {} is done'.format(server_name))
print('Contribution duration is: {} seconds'.format(time.time() - start_time))

Connected to:  ('10.9.240.14', 43922)
received epoch:  1 12500
Start training @  Mon Oct  2 12:00:06 2023


Training with SERVER_001: 100%|███████████████████████████████| 12500/12500 [04:22<00:00, 47.57it/s]

Contribution from SERVER_001 is done
Contribution duration is: 262.7710027694702 seconds



