In [1]:
import os
import struct
import socket
import pickle
import time

import h5py
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from torch.utils.data import Subset
from torch.autograd import Variable
import torch.nn.init as init

import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_path = '../cifar10_data'

# Setup cpu
device = 'cpu'
torch.manual_seed(777)

# Setup client order
client_order = int(0)
print('Client starts from: ', client_order)

num_train_data = 50000

# Load data
from random import shuffle

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

indices = list(range(50000))

part_tr = indices[num_train_data * client_order : num_train_data * (client_order + 1)]

train_set  = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform)
train_set_sub = Subset(train_set, part_tr)

train_loader = torch.utils.data.DataLoader(train_set_sub, batch_size=4, shuffle=True, num_workers=2)

x_train, y_train = next(iter(train_loader))
print('Train Size (x, y): ')
print(x_train.size(), '\n', y_train.size())

total_batch = len(train_loader)
print('Total Batch Size')
print(total_batch)

Client starts from:  0
Files already downloaded and verified
Train Size (x, y): 
torch.Size([4, 3, 32, 32]) 
 torch.Size([4])
Total Batch Size
12500


Helper functions for communication between client and server.

In [3]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msg_len = recv_all(sock, 4)
    if not raw_msg_len:
        return None
    msg_len = struct.unpack('>I', raw_msg_len)[0]
    # read the message data
    msg =  recv_all(sock, msg_len)
    msg = pickle.loads(msg)
    return msg

def recv_all(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

Definition of client side model (input layer only)

In [4]:
def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    
    def forward(self, x):
        return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)

        # self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)

        self.apply(_weights_init)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        # out = self.layer1(out)
        
        return out

def resnet20():
    return ResNet(BasicBlock, [1, 3, 3])

Training hyper parameters


In [5]:
resnet20_client = resnet20().to(device)

epoch = 1
lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet20_client.parameters(), lr = lr, momentum = 0.9)

Training 

In [6]:

# host = '10.2.144.188'
# host = '10.9.240.14'
host = '10.2.143.109'
port = 10081
epoch = 1

s1 = socket.socket()
s1.connect((host, port)) # establish connection
# s1.close()

client_weight = copy.deepcopy(resnet20_client.state_dict()) # init weight of client model.

msg = {
    'epoch': epoch,
    'total_batch': total_batch
}

send_msg(s1, msg) # send 'epoch' and 'batch size' to server

resnet20_client.eval()

remote_server = recv_msg(s1) # get server's meta information.

for epc in range(1):
    print("running epoch ", epc)

    target = 0

    for i, data in enumerate(tqdm(train_loader, ncols=100, desc='Training with {}'.format(remote_server))):
        x, label = data
        x = x.to(device)
        label = label.to(device)
        optimizer.zero_grad()


        output = resnet20_client(x)
        client_output = output.clone().detach().requires_grad_(True)

        msg = {
            'label': label,
            'client_output': client_output
        }
        send_msg(s1, msg) # send label and output(feature) to server
        
        client_grad = recv_msg(s1) # receive gradaint after the server has completed the back propagation.

        output.backward(client_grad) # continue back propagation for client side layers.
        optimizer.step()
     
s1.close()


running epoch  0


Training with SERVER_001:   0%|                                           | 0/12500 [00:00<?, ?it/s]

Training with SERVER_001: 100%|███████████████████████████████| 12500/12500 [04:44<00:00, 43.93it/s]
