In [1]:
import os
import struct
import socket
import pickle
import time

import h5py
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from torch.utils.data import Subset
from torch.autograd import Variable
import torch.nn.init as init

import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_path = '../../models/cifar10_data'

# Setup cpu
device = 'cpu'
torch.manual_seed(777)

# Setup client order
client_order = int(0)
print('Client starts from: ', client_order)

num_train_data = 50000

# Load data
from random import shuffle

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

indices = list(range(50000))

part_tr = indices[num_train_data * client_order : num_train_data * (client_order + 1)]

train_set  = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform)
train_set_sub = Subset(train_set, part_tr)
train_loader = torch.utils.data.DataLoader(train_set_sub, batch_size=8, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root=root_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=8, shuffle=False, num_workers=2)

x_train, y_train = next(iter(train_loader))
print(f'Train batch shape x: {x_train.size()} y: {y_train.size()}')
total_batch = len(train_loader)
print(f'Num Batch {total_batch}')

Client starts from:  0
Files already downloaded and verified
Files already downloaded and verified
Train batch shape x: torch.Size([8, 3, 32, 32]) y: torch.Size([8])
Num Batch 6250


Helper functions for communication between client and server.

In [3]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msg_len = recv_all(sock, 4)
    if not raw_msg_len:
        return None
    msg_len = struct.unpack('>I', raw_msg_len)[0]
    # read the message data
    msg =  recv_all(sock, msg_len)
    msg = pickle.loads(msg)
    return msg

def recv_all(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

Definition of client side model (input layer only)

In [4]:
''' ResNet '''

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
        super(BasicBlock, self).__init__()
        self.norm = norm
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
        super(Bottleneck, self).__init__()
        self.norm = norm
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    # def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
    def __init__(self, channel=3, norm='instancenorm'):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.norm = norm

        self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
        # self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        # self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        # self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # self.classifier = nn.Linear(512*block.expansion, num_classes)

    # def _make_layer(self, block, planes, num_blocks, stride):
    #     strides = [stride] + [1]*(num_blocks-1)
    #     layers = []
    #     for stride in strides:
    #         layers.append(block(self.in_planes, planes, stride, self.norm))
    #         self.in_planes = planes * block.expansion
    #     return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        # out = self.layer1(out)
        # out = self.layer2(out)
        # out = self.layer3(out)
        # out = self.layer4(out)
        # out = F.avg_pool2d(out, 4)
        # out = out.view(out.size(0), -1)
        # out = self.classifier(out)
        return out

    # def embed(self, x):
    #     out = F.relu(self.bn1(self.conv1(x)))
    #     out = self.layer1(out)
    #     out = self.layer2(out)
    #     out = self.layer3(out)
    #     out = self.layer4(out)
    #     out = F.avg_pool2d(out, 4)
    #     out = out.view(out.size(0), -1)
    #     return out

def ResNet50(channel):
    return ResNet(channel=channel)



Training hyper parameters


In [5]:
resnet_client = ResNet50(channel=3).to(device) # parameters depend on the dataset

epoch = 1
lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_client.parameters(), lr = lr, momentum = 0.9)

In [6]:
resnet_client

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): GroupNorm(64, 64, eps=1e-05, affine=True)
)

Training 

In [7]:

# host = '10.2.144.188'
# host = '10.9.240.14'
host = '10.2.143.109'
port = 10081
epoch = 1

s1 = socket.socket()
s1.connect((host, port)) # establish connection
# s1.close()

client_weight = copy.deepcopy(resnet_client.state_dict()) # init weight of client model.

msg = {
    'epoch': epoch,
    'total_batch': total_batch
}

send_msg(s1, msg) # send 'epoch' and 'batch size' to server

# resnet_client.eval() # Why eval()?

remote_server = recv_msg(s1) # get server's meta information.

for epc in range(1):
    print("running epoch ", epc)

    target = 0

    for i, data in enumerate(tqdm(train_loader, ncols=100, desc='Training with {}'.format(remote_server))):
        x, label = data
        x = x.to(device)
        label = label.to(device)
        optimizer.zero_grad()


        output = resnet_client(x)
        client_output = output.clone().detach().requires_grad_(True)

        msg = {
            'label': label,
            'client_output': client_output
        }
        send_msg(s1, msg) # send label and output(feature) to server
        
        client_grad = recv_msg(s1) # receive gradaint after the server has completed the back propagation.

        output.backward(client_grad) # continue back propagation for client side layers.
        optimizer.step()


        if (i + 1) % 1000 == 0:
            msg = {
                'num_batch': len(test_loader),
                'dataset_size': len(test_loader.dataset)
            }
            print(msg)
            send_msg(s1, msg) # 'num test batch' to server
            resnet_client.eval()
            with torch.no_grad():
                for x, label in test_loader:
                    x = x.to(device)
                    label = label.to(device)
                    output = resnet_client(x)
                    client_output = output.clone().detach().requires_grad_(True)
                    msg = {
                        'label': label,
                        'client_output': client_output
                    }
                    send_msg(s1, msg) # send label and output(feature) to server


     
s1.close()


running epoch  0


Training with SERVER_001:  16%|█████▍                            | 999/6250 [03:51<17:26,  5.02it/s]

{'num_batch': 1250, 'dataset_size': 10000}


Training with SERVER_001:  32%|██████████▌                      | 1999/6250 [09:38<19:51,  3.57it/s]

{'num_batch': 1250, 'dataset_size': 10000}


Training with SERVER_001:  41%|█████████████▍                   | 2554/6250 [16:29<23:51,  2.58it/s]


RuntimeError: grad can be implicitly created only for scalar outputs