In [1]:
import torch
import pandas as pd
import numpy as np
import syft as sy
import copy
hook = sy.TorchHook(torch)
from torch import nn
import torch.nn.functional as F
from torch import optim

In [2]:
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker = sy.VirtualWorker(hook, id="secure_worker")

bob.add_workers([alice, secure_worker])
alice.add_workers([bob, secure_worker])
secure_worker.add_workers([alice, bob])

In [3]:
dataframe=np.genfromtxt("diabetes.csv",delimiter=',',skip_header=1)

In [4]:
X=dataframe[:,0:8]
Y=dataframe[:,8]
Y=Y.reshape(768,1)
dtype=torch.float
device=torch.device("cpu")

In [5]:
from sklearn.preprocessing import StandardScaler
model=StandardScaler()
X=model.fit_transform(X)


In [6]:
data = torch.from_numpy(X)
target = torch.from_numpy(Y)
data,target=data.type(torch.FloatTensor),target.type(torch.FloatTensor)

In [7]:
data_length,data_width=data.shape

In [8]:
bobs_data = data[0:int(data_length/2)].send(bob)
bobs_target = target[0:int(data_length/2)].send(bob)

alices_data = data[int(data_length/2):].send(alice)
alices_target = target[int(data_length/2):].send(alice)


In [9]:
# class my_network(torch.nn.Module):
#     def __init__(self):
#         super(my_network, self).__init__()
#         self.fc1=nn.Linear(data_width,500)
#         self.fc2=nn.Linear(500,100)
#         self.fc3=nn.Linear(100,1)

In [10]:
class my_network(torch.nn.Module):
    def __init__(self):
        super(my_network, self).__init__()
        self.fc1=nn.Linear(data_width,200)
        self.activ=nn.ReLU()
        self.fc2=nn.Linear(200,100)
        self.activ2=nn.ReLU()
        self.fc3=nn.Linear(100,1)
#         self.activ3=nn.ReLU()
    
    def forward(self,input_):
        a1=self.fc1(input_)
        a1=self.activ(a1)
        a2=self.fc2(a1)
        a2=self.activ2(a2)
        y=self.fc3(a2)
#         y=self.activ3(y)
        return y

In [74]:
model=my_network()
loss= torch.nn.MSELoss()

In [12]:
bobs_model = model.copy().send(bob)
alices_model = model.copy().send(alice)

bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.001)
alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.001)

In [84]:
for i in range(10):

    # Train Bob's Model
    bobs_opt.zero_grad()
    bobs_pred = bobs_model(bobs_data)
    bobs_loss = loss(bobs_pred,bobs_target)
    bobs_loss.backward()

    bobs_opt.step()
    bobs_loss = bobs_loss.get().data

    # Train Alice's Model
    alices_opt.zero_grad()
    alices_pred = alices_model(alices_data)
    alices_loss = loss(alices_pred,alices_target)
    alices_loss.backward()

    alices_opt.step()
    alices_loss = alices_loss.get().data
    
    print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))

Bob:tensor(0.3146) Alice:tensor(0.2721)
Bob:tensor(0.3133) Alice:tensor(0.2712)
Bob:tensor(0.3121) Alice:tensor(0.2704)
Bob:tensor(0.3109) Alice:tensor(0.2696)
Bob:tensor(0.3097) Alice:tensor(0.2689)
Bob:tensor(0.3085) Alice:tensor(0.2681)
Bob:tensor(0.3073) Alice:tensor(0.2673)
Bob:tensor(0.3061) Alice:tensor(0.2666)
Bob:tensor(0.3050) Alice:tensor(0.2658)
Bob:tensor(0.3038) Alice:tensor(0.2651)


In [None]:
iterations = 10
worker_iters = 200

for a_iter in range(iterations):
    
    bobs_model = model.copy().send(bob)
    alices_model = model.copy().send(alice)

    bobs_opt = optim.Adam(params=bobs_model.parameters(),lr=0.01)
    alices_opt = optim.Adam(params=alices_model.parameters(),lr=0.01)

    for wi in range(worker_iters):

        # Train Bob's Model
        bobs_opt.zero_grad()
