# ECG Split 1D-CNN Client Side

This code is the server part of ECG split 1D-CNN model for **single** client and a server.

## Import required packages

In [47]:
from typing import List
import os
import struct
import socket
import pickle
import time
from pathlib import Path
from icecream import ic

import h5py
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD

import tenseal as ts
from tenseal.tensors.ckksvector import CKKSVector
from tenseal.enc_context import Context

print(f'torch version: {torch.__version__}')
print(f'tenseal version: {ts.__version__}')

torch version: 1.8.1+cu102
tenseal version: 0.3.5


## Define ECG dataset class

In [2]:
project_path = Path.cwd().parent.parent
print(f'project_path: {project_path}')

project_path: /home/dk/Desktop/split-learning-1D-HE


In [31]:
# paths to files and directories
data_dir = 'data'  # used to be 'mitdb'
train_name = 'train_ecg.hdf5'
test_name = 'test_ecg.hdf5'

In [8]:
class ECG(Dataset):
    def __init__(self, train=True):
        if train:
            with h5py.File(project_path/data_dir/train_name, 'r') as hdf:
                self.x = hdf['x_train'][:]
                self.y = hdf['y_train'][:]
        else:
            with h5py.File(project_path/data_dir/test_name, 'r') as hdf:
                self.x = hdf['x_test'][:]
                self.y = hdf['y_test'][:]
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float), torch.tensor(self.y[idx])

### Set batch size

In [9]:
batch_size = 32

## Make train and test dataset batch generator

