# Label DP SGD

This notebook walks through how to train a model to recognize hand written
digits using label differentially private gradient decent and the MNIST dataset.
In this setting, one party has the images and the other party has the labels.
They would like to collaborate to train a model without revealing their data.

Before starting, install the tf-shell package.

```bash
pip install tf-shell
```

First, import some modules and set up tf-shell. The parameters are for the SHELL
encryption library, which tf-shell uses, and mostly depend on the multiplicative
depth of the computation to be performed. This example performs back
propagation, thus the multiplicative depth is determined by the number of
layers. For more information, see [SHELL](https://github.com/google/shell).

In [1]:
import time
import os
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import tf_shell
import tf_shell_ml

2024-06-10 22:24:25.445864: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-10 22:24:25.446300: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-10 22:24:25.448282: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-10 22:24:25.475776: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Set up parameters for the SHELL encryption library.
context = tf_shell.create_context64(
    log_n=12,
    main_moduli=[288230376151760897, 288230376152137729],
    plaintext_modulus=4294991873,
    scaling_factor=3,
    mul_depth_supported=3,
    seed="test_seed",
)

# Create the secret key for encryption and a rotation key (rotation key is
# an auxilary key required for operations like roll or matmul).
secret_key = tf_shell.create_key64(context)
public_rotation_key = tf_shell.create_rotation_key64(context, secret_key)

# The batch size is determined by the ciphertext parameters, specifically the
# schemes polynomial's ring degree because tf-shell uses batch axis packing.
# Furthermore, two micro-batches to run in parallel.
batch_size = context.num_slots

Setup MNIST dataset.

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784))
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

epochs = 1
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = (
    train_dataset.shuffle(buffer_size=2048)
    .batch(batch_size, drop_remainder=True)
    .repeat(count=epochs)
)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)

Create a simple model with a hidden layer of size 64 and an output layer
of size 10 (for each of the 10 digits).

In [4]:
# Create the layers
hidden_layer = tf_shell_ml.ShellDense(
    64,
    activation=tf_shell_ml.relu,
    activation_deriv=tf_shell_ml.relu_deriv,
    is_first_layer=True,
)
output_layer = tf_shell_ml.ShellDense(
    10,
    activation=tf.nn.softmax,
)

# Call the layers once to create the weights.
y1 = hidden_layer(tf.zeros((batch_size, 784)))
y2 = output_layer(y1)

loss_fn = tf_shell_ml.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(0.01)

Next, define the `train_step` function which will be called for each batch on an
encrypted batch of labels, y. The function first does a forward on the plaintext
image x to compute a predicted label, then does backpropagation using the
encrypted label y.

In [5]:
@tf.function
def train_step(x, enc_y):
    # Forward pass always in plaintext
    y_1 = hidden_layer(x)
    y_pred = output_layer(y_1)

    # Backward pass.
    dJ_dy_pred = loss_fn.grad(enc_y, y_pred)
    dJ_dw1, dJ_dx1 = output_layer.backward(dJ_dy_pred, public_rotation_key)
    dJ_dw0, _ = hidden_layer.backward(dJ_dx1, public_rotation_key)

    # dJ_dw1, the output layer gradient, would usually have shape [10] for the
    # 10 classes. tf-shell instead back propagates in two mini-batches per batch
    # resulting in two gradients of shape [10]. Furthermore, the gradients are
    # in an "expanded" form where the gradient is repeated by the size of the
    # batch. Said another way, if real_grad_top/bottom is the "real" gradient of
    # shape [10] from the top/bottom halves of the batch:
    #
    # dJ_dw = tf.concat([
    #   tf.repeat(
    #       tf.expand_dims(real_grad_top, 0), repeats=[batch_sz // 2], axis=0
    #   ),
    #   tf.repeat(
    #       tf.expand_dims(real_grad_bottom, 0), repeats=[batch_sz // 2], axis=0
    #   )
    # ])
    #
    # This repetition is result of the SHELL library using a packed
    # representation of ciphertexts for efficiency. As such, if the ciphertexts
    # need to be sent over the network, they may be masked and packed together
    # before being transmitted to the party with the key.
    #
    # Only return the weight gradients at [0], not the bias gradients at [1].
    # The bias is not used in this test.
    return [dJ_dw1[0], dJ_dw0[0]]


