# Alice and Bob

In [18]:
import sys
import torch
import torchvision
import matplotlib.pyplot as plt

# python 3.7 is required
assert sys.version_info[0] == 3 and sys.version_info[1] == 7, "python 3.7 is required"

import crypten
from crypten import mpc
crypten.init()
torch.set_num_threads(1)


%matplotlib inline

# Save Encrypted Data From Each Party

In [19]:
alice_data = torch.tensor([1, 2, 3.0])
bob_data = torch.tensor([4, 5, 6.0])

In [20]:
# Alice is party 0
ALICE = 0
BOB = 1

In [21]:
@mpc.run_multiprocess(world_size=2)
def save_all_data():
    crypten.save(alice_data, "/tmp/data/alice_data.pth", src=ALICE)
    crypten.save(bob_data, "/tmp/data/bob_data.pth", src=BOB)
    
save_all_data()

[None, None]

In [22]:
! ls -lh /tmp/data

total 16
drwxr-xr-x  4 marksibrahim  wheel   128B Apr 17 18:58 [1m[36mMNIST[m[m
-rw-r--r--  1 marksibrahim  wheel   351B Apr 17 19:55 alice_data.pth
-rw-r--r--  1 marksibrahim  wheel   351B Apr 17 19:55 bob_data.pth


# Load Encrypted Data From Each Party

In [24]:
@mpc.run_multiprocess(world_size=2)
def load_data():
    alice_data_enc = crypten.load("/tmp/data/alice_data.pth", src=ALICE)
    bob_data_enc = crypten.load("/tmp/data/bob_data.pth", src=BOB)
    
    print(type(alice_data_enc))
    print(f"alice data is {alice_data_enc.get_plain_text()}")

load_data()
 

<class 'crypten.mpc.mpc.MPCTensor'>
<class 'crypten.mpc.mpc.MPCTensor'>
alice data is tensor([1., 2., 3.])
alice data is tensor([1., 2., 3.])


[None, None]

# Digits

## Alice has Digits. Bob has Labels.

In [25]:
digits = torchvision.datasets.MNIST(root='/tmp/data', 
                                           train=True, 
                                           transform=torchvision.transforms.ToTensor(),
                                           download=True)


In [26]:
def take_samples(digits, n_samples=1000):
    """Returns images and labels based on sample size"""
    images, labels = [], []

    for i, digit in enumerate(digits):
        if i == n_samples:
            break
        image, label = digit
        images.append(image)
        label_one_hot = torch.nn.functional.one_hot(torch.tensor(label), 10)
        labels.append(label_one_hot)

    images = torch.cat(images)
    labels = torch.stack(labels)
    return images, labels

In [27]:
images, labels = take_samples(digits, n_samples=100)

In [29]:
images.shape

torch.Size([100, 28, 28])

### Save Alice and Bob's Digits

In [30]:
@mpc.run_multiprocess(world_size=2)
def save_digits():
    crypten.save(images, "/tmp/data/alice_images.pth", src=ALICE)
    crypten.save(labels, "/tmp/data/bob_labels.pth", src=BOB)
      
save_digits()

[None, None]

In [31]:
# clean up tmp directory

In [32]:
! find /tmp/data -name "*.pth" -type f -delete

# Full Joint Training
see `training_across_parties.py`

For a full set of examples with scripts to run on separate AWS machines see https://github.com/facebookresearch/CrypTen/tree/master/examples