In [1]:
import crypten
import torch
import crypten.mpc as mpc

crypten.init()
torch.set_num_threads(1)

In [2]:
import argparse
import os

import torch
from torchvision import datasets, transforms

In [3]:
# Define source argument values for Alice and Bob
ALICE = 0
BOB = 1

In [5]:
#get mnist and normalize
def _get_norm_mnist(dir="/tmp", reduced=None, binary=False):
    """Downloads and normalizes mnist"""
    mnist_train = datasets.MNIST(dir, download=True, train=True)
    mnist_test = datasets.MNIST(dir, download=True, train=False)

    # compute normalization factors
    data_all = torch.cat([mnist_train.data, mnist_test.data]).float()
    data_mean, data_std = data_all.mean(), data_all.std()
    tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

    # normalize
    mnist_train_norm = transforms.functional.normalize(
        mnist_train.data.float(), tensor_mean, tensor_std
    )
    mnist_test_norm = transforms.functional.normalize(
        mnist_test.data.float(), tensor_mean, tensor_std
    )
    
    # change all nonzero labels to 1 if binary classification required
    if binary:
        mnist_train.targets[mnist_train.targets != 0] = 1
        mnist_test.targets[mnist_test.targets != 0] = 1

    # create a reduced dataset if required
    if reduced is not None:
        mnist_norm = (mnist_train_norm[:reduced], mnist_test_norm[:reduced])
        mnist_labels = (mnist_train.targets[:reduced], mnist_test.targets[:reduced])
    else:
        mnist_norm = (mnist_train_norm, mnist_test_norm)
        mnist_labels = (mnist_train.targets, mnist_test.targets)
    return mnist_norm, mnist_labels
    



#split features and lables
def split_features_v_labels(
    dir="/tmp", party1="alice", party2="bob", reduced=None, binary=False
):
    """Gives Party 1 features and Party 2 labels"""
    mnist_norm, mnist_labels = _get_norm_mnist(dir, reduced, binary)
    mnist_train_norm, mnist_test_norm = mnist_norm
    mnist_train_labels, mnist_test_labels = mnist_labels
    
    #Alice has features(train and test)
    torch.save(mnist_train_norm, os.path.join(dir, party1 + "_train.pth"))
    torch.save(mnist_test_norm, os.path.join(dir, party1 + "_test.pth"))
    
    #Bob has labels(train and test)
    torch.save(mnist_train_labels, os.path.join(dir, party2 + "_train_labels.pth"))#temp/bob_train_labels.pth
    torch.save(mnist_test_labels, os.path.join(dir, party2 + "_test_labels.pth"))

    
    
split_features_v_labels()


#%run crypten/mpc/mpc.py


#Alice has features and bob has labels
@mpc.run_multiprocess(world_size=2)
def loadData():
    x_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)
    x_bob_enc = crypten.load_from_party('/tmp/bob_train_labels.pth', src=BOB)

loadData()    
    
    
"""x_alice_enc = crypten.load_from_party('/tmp/alice_train.pth')
x_bob_enc = crypten.load_from_party('/tmp/bob_train_labels.pth')"""

#print(x_alice_enc.shape)
#print(x_bob_enc.shape)

#On 'x_alice_enc' perform 
def split_features(
    split=0.5, dir="/tmp", party1="alice", party2="bob", reduced=None, binary=False
):
    """Splits features between Party 1 and Party 2"""
    mnist_norm, mnist_labels = _get_norm_mnist(dir, reduced, binary)
    mnist_train_norm, mnist_test_norm = mnist_norm
    mnist_train_labels, mnist_test_labels = mnist_labels

    num_features = mnist_train_norm.shape[1]
    split_point = int(split * num_features)

    party1_train = mnist_train_norm[:, :, :split_point]
    party2_train = mnist_train_norm[:, :, split_point:]
    party1_test = mnist_test_norm[:, :, :split_point]
    party2_test = mnist_test_norm[:, :, split_point:]

    torch.save(party1_train, os.path.join(dir, party1 + "_train.pth"))
    torch.save(party2_train, os.path.join(dir, party2 + "_train.pth"))
    torch.save(party1_test, os.path.join(dir, party1 + "_test.pth"))
    torch.save(party2_test, os.path.join(dir, party2 + "_test.pth"))
    torch.save(mnist_train_labels, os.path.join(dir, "train_labels.pth"))
    torch.save(mnist_test_labels, os.path.join(dir, "test_labels.pth"))