# Label DP SGD

This notebook walks through how to train a model to recognize hand written
digits using label differentially private gradient decent and the MNIST dataset.
In this setting, one party has the images and the other party has the labels.
They would like to collaborate to train a model without revealing their data.

Before starting, install the tf-shell package.

```bash
pip install tf-shell
```

First, import some modules and set up tf-shell. The parameters are for the SHELL
encryption library, which tf-shell uses, and mostly depend on the multiplicative
depth of the computation to be performed. This example performs back
propagation, thus the multiplicative depth is determined by the number of
layers. For more information, see [SHELL](https://github.com/google/shell).

In [1]:
import time
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import tf_shell
import tf_shell_ml

2024-04-26 15:59:42.151758: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-26 15:59:42.173513: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Set up parameters for the SHELL encryption library.
context = tf_shell.create_context64(
    log_n=12,
    main_moduli=[288230376151760897, 288230376152137729],
    plaintext_modulus=4294991873,
    scaling_factor=3,
    mul_depth_supported=3,
    seed="test_seed",
)

# Create the secret key for encryption and a rotation key (rotation key is
# an auxilary key required for operations like roll or matmul).
secret_key = tf_shell.create_key64(context)
public_rotation_key = tf_shell.create_rotation_key64(context, secret_key)

# The batch size is determined by the ciphertext parameters, specifically the
# schemes polynomial's ring degree because tf-shell uses batch axis packing.
# Furthermore, two micro-batches to run in parallel.
batch_size = context.num_slots

Setup MNIST dataset.

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784))
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)

Create a simple model with a hidden layer of size 64 and an output layer
of size 10 (for each of the 10 digits).

In [4]:
# Create the layers
hidden_layer = tf_shell_ml.ShellDense(
    64,
    activation=tf_shell_ml.relu,
    activation_deriv=tf_shell_ml.relu_deriv,
    is_first_layer=True,
)
output_layer = tf_shell_ml.ShellDense(
    10,
    activation=tf.nn.softmax,
)

# Call the layers once to create the weights.
y1 = hidden_layer(tf.zeros((batch_size, 784)))
y2 = output_layer(y1)

loss_fn = tf_shell_ml.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(0.01)

Next, define the `train_step` function which will be called for each batch on an
encrypted batch of labels, y. The function first does a forward on the plaintext
image x to compute a predicted label, then does backpropagation using the
encrypted label y.

In [5]:
def train_step(x, enc_y):
    # Forward pass always in plaintext
    y_1 = hidden_layer(x)
    y_pred = output_layer(y_1)

    # Backward pass.
    dJ_dy_pred = loss_fn.grad(enc_y, y_pred)
    dJ_dw1, dJ_dx1 = output_layer.backward(dJ_dy_pred, public_rotation_key)
    dJ_dw0, _ = hidden_layer.backward(dJ_dx1, public_rotation_key)

    # dJ_dw1, the output layer gradient, would usually have shape [10] for the
    # 10 classes. tf-shell instead back propagates in two mini-batches per batch
    # resulting in two gradients of shape [10]. Furthermore, the gradients are
    # in an "expanded" form where the gradient is repeated by the size of the
    # mini-batch. Said another way, if real_grad_top/bottom is the "real"
    # gradient of shape [10] from the top/bottom halves of the batch:
    #
    # dJ_dw = tf.concat([
    #   tf.repeat(
    #       tf.expand_dims(real_grad_top, 0), repeats=[batch_sz // 2], axis=0
    #   ),
    #   tf.repeat(
    #       tf.expand_dims(real_grad_bottom, 0), repeats=[batch_sz // 2], axis=0
    #   )
    # ])
    #
    # This repetition is result of the SHELL library using a packed
    # representation of ciphertexts for efficiency. As such, if the ciphertexts
    # need to be sent over the network, they may be masked and packed together
    # before being transmitted to the party with the key.
    #
    # Only return the weight gradients at [0], not the bias gradients at [1].
    # The bias is not used in this test.
    return dJ_dw1[0], dJ_dw0[0]


