# Let's put it together

In [None]:
import numpy as np
%load_ext line_profiler

In [None]:
import sys

import torch
import torchvision
import crypten

assert sys.version_info[0] == 3 and sys.version_info[1] == 7, "python 3.7 is required!"

print(f"Okay, good! You have: {sys.version_info[:3]}")
# Now we can init crypten!
crypten.init()

import matplotlib.pyplot
%matplotlib inline

In [None]:
# Load parent folders into path
import os,sys,inspect
import pathlib
import shutil
import crypten.communicator as mpc_comm # the communicator is similar to the MPI communicator for example
from crypten import mpc

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
# Import some config variables
from config import PETER_ROOT, DATA_DIR, MNIST_SIZE

# Load a pytorch net
from ZeNet.nets import *

# Plotting
from plot_mnist import plot_batch, plot_digit

In [None]:
def check_and_mkdir(directory:pathlib.Path):
    if not directory.exists():
        print("TMP_DIR created")
        directory.mkdir()

def rm_dir(directory:pathlib.Path):
    shutil.rmtree(directory)
    

def get_filenames(directory:pathlib.Path):
    # Specify file locations to save each piece of data
    filenames = {
        "features": directory / "features.pth",
        "labels": directory / "labels.pth",
        "b_true": directory / "b_true.pth",
        "test_features": directory / "test_features.pth",
        "targets": directory / "targets.pth",
        "w_true": directory / "w_true.pth",
    }

    for u in participants:
        filenames["labels_"+u] = directory / ("labels_" + u)
        filenames["features_"+u] = directory / ("features_" + u)
    return filenames

def setup(participants, tmp_dir_name="./TMP"):
    num_participants = len(participants)
    TMP_DIR = pathlib.Path(tmp_dir_name)
    print(f"Our temporary data will land here: {TMP_DIR}")
    check_and_mkdir(TMP_DIR)
    filenames = get_filenames(TMP_DIR)
    return TMP_DIR, filenames, num_participants 

POSSIBLE_PARTICIPANTS = ("alice, bob, clara, daniel, " + 
    "elina, franz, georg, hilda, ilya, julia, karin, luke, " +
    "martin, nadia, olaf, peter, queenie, rasmus, sarah, tal, " +
    "ulyana, valerie, walter, xander, ymir, zorro").split(", ")
len(POSSIBLE_PARTICIPANTS)

In [None]:
# Across
ALICE = 0
BOB = 1
participants = POSSIBLE_PARTICIPANTS[:2]
dir_name = "./TMP_" + "train_on_shared_data"

TMP_DIR, filenames, num_participants = setup(participants, tmp_dir_name=dir_name)
DATA_DIR = TMP_DIR / "data"
participants

In [None]:
print(participants)
print(TMP_DIR)
print(filenames)

In [None]:
subset = 1/60
train_ratio = 0.75
test_ratio = 1 - train_ratio
batch_size_train = int((subset * MNIST_SIZE) * train_ratio)
batch_size_test = int((subset * MNIST_SIZE) * test_ratio)

print(f"Using train_test ratios: {train_ratio} : {test_ratio}")
print(f"Train batch size: {batch_size_train}")
print(f"Test batch size: {batch_size_test}")