In [10]:
train_dataset = ECG(train=True)
test_dataset = ECG(train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [12]:
for i, b in enumerate(train_loader):
    ic(b[0].shape, b[1])
    if i == 1: break

ic| b[0].shape: torch.Size([32, 1, 128])
    b[1]: tensor([2, 4, 2, 4, 2, 1, 1, 1, 2, 1, 0, 0, 2, 4, 3, 1, 2, 4, 1, 0, 1, 0, 4, 2,
                  4, 3, 0, 1, 2, 4, 4, 2])
ic| b[0].shape: torch.Size([32, 1, 128])
    b[1]: tensor([0, 4, 3, 4, 2, 0, 1, 4, 1, 0, 3, 4, 2, 4, 1, 0, 2, 0, 0, 2, 0, 1, 1, 0,
                  2, 4, 4, 2, 4, 4, 1, 4])


### Total number of batches

In [13]:
total_batch = len(train_loader)
print(total_batch)

414


## Making TenSeal context

The client makes the context to homomorphically encrypt and decrypt the data

In [48]:
def context(poly_modulus_degree: int, 
            coeff_mod_bit_sizes: List[int], 
            glob_scale: int) -> Context:
    context = ts.context(
        ts.SCHEME_TYPE.CKKS, 
        poly_modulus_degree=poly_modulus_degree, 
        coeff_mod_bit_sizes=coeff_mod_bit_sizes
    )
    context.global_scale = glob_scale
    context.generate_galois_keys()
    return context

context = context(8192, [60, 40, 40, 60], pow(2, 40))

## Define ECG client model
Client side has only **2 convolutional layers**.

In [18]:
class EcgClient(nn.Module):
    # will be sent to the client
    def __init__(self, context: Context):
        super(EcgClient, self).__init__()
        self.conv1 = self.torch_ref.nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = self.torch_ref.nn.LeakyReLU()
        self.pool1 = self.torch_ref.nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = self.torch_ref.nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = self.torch_ref.nn.LeakyReLU()
        self.pool2 = self.torch_ref.nn.MaxPool1d(2)  # 32 x 16
        
        self.load_init_weights()
        self.context = context
    
    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(-1, 32 * 16)
        # x is a syft's TensorPointer. x.get() returns torch.Tensor of size [1, 512]
        enc_x: CKKSTensor = self.encrypt_activations(x)  # enc_x is a list of 512 elements
        return enc_x
    
    def load_init_weights(self):
        checkpoint = torch.load("init_weight.pth")
        self.conv1.weight.data = checkpoint["conv1.weight"]
        self.conv1.bias.data = checkpoint["conv1.bias"]
        self.conv2.weight.data = checkpoint["conv2.weight"]
        self.conv2.bias.data = checkpoint["conv2.bias"]
    
    def encrypt_activations(self, x: torch.Tensor):
        enc_x: CKKSVector = ts.ckks_vector(self.context, x.tolist()[0])
        return enc_x

ecg_client = EcgClient(torch_ref=torch, context=context)        

In [19]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name(0)

'NVIDIA GeForce GTX 1070 Ti'

### Set random seed

In [20]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

### Assign intial weight as same as non-split model

In [21]:
ecg_client = EcgClient()
checkpoint = torch.load("init_weight.pth")
ecg_client.conv1.weight.data = checkpoint["conv1.weight"]
ecg_client.conv1.bias.data = checkpoint["conv1.bias"]
ecg_client.conv2.weight.data = checkpoint["conv2.weight"]
ecg_client.conv2.bias.data = checkpoint["conv2.bias"]
ecg_client.to(device)

EcgClient(
  (conv1): Conv1d(1, 16, kernel_size=(7,), stride=(1,), padding=(3,))
  (relu1): LeakyReLU(negative_slope=0.01)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu2): LeakyReLU(negative_slope=0.01)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

### Set other hyperparameters in the model
Hyperparameters here should be same with the server side.

In [22]:
epoch = 400
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = Adam(ecg_client.parameters(), lr=lr)

## Socket initialization

### Required socket functions

In [23]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    return recvall(sock, msglen)

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

### Set host address and port number

In [24]:
host = 'localhost'
port = 10080
max_recv = 4096

### Open the client socket

In [26]:
s = socket.socket()
s.connect((host, port))

In [27]:
type(s)

socket.socket

## Real training process

In [43]:
dry_run = True # break after 2 batches for 2 epoch

train_losses = list()
train_accs = list()
test_losses = list()
test_accs = list()
best_test_acc = 0  # best test accuracy

for e in range(epoch):
    print("Epoch {} - ".format(e+1), end='')
    
    # training loop
    for i, batch in enumerate(train_loader):
        x, y_gt = batch
        x, y_gt = x.to(device), y_gt.to(device)
        optimizer.zero_grad()
        y_hat = ecg_client(x)
        
        ic(y_hat)
        if dry_run:
            if i == 1: break
    
    if dry_run:
        if e == 2: break

Epoch 1 - 

ic| y_hat: tensor([[0.1132, 0.1128, 0.0665,  ..., 0.0355, 0.0208, 0.0751],
                   [0.1149, 0.1061, 0.0750,  ..., 0.0932, 0.0821, 0.1289],
                   [0.1143, 0.1103, 0.0829,  ..., 0.0584, 0.0451, 0.0943],
                   ...,
                   [0.1159, 0.1033, 0.0660,  ..., 0.0716, 0.0796, 0.1377],
                   [0.1165, 0.1005, 0.0622,  ..., 0.1024, 0.0983, 0.1450],
                   [0.1145, 0.1098, 0.0829,  ..., 0.0214, 0.0144, 0.0655]],
                  device='cuda:0', grad_fn=<ViewBackward>)
ic| y_hat: tensor([[ 1.1821e-01,  9.4653e-02,  5.2847e-02,  ...,  5.0661e-02,
                     3.8753e-02,  9.4769e-02],
                   [ 1.2001e-01,  1.0305e-01,  7.0538e-02,  ..., -1.1005e-05,
                    -1.2015e-04,  2.3481e-02],
                   [ 1.1315e-01,  1.1298e-01,  8.3798e-02,  ...,  4.6220e-02,
                     4.3631e-02,  1.0562e-01],
                   ...,
                   [ 1.1760e-01,  1.1810e-01,  9.9960e-02,  ...,  9

Epoch 2 - 

ic| y_hat: tensor([[0.1132, 0.1128, 0.0665,  ..., 0.0355, 0.0208, 0.0751],
                   [0.1149, 0.1061, 0.0750,  ..., 0.0932, 0.0821, 0.1289],
                   [0.1143, 0.1103, 0.0829,  ..., 0.0584, 0.0451, 0.0943],
                   ...,
                   [0.1159, 0.1033, 0.0660,  ..., 0.0716, 0.0796, 0.1377],
                   [0.1165, 0.1005, 0.0622,  ..., 0.1024, 0.0983, 0.1450],
                   [0.1145, 0.1098, 0.0829,  ..., 0.0214, 0.0144, 0.0655]],
                  device='cuda:0', grad_fn=<ViewBackward>)
ic| y_hat: tensor([[ 1.1821e-01,  9.4653e-02,  5.2847e-02,  ...,  5.0661e-02,
                     3.8753e-02,  9.4769e-02],
                   [ 1.2001e-01,  1.0305e-01,  7.0538e-02,  ..., -1.1005e-05,
                    -1.2015e-04,  2.3481e-02],
                   [ 1.1315e-01,  1.1298e-01,  8.3798e-02,  ...,  4.6220e-02,
                     4.3631e-02,  1.0562e-01],
                   ...,
                   [ 1.1760e-01,  1.1810e-01,  9.9960e-02,  ...,  9

Epoch 3 - 

ic| y_hat: tensor([[0.1132, 0.1128, 0.0665,  ..., 0.0355, 0.0208, 0.0751],
                   [0.1149, 0.1061, 0.0750,  ..., 0.0932, 0.0821, 0.1289],
                   [0.1143, 0.1103, 0.0829,  ..., 0.0584, 0.0451, 0.0943],
                   ...,
                   [0.1159, 0.1033, 0.0660,  ..., 0.0716, 0.0796, 0.1377],
                   [0.1165, 0.1005, 0.0622,  ..., 0.1024, 0.0983, 0.1450],
                   [0.1145, 0.1098, 0.0829,  ..., 0.0214, 0.0144, 0.0655]],
                  device='cuda:0', grad_fn=<ViewBackward>)
ic| y_hat: tensor([[ 1.1821e-01,  9.4653e-02,  5.2847e-02,  ...,  5.0661e-02,
                     3.8753e-02,  9.4769e-02],
                   [ 1.2001e-01,  1.0305e-01,  7.0538e-02,  ..., -1.1005e-05,
                    -1.2015e-04,  2.3481e-02],
                   [ 1.1315e-01,  1.1298e-01,  8.3798e-02,  ...,  4.6220e-02,
                     4.3631e-02,  1.0562e-01],
                   ...,
                   [ 1.1760e-01,  1.1810e-01,  9.9960e-02,  ...,  9

In [15]:
print('Finished Training!')
print('Result is on the server side.')

Finished Training!
Result is on the server side.
