# SMPC idea for model inference

This notebook shows the overall flow of model inferencing with SMPC.  
Assumptions made:  
1. Two workers will be performing the computation.  Pygrid and the data holder's device will not participate in any computation, they are just responsible for data splitting and sending to the workers.  
2. No worker will get a full copy of unencrypted user data nor the model weights. Each of them will get only an encrypted tensor which could assemble the original data only when the tensors from different devices are combined together. 
3. Both workers will be able to get the basic model architecture and initialize it with random weights (i.e. Assume that only the weights are sensitive but not the model)

In [1]:
#Run this if needed
#!pip install crypten

In [2]:
import crypten
import torch as th
import collections

crypten.init()
th.set_num_threads(1)

In [3]:
#define different hosts participating in SMPC
WORKER1 = 0
WORKER2 = 1

In [4]:
# Global Parameters (this will be )
n_users = 10
song_features = 10
bs = 64  # batch_size
lr = 5e-4
embedding_size = 50
layer_sizes = [(embedding_size + song_features, 150), (150, 300), (300, 200), (200, 1)]

# Step 1 -- User data split

The following are assumed to be run on user device

In [5]:
#User data generation -- assume the data below are from the user device
user = th.zeros((n_users,))
user[0] = 1
user = th.reshape(user,(1,-1))
features = th.tensor([0.233021,1.320897,0.128987,0.613270,0.256593,-0.377054,0.104040,0.394174,-0.239261,-0.550271])
features = th.reshape(features,(1,-1))

In [6]:
from crypten import mpc
import crypten.communicator as comm 
@mpc.run_multiprocess(world_size=2) #the world_size here corresponds to how many workers running the SMPC
def data_separation():
    usertensor = crypten.cryptensor(user)
    featuretensor = crypten.cryptensor(features)
    rank = comm.get().get_rank()
    #crypten.print(f"\nRank {rank}:\n {usertensor}\n", in_order=True)
    crypten.save(usertensor, f"data/user{rank}.pth") 
    crypten.save(featuretensor,f"data/feature{rank}.pth")
data_separation()

[None, None]

# Step 2 -- Model split

The following are assumed to be run on pygrid / parcel (#TODO: check if it is possible to obtain the encrypted split weights directly in parcel and save to pygrid?)

In [7]:
#TODO: how to save this model architecture in pygrid and send it to both workers (unencrypted, only random weights)
#NOTE THAT the model definition is slightly different from the one used in pygrid mobile, but should still work 
#since the shape of model weights remain the same (i.e. we could still copy the weights from the trained model in pygrid to here)
class EmbeddingNet(th.nn.Module):  
    """
    Simple model with method for loss and hand-written backprop.
    """

    def __init__(self) -> None:
        super(EmbeddingNet, self).__init__()
        self.embedlayer = th.nn.Linear(n_users, embedding_size) #chagned
        self.fc1 = th.nn.Linear(layer_sizes[0][0], layer_sizes[0][1])
        self.fc2 = th.nn.Linear(layer_sizes[1][0], layer_sizes[1][1])
        self.fc3 = th.nn.Linear(layer_sizes[2][0], layer_sizes[2][1])
        self.fc4 = th.nn.Linear(layer_sizes[3][0], layer_sizes[3][1])

    def forward(self, users,features):
        """
        users: a one-hot tensor of size (n_users,) representing the user.
        features: 10d vector using spotify-provided feature values
        x: a 60d dummy vector required from user
        """
        out = self.embedlayer(users)
        out = th.cat((out,features),dim=1)
        out = self.fc1(out)
        out = th.nn.functional.relu(out)
        out = self.fc2(out)
        out = th.nn.functional.relu(out)
        out = self.fc3(out)
        out = th.nn.functional.relu(out)
        out = self.fc4(out)
        out = th.sigmoid(out)
        return out

In [8]:
#save a dummy model with random weights
crypten.common.serial.register_safe_class(EmbeddingNet)
local_model = EmbeddingNet()
th.save(local_model,"data/model_random.pth")
#Save trained model : Weights are dummy for the purpose of this demo
trained_model = EmbeddingNet()
sd = trained_model.state_dict()
for layer in sd:
    sd[layer] = th.ones(sd[layer].size())
trained_model.load_state_dict(sd)
th.save(trained_model,"data/model_trained.pth")

In [9]:
@mpc.run_multiprocess(world_size=2)
def model_split():
    #Load the model
    plaintext_model = th.load('data/model_trained.pth')
    
    #Construct a CrypTen network with the trained model and dummy_input
    dummy_input = th.empty((1,10,)) ,th.empty((1,10,))
    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)
    
    #3. Encrypt the CrypTen network 
    private_model.encrypt(src=0)
    rank = comm.get().get_rank()

    crypten.save(private_model.state_dict(), f"data/modelweights{rank}.pth") 

    # #Check that model is encrypted:
    print("Model successfully encrypted:", private_model.encrypted)