loader_train = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(DATA_DIR, train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

loader_test = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(DATA_DIR, train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

train_batches = enumerate(loader_train)
train_idx, digits = next(train_batches)
test_batches = enumerate(loader_test)
test_idx, digits_test = next(train_batches)




In [None]:
img_num = 0
plot_digit(digits[0][img_num], digits[1][img_num])

In [None]:
def split_data(data, frac):
    length = len(data[1]) #.shape[0]
    split_idx = int(length*frac)
    print(f"Returning: 0 <-1-> {split_idx} <-2->{length}")
    feats_1, labels_1 = data[0][:split_idx], data[1][:split_idx]
    feats_2, labels_2 = data[0][split_idx:], data[1][split_idx:]
    return (feats_1, labels_1), (feats_2, labels_2)

In [None]:
frac_alice = 0.6
frac_bob = 1 - frac_alice

# Split data and save
dig_alice, dig_bob = split_data(digits, frac_alice)


# Save features, labels for Data Labeling example
crypten.save(digits[0], filenames["features"])
crypten.save(digits[1], filenames["labels"])

@mpc.run_multiprocess(world_size=num_participants)
def save_all_data():
    
    print(f"Hello from {mpc_comm.get().get_rank()}")
    
    crypten.save(dig_alice[0], filenames["features_alice"], src=ALICE)
    crypten.save(dig_bob[0], filenames["features_bob"], src=BOB)
    
    # Save split dataset for Dataset Aggregation example
    crypten.save(dig_alice[1], filenames["labels_alice"], src=ALICE)
    crypten.save(dig_bob[1], filenames["labels_bob"], src=BOB)
    
    # Save true model weights and biases for Model Hiding example
#     crypten.save(w_true, filenames["w_true"], src=ALICE)
#     crypten.save(b_true, filenames["b_true"], src=ALICE) 
    
    crypten.save(digits_test[0], filenames["test_features"], src=BOB)
    crypten.save(digits_test[1], filenames["targets"], src=BOB)
    print(f"{mpc_comm.get().get_rank()} is done! Signing off...")
    
save_all_data()

# Load the data into the respective threads

*Note*: We will use the term thread in a sloppy manner, i.e. a thread is synonymous with process or a participant (quite possibly running on a completely different PC).

## About the mpc.run_multiprocess decorator

To execute multi-party computations locally, we provide a `@mpc.run_multiprocess` function decorator, which we developed to execute CrypTen code from a single script. CrypTen follows the standard MPI programming model: it runs a separate process for each party, but each process runs an identical (complete) program. Each process has a rank variable to identify itself.

[Docs](https://crypten.readthedocs.io/en/latest/mpctensor.html#communicator)

### Scenario: Alice Data -> Bob Model

In [None]:
dummz = []
dummz.append(1)
dummz

In [None]:
# I mean, really just send the encrypted data to Bob

def load_enc_data(X_files, y_files):
    # Load images
    X = []
    for file, rank in X_files:
        X.append(crypten.load(file, src=rank))
        
    # Load labels
    y = []
    for file, rank in y_files:
        y.append(crypten.load(file, src=rank))
        
    return X, y

def load_model(Net, dummy_input, ENC_RANK):
    # (1, 1, 28, 28)
    private_model = crypten.nn.from_pytorch(Net, dummy_input)
    private_model.encrypt(src=ENC_RANK)
    return private_model

def train_model(model, X, y, epochs=10, learning_rate=0.05):
    criterion = crypten.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        print(f"epoch {epoch} loss: {loss.get_plain_text()}")
        loss.backward()
        model.update_parameters(learning_rate)
    return model

def evaluate_model():
    pass


In [None]:
@mpc.run_multiprocess(world_size=num_participants)
def load_data_and_encrypt():
    pid = mpc_comm.get().get_rank()
    print(f"pid: {pid}")
    
    test = crypten.load(filenames["labels_alice"], src=ALICE)
    print(test.get_plain_text())

load_data_and_encrypt()

In [None]:
labels = torch.load('/tmp/bob_test_labels.pth').long()
count = 100 # For illustration purposes, we'll use only 100 samples for classification

@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    # Load pre-trained model to Alice
    model = crypten.load('models/tutorial4_alice_model.pth', dummy_model=dummy_model, src=ALICE)
    
    # Encrypt model from Alice 
    dummy_input = torch.empty((1, 784))
    private_model = crypten.nn.from_pytorch(model, dummy_input)
    private_model.encrypt(src=ALICE)
    
    # Load data to Bob
    data_enc = crypten.load('/tmp/bob_test.pth', src=BOB)
    data_enc2 = data_enc[:count]
    data_flatten = data_enc2.flatten(start_dim=1)

    # Classify the encrypted data
    private_model.eval()
    output_enc = private_model(data_flatten)
    
    # Compute the accuracy
    output = output_enc.get_plain_text()
    accuracy = compute_accuracy(output, labels[:count])
    print("\tAccuracy: {0:.4f}".format(accuracy.item()))
    
encrypt_model_and_data()

## Running on different machines

Tricky: https://github.com/facebookresearch/CrypTen/issues/104

Scripts: <https://github.com/facebookresearch/CrypTen/tree/master/scripts>