# Training split 1D CNN on homomorphic encrypted ECG data

Package versions:  
`torch`: 1.8.1+cu102  
`pysyft`: 0.5.0  
`tenseal`: 0.3.5  

In [1]:
from typing import List

from pathlib import Path
import h5py
import numpy as np
import matplotlib.pyplot as plt
from icecream import ic
from tqdm import tqdm
from time import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD

import syft as sy
from syft.core.node.vm.vm import VirtualMachine
from syft.core.node.vm.client import VirtualMachineClient
from syft.ast.module import Module
from syft.core.remote_dataloader import RemoteDataLoader
from syft.core.remote_dataloader import RemoteDataset

import tenseal as ts
from tenseal.tensors.ckksvector import CKKSVector
from tenseal.enc_context import Context

print(f'torch version: {torch.__version__}')
print(f'syft version: {sy.__version__}')
print(f'tenseal version: {ts.__version__}')


torch version: 1.8.1+cu102
syft version: 0.5.0
tenseal version: 0.3.5


## Files and Directories

In [2]:
project_path = Path.cwd().parent.parent
print(f'project_path: {project_path}')

project_path: /mnt/batch/tasks/shared/LS_root/mounts/clusters/teslak80-56gbram/code/Users/dkn.work/split-learning-he


In [3]:
# paths to files and directories
data_dir = 'data'  # used to be 'mitdb'
train_name = 'train_ecg.hdf5'
test_name = 'test_ecg.hdf5'
dry_run = True  # load less data (50 examples)

## Construct the client and server

In [4]:
server: VirtualMachine = sy.VirtualMachine(name="server")
client: VirtualMachineClient = server.get_root_client()
remote_torch: Module = client.torch
remote_torch

Module:
	.Tensor -> <syft.ast.klass.Class object at 0x7f6b442c31c0>
	.BFloat16Tensor -> <syft.ast.klass.Class object at 0x7f6b442c3220>
	.BoolTensor -> <syft.ast.klass.Class object at 0x7f6b442c3280>
	.ByteTensor -> <syft.ast.klass.Class object at 0x7f6b442c32e0>
	.CharTensor -> <syft.ast.klass.Class object at 0x7f6b442c3340>
	.DoubleTensor -> <syft.ast.klass.Class object at 0x7f6b442c33a0>
	.FloatTensor -> <syft.ast.klass.Class object at 0x7f6b442c3400>
	.HalfTensor -> <syft.ast.klass.Class object at 0x7f6b442c3460>
	.IntTensor -> <syft.ast.klass.Class object at 0x7f6b442c34c0>
	.LongTensor -> <syft.ast.klass.Class object at 0x7f6b442c3520>
	.ShortTensor -> <syft.ast.klass.Class object at 0x7f6b442c3580>
	.nn -> Module:
		.Parameter -> <syft.ast.klass.Class object at 0x7f6b442c36a0>
		.Module -> <syft.ast.klass.Class object at 0x7f6b440f1b20>
		.Conv2d -> <syft.ast.klass.Class object at 0x7f6b440f5040>
		.Dropout2d -> <syft.ast.klass.Class object at 0x7f6b440f5460>
		.Linear -> <syft.

## Client: preparing the dataset (only for training now)

In [5]:
class ECG(Dataset):
    # The class used to load the ECG dataset
    def __init__(self, mode='train'):
        if mode == 'train':
            with h5py.File(project_path/data_dir/train_name, 'r') as hdf:
                if dry_run:
                    self.x = torch.tensor(hdf['x_train'][:50], dtype=torch.float)
                    self._y = torch.tensor(hdf['y_train'][:50])
                else:
                    self.x = torch.tensor(hdf['x_train'][:], dtype=torch.float)
                    self._y = torch.tensor(hdf['y_train'][:])
        elif mode == 'test':
            with h5py.File(project_path/data_dir/test_name, 'r') as hdf:
                if dry_run:
                    self.x = torch.tensor(hdf['x_test'][:50], dtype=torch.float)
                    self._y = torch.tensor(hdf['y_test'][:50])
                else:
                    self.x = torch.tensor(hdf['x_test'][:], dtype=torch.float)
                    self._y = torch.tensor(hdf['y_test'][:])
        else:
            raise ValueError('Argument of mode should be train or test')

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx]  # only returns the input data this time
    
    def encrypt_y(self, context: Context):
        encrypted_y: List[CKKSVector] = [ts.ckks_vector(context, [y.tolist()]) for y in self._y]
        return encrypted_y

