In [1]:
import crypten
import torch
import crypten.mpc as mpc

crypten.init()
torch.set_num_threads(1)

In [2]:
import argparse
import os

import torch
from torchvision import datasets, transforms

In [3]:
# Define source argument values for Alice and Bob
ALICE = 0
BOB = 1
CHARLIE = 2

In [4]:
#get mnist and normalize
def _get_norm_mnist(dir="/tmp", reduced=None, binary=False):
    """Downloads and normalizes mnist"""
    mnist_train = datasets.MNIST(dir, download=True, train=True)
    mnist_test = datasets.MNIST(dir, download=True, train=False)

    # compute normalization factors
    data_all = torch.cat([mnist_train.data, mnist_test.data]).float()
    data_mean, data_std = data_all.mean(), data_all.std()
    tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

    # normalize
    mnist_train_norm = transforms.functional.normalize(
        mnist_train.data.float(), tensor_mean, tensor_std
    )
    mnist_test_norm = transforms.functional.normalize(
        mnist_test.data.float(), tensor_mean, tensor_std
    )
    
    # change all nonzero labels to 1 if binary classification required
    if binary:
        mnist_train.targets[mnist_train.targets != 0] = 1
        mnist_test.targets[mnist_test.targets != 0] = 1

    # create a reduced dataset if required
    if reduced is not None:
        mnist_norm = (mnist_train_norm[:reduced], mnist_test_norm[:reduced])
        mnist_labels = (mnist_train.targets[:reduced], mnist_test.targets[:reduced])
    else:
        mnist_norm = (mnist_train_norm, mnist_test_norm)
        mnist_labels = (mnist_train.targets, mnist_test.targets)
    return mnist_norm, mnist_labels
    



#split features and lables
def split_features_v_labels(dir="/tmp", party1="alice", party2="bob", reduced=None, binary=False):
    """Gives Party 1 features and Party 2 labels"""
    mnist_norm, mnist_labels = _get_norm_mnist(dir, reduced, binary)
    mnist_train_norm, mnist_test_norm = mnist_norm
    mnist_train_labels, mnist_test_labels = mnist_labels
    
    #Alice has features(train and test)
    torch.save(mnist_train_norm, os.path.join(dir, party1 + "_train.pth"))
    torch.save(mnist_test_norm, os.path.join(dir, party1 + "_test.pth"))
    
    #save from party
    crypten.save_from_party(mnist_train_norm, '/tmp/alice_train.pth', src=ALICE)
    crypten.save_from_party(mnist_test_norm, '/tmp/alice_test.pth', src=ALICE)
    
    
    #Bob has labels(train and test)
    torch.save(mnist_train_labels, os.path.join(dir, party2 + "_train_labels.pth"))
    torch.save(mnist_test_labels, os.path.join(dir, party2 + "_test_labels.pth"))

    #save from party
    crypten.save_from_party(mnist_train_labels, '/tmp/bob_train_labels.pth', src=BOB)
    crypten.save_from_party(mnist_test_labels, '/tmp/bob_test_labels.pth', src=BOB)
    
    
    
    
def split_features(mnist_norm, split=0.5, dir="/tmp", party1="alice", party2="bob", reduced=None, binary=False):
    """Splits features between Party 1 and Party 2"""
    
#     mnist_train_norm, mnist_test_norm = mnist_norm
#     mnist_train_labels, mnist_test_labels = mnist_labels

    num_features = mnist_norm.shape[1]
    split_point = int(split * num_features)

    party1_train = mnist_norm[:, :, :split_point]
    party2_train = mnist_norm[:, :, split_point:]
#     party1_test = mnist_test_norm[:, :, :split_point]
#     party2_test = mnist_test_norm[:, :, split_point:]

    torch.save(party1_train, os.path.join(dir, party1 + "_train.pth"))
    torch.save(party2_train, os.path.join(dir, party2 + "_train.pth"))

#     torch.save(party1_test, os.path.join(dir, party1 + "_test.pth"))
#     torch.save(party2_test, os.path.join(dir, party2 + "_test.pth"))
#     torch.save(mnist_train_labels, os.path.join(dir, "train_labels.pth"))
#     torch.save(mnist_test_labels, os.path.join(dir, "test_labels.pth"))    
    
    crypten.save_from_party(party1_train, '/tmp/s1_alice_train_labels.pth', src=ALICE)

    crypten.save_from_party(party2_train, '/tmp/s2_charlie_test_labels.pth', src=BOB)



In [5]:
import torch.nn as nn
import torch.nn.functional as F

#Define an example network
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 2) # For binary classification, final layer needs only 2 outputs
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 16 * 12 * 12)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out
    
crypten.common.serial.register_safe_class(ExampleNet)

In [16]:
#Alice has features and bob has labels
import crypten.mpc as mpc
import crypten.communicator as comm
from sklearn.svm import SVC
from examples.mpc_linear_svm.mpc_linear_svm import train_linear_svm, evaluate_linear_svm

@mpc.run_multiprocess(world_size=2)
def loadData(dir="/tmp", party1="alice", party2="bob", reduced=None, binary=False):
    
    #getting mnist data
    mnist_norm, mnist_labels = _get_norm_mnist(dir, reduced, binary)
    
    #features train test split
    mnist_train_norm, mnist_test_norm = mnist_norm
    
    #labels train test split
    mnist_train_labels, mnist_test_labels = mnist_labels
    
    #split features
    split_features(mnist_train_norm)
    
    #loading split1 and split2
    split1 = crypten.load_from_party('/tmp/s1_alice_train_labels.pth', src=ALICE)
    split2 = crypten.load_from_party('/tmp/s2_charlie_test_labels.pth', src=BOB)

    #encrypting train labels
    torch.save(mnist_train_labels, os.path.join(dir, party2 + "_train_labels.pth"))
    crypten.save_from_party(mnist_train_labels, '/tmp/bob_train_labels.pth', src=BOB)
    label = crypten.load_from_party('/tmp/bob_train_labels.pth', src=BOB)
    

    print("label shape: ",label.shape) 
    print("split1 shape: ",split1.shape) 
    #train(split1, label_oh)
    print("till here")
    w, b = train_linear_svm(split1, label, epochs=40, lr=0.001)
    print("training done!")
    # Evaluate model
    evaluate_linear_svm(mnist_test_norm, mnist_test_labels, w, b)
               
    #train(split2, label_oh)
               

loadData()
print("hello")


label shape: label shape:   torch.Size([60000])
torch.Size([60000])
split1 shape: split1 shape:  torch.Size([60000, 28, 14]) torch.Size([60000, 28, 14])

till here
till here


Process Process-21:
Process Process-22:
Traceback (most recent call last):
  File "/anaconda/envs/py38_default/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/anaconda/envs/py38_default/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/anaconda/envs/py38_default/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/nsblack/crypten/CrypTen/crypten/mpc/context.py", line 30, in _launch
    return_value = func(*func_args, **func_kwargs)
  File "/anaconda/envs/py38_default/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/nsblack/crypten/CrypTen/crypten/mpc/context.py", line 30, in _launch
    return_value = func(*func_args, **func_kwargs)
  File "/tmp/ipykernel_22972/2240038617.py", line 36, in loadData
    w, b = train_linear_svm(mnist_tr

hello
