<a href="https://colab.research.google.com/github/strongrunner/Secure-Machine-Learning/blob/main/Encrypted_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tf-encrypted
!pip install syft
! URL="https://github.com/openmined/PySyft.git" && FOLDER="PySyft" && if [ ! -d $FOLDER ]; then git clone -b ryffel/4P --single-branch $URL; else (cd $FOLDER && git pull $URL && cd ..); fi;

!cd PySyft; python setup.py install  > /dev/null

import os
import sys
module_path = os.path.abspath(os.path.join('./PySyft'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
!pip install --upgrade --force-reinstall lz4
!pip install --upgrade --force-reinstall websocket
!pip install --upgrade --force-reinstall websockets

Collecting tf-encrypted
[?25l  Downloading https://files.pythonhosted.org/packages/15/be/a4c0af9fdc5e5cee28495460538acf2766382bd572e01d4847abc7608dba/tf_encrypted-0.5.9-py3-none-manylinux1_x86_64.whl (2.7MB)
[K     |████████████████████████████████| 2.7MB 5.8MB/s 
Collecting pyyaml>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 43.2MB/s 
[?25hCollecting tensorflow<2,>=1.12.0
[?25l  Downloading https://files.pythonhosted.org/packages/8e/64/7a19837dd54d3f53b1ce5ae346ab401dde9678e8f233220317000bfdb3e2/tensorflow-1.15.4-cp36-cp36m-manylinux2010_x86_64.whl (110.5MB)
[K     |████████████████████████████████| 110.5MB 78kB/s 
[?25hCollecting gast==0.2.2
  Downloading https://files.pythonhosted.org/packages/4e/35/11749bf99b2d4e3cceb4d55ca22590b0d7c2c62b9de38ac4a4a7f4687421/gast-0.2.2.tar.gz
Collecting tensorflow-estimator==1.15.1

Collecting websockets
  Using cached https://files.pythonhosted.org/packages/bb/d9/856af84843912e2853b1b6e898ac8b802989fcf9ecf8e8445a1da263bf3b/websockets-8.1-cp36-cp36m-manylinux2010_x86_64.whl
Installing collected packages: websockets
  Found existing installation: websockets 8.1
    Uninstalling websockets-8.1:
      Successfully uninstalled websockets-8.1
Successfully installed websockets-8.1


## Imports and training configuration

In [None]:
from __future__ import print_function
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#from torchvision import datasets, transforms
#import time
torch.__version__

'1.7.0+cu101'

In [None]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 64
        self.epochs = 10
        self.lr = 0.02
        self.seed = 1
        self.log_interval = 1 # Log info at each batch
        self.precision_fractional = 3

args = Arguments()

_ = torch.manual_seed(args.seed)

Secret sharing of data in each worker

In [None]:
import syft as sy  
hook = sy.TorchHook(torch)  
def connect_to_workers(n_workers):
    return [
        sy.VirtualWorker(hook, id=f"worker{i+1}")
        for i in range(n_workers)
    ]
def connect_to_crypto_provider():
    return sy.VirtualWorker(hook, id="crypto_provider")

workers = connect_to_workers(n_workers=2)
crypto_provider = connect_to_crypto_provider()

In [None]:

n_train_items = 640
n_test_items = 640

def get_private_data_loaders(precision_fractional, workers, crypto_provider):
    
    def one_hot_of(index_tensor):
 
        onehot_tensor = torch.zeros(*index_tensor.shape, 10) 
        onehot_tensor = onehot_tensor.scatter(1, index_tensor.view(-1, 1), 1)
        return onehot_tensor
        
    def secret_share(tensor):
 
        return (
            tensor
            .fix_precision(precision_fractional=precision_fractional)
            .share(*workers, crypto_provider=crypto_provider, requires_grad=True)
        )
    
    transformation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True, transform=transformation),
        batch_size=args.batch_size
    )
    
    private_train_loader = [
        (secret_share(data), secret_share(one_hot_of(target)))
        for i, (data, target) in enumerate(train_loader)
        if i < n_train_items / args.batch_size
    ]
    
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, download=True, transform=transformation),
        batch_size=args.test_batch_size
    )
    
    private_test_loader = [
        (secret_share(data), secret_share(target.float()))
        for i, (data, target) in enumerate(test_loader)
        if i < n_test_items / args.test_batch_size
    ]
    
    return private_train_loader, private_test_loader
    
    
private_train_loader, private_test_loader = get_private_data_loaders(
    precision_fractional=args.precision_fractional,
    workers=workers,
    crypto_provider=crypto_provider
)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
def train(args, model, private_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(private_train_loader): # <-- now it is a private dataset
        start_time = time.time()
        
        optimizer.zero_grad()
        
        output = model(data)
        
        # loss = F.nll_loss(output, target)  <-- not possible here
        batch_size = output.shape[0]
        loss = ((output - target)**2).sum().refresh()/batch_size
        
        loss.backward()
        
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            loss = loss.get().float_precision()
      #      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime: {:.3f}s'.format(
       #         epoch, batch_idx * args.batch_size, len(private_train_loader) * args.batch_size,
        #        100. * batch_idx / len(private_train_loader), loss.item(), time.time() - start_time))
            

In [None]:
def test(args, model, private_test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in private_test_loader:
            start_time = time.time()
            
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum()

    correct = correct.get().float_precision()
  #  print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
   #     correct.item(), len(private_test_loader)* args.test_batch_size,
    #    100. * correct.item() / (len(private_test_loader) * args.test_batch_size)))

In [None]:
model = Net()
model = model.fix_precision().share(*workers, crypto_provider=crypto_provider, requires_grad=True)

optimizer = optim.SGD(model.parameters(), lr=args.lr)
optimizer = optimizer.fix_precision() 

for epoch in range(1, args.epochs + 1):
    train(args, model, private_train_loader, optimizer, epoch)
    test(args, model, private_test_loader)


Test set: Accuracy: 221.0/640 (35%)


Test set: Accuracy: 362.0/640 (57%)


Test set: Accuracy: 400.0/640 (62%)


Test set: Accuracy: 429.0/640 (67%)


Test set: Accuracy: 445.0/640 (70%)


Test set: Accuracy: 461.0/640 (72%)


Test set: Accuracy: 473.0/640 (74%)


Test set: Accuracy: 478.0/640 (75%)


Test set: Accuracy: 482.0/640 (75%)


Test set: Accuracy: 483.0/640 (75%)