train_dataset = ECG(mode='train')

Make a tenseal context to encrypt the ground-truth output

In [6]:
# parameters
poly_mod_degree = 4096
coeff_mod_bit_sizes = [40, 20, 40]
# create TenSEALContext
context = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
# scale of ciphertext to use
context.global_scale = 2 ** 20
# this key is needed for doing dot-product operations
context.generate_galois_keys()

In [7]:
enc_train_y: List = train_dataset.encrypt_y(context)

The client creates the Dataset object and save it in a `.pt` file. If using `duet`, he can send the string path to the server using `sy.lib.python.String(string_path).send(duet, pointable=True, tags=["data"])`. From the `.pt` file, the server can point to the dataset, but he needs to ask for permissions if he wants to access the data

In [8]:
torch.save(train_dataset, "train_dataset.pt")

## Server: creating the remote dataset and dataloader for the train dataset

In [9]:
train_rds = RemoteDataset(path='train_dataset.pt', data_type="torch_tensor")
train_rds

<class 'syft.core.remote_dataloader.remote_dataloader.RemoteDataset'>: torch_tensor

From the remote dataset, the server constructs the data loader. Then the server uses `.send`
to create a pointer to do remote data loading

In [10]:
# we need to use batch_size 1 (for now) because of training on encrypted data
train_rdl = RemoteDataLoader(remote_dataset=train_rds, batch_size=1)
train_rdl_ptr = train_rdl.send(client)
ic(train_rdl, train_rdl_ptr)
# call create_dataset to create the real Dataset object on remote side
train_rdl_ptr.load_dataset()
# call create_dataloader to create the real DataLoader object on remote side
train_rdl_ptr.create_dataloader()
ic(len(train_rdl_ptr))

ic| train_rdl: <syft.core.remote_dataloader.remote_dataloader.RemoteDataLoader object at 0x7f6b43e904c0>
    train_rdl_ptr: <syft.proxy.syft.core.remote_dataloader.RemoteDataLoaderPointer object at 0x7f6b43e90760>
ic| len(train_rdl_ptr): 50


50

Let's play with the remote dataloader and tenseal encrypted vectors

In [38]:
def context():
    context = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
    context.global_scale = pow(2, 40)
    context.generate_galois_keys()
    return context

context2 = context()

In [12]:
for i, b in enumerate(zip(train_rdl_ptr, enc_train_y)):
    if i==0:
        x, enc_y = b[0], b[1]
        ic(x, x.get_copy().shape, enc_y, enc_y.decrypt(secret_key=context.secret_key()))

ic| x: <syft.proxy.syft.lib.misc.union.FloatIntStringTensorParameterUnionPointer object at 0x7f6b3fc24be0>
    x.get_copy().shape: torch.Size([1, 1, 128])
    enc_y: <tenseal.tensors.ckksvector.CKKSVector object at 0x7f6c01c59790>
    enc_y.decrypt(secret_key=context.secret_key()): [2.0004831754342143]


What if we decrypt with a wrong context key?

In [16]:
for i, b in enumerate(zip(train_rdl_ptr, enc_train_y)):
    if i==0:
        x, enc_y = b[0], b[1]
        try:
            ic(x, x.get_copy().shape, enc_y, 
               enc_y.decrypt(secret_key=context2.secret_key()))
        except Exception as e:
            ic(e)

