# Split Learning with CIFAR-10

In [1]:
#export PYTHONPATH=/home/hroth/Code2/nvflare/splitnn:/home/hroth/Code2/nvflare/splitnn/examples/cifar10

## 1. Download and split the CIFAR-10 dataset
To simulate a vertical split dataset, we first download the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and distribute it between the two clients.

In [2]:
%env SPLIT_DIR=/tmp/cifar10_vert_splits
%env OVERLAP=10000
!python3 ../pt/utils/cifar10_split_data_vertical.py --split_dir ${SPLIT_DIR} --overlap ${OVERLAP}

env: SPLIT_DIR=/tmp/cifar10_vert_splits
env: OVERLAP=10000
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: Partition CIFAR-10 dataset into vertically with 10000 overlapping samples.
Files already downloaded and verified
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/overlap.npy
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/site-1.npy
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/site-2.npy


## 2. Run private set intersection
We are using NVFlare's FL simulator to run the following experiments.

In order to find the overlapping data indices between the different clients participating in split learning, 
we randomly select an subset of the training indices.

In [3]:
import os
#from nvflare import SimulatorRunner
from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner

simulator = SimulatorRunner(
    job_folder=f"job_configs/cifar10_psi",
    workspace="/tmp/nvflare/cifar10_psi",
    n_clients=2,
    threads=2
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-02-01 10:54:39,623 - SimulatorRunner - INFO - Create the Simulator Server.
2023-02-01 10:54:39,644 - Cell - INFO - server: creating listener on grpc://localhost:36511
2023-02-01 10:54:39,646 - Cell - INFO - server: created backbone external listener for grpc://localhost:36511
2023-02-01 10:54:39,647 - ConnectorManager - INFO - 337646: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 10:54:39,649 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:075465be-e3f5-45a3-879e-f532cd804faa is starting in PASSIVE mode
2023-02-01 10:54:40,152 - Cell - INFO - server: created backbone internal listener for tcp://localhost:33609
2023-02-01 10:54:40,154 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector AioGrpcDriver:5dc51723-6de2-439d-9acd-68e89fb3f219 is starting in PASSIVE mode
2023-02-01 10:54:40,156 - nvflare.fuel.f3.communicator - INFO - Communicator is started for local endpoint: server
2023-02-01 10:54:40,

E0201 10:54:44.733119470  337817 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers
E0201 10:54:44.743499568  337818 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers


2023-02-01 10:54:45,824 - Cell - INFO - site-1.simulate_job: created backbone internal connector to tcp://localhost:35122 on parent
2023-02-01 10:54:45,825 - ConnectorManager - INFO - 337827: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 10:54:45,825 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:5fb0b737-04e8-4d76-961c-75fe4befb5e2 is starting in PASSIVE mode
2023-02-01 10:54:45,827 - Cell - INFO - site-2.simulate_job: created backbone internal connector to tcp://localhost:38478 on parent
2023-02-01 10:54:45,827 - ConnectorManager - INFO - 337828: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 10:54:45,827 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:bf6cc37b-b841-4769-a1ca-dfb5204e0d39 is starting in PASSIVE mode
2023-02-01 10:54:46,326 - Cell - INFO - site-1.simulate_job: created backbone internal listener for tc

The result will be saved on each client's working directory in `intersection.txt`.

We can check the correctness of the result by comparing to the generate ground truth overlap, saved in `overlap.npy`.

### Check the PSI result
We can check the correctness of the result by comparing to the generate ground truth overlap, saved in overlap.npy.

In [4]:
import os
import numpy as np

split_dir = os.environ["SPLIT_DIR"]
gt_overlap = np.load(os.path.join(split_dir, "overlap.npy"))

psi_overlap_1 = np.loadtxt("/tmp/nvflare/cifar10_psi/simulate_job/site-1/psi/intersection.txt")
psi_overlap_2 = np.loadtxt("/tmp/nvflare/cifar10_psi/simulate_job/site-2/psi/intersection.txt")
                     
print("gt_overlap", gt_overlap, f"n={len(gt_overlap)}")
print("psi_overlap_1", psi_overlap_1, f"n={len(psi_overlap_1)}")
print("psi_overlap_2", psi_overlap_2, f"n={len(psi_overlap_2)}")

intersect_1 = np.intersect1d(psi_overlap_1, gt_overlap, assume_unique=True)
intersect_2 = np.intersect1d(psi_overlap_2, gt_overlap, assume_unique=True)

print(f"Found {100*len(intersect_1)/len(gt_overlap):.1f}% of the overlapping sample ids for site-1.")
print(f"Found {100*len(intersect_2)/len(gt_overlap):.1f}% of the overlapping sample ids for site-2.")

gt_overlap [11841 19602 45519 ... 47278 37020  2217] n=10000
psi_overlap_1 [ 4481. 45431. 46253. ... 34846.   179.  7277.] n=10000
psi_overlap_2 [38639. 10733. 31911. ... 12172. 46167.   865.] n=10000
Found 100.0% of the overlapping sample ids for site-1.
Found 100.0% of the overlapping sample ids for site-2.


## 3. Run simulated split-learning experiments
Next we use the `intersection.txt` files to align the datasets on each participating site in order to do split learning.
The [config_fed_client.json](./job_configs/cifar10_splitnn/site-1/config/config_fed_client.json) takes as input the previously generated intersection file for each site.
```
    {
        "id": "cifar10-learner",
        "path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN",
        "args": {
            "dataset_root": "{DATASET_ROOT}",
            "intersection_file": "{INTERSECTION_FILE}",
            "lr": 1e-2,
            "model": {"path": "pt.networks.split_nn.SplitNN", "args":  {"split_id":  0}},
            "timeit": true
        }
    }
```
On the server side, the [config_fed_server.json](./job_configs/cifar10_splitnn/server/config/config_fed_server.json) needs to specify the size of the training dataset in order to generate random sample ids to build each batch during training. Here, the training set size (`train_size`) is equal to the number of overlapping samples defined above.
```
    {
        "id": "splitnn_ctl",
        "path": "pt.workflows.splitnn_workflow.SplitNNController",
        "args": {
            "num_rounds" : "{num_rounds}",
            "batch_size": "{batch_size}",
            "train_size": "{train_size}",
            "start_round": 0,
            "persistor_id": "persistor",
            "task_timeout": 0,
            "shareable_generator_id": "shareable_generator",
            "timeit": true
        }
    }
```
To run the experiment, execute:

In [5]:
import os
from nvflare import SimulatorRunner    

simulator = SimulatorRunner(
    job_folder=f"job_configs/cifar10_splitnn",
    workspace="/tmp/nvflare/cifar10_splitnn",
    n_clients=2,
    threads=2
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-02-01 10:55:55,569 - SimulatorRunner - INFO - Create the Simulator Server.
2023-02-01 10:55:55,589 - Cell - INFO - server: creating listener on grpc://localhost:38613
2023-02-01 10:55:55,590 - Cell - INFO - server: created backbone external listener for grpc://localhost:38613
2023-02-01 10:55:55,592 - ConnectorManager - INFO - 338942: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 10:55:55,595 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:f496ddad-8c3a-41bb-9c23-8fae7ba54db0 is starting in PASSIVE mode
2023-02-01 10:55:56,098 - Cell - INFO - server: created backbone internal listener for tcp://localhost:38730
2023-02-01 10:55:56,100 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector AioGrpcDriver:18ff358c-a539-49b3-8d8d-a75f03de65b8 is starting in PASSIVE mode
2023-02-01 10:55:56,102 - nvflare.fuel.f3.communicator - INFO - Communicator is started for local endpoint: server
2023-02-01 10:55:56,

E0201 10:56:01.775715151  339122 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers
E0201 10:56:01.789724459  339123 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers


2023-02-01 10:56:02,859 - Cell - INFO - site-1.simulate_job: created backbone internal connector to tcp://localhost:33082 on parent
2023-02-01 10:56:02,859 - ConnectorManager - INFO - 339128: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 10:56:02,859 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:ee127528-f1b3-449f-b591-ae436fcdd5a7 is starting in PASSIVE mode
2023-02-01 10:56:02,875 - Cell - INFO - site-2.simulate_job: created backbone internal connector to tcp://localhost:36295 on parent
2023-02-01 10:56:02,875 - ConnectorManager - INFO - 339129: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 10:56:02,875 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:b730d40f-1906-471a-a375-8ebcf30ab97d is starting in PASSIVE mode
2023-02-01 10:56:03,360 - Cell - INFO - site-1.simulate_job: created backbone internal listener for tc

The site containing the labels can compute accuracy and losses, which can be visualized in tensorboard.

In [8]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

%tensorboard --logdir /tmp/nvflare/cifar10_splitnn