# ECG Split 1D-CNN Client Side

This code is the server part of ECG split 1D-CNN model for **single** client and a server.

## Import required packages

In [21]:
import os
import struct
import socket
import pickle
import time

import h5py
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD

## Define ECG dataset class

In [22]:
class ECG(Dataset):
    def __init__(self, train=True):
        if train:
            with h5py.File(os.path.join('mitdb', 'train_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_train'][:]
                self.y = hdf['y_train'][:]
        else:
            with h5py.File(os.path.join('mitdb', 'test_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_test'][:]
                self.y = hdf['y_test'][:]
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float), torch.tensor(self.y[idx])

### Set batch size

In [23]:
batch_size = 32

## Make train and test dataset batch generator

In [24]:
train_dataset = ECG(train=True)
test_dataset = ECG(train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

### Total number of batches

In [25]:
total_batch = len(train_loader)
print(total_batch)

414


## Define ECG client model
Client side has only **2 convolutional layers**.

In [26]:
class EcgClient(nn.Module):
    def __init__(self):
        super(EcgClient, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = nn.LeakyReLU()
        self.conv3 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu3 = nn.LeakyReLU()
        self.pool3 = nn.MaxPool1d(2)  # 32 x 16
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        x = x.view(-1, 32 * 16)
        return x

In [27]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

### Set random seed

In [28]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

### Assign intial weight as same as non-split model

In [29]:
ecg_client = EcgClient()
ecg_client.to(device)

checkpoint = torch.load("init_weight.pth")
ecg_client.conv1.weight.data = checkpoint["conv1.weight"]
ecg_client.conv1.bias.data = checkpoint["conv1.bias"]
ecg_client.conv2.weight.data = checkpoint["conv2.weight"]
ecg_client.conv2.bias.data = checkpoint["conv2.bias"]
ecg_client.conv3.weight.data = checkpoint["conv3.weight"]
ecg_client.conv3.bias.data = checkpoint["conv3.bias"]

### Set other hyperparameters in the model
Hyperparameters here should be same with the server side.

In [30]:
epoch = 400
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = Adam(ecg_client.parameters(), lr=lr)

## Socket initialization

### Required socket functions

In [31]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    return recvall(sock, msglen)

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

### Set host address and port number

In [32]:
host = 'localhost'
port = 10080
max_recv = 4096

### Open the client socket

In [33]:
s = socket.socket()
s.connect((host, port))

## Real training process

In [34]:
for e in range(epoch):
    print("Epoch {} - ".format(e+1), end='')
    
    for _, batch in enumerate(train_loader):
        x, label = batch
        x, label = x.to(device), label.to(device)
        optimizer.zero_grad()
        output = ecg_client(x)
        client_output = output.clone().detach().requires_grad_(True)
        msg = {
            'client_output': client_output,
            'label': label
        }
        msg = pickle.dumps(msg)
        send_msg(s, msg)
        msg = recv_msg(s)
        client_grad = pickle.loads(msg)
        output.backward(client_grad)
        optimizer.step()
            
    with torch.no_grad():  # calculate test accuracy
        for _, batch in enumerate(test_loader):
            x, label = batch
            x, label = x.to(device), label.to(device)
            client_output = ecg_client(x)
            msg = {
                'client_output': client_output,
                'label': label
            }
            msg = pickle.dumps(msg)
            send_msg(s, msg)
    
    msg = recv_msg(s)
    train_test_status = pickle.loads(msg)
    print(train_test_status)

Epoch 1 - loss: 1.3068, acc: 60.72% / test_loss: 1.1514, test_acc: 76.04%
Epoch 2 - loss: 1.1009, acc: 80.64% / test_loss: 1.0799, test_acc: 82.37%
Epoch 3 - loss: 1.0882, acc: 81.50% / test_loss: 1.0766, test_acc: 82.75%
Epoch 4 - loss: 1.0816, acc: 82.20% / test_loss: 1.0481, test_acc: 85.82%
Epoch 5 - loss: 1.0405, acc: 86.56% / test_loss: 1.0291, test_acc: 87.64%
Epoch 6 - loss: 1.0336, acc: 87.22% / test_loss: 1.0231, test_acc: 88.18%
Epoch 7 - loss: 1.0301, acc: 87.50% / test_loss: 1.0183, test_acc: 88.64%
Epoch 8 - loss: 1.0294, acc: 87.57% / test_loss: 1.0265, test_acc: 88.01%
Epoch 9 - loss: 1.0280, acc: 87.69% / test_loss: 1.0188, test_acc: 88.52%
Epoch 10 - loss: 1.0246, acc: 87.99% / test_loss: 1.0154, test_acc: 88.85%
Epoch 11 - loss: 1.0249, acc: 87.93% / test_loss: 1.0142, test_acc: 89.00%
Epoch 12 - loss: 1.0215, acc: 88.31% / test_loss: 1.0133, test_acc: 89.05%
Epoch 13 - loss: 1.0226, acc: 88.17% / test_loss: 1.0147, test_acc: 88.95%
Epoch 14 - loss: 1.0240, acc: 87.9

Epoch 111 - loss: 0.9696, acc: 93.18% / test_loss: 0.9711, test_acc: 93.00%
Epoch 112 - loss: 0.9668, acc: 93.46% / test_loss: 0.9712, test_acc: 92.97%
Epoch 113 - loss: 0.9666, acc: 93.47% / test_loss: 0.9724, test_acc: 92.91%
Epoch 114 - loss: 0.9679, acc: 93.34% / test_loss: 0.9708, test_acc: 93.05%
Epoch 115 - loss: 0.9670, acc: 93.39% / test_loss: 0.9722, test_acc: 92.87%
Epoch 116 - loss: 0.9691, acc: 93.20% / test_loss: 0.9726, test_acc: 92.84%
Epoch 117 - loss: 0.9682, acc: 93.33% / test_loss: 0.9715, test_acc: 92.95%
Epoch 118 - loss: 0.9663, acc: 93.48% / test_loss: 0.9726, test_acc: 92.83%
Epoch 119 - loss: 0.9671, acc: 93.40% / test_loss: 0.9713, test_acc: 92.97%
Epoch 120 - loss: 0.9660, acc: 93.55% / test_loss: 0.9846, test_acc: 91.63%
Epoch 121 - loss: 0.9689, acc: 93.23% / test_loss: 0.9738, test_acc: 92.74%
Epoch 122 - loss: 0.9694, acc: 93.20% / test_loss: 0.9748, test_acc: 92.67%
Epoch 123 - loss: 0.9662, acc: 93.48% / test_loss: 0.9708, test_acc: 93.05%
Epoch 124 - 

Epoch 219 - loss: 0.9166, acc: 98.84% / test_loss: 0.9277, test_acc: 97.73%
Epoch 220 - loss: 0.9201, acc: 98.47% / test_loss: 0.9244, test_acc: 98.06%
Epoch 221 - loss: 0.9178, acc: 98.70% / test_loss: 0.9252, test_acc: 97.98%
Epoch 222 - loss: 0.9179, acc: 98.70% / test_loss: 0.9301, test_acc: 97.43%
Epoch 223 - loss: 0.9180, acc: 98.67% / test_loss: 0.9247, test_acc: 97.98%
Epoch 224 - loss: 0.9163, acc: 98.86% / test_loss: 0.9235, test_acc: 98.15%
Epoch 225 - loss: 0.9147, acc: 99.00% / test_loss: 0.9214, test_acc: 98.35%
Epoch 226 - loss: 0.9152, acc: 98.96% / test_loss: 0.9242, test_acc: 98.06%
Epoch 227 - loss: 0.9166, acc: 98.84% / test_loss: 0.9274, test_acc: 97.75%
Epoch 228 - loss: 0.9200, acc: 98.48% / test_loss: 0.9256, test_acc: 97.92%
Epoch 229 - loss: 0.9179, acc: 98.68% / test_loss: 0.9286, test_acc: 97.63%
Epoch 230 - loss: 0.9174, acc: 98.72% / test_loss: 0.9222, test_acc: 98.26%
Epoch 231 - loss: 0.9156, acc: 98.93% / test_loss: 0.9222, test_acc: 98.26%
Epoch 232 - 

Epoch 327 - loss: 0.9155, acc: 98.93% / test_loss: 0.9233, test_acc: 98.13%
Epoch 328 - loss: 0.9138, acc: 99.11% / test_loss: 0.9233, test_acc: 98.15%
Epoch 329 - loss: 0.9150, acc: 98.97% / test_loss: 0.9229, test_acc: 98.20%
Epoch 330 - loss: 0.9166, acc: 98.82% / test_loss: 0.9255, test_acc: 97.92%
Epoch 331 - loss: 0.9149, acc: 99.00% / test_loss: 0.9232, test_acc: 98.17%
Epoch 332 - loss: 0.9163, acc: 98.86% / test_loss: 0.9242, test_acc: 98.06%
Epoch 333 - loss: 0.9165, acc: 98.84% / test_loss: 0.9237, test_acc: 98.11%
Epoch 334 - loss: 0.9155, acc: 98.93% / test_loss: 0.9253, test_acc: 97.94%
Epoch 335 - loss: 0.9152, acc: 98.95% / test_loss: 0.9232, test_acc: 98.17%
Epoch 336 - loss: 0.9144, acc: 99.03% / test_loss: 0.9285, test_acc: 97.63%
Epoch 337 - loss: 0.9175, acc: 98.73% / test_loss: 0.9240, test_acc: 98.06%
Epoch 338 - loss: 0.9169, acc: 98.78% / test_loss: 0.9274, test_acc: 97.73%
Epoch 339 - loss: 0.9156, acc: 98.94% / test_loss: 0.9223, test_acc: 98.25%
Epoch 340 - 

In [35]:
print('Finished Training!')
print('Result is on the server side.')

Finished Training!
Result is on the server side.
