In [1]:
import os
import struct
import socket
import pickle
import time

from rescalenet.layers import AvgPool2d, Bias2DMean
import h5py
from tqdm import tqdm
from rescale_u_shaped import rescale18 as rescale18_server
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from utils import get_metrics_u_shaped, get_metrics_u_shaped_client_side
from torch.utils.data import Subset

import copy

In [2]:
root_path = 'cifar10_data'

# Setup cpu
device = 'cpu'

# Setup client order
client_order = int(0)
print('Client starts from: ', client_order)
batch_size = 256
num_train_data = 5000

# Load data
from random import shuffle

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

indices = list(range(50000))

part_tr = indices[num_train_data * client_order : num_train_data * (client_order + 1)]

train_set  = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform)
train_set_sub = Subset(train_set, part_tr)
train_loader = torch.utils.data.DataLoader(train_set_sub, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)

test_set = torchvision.datasets.CIFAR10(root=root_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True)


x_train, y_train = next(iter(train_loader))
print('Train Size (x, y): ')
print(x_train.size(), '\n', y_train.size())

total_batch = len(train_loader)
print('Total Batch Number')
print(total_batch)

Client starts from:  0
Files already downloaded and verified
Files already downloaded and verified
Train Size (x, y): 
torch.Size([256, 3, 32, 32]) 
 torch.Size([256])
Total Batch Number
19


Helper functions for communication between client and server.

In [3]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msg_len = recv_all(sock, 4)
    if not raw_msg_len:
        return None
    msg_len = struct.unpack('>I', raw_msg_len)[0]
    # read the message data
    msg =  recv_all(sock, msg_len)
    msg = pickle.loads(msg)
    return msg

def recv_all(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

Definition of client side model (input layer only)

In [4]:
class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(self,
                 inplanes: int,
                 planes: int,
                 block_idx: int,
                 max_block: int,
                 stride: int = 1,
                 groups: int = 1,
                 base_width: int = 64,
                 drop_conv=0.0) -> None:

        super(BasicBlock, self).__init__()

        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')

        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, groups=groups, bias=False)

        self.addbias1 = Bias2DMean(inplanes)
        self.addbias2 = Bias2DMean(planes)

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.stride = stride
        self._scale = nn.Parameter(torch.ones(1))
        multiplier = (block_idx + 1)**-(1 / 6) * max_block**(1 / 6)
        multiplier = multiplier * (1 - drop_conv)**.5

        for m in self.modules():
            if isinstance(m, nn.Conv2d) and m.weight.requires_grad:
                _, C, H, W = m.weight.shape
                stddev = (C * H * W / 2)**-.5
                nn.init.normal_(m.weight, std=stddev * multiplier)

        self.residual = max_block**-.5
        self.identity = block_idx**.5 / (block_idx + 1)**.5

        self.downsample = nn.Sequential()
        if stride != 1 or inplanes != self.expansion * planes:
            if stride == 1:
                avgpool = nn.Sequential()
            else:
                avgpool = nn.AvgPool2d(stride)

            self.downsample = nn.Sequential(avgpool, Bias2DMean(num_features=inplanes),
                                            nn.Conv2d(inplanes, self.expansion * planes, kernel_size=1, bias=False))

            nn.init.kaiming_normal_(self.downsample[2].weight, a=1)

        self.drop = nn.Sequential()
        if drop_conv > 0.0:
            self.drop = nn.Dropout2d(drop_conv)

    def forward(self, x):
        # Not adding dropout here.
        out = F.relu(self.drop(self.conv1(self.addbias1(x))))
        out = self.drop(self.conv2(self.addbias2(out)))
        out = out * self.residual * self._scale + self.identity * self.downsample(x)
        out = F.relu(out)
        return out

    def init_pass(self, x, count):
        out = F.relu(self.drop(self.conv1(self.addbias1.init_pass(x, count))))
        out = self.drop(self.conv2(self.addbias2.init_pass(out, count)))
        out = out * self.residual * self._scale + self.identity * self.downsample(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, block_idx, max_block, stride=1, groups=1, base_width=64, drop_conv=0.0):
        super(Bottleneck, self).__init__()
        width = int(planes * (base_width / 64.)) * groups
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, bias=False)

        self.addbias1 = Bias2DMean(inplanes)
        self.addbias2 = Bias2DMean(width)
        self.addbias3 = Bias2DMean(width)

        self._scale = nn.Parameter(torch.ones(1))
        multiplier = (block_idx + 1)**-(1 / 6) * max_block**(1 / 6)
        multiplier = multiplier * (1 - drop_conv)**.5

        for m in self.modules():
            if isinstance(m, nn.Conv2d) and m.weight.requires_grad:
                _, C, H, W = m.weight.shape
                stddev = (C * H * W / 2)**-.5
                nn.init.normal_(m.weight, std=stddev * multiplier)

        self.residual = max_block**-.5
        self.identity = block_idx**.5 / (block_idx + 1)**.5

        self.downsample = nn.Sequential()
        if stride != 1 or inplanes != self.expansion * planes:
            if stride == 1:
                avgpool = nn.Sequential()
            else:
                avgpool = nn.AvgPool2d(stride)

            self.downsample = nn.Sequential(avgpool, Bias2DMean(num_features=inplanes),
                                            nn.Conv2d(inplanes, self.expansion * planes, kernel_size=1, bias=False))
            nn.init.kaiming_normal_(self.downsample[2].weight, a=1)

        self.drop = nn.Sequential()
        if drop_conv > 0.0:
            self.drop = nn.Dropout2d(drop_conv)

    def forward(self, x):
        out = F.relu(self.drop(self.conv1(self.addbias1(x))))
        out = F.relu(self.drop(self.conv2(self.addbias2(out))))
        out = self.drop(self.conv3(self.addbias3(out)))
        out = out * self.residual * self._scale + self.identity * self.downsample(x)
        out = F.relu(out)
        return out

    def init_pass(self, x, count):
        out = F.relu(self.drop(self.conv1(self.addbias1.init_pass(x, count))))
        out = F.relu(self.drop(self.conv2(self.addbias2.init_pass(out, count))))
        out = self.drop(self.conv3(self.addbias3.init_pass(out, count)))
        out = out * self.residual * self._scale + self.identity * self.downsample(x)
        out = F.relu(out)
        return out


