In [1]:
server_name = 'SERVER_001'

import os
import h5py

import socket
import struct
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import time
import sys



from tqdm import tqdm


from torch.utils.data import Dataset, DataLoader

from torch.autograd import Variable
import torch.nn.init as init

import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Setup CUDA
seed_num = 777
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# torch.manual_seed(seed_num)
# if device == "cuda:0":
#     torch.cuda.manual_seed_all(seed_num)
device = "cpu"

In [3]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msg_len = recv_all(sock, 4)
    if not raw_msg_len:
        return None
    msg_len = struct.unpack('>I', raw_msg_len)[0]
    # read the message data
    msg =  recv_all(sock, msg_len)
    msg = pickle.loads(msg)
    return msg, msg_len

def recv_all(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

The model definition of server

In [4]:
''' ResNet '''

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
        super(BasicBlock, self).__init__()
        self.norm = norm
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
        super(Bottleneck, self).__init__()
        self.norm = norm
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
    

class Bottleneck_server(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
        super(Bottleneck_server, self).__init__()
        self.norm = norm
        # self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        # self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        # out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(x)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    # def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
    def __init__(self, block, block_server, num_blocks, num_classes=10, norm='instancenorm'):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.norm = norm

        # self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        # self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
        # self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer1 = self._make_layer_server(block, block_server, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.classifier = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, self.norm))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def _make_layer_server(self, block, block_server, planes, num_blocks, stride):
        # strides = [stride] + [1]*(num_blocks-1)
        layers = []
        layers.append(block_server(self.in_planes, planes, 1))
        self.in_planes = planes * block.expansion
        layers.append(block(self.in_planes, planes, 1))
        self.in_planes = planes * block.expansion
        layers.append(block(self.in_planes, planes, 1))
        self.in_planes = planes * block.expansion
        
        return nn.Sequential(*layers)

    def forward(self, x):
        # out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    # def embed(self, x):
    #     out = F.relu(self.bn1(self.conv1(x)))
    #     out = self.layer1(out)
    #     out = self.layer2(out)
    #     out = self.layer3(out)
    #     out = self.layer4(out)
    #     out = F.avg_pool2d(out, 4)
    #     out = out.view(out.size(0), -1)
    #     return out



def ResNet50(num_classes):
    return ResNet(Bottleneck, Bottleneck_server, [3,4,6,3], num_classes=num_classes)

In [5]:
resnet_server =  ResNet50(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.SGD(resnet_server.parameters(), lr=lr, momentum=0.9)

epochs = 1

In [6]:
resnet_server

ResNet(
  (layer1): Sequential(
    (0): Bottleneck_server(
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm(64, 64, eps=1e-05, affine=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): GroupNorm(256, 256, eps=1e-05, affine=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): GroupNorm(256, 256, eps=1e-05, affine=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm(64, 64, eps=1e-05, affine=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): GroupNorm(256, 256, eps=1e-05, affine=True)
      (shortcut): Sequential()
    )
    (2

In [7]:
# host = '10.2.144.188'
# host = '10.9.240.14'
host = '10.2.143.109'
port = 10081

s = socket.socket()
s.bind((host, port))
s.listen(5)

conn, addr = s.accept()
print("Connected to: ", addr)

# read epoch
rmsg, data_size = recv_msg(conn) # receive total bach number and epoch from client.

epoch = rmsg['epoch']
num_batch = rmsg['total_batch']

print("received epoch: ", rmsg['epoch'], rmsg['total_batch'])

send_msg(conn, server_name) # send server meta information.

# Start training
start_time = time.time()
print("Start training @ ", time.asctime())

for epc in range(epoch):
    init = 0
    for i in tqdm(range(num_batch), ncols = 100, desc='Training with {}'.format(server_name)):
        optimizer.zero_grad()
        
        msg, data_size = recv_msg(conn) # receives label and feature from client.
        
        # label
        label = msg['label']
        label = label.clone().detach().long().to(device) # conversion between gpu and cpu.
        
        # feature
        client_output_cpu = msg['client_output']
        client_output = client_output_cpu.to(device)

        # forward propagation
        output = resnet_server(client_output)
        loss = criterion(output, label) # compute cross-entropy loss
        loss.backward() # backward propagation
        
        # send gradient to client
        msg = client_output_cpu.grad.clone().detach()
        data_size = send_msg(conn, msg)
        
        optimizer.step()


        if i % 100 == 0:

            # measure accuracy and record loss
            _, predicted = torch.max(output, 1)
            correct = (predicted == label).sum().item()
            accuracy = correct / len(label)
            print(f'Epoch: {epc+1}/{epoch}, Batch: {i+1}/{num_batch}, Train Loss: {round(loss.item(), 2)} Train Accuracy: {round(accuracy, 2)}')


        


print('Contribution from {} is done'.format(server_name))
print('Contribution duration is: {} seconds'.format(time.time() - start_time))

Connected to:  ('10.2.143.109', 44298)
received epoch:  1 6250
Start training @  Thu Oct  5 11:10:15 2023


Training with SERVER_001:   0%|                                            | 0/6250 [00:00<?, ?it/s]

Epoch: 1/1, Batch: 1/6250, Train Loss: 2.6 Train Accuracy: 0.12
Contribution from SERVER_001 is done
Contribution duration is: 0.29991793632507324 seconds