@tf.function
def train_step_wrapper(x_batch, y_batch):
    x_batch = tf.cast(x_batch, tf.float32)
    y_batch = tf.cast(y_batch, tf.float32)

    # Encrypt the batch of secret labels y.
    enc_y_batch = tf_shell.to_encrypted(y_batch, secret_key, context)

    # Run the training step. The top and bottom halves of the batch are
    # treated as two separate mini-batches run in parallel to maximize
    # efficiency.
    enc_grads = train_step(x_batch, enc_y_batch)

    # Decrypt the weight gradients. In practice, the gradients should be
    # noised before decrypting.
    repeated_grads = [tf_shell.to_tensorflow(g, secret_key) for g in enc_grads]

    # Pull out grads from the top and bottom batches.
    top_grad = [g[0] for g in repeated_grads]
    bottom_grad = [g[batch_size // 2] for g in repeated_grads]

    # Decrypt the weight gradients. In practice, the gradients should be
    # noised before decrypting.
    weights = output_layer.weights + hidden_layer.weights

    optimizer.apply_gradients(zip(top_grad, weights))
    optimizer.apply_gradients(zip(bottom_grad, weights))

Here is the training loop. Each inner iteration runs two batches of size
$2^{12-1}$ simultaneously.

Tensorboard can be used to visualize the training progress. See cell output for
command to start tensorboard.

In [6]:
start_time = time.time()

# Set up tensorboard logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.abspath("") + "/tflogs/dp-sgd-%s" % stamp
print(f"To start tensorboard, run: tensorboard --logdir /tmp/tflogs")
print(f"\ttensorboard profiling requires: pip install tensorboard_plugin_profile")
writer = tf.summary.create_file_writer(logdir)

for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):
    print(
        f"Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}"
    )

    if step == 0:
        tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)

    train_step_wrapper(x_batch, y_batch)

    if step == 0:
        with writer.as_default():
            tf.summary.trace_export(name="label_dp_sgd", step=step)

    # Check the accuracy.
    average_loss = 0
    average_accuracy = 0
    for x, y in val_dataset:
        y_pred = output_layer(hidden_layer(x))
        loss = tf.reduce_mean(loss_fn(y, y_pred))
        accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32
            )
        )
        average_loss += loss
        average_accuracy += accuracy
    average_loss /= len(val_dataset)
    average_accuracy /= len(val_dataset)
    tf.print(f"\taccuracy: {accuracy}")

    with writer.as_default():
        tf.summary.scalar("loss", average_loss, step=step)
        tf.summary.scalar("accuracy", average_accuracy, step=step)


print(f"Total training time: {time.time() - start_time} seconds")

To start tensorboard, run: tensorboard --logdir /tmp/tflogs
	tensorboard profiling requires: pip install tensorboard_plugin_profile
Batch: 0 / 14, Time Stamp: 0.06959199905395508


2024-06-10 22:24:40.385175: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-06-10 22:24:40.385203: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.


	accuracy: 0.119384765625


2024-06-10 22:30:47.746997: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.
2024-06-10 22:30:47.758943: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
2024-06-10 22:30:47.759670: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /workspaces/tf-shell/examples/tflogs/dp-sgd-20240610-222440/plugins/profile/2024_06_10_22_30_47/e64b0b6b3843.xplane.pb
2024-06-10 22:30:47.796887: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 1 / 14, Time Stamp: 367.5231137275696
	accuracy: 0.130615234375


2024-06-10 22:36:52.484897: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 2 / 14, Time Stamp: 732.1724491119385
	accuracy: 0.12939453125


2024-06-10 22:42:57.920091: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 3 / 14, Time Stamp: 1097.6079428195953
	accuracy: 0.1337890625


2024-06-10 22:48:57.431460: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 4 / 14, Time Stamp: 1457.1190173625946
	accuracy: 0.144287109375


2024-06-10 22:54:55.687244: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 5 / 14, Time Stamp: 1815.3750908374786
	accuracy: 0.161376953125


2024-06-10 23:00:51.842868: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 6 / 14, Time Stamp: 2171.5303070545197
	accuracy: 0.1787109375


2024-06-10 23:06:46.069293: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 7 / 14, Time Stamp: 2525.7567131519318
	accuracy: 0.189453125


2024-06-10 23:12:38.180800: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 8 / 14, Time Stamp: 2877.8683331012726
	accuracy: 0.207275390625


2024-06-10 23:18:34.076211: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 9 / 14, Time Stamp: 3233.7637860774994
	accuracy: 0.2255859375


2024-06-10 23:24:32.819610: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 10 / 14, Time Stamp: 3592.507264852524
	accuracy: 0.2412109375


2024-06-10 23:30:24.891742: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 11 / 14, Time Stamp: 3944.5792849063873
	accuracy: 0.255859375


2024-06-10 23:36:22.160473: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 12 / 14, Time Stamp: 4301.848348855972
	accuracy: 0.2705078125


2024-06-10 23:42:21.741076: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 13 / 14, Time Stamp: 4661.428592205048
	accuracy: 0.28564453125
Total training time: 5035.84995174408 seconds


2024-06-10 23:48:36.161868: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-06-10 23:48:36.164909: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
