In [1]:
from pathlib import Path
import h5py
import numpy as np
import matplotlib.pyplot as plt
from icecream import ic
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD

import syft as sy
from syft.core.node.vm.vm import VirtualMachine
from syft.core.node.vm.client import VirtualMachineClient
from syft.ast.module import Module
from syft.core.remote_dataloader import RemoteDataLoader
from syft.core.remote_dataloader import RemoteDataset

print(f'torch version: {torch.__version__}')
print(f'syft version: {sy.__version__}')


torch version: 1.8.1+cu102
syft version: 0.5.0


## Files and Directories

In [2]:
# paths to files and directories
project_path = Path.cwd().parent
print(f'project_path: {project_path}')
data_dir = 'mitdb'
train_name = 'train_ecg.hdf5'
test_name = 'test_ecg.hdf5'
all_name = 'all_ecg.hdf5'
model_dir = 'model'
model_name = 'conv2'
model_ext = '.pth'
csv_dir = 'csv'
csv_ext = '.csv'
csv_name = 'conv2'
csv_accs_name = 'accs_conv2'

project_path: /mnt/batch/tasks/shared/LS_root/mounts/clusters/teslak80-56gbram/code/Users/dkn.work/split-learning-he


## Construct the client and server

In [None]:
server: VirtualMachine = sy.VirtualMachine(name="server")
client: VirtualMachineClient = server.get_root_client()

In [None]:
remote_torch: Module = client.torch
remote_torch

## Client: loading and exploring the dataset

In [None]:
class ECG(Dataset):
    # The class used to load the ECG dataset
    def __init__(self, mode='train'):
        if mode == 'train':
            with h5py.File(project_path/data_dir/train_name, 'r') as hdf:
                self.x = torch.tensor(hdf['x_train'][:], dtype=torch.float)
                self.y = torch.tensor(hdf['y_train'][:])                
        elif mode == 'test':
            with h5py.File(project_path/data_dir/test_name, 'r') as hdf:
                self.x = torch.tensor(hdf['x_test'][:], dtype=torch.float)
                self.y = torch.tensor(hdf['y_test'][:])
        else:
            raise ValueError('Argument of mode should be train or test')
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [None]:
train_dataset = ECG(mode='train')
test_dataset = ECG(mode='test')

Let's first get everything in the dataset and see how many examples we have, and how each
of them look like

In [None]:
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
x_train, y_train = next(iter(train_loader))
x_test, y_test = next(iter(test_loader))
print(f'x_train: {type(x_train)}, {x_train.size()}')
print(f'y_train: {type(y_train)}, {y_train.size()}')
print(f'x_test: {type(x_test)}, {x_test.size()}')
print(f'y_test: {type(y_test)}, {y_test.size()}')

In [None]:
x0 = x_train[0, :, :]
print(f'x_0: {x0.shape}')
x0_unroll = x0.view(-1)
print(f'unrolling: {x0_unroll.shape}')
indx = np.arange(0, 128)

# plt.figure(figsize=(3,3))
plt.style.use('dark_background')
plt.plot(indx, x0_unroll)
plt.show()

The client creates the Dataset object and save it in a `.pt` file. 
If using `duet`, he can send the string path to the server using 
`sy.lib.python.String(string_path).send(duet, pointable=True, tags=["data"])`

In [None]:
torch.save(train_dataset, "train_dataset.pt")
torch.save(test_dataset, "test_dataset.pt")

## Server: creating remote dataset and dataloader

In [None]:
train_rds = RemoteDataset(path='train_dataset.pt', data_type="torch_tensor")
train_rds

From the remote dataset, the server constructs the data loader. Then the server uses `.send`
to create a pointer to do remote data loading

In [None]:
train_rdl = RemoteDataLoader(remote_dataset=train_rds, batch_size=32)
train_rdl_ptr = train_rdl.send(client)
ic(train_rdl, train_rdl_ptr)
# call create_dataset to create the real Dataset object on remote side
train_rdl_ptr.load_dataset()
# call create_dataloader to create the real DataLoader object on remote side
train_rdl_ptr.create_dataloader()