def train_step_wrapper(x_batch, y_batch):
    x_batch = tf.cast(x_batch, tf.float32)
    y_batch = tf.cast(y_batch, tf.float32)

    # Encrypt the batch of secret labels y.
    enc_y_batch = tf_shell.to_encrypted(y_batch, secret_key, context)

    # Run the training step. The top and bottom halves of the batch are
    # treated as two separate mini-batches run in parallel to maximize
    # efficiency.
    enc_output_layer_grad, enc_hidden_layer_grad = train_step(x_batch, enc_y_batch)

    # Decrypt the weight gradients. In practice, the gradients should be
    # noised before decrypting.
    repeated_output_layer_grad = tf_shell.to_tensorflow(
        enc_output_layer_grad, secret_key
    )
    repeated_hidden_layer_grad = tf_shell.to_tensorflow(
        enc_hidden_layer_grad, secret_key
    )

    # Apply the gradients to the model. We choose the first dimension at
    # index 0 arbitrarily. The weight gradients are repeated across the
    # first dimension. See note in train_step for more information.
    optimizer.apply_gradients(
        zip(
            [repeated_output_layer_grad[0], repeated_hidden_layer_grad[0]],
            output_layer.weights + hidden_layer.weights,
        )
    )
    optimizer.apply_gradients(
        zip(
            [repeated_output_layer_grad[batch_size // 2], repeated_hidden_layer_grad[batch_size // 2]],
            output_layer.weights + hidden_layer.weights,
        )
    )

Here is the training loop. Each inner iteration runs two batches of size
$2^{12-1}$ simultaneously.

Tensorboard can be used to visualize the training progress. See cell output for
command to start tensorboard.

In [6]:
epochs = 1
start_time = time.time()

# Set up tensorboard logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = "/tmp/tflogs/pt-%s" % stamp
print(f"To start tensorboard, run: tensorboard --logdir /tmp/tflogs")
writer = tf.summary.create_file_writer(logdir)

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):
        print(
            f"Epoch: {epoch}, Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}"
        )

        # Skip the last batch if it is not full for performance.
        if x_batch.shape[0] != batch_size:
            break

        # If using deferred execution, one can trace and profile the training.
        # tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)

        train_step_wrapper(x_batch, y_batch)

        # with writer.as_default():
        #     tf.summary.trace_export(
        #         name="tf_shell_example_label_dp_sgd", step=(epoch + 1) * step
        #     )

        # Check the accuracy.
        average_loss = 0
        average_accuracy = 0
        for x, y in val_dataset:
            y_pred = output_layer(hidden_layer(x))
            accuracy = tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32
                )
            )
            average_accuracy += accuracy
        average_loss /= len(val_dataset)
        average_accuracy /= len(val_dataset)
        tf.print(f"\taccuracy: {accuracy}")

        with writer.as_default():
            tf.summary.scalar("loss", average_loss, step=(epoch + 1) * batch_size - 1)
            tf.summary.scalar(
                "accuracy", average_accuracy, step=(epoch + 1) * batch_size - 1
            )


print(f"Total plaintext training time: {time.time() - start_time} seconds")

To start tensorboard, run: tensorboard --logdir /tmp/tflogs

Start of epoch 0
Epoch: 0, Batch: 0 / 15, Time Stamp: 0.07017874717712402
	accuracy: 0.11117256432771683


2024-04-26 16:05:57.677839: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 1 / 15, Time Stamp: 360.4916572570801
	accuracy: 0.1150442510843277


2024-04-26 16:11:59.988257: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 2 / 15, Time Stamp: 722.7873704433441
	accuracy: 0.11117256432771683


2024-04-26 16:18:09.918304: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 3 / 15, Time Stamp: 1092.714866399765
	accuracy: 0.10951327532529831


2024-04-26 16:24:15.157070: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 4 / 15, Time Stamp: 1457.9539773464203
	accuracy: 0.11946902424097061


2024-04-26 16:30:21.150770: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 5 / 15, Time Stamp: 1823.9476954936981
	accuracy: 0.12555310130119324


2024-04-26 16:36:25.598224: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 6 / 15, Time Stamp: 2188.3950967788696
	accuracy: 0.13993363082408905


2024-04-26 16:42:28.388670: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 7 / 15, Time Stamp: 2551.185366868973
	accuracy: 0.15873894095420837


2024-04-26 16:48:32.459850: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 8 / 15, Time Stamp: 2915.256863594055
	accuracy: 0.1692477911710739


2024-04-26 16:54:34.126222: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 9 / 15, Time Stamp: 3276.922953605652
	accuracy: 0.17865043878555298


2024-04-26 17:00:35.599672: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 10 / 15, Time Stamp: 3638.3964817523956
	accuracy: 0.1875


2024-04-26 17:06:38.453861: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 11 / 15, Time Stamp: 4001.250802755356
	accuracy: 0.20630531013011932


2024-04-26 17:12:41.857820: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 12 / 15, Time Stamp: 4364.654639005661
	accuracy: 0.21902655065059662


2024-04-26 17:18:46.174162: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 13 / 15, Time Stamp: 4728.971588611603
	accuracy: 0.24170354008674622
Epoch: 0, Batch: 14 / 15, Time Stamp: 5091.019182920456
Total plaintext training time: 5091.019740104675 seconds


2024-04-26 17:24:48.222336: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