class ReScale_First(nn.Module):
    def __init__(self,
                 layers,
                 num_classes=1000,
                 groups=1,
                 width_per_group=64,
                 drop_conv=0.0,
                 drop_fc=0.0,
                 block=Bottleneck,
                 input_shapes=(None, None),
                 num_flexible_classes=-1):
        super(ReScale_First, self).__init__()

        self.inplanes = 64
        self.num_classes = num_classes
        self.input_shapes = input_shapes
        self.groups = groups
        self.base_width = width_per_group
        self.block_idx = sum(layers) - 1
        self.max_depth = sum(layers)
        self.num_flexible_classes = num_flexible_classes

        # KT TEST SPLIT LEARNING
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.addbias1 = Bias2DMean(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # self.layer1 = self._make_layer(block, 64, layers[0])
        # self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        # self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_conv=drop_conv)
        # self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_conv=drop_conv)
        # self.addbias2 = Bias2DMean(512 * block.expansion)
        # self.drop = nn.Dropout(drop_fc)
        # self.fc = nn.Linear(512 * block.expansion, num_classes)

        self.mean_pool = nn.AvgPool2d((input_shapes[0] // 32, input_shapes[1] // 32))

        # KT TEST SPLIT LEARNING
        #nn.init.kaiming_normal_(self.conv1.weight)
        # nn.init.kaiming_normal_(self.fc.weight, a=1)

        if self.num_flexible_classes != -1:
            _fixed_sum_layer = torch.zeros(num_classes)
            num_unused_classes = num_classes - self.num_flexible_classes
            if num_unused_classes > 0:
                _fixed_sum_layer[self.num_flexible_classes:] = torch.ones(num_unused_classes) * -10000.0
                # initialize bias and weight of unused to 0
                # self.fc.bias.data[self.num_flexible_classes:] = 0
                # self.fc.weight.data[self.num_flexible_classes:, :] = 0

            # make the fixed_mask not trainable
            self.register_buffer("fixed_sum_layer", _fixed_sum_layer)

    def _make_layer(self, block, planes, num_blocks, stride=1, drop_conv=0.0):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(
                block(self.inplanes,
                      planes,
                      block_idx=self.block_idx,
                      max_block=self.max_depth,
                      stride=stride,
                      groups=self.groups,
                      base_width=self.base_width,
                      drop_conv=drop_conv))
            self.inplanes = planes * block.expansion
            self.block_idx += 1
        return nn.Sequential(*layers)

    def forward(self, x):
        # KT TEST SPLIT LEARNING
        x = self.conv1(x)
        x = self.addbias1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        # x = self.layer1(x)
        # x = self.layer2(x)
        # x = self.layer3(x)
        # x = self.layer4(x)
        # x = self.addbias2(x)
        #
        # x = self.mean_pool(x)
        # x = x.squeeze(-1).squeeze(-1)
        # x = self.drop(x)
        # x = self.fc(x)
        # if self.num_flexible_classes != -1:
        #     x = x + self.fixed_sum_layer

        return x


class ReScale_Last(nn.Module):
    def __init__(self,
                 layers,
                 num_classes=1000,
                 groups=1,
                 width_per_group=64,
                 drop_conv=0.0,
                 drop_fc=0.0,
                 block=Bottleneck,
                 input_shapes=(None, None),
                 num_flexible_classes=-1):
        super(ReScale_Last, self).__init__()

        self.inplanes = 64
        self.num_classes = num_classes
        self.input_shapes = input_shapes
        self.groups = groups
        self.base_width = width_per_group
        self.block_idx = sum(layers) - 1
        self.max_depth = sum(layers)
        self.num_flexible_classes = num_flexible_classes

        # KT TEST SPLIT LEARNING
        # self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        # self.addbias1 = Bias2DMean(self.inplanes)
        # self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # self.layer1 = self._make_layer(block, 64, layers[0])
        # self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        # self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_conv=drop_conv)
        # self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_conv=drop_conv)
        # self.addbias2 = Bias2DMean(512 * block.expansion)
        # self.drop = nn.Dropout(drop_fc)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        self.mean_pool = nn.AvgPool2d((input_shapes[0] // 32, input_shapes[1] // 32))

        # KT TEST SPLIT LEARNING
        #nn.init.kaiming_normal_(self.conv1.weight)
        nn.init.kaiming_normal_(self.fc.weight, a=1)

        if self.num_flexible_classes != -1:
            _fixed_sum_layer = torch.zeros(num_classes)
            num_unused_classes = num_classes - self.num_flexible_classes
            if num_unused_classes > 0:
                _fixed_sum_layer[self.num_flexible_classes:] = torch.ones(num_unused_classes) * -10000.0
                # initialize bias and weight of unused to 0
                self.fc.bias.data[self.num_flexible_classes:] = 0
                self.fc.weight.data[self.num_flexible_classes:, :] = 0

            # make the fixed_mask not trainable
            self.register_buffer("fixed_sum_layer", _fixed_sum_layer)

    # def _make_layer(self, block, planes, num_blocks, stride=1, drop_conv=0.0):
    #     strides = [stride] + [1] * (num_blocks - 1)
    #     layers = []
    #     for stride in strides:
    #         layers.append(
    #             block(self.inplanes,
    #                   planes,
    #                   block_idx=self.block_idx,
    #                   max_block=self.max_depth,
    #                   stride=stride,
    #                   groups=self.groups,
    #                   base_width=self.base_width,
    #                   drop_conv=drop_conv))
    #         self.inplanes = planes * block.expansion
    #         self.block_idx += 1
    #     return nn.Sequential(*layers)

    def forward(self, x):
        # KT TEST SPLIT LEARNING
        # x = self.conv1(x)
        # x = self.addbias1(x)
        # x = self.relu(x)
        # x = self.maxpool(x)
        # x = self.layer1(x)
        # x = self.layer2(x)
        # x = self.layer3(x)
        # x = self.layer4(x)
        # x = self.addbias2(x)
        #
        # x = self.mean_pool(x)
        # x = x.squeeze(-1).squeeze(-1)
        # x = self.drop(x)
        x = self.fc(x)
        if self.num_flexible_classes != -1:
            x = x + self.fixed_sum_layer

        return x



def rescale18_first(num_classes=10, drop_conv=0.0, drop_fc=0.0, **kwargs):
    return ReScale_First([2, 2, 2, 2],
                   num_classes=num_classes,
                   drop_conv=drop_conv,
                   drop_fc=drop_fc,
                   groups=1,
                   width_per_group=64,
                   input_shapes=[32, 32],
                   block=BasicBlock,
                   **kwargs)

def rescale18_last(num_classes=10, drop_conv=0.0, drop_fc=0.0, **kwargs):
    return ReScale_Last([2, 2, 2, 2],
                   num_classes=num_classes,
                   drop_conv=drop_conv,
                   drop_fc=drop_fc,
                   groups=1,
                   width_per_group=64,
                   input_shapes=[32, 32],
                   block=BasicBlock,
                   **kwargs)



Training hyper parameters


In [5]:
resnet18_client_first = rescale18_first().to(device)
lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer_first = optim.SGD(resnet18_client_first.parameters(), lr = lr, momentum = 0.9)


resnet18_client_last = rescale18_last().to(device)
optimizer_last = optim.SGD(resnet18_client_last.parameters(), lr = lr, momentum = 0.9)

Training

In [6]:
epoch = 10
host = '10.9.240.14'
port = 8890

# host = '10.2.16.246'
# port = 18888
s1 = socket.socket()
s1.connect((host, port)) # establish connection


msg = {
    'epoch': epoch,
    'batch_size': batch_size,
    'total_batch': total_batch
}

send_msg(s1, msg) # send 'epoch' and 'batch size' to server



remote_server = recv_msg(s1) # get server's meta information.

for epc in range(epoch):
    # print("running epoch ", epc)
    resnet18_client_first.train()
    resnet18_client_last.train()

    for i, data in enumerate(tqdm(train_loader, ncols=100, desc='Training with {}'.format(remote_server), disable=True)):
        x, label = data
        x = x.to(device)
        label = label.to(device)
        optimizer_first.zero_grad()
        optimizer_last.zero_grad()


        output = resnet18_client_first(x)
        client_output_first = output.clone().detach().requires_grad_(True)
        client_output_first = client_output_first.to(torch.bfloat16)

        msg = {
            # 'label': label,
            'client_output': client_output_first
        }
        send_msg(s1, msg) # send label and output(feature) to server
        
        
        rmsg = recv_msg(s1)
        server_output_samba = rmsg['server_output']
        server_output = server_output_samba.to(device)

        output_last = resnet18_client_last(server_output)

        loss = criterion(output_last, label)
        # print(loss.item())
        loss.backward()
        
        msg = {
            'server_grad': server_output_samba.grad.clone().detach()
        }
        send_msg(s1, msg) # send server's gradient to server
        optimizer_first.step()


        rmsg = recv_msg(s1)
        client_output_first.backward(rmsg['client_output_grad'])
        optimizer_last.step()


    # rmsg = recv_msg(s1)
    # server_model_state_dict = rmsg['server model']
    # rescalenet_server = rescale18_server(num_classes=10, input_shapes=(32, 32)).to(device)
    # rescalenet_server.load_state_dict(server_model_state_dict)

    # # print(rescalenet_server.state_dict().keys())
    # # print(rescalenet_server.state_dict()['layer3.0.conv1.weight'])

    # train_loss, train_acc, train_auc, train_bal_acc = get_metrics_u_shaped(rescalenet_server, resnet18_client_first, resnet18_client_last, train_loader, criterion, device)
    # test_loss, test_acc, test_auc, test_bal_acc = get_metrics_u_shaped(rescalenet_server, resnet18_client_first, resnet18_client_last, test_loader, criterion, device)

    print(f"Epoch: {epc+1}/{epoch}")
    train_loss, train_acc, train_auc, train_bal_acc = get_metrics_u_shaped_client_side(resnet18_client_first, resnet18_client_last, train_loader, send_msg, recv_msg, s1, criterion, device)
    test_loss, test_acc, test_auc, test_bal_acc = get_metrics_u_shaped_client_side(resnet18_client_first, resnet18_client_last, test_loader, send_msg, recv_msg, s1, criterion, device)
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {round(train_acc, 4)}, Train AUC: {train_auc:.4f}, Train Bal Acc: {train_bal_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test AUC: {test_auc:.4f}, Test Bal Acc: {test_bal_acc:.4f}")
    print()    
     
s1.close()


Epoch: 0/10
Train Loss: 0.0004, Train Acc: 0.2208, Train AUC: 0.7158, Train Bal Acc: 0.2185
Test Loss: 0.0002, Test Acc: 0.2073, Test AUC: 0.7058, Test Bal Acc: 0.2073

Epoch: 1/10
Train Loss: 0.0004, Train Acc: 0.286, Train AUC: 0.7501, Train Bal Acc: 0.2852
Test Loss: 0.0002, Test Acc: 0.2795, Test AUC: 0.7441, Test Bal Acc: 0.2795

Epoch: 2/10
Train Loss: 0.0004, Train Acc: 0.2911, Train AUC: 0.7763, Train Bal Acc: 0.2897
Test Loss: 0.0002, Test Acc: 0.2830, Test AUC: 0.7701, Test Bal Acc: 0.2829

Epoch: 3/10
Train Loss: 0.0004, Train Acc: 0.3363, Train AUC: 0.8019, Train Bal Acc: 0.3388
Test Loss: 0.0002, Test Acc: 0.3279, Test AUC: 0.7938, Test Bal Acc: 0.3280

Epoch: 4/10
Train Loss: 0.0004, Train Acc: 0.3565, Train AUC: 0.8210, Train Bal Acc: 0.3570
Test Loss: 0.0002, Test Acc: 0.3381, Test AUC: 0.8119, Test Bal Acc: 0.3382

Epoch: 5/10
Train Loss: 0.0003, Train Acc: 0.3705, Train AUC: 0.8289, Train Bal Acc: 0.3720
Test Loss: 0.0002, Test Acc: 0.3538, Test AUC: 0.8161, Test Bal 