# ECG Split 1D-CNN Client Side

This code is the server part of ECG split 1D-CNN model for **single** client and a server.

## Import required packages

In [5]:
mode = 'CIFAR10'

import os
import struct
import socket
import pickle
import time

import h5py
import numpy as np

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD
from torchvision.datasets import CIFAR10

## Define ECG dataset class

In [6]:
if mode == 'ECG':
    class ECG(Dataset):
        def __init__(self, train=True):
            if train:
                with h5py.File(os.path.join('../../', 'mitdb', 'train_ecg.hdf5'), 'r') as hdf:
                    self.x = hdf['x_train'][:]
                    self.y = hdf['y_train'][:]
            else:
                with h5py.File(os.path.join('../../', 'mitdb', 'test_ecg.hdf5'), 'r') as hdf:
                    self.x = hdf['x_test'][:]
                    self.y = hdf['y_test'][:]
        
        def __len__(self):
            return len(self.x)
        
        def __getitem__(self, idx):
            return torch.tensor(self.x[idx], dtype=torch.float), torch.tensor(self.y[idx])
        
    batch_size = 32
    train_dataset = ECG(train=True)
    test_dataset = ECG(train=False)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    total_batch = len(train_loader)
    print(total_batch)
elif mode == 'CIFAR10':
    root_path = '../../data/'  # Replace with actual path to CIFAR10 data

    batch_size = 32
    cifar10_mean = (0.485, 0.456, 0.406)
    cifar10_std = (0.229, 0.224, 0.225)

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    train_dataset = CIFAR10(root=root_path, train=True, download=True, transform=transform)
    test_dataset = CIFAR10(root=root_path, train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    print("Loaded CIFAR10 datasets with batch size", batch_size)

Files already downloaded and verified
Files already downloaded and verified
Loaded CIFAR10 datasets with batch size 32


In [7]:
from torchvision.models import resnet18, vgg16

class EcgClient(nn.Module):
    def __init__(self):
        super(EcgClient, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(2)  # 32 x 16
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(-1, 32 * 16)
        return x
class VGG16_Any_Cut_Client(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cut_layer = config["cut_layer"]

        self.model = vgg16(weights=None)

        features = list(self.model.features.children())
        classifier = list(self.model.classifier.children())
        self.model = nn.Sequential(*(features + classifier))

    def forward(self, x):
        for i, layer in enumerate(self.model):
            if i > self.cut_layer:
                break
            else:
                x = layer(x)
        return x

class AlexNetClient(nn.Module):
    def __init__(self):
        super(AlexNetClient, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )

    def forward(self, x):
        x = self.features(x)
        return x

class Resnet18_Any_Cut_Client(nn.Module): # M
    def __init__(self, config):
        super().__init__()
        self.cut_layer = config["cut_layer"]

        self.model = resnet18(weights = None)

        self.model = nn.ModuleList(self.model.children())
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        for i, layer in enumerate(self.model):
            if i > self.cut_layer:
                break
            else:
                x = layer(x)

        return x  

In [8]:
device = torch.device('cuda:0')
torch.cuda.get_device_name(0)
assert('cuda' in device.type)
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

cut_layer = 3
logits = 10

config = {"cut_layer":cut_layer, "logits":logits}
ecg_client = AlexNetClient()
ecg_client.to(device)

# checkpoint = torch.load(f"../ecg_2conv/init_weight{mode}.pth")
# ecg_client.conv1.weight.data = checkpoint["conv1.weight"]
# ecg_client.conv1.bias.data = checkpoint["conv1.bias"]
# ecg_client.conv2.weight.data = checkpoint["conv2.weight"]
# ecg_client.conv2.bias.data = checkpoint["conv2.bias"]
epoch = 400
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = Adam(ecg_client.parameters(), lr=lr)
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    return recvall(sock, msglen)

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data
host = 'localhost'
port = 10080
max_recv = 4096
s = socket.socket()
s.connect((host, port))

Training Process

In [9]:
from tqdm import tqdm
for e in range(epoch):
    print("Epoch {} - ".format(e+1), end='')
    
    for _, batch in tqdm(enumerate(train_loader)):
        x, label = batch
        x, label = x.to(device), label.to(device)
        optimizer.zero_grad()
        ecg_client = ecg_client.to(device)
        output = ecg_client(x)
        client_output = output.clone().detach().requires_grad_(True)
        msg = {
            'client_output': client_output,
            'label': label
        }
        msg = pickle.dumps(msg)
        send_msg(s, msg)
        msg = recv_msg(s)
        client_grad = pickle.loads(msg)
        output.backward(client_grad)
        optimizer.step()

    # calculate test accuracy
    with torch.no_grad():
        for _, batch in enumerate(test_loader):
            x, label = batch
            x, label = x.to(device), label.to(device)
            client_output = ecg_client(x)
            msg = {
                'client_output': client_output,
                'label': label
            }
            msg = pickle.dumps(msg)
            send_msg(s, msg)
    
    msg = recv_msg(s)
    train_test_status = pickle.loads(msg)
    print(train_test_status)

Epoch 1 - 

1563it [00:15, 101.94it/s]


loss: 1.6747, acc: 35.32% / test_loss: 0.2795, test_acc: 48.48%
Epoch 2 - 

1563it [00:14, 106.34it/s]


loss: 1.2900, acc: 53.77% / test_loss: 0.2339, test_acc: 58.39%
Epoch 3 - 

1563it [00:14, 107.33it/s]


loss: 1.1324, acc: 60.20% / test_loss: 0.2249, test_acc: 61.40%
Epoch 4 - 

1563it [00:14, 106.57it/s]


loss: 1.0315, acc: 63.73% / test_loss: 0.2235, test_acc: 62.51%
Epoch 5 - 

1563it [00:14, 106.82it/s]


loss: 0.9661, acc: 66.24% / test_loss: 0.2242, test_acc: 63.38%
Epoch 6 - 

1563it [00:14, 106.70it/s]


loss: 0.9013, acc: 68.68% / test_loss: 0.2295, test_acc: 63.76%
Epoch 7 - 

1563it [00:14, 106.96it/s]


loss: 0.8553, acc: 70.39% / test_loss: 0.2117, test_acc: 64.76%
Epoch 8 - 

1563it [00:14, 106.86it/s]


loss: 0.8076, acc: 71.95% / test_loss: 0.2191, test_acc: 65.45%
Epoch 9 - 

1563it [00:14, 106.02it/s]


loss: 0.7699, acc: 73.43% / test_loss: 0.2259, test_acc: 65.35%
Epoch 10 - 

1563it [00:14, 106.70it/s]


loss: 0.7487, acc: 74.09% / test_loss: 0.2201, test_acc: 66.64%
Epoch 11 - 

1563it [00:14, 106.96it/s]


loss: 0.7088, acc: 75.83% / test_loss: 0.2283, test_acc: 66.27%
Epoch 12 - 

1563it [00:14, 106.59it/s]


loss: 0.6851, acc: 76.29% / test_loss: 0.2405, test_acc: 66.48%
Epoch 13 - 

1563it [00:14, 107.24it/s]


loss: 0.6550, acc: 77.57% / test_loss: 0.2363, test_acc: 66.90%
Epoch 14 - 

1563it [00:14, 106.72it/s]


loss: 0.6220, acc: 78.60% / test_loss: 0.2327, test_acc: 67.34%
Epoch 15 - 

1563it [00:14, 107.21it/s]


loss: 0.5927, acc: 79.85% / test_loss: 0.2513, test_acc: 67.18%
Epoch 16 - 

1563it [00:14, 106.81it/s]


loss: 0.5683, acc: 80.68% / test_loss: 0.2561, test_acc: 66.88%
Epoch 17 - 

1563it [00:14, 107.46it/s]


loss: 0.5517, acc: 81.27% / test_loss: 0.2553, test_acc: 66.66%
Epoch 18 - 

1563it [00:14, 107.33it/s]


loss: 0.5268, acc: 82.22% / test_loss: 0.2757, test_acc: 65.62%
Epoch 19 - 

1563it [00:14, 107.12it/s]


loss: 0.5181, acc: 82.63% / test_loss: 0.2778, test_acc: 66.59%
Epoch 20 - 

1563it [00:14, 107.30it/s]


loss: 0.4978, acc: 83.29% / test_loss: 0.2767, test_acc: 67.57%
Epoch 21 - 

1563it [00:14, 107.33it/s]


loss: 0.4918, acc: 83.82% / test_loss: 0.2823, test_acc: 67.79%
Epoch 22 - 

1563it [00:14, 107.30it/s]


loss: 0.4601, acc: 84.77% / test_loss: 0.2937, test_acc: 66.61%
Epoch 23 - 

1563it [00:14, 105.74it/s]


KeyboardInterrupt: 

In [None]:
print('Finished Training!')
print('Result is on the server side.')

Finished Training!
Result is on the server side.