ic| e: ValueError('secret key is not valid for encryption parameters')


## Server: define the spit neural network used to train on the ECG dataset

Client's side contains conv layers, trained on plaintext input data

In [39]:
class EcgClient(sy.Module):
    # will be sent to the client
    def __init__(self, torch_ref, context: Context):
        super(EcgClient, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = self.torch_ref.nn.LeakyReLU()
        self.pool1 = self.torch_ref.nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = self.torch_ref.nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = self.torch_ref.nn.LeakyReLU()
        self.pool2 = self.torch_ref.nn.MaxPool1d(2)  # 32 x 16
        
        self.load_init_weights()
        self.context = context
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(-1, 32 * 16)
        # x is a syft's TensorPointer. x.get() returns torch.Tensor of size [1, 512]
        enc_x = self.encrypt_activations(x.get())  # enc_x is a list of 512 elements
        return enc_x
    
    def load_init_weights(self):
        checkpoint = torch.load("init_weight.pth")
        self.conv1.weight.data = checkpoint["conv1.weight"]
        self.conv1.bias.data = checkpoint["conv1.bias"]
        self.conv2.weight.data = checkpoint["conv2.weight"]
        self.conv2.bias.data = checkpoint["conv2.bias"]
    
    def encrypt_activations(self, x: torch.Tensor):
        enc_x: CKKSVector = ts.ckks_vector(self.context, x.tolist()[0])
        return enc_x

ecg_client = EcgClient(torch_ref=torch, context=context2)
ecg_client_ptr = ecg_client.send(client)
        

Let's try to do a forward pass on the client's model

In [40]:
checkpoint = torch.load("init_weight.pth")
linear3_weight = checkpoint["linear3.weight"]  # torch.Tensor size [128, 512]
linear3_bias = checkpoint["linear3.bias"]  # torch.Tensor size [128]
linear4_weight = checkpoint["linear4.weight"]  # torch.Tensor size [5, 128]
linear4_bias = checkpoint["linear4.bias"]  # torch.Tensor size [5]

In [41]:
for i, b in enumerate(zip(train_rdl_ptr, enc_train_y)):
    if i==0:
        x_ptr, enc_y = b[0], b[1]
        # ic(x_ptr, x_ptr.get_copy().shape, enc_y)
        enc_activs: CKKSVector = ecg_client_ptr(x_ptr)
        ic(enc_activs.size())
        ic(enc_y.size())

ic| enc_activs.size(): 512
ic| enc_y.size(): 1


Server's side contains fully connected layers, trained on HE activation maps

In [89]:
class EcgServer(sy.Module):
    def __init__(self, torch_ref):
        super(EcgServer, self).__init__(torch_ref=torch_ref)
        self.linear3 = nn.Linear(32 * 16, 128)
        self.relu3 = nn.LeakyReLU() 
        self.linear4 = nn.Linear(128, 5)
#         self.softmax4 = nn.Softmax(dim=1)

        checkpoint = torch.load("init_weight.pth")
        self.linear3_weight = checkpoint["linear3.weight"]  # torch.Tensor size [128, 512]
        self.linear3_bias = checkpoint["linear3.bias"]  # torch.Tensor size [128]
        self.linear4_weight = checkpoint["linear4.weight"]  # torch.Tensor size [5, 128]
        self.linear4_bias = checkpoint["linear4.bias"]  # torch.Tensor size [5]
    
    @staticmethod
    def approx_leaky_relu(enc_x: CKKSVector):
        # 2.368475785867e-19*x**5 - 0.000252624921308674*x**4 - 
        # 2.90138283768708e-17*x**3 + 0.0660873211772537*x**2 + 
        # 0.500000000000001*x + 0.862730150341736
        
        # 2.30556314780491e-19*x**5 - 0.000250098672095587*x**4 - 
        # 2.83384427035571e-17*x**3 + 0.0654264479654812*x**2 + 
        # 0.505000000000001*x + 0.854102848838318
        return enc_x.polyval([0.854102848838318, 0.505000000000001, 
                              0.0654264479654812, 2.83384427035571**-17])
#         return enc_x.polyval([0.5, 0.197, 0, -0.004])

    def approx_softmax():
        raise NotImplementedError

    def forward(self, enc_x: CKKSVector):
        x = self.linear3(enc_x)
        x = self.relu3(x)
        x = self.linear4(x)
#         x = self.softmax4(x)
        return x

ecg_server = EcgServer(torch_ref=torch)

In [90]:
x = ts.plain_tensor(tensor=[-10, 0.2, 0.3, 0.4], shape=[2,2])
enc_x = ts.ckks_tensor(context2, x)
approx_relu_enc_x = EcgServer.approx_leaky_relu(enc_x)
lr = nn.LeakyReLU(0.1)

print(f"X = {enc_x.decrypt().tolist()}")
print(f"approximating leaky relu of X = {approx_relu_enc_x.decrypt().tolist()}.")
print(f'Output of the real leaky relu function: \n {lr(torch.tensor(x.tolist()))}')

X = [[-10.00000000078723, 0.2000000013982431], [0.300000000690922, 0.3999999988392446]]
approximating leaky relu of X = [[2.3467317631676647, 0.9577199245346658], [1.0114912549774142, 1.0665711164118634]].
Output of the real leaky relu function: 
 tensor([[-1.0000,  0.2000],
        [ 0.3000,  0.4000]])


## Server: training process

Some hyper-parameters

In [18]:
total_batch = 414  # 32*414=13248. We have 13245 data samples

epoch = 400
criterion = nn.CrossEntropyLoss()
lr = 0.001

optim_client = remote_torch.optim.Adam(params=ecg_client_ptr.parameters(), lr=lr)
optim_server = torch.optim.Adam(params=ecg_server.parameters(), lr=lr)

seed = 0  # the meaning of life
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
remote_torch.manual_seed(seed)

Training (with CPU)

In [19]:
train_losses = list()
train_accs = list()
test_losses = list()
test_accs = list()
best_test_acc = 0  # best test accuracy
for e in range(epoch):
    print(f"Epoch {e+1} - train ", end='')
    
    train_loss = 0.0
    correct, total = 0, 0
    for i, batch in enumerate(tqdm(train_rdl_ptr)):
        x_ptr, y_gt_ptr = batch[0], batch[1]
        # ic(x.get_copy(), y.get_copy())
        # initialize all gradients to zero
        optim_server.zero_grad()
        optim_client.zero_grad()
        # compute and get the activation signals from the first half of the network
        activs_ptr = ecg_client_ptr(x_ptr)
        # the server still gets access to plain activation signals
        activs = activs_ptr.clone().get(request_block=True)
        # the server continues the forward pass on the activation maps
        y_hat = ecg_server(activs)
        # the server asks to access ground truths in plain text
        y_gt = y_gt_ptr.get_copy()
        # calculates cross-entropy loss
        loss = criterion(y_hat, y_gt)
        train_loss += loss.item()
        correct += torch.sum(y_hat.argmax(dim=1) == y_gt).item()
        # backward propagation (calculating gradients of the loss w.r.t the weights)
        loss.backward()
        # send the gradients to the client
        client_grad_ptr = activs.grad.clone().send(client)
        # update the gradients of the client's model
        activs_ptr.backward(client_grad_ptr)
        # update the weights based on the gradients
        optim_client.step()
        optim_server.step()
        total += len(y_gt)

    train_losses.append(train_loss / total_batch)
    train_accs.append(correct / total)

    print(f'loss: {train_losses[-1]: .4f}, accuracy: {train_accs[-1]*100: 2f}')

    # testing
    with torch.no_grad():  
        test_loss = 0.0
        correct, total = 0, 0
        for i, batch in enumerate(tqdm(test_rdl_ptr)):
            x_ptr, y_gt_ptr = batch[0], batch[1]
            # forward pass
            activs_ptr = ecg_client_ptr(x_ptr)
            activs = activs_ptr.clone().get(request_block=True)
            y_hat = ecg_server(activs)
            # the server asks to access ground truths in plain text
            y_gt = y_gt_ptr.get_copy()
            # calculate test loss
            loss = criterion(y_hat, y_gt)
            test_loss += loss.item()
            correct += torch.sum(y_hat.argmax(dim=1) == y_gt).item()
            total += len(y_gt)

        test_losses.append(test_loss / total_batch)
        test_accs.append(correct / total)
        print(f'test_loss: {test_losses[-1]: .4f}, test_acc: {test_accs[-1]*100: 2f}')
        
    if test_accs[-1] > best_test_acc:
        best_test_acc = test_accs[-1]

Epoch 1 - train loss:  1.3243, accuracy:  59.041148
test_loss:  1.1502, test_acc:  78.391846
Epoch 2 - train loss:  1.0875, accuracy:  83.057758
test_loss:  1.0550, test_acc:  85.730464
Epoch 3 - train loss:  1.0533, accuracy:  85.654964
test_loss:  1.0352, test_acc:  87.180068
Epoch 4 - train loss:  1.0423, accuracy:  86.621367
test_loss:  1.0261, test_acc:  88.070970
Epoch 5 - train loss:  1.0350, accuracy:  87.240468
test_loss:  1.0218, test_acc:  88.478671
Epoch 6 - train loss:  1.0306, accuracy:  87.580219
test_loss:  1.0184, test_acc:  88.682522
Epoch 7 - train loss:  1.0270, accuracy:  87.889770
test_loss:  1.0208, test_acc:  88.614572
Epoch 8 - train loss:  1.0258, accuracy:  87.950170
test_loss:  1.0208, test_acc:  88.493771
Epoch 9 - train loss:  1.0245, accuracy:  88.101170
test_loss:  1.0184, test_acc:  88.659872
Epoch 10 - train loss:  1.0216, accuracy:  88.289921
test_loss:  1.0123, test_acc:  89.188373
Epoch 11 - train loss:  1.0219, accuracy:  88.244621
test_loss:  1.01

  grad = getattr(obj, "grad", None)
100%|██████████| 414/414 [01:05<00:00,  6.31it/s]
100%|██████████| 414/414 [00:58<00:00,  7.12it/s]
100%|██████████| 414/414 [01:03<00:00,  6.56it/s]
100%|██████████| 414/414 [00:58<00:00,  7.12it/s]
100%|██████████| 414/414 [01:03<00:00,  6.55it/s]
100%|██████████| 414/414 [00:58<00:00,  7.13it/s]
100%|██████████| 414/414 [01:03<00:00,  6.56it/s]
100%|██████████| 414/414 [00:58<00:00,  7.13it/s]
100%|██████████| 414/414 [01:03<00:00,  6.57it/s]
100%|██████████| 414/414 [01:02<00:00,  6.66it/s]
100%|██████████| 414/414 [01:49<00:00,  3.78it/s]
100%|██████████| 414/414 [01:09<00:00,  5.99it/s]
100%|██████████| 414/414 [01:49<00:00,  3.78it/s]
100%|██████████| 414/414 [01:00<00:00,  6.81it/s]
100%|██████████| 414/414 [01:03<00:00,  6.57it/s]
100%|██████████| 414/414 [00:58<00:00,  7.14it/s]
100%|██████████| 414/414 [01:03<00:00,  6.57it/s]
100%|██████████| 414/414 [00:58<00:00,  7.14it/s]
100%|██████████| 414/414 [01:03<00:00,  6.57it/s]
100%|█████████