In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import syft as sy

Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/opt/conda/lib/python3.7/site-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.2.so'





In [2]:
# Set everything up
hook = sy.TorchHook(torch) 

alice = sy.VirtualWorker(id="alice", hook=hook)
bob = sy.VirtualWorker(id="bob", hook=hook)
james = sy.VirtualWorker(id="james", hook=hook)

In [3]:
# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])
target = torch.tensor([[0],[0],[1],[1.]])

# A Toy Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x
model = Net()

In [4]:
# We encode everything
data = data.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
target = target.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)



In [5]:
opt = optim.SGD(params=model.parameters(),lr=0.1).fix_precision()

for iter in range(20):
    # 1) erase previous gradients (if they exist)
    opt.zero_grad()

    # 2) make a prediction
    pred = model(data)

    # 3) calculate how much we missed
    loss = ((pred - target)**2).sum()

    # 4) figure out which weights caused us to miss
    loss.backward()

    # 5) change those weights
    opt.step()

    # 6) print our progress
    print(loss.get().float_precision())

tensor(0.8350)
tensor(0.7070)
tensor(0.6110)
tensor(0.5240)
tensor(0.4390)
tensor(0.3560)
tensor(0.2800)
tensor(0.2140)
tensor(0.1660)
tensor(0.1350)
tensor(0.1010)
tensor(0.0740)
tensor(0.0610)
tensor(0.0380)
tensor(0.0330)
tensor(0.0210)
tensor(0.0140)
tensor(0.0110)
tensor(0.0080)
tensor(0.0070)
