Skip to content

csiro-mlai/deep-mpc

Repository files navigation

Deep Learning Training with Multi-Party Computation

Small set of scripts for training MNIST and CIFAR10 with a number of networks using MP-SPDZ. This version underlies the figures in https://arxiv.org/abs/2107.00501. For the version underlying https://eprint.iacr.org/2020/1330 and https://arxiv.org/abs/2011.11202, see the v1 branch.

Installation (with Docker)

Run the following from this directory to build a Docker container:

$ docker build .

Dockerfile contains some commented out tests at the end.

Running locally

After setting everything up, you can use this script to run the computation:

$ ./run-local.sh <protocol> <net> <round> <n_threads> <n_epochs> <precision> [<further>]

The options are as follows:

  • protocol is one of emul (emulation), sh2 (semi-honest two-party computation) sh3 (semi-honest three-party computation, mal3 (malicious three-party computation), mal4 (malicious four-party computation), dm3 (dishonest-majority semi-honest three-party computation), sh10 (semi-honest ten-party computation), dm10 (dishonest-majority semi-honest ten-party computation). All protocols assume an honest majority unless stated otherwise.
  • net is the network (A-D for MNIST, alex for Falcon AlexNet on CIFAR10, and new_alex for a more sophisticated AlexNet-like network on CIFAR10).
  • round is the kind of rounding, prob for probabilistic and near for nearest.
  • n_threads is the number of threads per party.
  • n_epochs is the number of epochs.
  • precision is the precision (in bits) after the decimal point.

The following options can be given in addition.

  • rate<rate> the learning rate, e.g., rate.1 for a learning rate of 0.1.
  • adam, adamapprox, or amsgrad to use Adam, Adam with less precise approximation of inverse square root, or AMSgrad instead of SGD.

For example,

$ ./run-local.sh emul D prob 2 20 32 adamapprox

runs 20 epochs of training network D in the emulation with two threads, probabilistic rounding, 32-bit precision, and Adam with less precise inverse square root approximation.

Running remotely

You need to set up hosts that run SSH and have all higher TCP ports open between each other. We have used c5.9xlarge instances in the same AWS zone and hence 36 threads. The hosts have to run Linux with a glib not older than Ubuntu 18.04 (2.27), which is the case for Amazon Linux 2. Honest-majority protocols require three hosts while dishonest-majority protocols require two.

With Docker, you can run the following script to set up host names, user name and SSH RSA key. We do NOT recommend running it outside Docker because it might overwrite an existing RSA key file.

$ ./setup-remote.sh

Without Docker, familiarise yourself with SSH configuration options and SSH keys. You can use ssh_config and the above script to find out the requirements. HOSTS has to contain the hostnames separated by whitespace.

After setting up, you can the following using the same options as above:

$ ./run-remote.sh <protocol> <net> <round> <n_threads> <n_epochs> <precision> [<further>]

For example,

$ ./run-remote.sh sh3 A near 1 1 16 rate.1

runs one epoch of training network A with semi-honest three-party computation, one thread, and nearest rounding, 16-bit precision, and SGD with rate 0.1.

Cleartext training

train.py allows to run the equivalent cleartext training with Tensorflow (using floating-point instead of fixed-point representation). Simply run the following:

$ ./train.py <network> <n_epochs> [<optimizer>] [<further>]

<optimizer> is either a learning rate for SGD, or adam or amsgrad followed by a learning rate. <further> is any combination of dropout to add a Dropout layer (only with network C) and fashion for Fashion MNIST.

Data format and scripts

full.py processes the whole MNIST dataset, and full01.py processes two digits to allow logistic regression and similar. By default the digits are 0 and 1, but you can request any pair. The following processes the digits 4 and 9, which are harder to distinguish:

$ ./full01.py 49

Both output to stdout in the format that mnist_*.mpc expect in the input file for party 0, that is: training set labels, training set data (example by example), test set labels, test set data (example by example), all as whitespace-seperated text. Labels are stored as one-hot vectors for the whole dataset and 0/1 for logistic regression.

The scripts expect the unzipped MNIST dataset in the current directory. Run download.sh to download the MNIST dataset.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages