# Split Learning with CIFAR-10

In [1]:
#export PYTHONPATH=/home/hroth/Code2/nvflare/splitnn:/home/hroth/Code2/nvflare/splitnn/examples/cifar10

## 1. Download and split the CIFAR-10 dataset
To simulate a vertical split dataset, we first download the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and distribute it between the two clients.

In [2]:
%env SPLIT_DIR=/tmp/cifar10_vert_splits
!python3 ../pt/utils/cifar10_split_data_vertical.py --split_dir ${SPLIT_DIR}

env: SPLIT_DIR=/tmp/cifar10_vert_splits
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: Partition CIFAR-10 dataset into vertically with 10000 overlapping samples.
Files already downloaded and verified
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/overlap.npy
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/site-1.npy
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/site-2.npy


## 2. Run private set intersection
In order to find the overlapping data indices between the different clients participating in split learning, 
we randomly select an subset of the training indices.

In [3]:
import os
from nvflare import SimulatorRunner    

simulator = SimulatorRunner(
    job_folder=f"job_configs/cifar10_psi",
    workspace="/tmp/nvflare/cifar10_psi",
    n_clients=2,
    threads=2
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-01-20 16:54:35,329 - SimulatorRunner - INFO - Create the Simulator Server.
2023-01-20 16:54:35,407 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 45001
2023-01-20 16:54:35,413 - SimulatorServer - INFO - starting insecure server at localhost:50835
2023-01-20 16:54:35,415 - SimulatorRunner - INFO - Deploy the Apps.
2023-01-20 16:54:35,419 - SimulatorRunner - INFO - Create the simulate clients.
2023-01-20 16:54:35,470 - ClientManager - INFO - Client: New client site-1@127.0.0.1 joined. Sent token: a2f1d779-dc35-4cd0-b64d-7fda136ffd3c.  Total clients: 1
2023-01-20 16:54:35,474 - FederatedClient - INFO - Successfully registered client:site-1 for project simulator_server. Token:a2f1d779-dc35-4cd0-b64d-7fda136ffd3c SSID:
2023-01-20 16:54:35,587 - ClientManager - INFO - Client: New client site-2@127.0.0.1 joined. Sent token: 2426a266-61fa-4f1c-b2da-2fb0cf0cd692.  Total clients: 2
2023-01-20 16:54:35,592 - FederatedClient - INFO - Successfully registered cli

E0120 16:54:38.682205140   17417 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers


2023-01-20 16:54:40,954 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController, peer=site-2, peer_run=simulate_job, task_name=PSI, task_id=525bf31f-322a-46fb-ab34-36777fba0f25]: assigned task to client site-2: name=PSI, id=525bf31f-322a-46fb-ab34-36777fba0f25
2023-01-20 16:54:40,961 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController, peer=site-2, peer_run=simulate_job, task_name=PSI, task_id=525bf31f-322a-46fb-ab34-36777fba0f25]: sent task assignment to client
2023-01-20 16:54:40,964 - SimulatorServer - INFO - GetTask: Return task: PSI to client: site-2 (2426a266-61fa-4f1c-b2da-2fb0cf0cd692) 
2023-01-20 16:54:40,966 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController, peer=site-1, peer_run=simulate_job, task_name=PSI, task_id=8f4fd87d-5f92-4219-9301-7957a91c478e]: assigned task to client site-1: name=PSI, id=8f4fd87d-5f92-4219-9301-7957a91c478e
2023-01-20 16:54:40,969 - ServerRu

The result will be saved on each client's working directory in `intersection.txt`.

We can check the correctness of the result by comparing to the generate ground truth overlap, saved in `overlap.npy`.

### Check the PSI result
We can check the correctness of the result by comparing to the generate ground truth overlap, saved in overlap.npy.

In [4]:
import os
import numpy as np

split_dir = os.environ["SPLIT_DIR"]
gt_overlap = np.load(os.path.join(split_dir, "overlap.npy"))

psi_overlap = np.loadtxt("/tmp/nvflare/cifar10_psi/simulate_job/site-1/psi/intersection.txt")
                     
print("gt_overlap", gt_overlap, f"n={len(gt_overlap)}")
print("psi_overlap", psi_overlap, f"n={len(psi_overlap)}")

intersect = np.intersect1d(psi_overlap, gt_overlap, assume_unique=True)

print(f"Found {100*len(intersect)/len(gt_overlap):.1f}% of the overlapping sample ids.")

gt_overlap [11841 19602 45519 ... 47278 37020  2217] n=10000
psi_overlap [ 4481. 45431. 46253. ... 34846.   179.  7277.] n=10000
Found 100.0% of the overlapping sample ids.


## 3. Run simulated split-learning experiments
We are using NVFlare's [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html) to run the following experiments. 

To run the experiment, execute:

In [5]:
#nvflare simulator job_configs/cifar10_splitnn --workspace /tmp/nvflare/splitnn_cifar10 --threads 2 --n_clients 2