#         bobs_w1=bobs_model.fc1(bobs_data)
#         bobs_w2=bobs_model.fc2(bobs_w1)
#         bobs_pred=bobs_model.fc3(bobs_w2)
#         bobs_pred=bobs_model.fc3(bobs_w2)
#         bobs_w1=bobs_model.fc1(bobs_data)
#         bobs_w1=bobs_model.activ1(bobs_w1)
#         bobs_pred=bobs_model.fc2(bobs_w1)
#         bobs_pred=bobs_model.softmax(bobs_pred)
        bobs_pred = bobs_model.forward(bobs_data)
        bobs_loss = loss(bobs_pred,bobs_target)
        bobs_loss.backward()        
        bobs_opt.step()
#         bobs_opt.step(bobs_data.shape[0])
        bobs_loss = bobs_loss.get().data

        # Train Alice's Model
        alices_opt.zero_grad()
#         alices_w1=F.sigmoid(alices_model.fc1(alices_data))
#         alices_w2=F.sigmoid(alices_model.fc2(alices_w1))
#         alices_pred=F.softmax(alices_model.fc3(alices_w2))
#         alices_w1=alices_model.fc1(alices_data)
#         alices_w1=alices_model.activ1(alices_w1)
#         alices_pred=alices_model.fc2(alices_w1)
#         alices_pred=alices_model.softmax(alices_pred)
        alices_pred = alices_model.forward(alices_data)
        alices_loss = loss(alices_pred,alices_target)
        alices_loss.backward()
        alices_opt.step()
#         alices_opt.step(alices_data.shape[0])
        alices_loss = alices_loss.get().data
#         print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))
    
    alices_model.move(secure_worker)
    bobs_model.move(secure_worker)
    
    model.fc1.weight.data.set_(((alices_model.fc1.weight.data + bobs_model.fc1.weight.data) / 2).get())
    model.fc1.bias.data.set_(((alices_model.fc1.bias.data + bobs_model.fc1.bias.data) / 2).get())
    model.fc2.weight.data.set_(((alices_model.fc2.weight.data + bobs_model.fc2.weight.data) / 2).get())
    model.fc2.bias.data.set_(((alices_model.fc2.bias.data + bobs_model.fc2.bias.data) / 2).get())
    model.fc3.weight.data.set_(((alices_model.fc3.weight.data + bobs_model.fc3.weight.data) / 2).get())
    model.fc3.bias.data.set_(((alices_model.fc3.bias.data + bobs_model.fc3.bias.data) / 2).get())
    
    print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))

In [75]:
bobs_opt.zero_grad()

In [76]:
bobs_pred = bobs_model(bobs_data)

In [77]:
bobs_loss = loss(bobs_pred,bobs_target)

In [78]:
prev=(bobs_model.fc1.weight.data).clone().get()

In [79]:
bobs_loss.backward(retain_graph=True)

In [80]:
bobs_opt.step()

In [81]:
after=(bobs_model.fc1.weight.data).clone().get()

In [82]:
torch.sum(prev-after)

tensor(-0.0005)

In [83]:
list(bobs_model.parameters())[1].grad.clone().get()

tensor([-8.6007e-03, -7.4330e-04,  3.7530e-03,  1.0351e-03,  5.9490e-03,
         2.5547e-04,  3.7711e-03, -1.0140e-02,  2.1963e-03, -1.9414e-03,
         2.2193e-03,  7.1897e-04, -1.2450e-03,  6.9786e-04,  5.3132e-04,
        -7.3041e-03, -1.1606e-03, -8.4972e-03,  1.7780e-03,  1.3816e-03,
        -2.6396e-03,  7.7571e-04,  6.5324e-04, -5.5166e-03,  1.2163e-03,
         2.5656e-03,  2.5129e-03,  1.7607e-03, -3.5983e-03,  3.1386e-04,
         5.4849e-03,  1.5737e-03, -6.3848e-04, -1.1651e-03, -4.8420e-03,
        -8.7499e-03, -8.3866e-03, -1.8201e-03, -1.5377e-03, -6.5627e-03,
         6.2573e-03, -6.9072e-04, -2.9139e-03, -5.1051e-03, -4.7599e-03,
        -3.1543e-03,  1.4664e-04,  1.0161e-03,  3.3814e-03,  1.8423e-03,
         2.6983e-03, -3.5660e-03, -6.1779e-05,  5.5516e-03,  3.4894e-03,
         9.2809e-03, -4.2183e-03,  6.9037e-03,  2.3915e-03,  1.9225e-03,
         4.4367e-04,  6.4051e-04, -1.5056e-03, -8.4270e-03, -2.1535e-03,
         1.4576e-03, -2.0714e-03,  1.8619e-04,  3.1