In [None]:
for i, b in enumerate(tqdm(train_rdl_ptr)):
    if i<2:
        X, y = b[0], b[1]
        ic(X, y)
        ic(X.get_copy(), y.get_copy())

Similarly, for the test dataset

In [None]:
test_rds = RemoteDataset(path='test_dataset.pt', data_type="torch_tensor")
test_rdl = RemoteDataLoader(remote_dataset=test_rds, batch_size=32)
test_rdl_ptr = test_rdl.send(client)
ic(test_rds, test_rdl, test_rdl_ptr)
# call create_dataset to create the real Dataset object on remote side
test_rdl_ptr.load_dataset()
# call create_dataloader to create the real DataLoader object on remote side
test_rdl_ptr.create_dataloader()

In [None]:
for i, b in enumerate(test_rdl_ptr):
    if i<2:
        X, y = b[0], b[1]
        ic(X, y)
        ic(X.get_copy(), y.get_copy())

## Server: define the spit neural network used to train on the ECG dataset

Client's side contains conv layers

In [None]:
class EcgClient(sy.Module):
    # used by the data owners
    def __init__(self, torch_ref):
        super(EcgClient, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = self.torch_ref.nn.LeakyReLU()
        self.pool1 = self.torch_ref.nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = self.torch_ref.nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = self.torch_ref.nn.LeakyReLU()
        self.pool2 = self.torch_ref.nn.MaxPool1d(2)  # 32 x 16
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(-1, 32 * 16)
        return x

Server's side contains fully connected layers

In [None]:
class EcgServer(sy.Module):
    def __init__(self, torch_ref):
        super(EcgServer, self).__init__(torch_ref=torch_ref)
        self.linear3 = nn.Linear(32 * 16, 128)
        self.relu3 = nn.LeakyReLU() 
        self.linear4 = nn.Linear(128, 5)
        self.softmax4 = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.linear3(x)
        x = self.relu3(x)
        x = self.linear4(x)
        x = self.softmax4(x)
        return x

## Server: training process

In [17]:
ecg_client = EcgClient(torch_ref=torch)
checkpoint = torch.load("init_weight.pth")
ecg_client.conv1.weight.data = checkpoint["conv1.weight"]
ecg_client.conv1.bias.data = checkpoint["conv1.bias"]
ecg_client.conv2.weight.data = checkpoint["conv2.weight"]
ecg_client.conv2.bias.data = checkpoint["conv2.bias"]

ecg_server = EcgServer(torch_ref=torch)
checkpoint = torch.load("init_weight.pth")
ecg_server.linear3.weight.data = checkpoint["linear3.weight"]
ecg_server.linear3.bias.data = checkpoint["linear3.bias"]
ecg_server.linear4.weight.data = checkpoint["linear4.weight"]
ecg_server.linear4.bias.data = checkpoint["linear4.bias"]

# Send the client's model to the client
ecg_client_ptr = ecg_client.send(client)

Some hyper-parameters

In [24]:
total_batch = 414  # 32*414=13248. We have 13245 data samples

epoch = 400
lr = 0.001

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
optim_client = remote_torch.optim.Adam(params=ecg_client_ptr.parameters(), lr=lr)
optim_server = torch.optim.Adam(params=ecg_server.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

Training (with CPU)

In [25]:
train_losses = list()
train_accs = list()
test_losses = list()
test_accs = list()
best_test_acc = 0  # best test accuracy
for e in range(epoch):
    print(f"Epoch {e+1} - train ", end='')
    
    train_loss = 0.0
    correct, total = 0, 0
    for i, batch in enumerate(tqdm(train_rdl_ptr)):
        x_ptr, y_gt_ptr = batch[0], batch[1]
        # ic(x.get_copy(), y.get_copy())
        # initialize all gradients to zero
        optim_server.zero_grad()
        optim_client.zero_grad()
        # compute and get the activation signals from the first half of the network
        activs_ptr = ecg_client_ptr(x_ptr)
        # the server still gets access to plain activation signals
        activs = activs_ptr.clone().get(request_block=True)
        # the server continues the forward pass on the activation maps
        y_hat = ecg_server(activs)
        # the server asks to access ground truths in plain text
        y_gt = y_gt_ptr.get_copy()
        # calculates cross-entropy loss
        loss = criterion(y_hat, y_gt)
        train_loss += loss.item()
        correct += torch.sum(y_hat.argmax(dim=1) == y_gt).item()
        # backward propagation (calculating gradients of the loss w.r.t the weights)
        loss.backward()
        # send the gradients to the client
        client_grad_ptr = activs.grad.clone().send(client)
        # update the gradients of the client's model
        activs_ptr.backward(client_grad_ptr)
        # update the weights based on the gradients
        optim_client.step()
        optim_server.step()
        total += len(y_gt)

    train_losses.append(train_loss / total_batch)
    train_accs.append(correct / total)

    print(f'loss: {train_losses[-1]: .4f}, accuracy: {train_accs[-1]*100: 2f}')

    # testing
    with torch.no_grad():  
        test_loss = 0.0
        correct, total = 0, 0
        for i, batch in enumerate(tqdm(test_rdl_ptr)):
            x_ptr, y_gt_ptr = batch[0], batch[1]
            # forward pass
            activs_ptr = ecg_client_ptr(x_ptr)
            activs = activs_ptr.clone().get(request_block=True)
            y_hat = ecg_server(activs)
            # the server asks to access ground truths in plain text
            y_gt = y_gt_ptr.get_copy()
            # calculate test loss
            loss = criterion(y_hat, y_gt)
            test_loss += loss.item()
            correct += torch.sum(y_hat.argmax(dim=1) == y_gt).item()
            total += len(y_gt)

        test_losses.append(test_loss / total_batch)
        test_accs.append(correct / total)
        print(f'test_loss: {test_losses[-1]: .4f}, test_acc: {test_accs[-1]*100: 2f}')
        
    if test_accs[-1] > best_test_acc:
        best_test_acc = test_accs[-1]

Epoch 1 - train loss:  1.0215, accuracy:  88.312571
test_loss:  1.0128, test_acc:  89.218573
Epoch 2 - train loss:  1.0201, accuracy:  88.440921
test_loss:  1.0171, test_acc:  88.659872
Epoch 3 - train loss:  1.0183, accuracy:  88.614572
test_loss:  1.0114, test_acc:  89.316723
Epoch 4 - train loss:  1.0171, accuracy:  88.720272
test_loss:  1.0086, test_acc:  89.467724
Epoch 5 - train loss:  1.0164, accuracy:  88.735372
test_loss:  1.0079, test_acc:  89.596074
Epoch 6 - train loss:  1.0137, accuracy:  88.984522
test_loss:  1.0134, test_acc:  89.143073
Epoch 7 - train loss:  1.0136, accuracy:  89.014723


  grad = getattr(obj, "grad", None)
100%|██████████| 414/414 [01:22<00:00,  5.05it/s]
100%|██████████| 414/414 [00:58<00:00,  7.09it/s]
100%|██████████| 414/414 [01:03<00:00,  6.52it/s]
100%|██████████| 414/414 [00:58<00:00,  7.12it/s]
100%|██████████| 414/414 [01:03<00:00,  6.54it/s]
100%|██████████| 414/414 [00:58<00:00,  7.13it/s]
100%|██████████| 414/414 [01:03<00:00,  6.54it/s]
100%|██████████| 414/414 [00:57<00:00,  7.14it/s]
100%|██████████| 414/414 [01:03<00:00,  6.56it/s]
100%|██████████| 414/414 [00:58<00:00,  7.12it/s]
100%|██████████| 414/414 [01:11<00:00,  5.83it/s]
100%|██████████| 414/414 [00:58<00:00,  7.08it/s]
100%|██████████| 414/414 [01:03<00:00,  6.56it/s]
 13%|█▎        | 54/414 [00:07<00:51,  7.03it/s]


KeyboardInterrupt: 