In [6]:
from pathlib import Path
import socket
import struct
import pickle
import time

import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD

In [7]:
project_path = Path.cwd().parents[0]
project_path

PosixPath('/home/dk/Desktop/split-learning-1D-HE')

## Connect to the client

In [2]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    return recvall(sock, msglen)

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [4]:
host = 'localhost'
port = 10080
max_recv = 4096

s = socket.socket()
s.bind((host, port))
s.listen(5)
conn, addr = s.accept()
print('Conntected with', addr)

Conntected with ('127.0.0.1', 43114)


## The model on the server side

In [8]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [5]:
class EcgServer(nn.Module):
    def __init__(self):
        super(EcgServer, self).__init__()
        self.linear = nn.Linear(256, 5)

    def forward(self, x):
        x = self.linear(x)

        return x

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name(0)

ecg_server = EcgServer()

checkpoint = torch.load(project_path/"weights/init_weight_256.pth")
ecg_server.linear.weight.data = checkpoint["linear.weight"]
ecg_server.linear.bias.data = checkpoint["linear.bias"]

ecg_server.to(device)

EcgServer(
  (linear): Linear(in_features=256, out_features=5, bias=True)
)

In [None]:
batch_size = 32
total_batch = 414

epoch = 10
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = Adam(ecg_server.parameters(), lr=lr)

## Training process

In [None]:
train_losses = list()
train_accs = list()
test_losses = list()
test_accs = list()
best_test_acc = 0  # best test accuracy

for e in range(epoch):
    print("Epoch {} - ".format(e+1), end='')
    
    train_loss = 0.0
    correct, total = 0, 0
    for i in range(total_batch):
        # forward pass
        optimizer.zero_grad()  # initialize all gradients to zero
        msg = recv_msg(conn)  # receive client message from socket
        a = pickle.loads(msg)  # deserialize
        a = a.to(device)
        a2 = ecg_server(a)  # forward propagation
        
        loss.backward()  # backward propagation
        client_grad = client_output.grad.clone().detach()
        msg = pickle.dumps(client_grad)
        send_msg(conn, msg)
        optimizer.step()
        
        train_loss += loss.item()
        correct += torch.sum(output.argmax(dim=1) == label).item()
        total += len(label)
    train_losses.append(train_loss / total_batch)
    train_accs.append(correct / total)
    train_status = "loss: {:.4f}, acc: {:.2f}% / ".format(train_losses[-1], train_accs[-1]*100)
    print(train_status, end='')
        
    with torch.no_grad():  # calculate test accuracy
        test_loss = 0.0
        correct, total = 0, 0
        for j in range(total_batch):
            msg = recv_msg(conn)
            msg = pickle.loads(msg)
            client_output = msg['client_output']
            test_label = msg['label']
            client_output, test_label = client_output.to(device), test_label.to(device)
            test_output = ecg_server(client_output)
            loss = criterion(test_output, test_label)
            
            test_loss += loss.item()
            correct += torch.sum(test_output.argmax(dim=1) == test_label).item()
            total += len(test_label)
        test_losses.append(test_loss / total_batch)
        test_accs.append(correct / total)
        test_status = "test_loss: {:.4f}, test_acc: {:.2f}%".format(test_losses[-1], test_accs[-1]*100)
        print(test_status)
        
    if test_accs[-1] > best_test_acc:
        best_test_acc = test_accs[-1]
    
    msg = pickle.dumps(train_status + test_status)
    send_msg(conn, msg)