model_split()

  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))


Model successfully encrypted:Model successfully encrypted:  TrueTrue



[None, None]

# Step 3: Run the SMPC inferencing

The following is to model two workers simply by two processes existing in a single machine. I have yet to determine how this could be done in multiple machines (hopefully should be straightforward.)

In [10]:
#Confirm data ok
@mpc.run_multiprocess(world_size=2)
def model_inference():
    #1. Combine user data 
    user0 = crypten.load_from_party("data/user0.pth",encrypted=True,src=WORKER1)
    user1 = crypten.load_from_party("data/user1.pth",encrypted=True,src=WORKER2)
    user = (user0+user1)
    crypten.print("Tensor encrypted:", crypten.is_encrypted_tensor(user)) 
    crypten.print("Decrypted data:",user.get_plain_text())
    feature0 = crypten.load_from_party("data/feature0.pth",encrypted=True,src=WORKER1)
    feature1 = crypten.load_from_party("data/feature1.pth",encrypted=True,src=WORKER2)
    feature = (feature0+feature1)
    crypten.print("Tensor encrypted:", crypten.is_encrypted_tensor(feature)) 
    crypten.print("Decrypted data:",feature.get_plain_text())
    #2. load basic model
    plaintext_model = th.load('data/model_random.pth') #use of load method instead of load_from_party: meaning that both process should have a copy of the document.
    dummy_input = th.empty((1,10,)), th.empty((1,10,))
    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)
    private_model.encrypt(src=WORKER1) #Actually the src is not important, weights are just dummy
    #3. load and combine model weights
    w1 = crypten.load_from_party("data/modelweights0.pth",encrypted=True,src=WORKER1)
    w2 = crypten.load_from_party("data/modelweights1.pth",encrypted=True,src=WORKER2)
    rank = comm.get().get_rank()
    for k in w1:
        #Crypten strangely only supports enc + plain but does not support the other way round (plain + enc)
        if rank == WORKER1:
            w1[k] = w1[k] + w2[k]
        else:
            w2[k] = w2[k] + w1[k]
    #4. restore weights to basic model
    if rank == WORKER1:
        private_model.load_state_dict(w1)
    elif rank == WORKER2:
        private_model.load_state_dict(w2)
    #5. inference
    private_model.eval()
    output_enc = private_model(user,feature)
    crypten.print(output_enc.get_plain_text())
    #Sanity check the model weights
    # private_model.decrypt()
    # print(private_model.state_dict())
model_inference()

Tensor encrypted: True
Decrypted data: tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
Tensor encrypted: True
Decrypted data: tensor([[ 0.2330,  1.3209,  0.1290,  0.6133,  0.2566, -0.3770,  0.1040,  0.3942,
         -0.2393, -0.5503]])


  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))


tensor([[-7.8450e+09]])


[None, None]

In [11]:
'''
TODO
1. complete this poc and confirm things are right
2. try deploy this in two separate machines
3. (see if ok) Modify PyGrid to give teh  encrypted tensors directly so even pygrid no need see the data?
'''

'\nTODO\n1. complete this poc and confirm things are right\n2. try deploy this in two separate machines\n3. (see if ok) Modify PyGrid to give teh  encrypted tensors directly so even pygrid no need see the data?\n'