# Label DP SGD

This notebook walks through how to train a model to recognize hand written
digits using label differentially private gradient decent and the MNIST dataset.
In this setting, one party has the images and the other party has the labels.
They would like to collaborate to train a model without revealing their data.

Before starting, install the tf-shell package.

```bash
pip install tf-shell
```

First, import some modules and set up tf-shell. The parameters are for the SHELL
encryption library, which tf-shell uses, and mostly depend on the multiplicative
depth of the computation to be performed. This example performs back
propagation, thus the multiplicative depth is determined by the number of
layers. For more information, see [SHELL](https://github.com/google/shell).

In [1]:
import time
import os
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import tf_shell
import tf_shell_ml

2024-08-15 18:41:41.612358: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-15 18:41:41.634304: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
use_fast_rotation_protocol = True

# Set up parameters for the SHELL encryption library.
if use_fast_rotation_protocol:
    # Num plaintext bits: 32, noise bits: 86
    # Max representable value: 477221319
    context = tf_shell.create_context64(
        log_n=12,
        main_moduli=[288230376151760897, 288230376152137729],
        plaintext_modulus=4294991873,
        scaling_factor=3,
        mul_depth_supported=1,
        seed="test_seed",
    )
else:
    # Num plaintext bits: 32, noise bits: 86
    # Max representable value: 477221319
    context = tf_shell.create_context64(
        log_n=12,
        main_moduli=[288230376151760897, 288230376152137729],
        plaintext_modulus=4294991873,
        scaling_factor=3,
        mul_depth_supported=1,
        seed="test_seed",
    )

# Create the secret key for encryption and a rotation key (rotation key is
# an auxilary key required for operations like roll or matmul).
secret_key = tf_shell.create_key64(context)
public_rotation_key = tf_shell.create_rotation_key64(context, secret_key)
secret_fast_rotation_key = tf_shell.create_fast_rotation_key64(context, secret_key)

# The batch size is determined by the ciphertext parameters, specifically the
# schemes polynomial's ring degree because tf-shell uses batch axis packing.
# Furthermore, two micro-batches to run in parallel.
batch_size = context.num_slots

Setup MNIST dataset.

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784))
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

epochs = 1
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = (
    train_dataset.shuffle(buffer_size=2048)
    .batch(batch_size, drop_remainder=True)
    .repeat(count=epochs)
)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)

Create a simple model with a hidden layer of size 64 and an output layer
of size 10 (for each of the 10 digits).

In [4]:
# Create the layers
hidden_layer = tf_shell_ml.ShellDense(
    64,
    activation=tf_shell_ml.relu,
    activation_deriv=tf_shell_ml.relu_deriv,
    is_first_layer=True,
    use_fast_reduce_sum=use_fast_rotation_protocol,
)
output_layer = tf_shell_ml.ShellDense(
    10,
    activation=tf.nn.softmax,
    use_fast_reduce_sum=use_fast_rotation_protocol,
)

# Call the layers once to create the weights.
y1 = hidden_layer(tf.zeros((batch_size, 784)))
y2 = output_layer(y1)

loss_fn = tf_shell_ml.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(0.01)

Next, define the `train_step` function which will be called for each batch on an
encrypted batch of labels, y. The function first does a forward on the plaintext
image x to compute a predicted label, then does backpropagation using the
encrypted label y.

In [5]:
@tf.function
def train_step(x, enc_y):
    # Forward pass always in plaintext
    y_1 = hidden_layer(x)
    y_pred = output_layer(y_1)

    # Backward pass.
    dJ_dy_pred = loss_fn.grad(enc_y, y_pred)
    dJ_dw1, dJ_dx1 = output_layer.backward(dJ_dy_pred, public_rotation_key)
    dJ_dw0, _ = hidden_layer.backward(dJ_dx1, public_rotation_key)

    # dJ_dw1, the output layer gradient, would usually have shape [10] for the
    # 10 classes. tf-shell instead back propagates in two mini-batches per batch
    # resulting in two gradients of shape [10]. Furthermore, the gradients are
    # in an "expanded" form where the gradient is repeated by the size of the
    # batch. Said another way, if real_grad_top/bottom is the "real" gradient of
    # shape [10] from the top/bottom halves of the batch:
    #
    # dJ_dw = tf.concat([
    #   tf.repeat(
    #       tf.expand_dims(real_grad_top, 0), repeats=[batch_sz // 2], axis=0
    #   ),
    #   tf.repeat(
    #       tf.expand_dims(real_grad_bottom, 0), repeats=[batch_sz // 2], axis=0
    #   )
    # ])
    #
    # This repetition is result of the SHELL library using a packed
    # representation of ciphertexts for efficiency. As such, if the ciphertexts
    # need to be sent over the network, they may be masked and packed together
    # before being transmitted to the party with the key.
    #
    # Only return the weight gradients at [0], not the bias gradients at [1].
    # The bias is not used in this test.
    return [dJ_dw1[0], dJ_dw0[0]]


@tf.function
def train_step_wrapper(x_batch, y_batch):
    x_batch = tf.cast(x_batch, tf.float32)
    y_batch = tf.cast(y_batch, tf.float32)

    # Encrypt the batch of secret labels y.
    enc_y_batch = tf_shell.to_encrypted(y_batch, secret_key, context)

    # Run the training step. The top and bottom halves of the batch are
    # treated as two separate mini-batches run in parallel to maximize
    # efficiency.
    enc_grads = train_step(x_batch, enc_y_batch)

    # Decrypt the weight gradients. In practice, the gradients should be
    # noised before decrypting.
    if use_fast_rotation_protocol:
        decrypt_key = secret_fast_rotation_key
    else:
        decrypt_key = secret_key
    repeated_grads = [tf_shell.to_tensorflow(g, decrypt_key) for g in enc_grads]

    # Pull out grads from the top and bottom batches.
    top_grad = [g[0] for g in repeated_grads]
    bottom_grad = [g[batch_size // 2] for g in repeated_grads]

    # Decrypt the weight gradients. In practice, the gradients should be
    # noised before decrypting.
    weights = output_layer.weights + hidden_layer.weights

    optimizer.apply_gradients(zip(top_grad, weights))
    optimizer.apply_gradients(zip(bottom_grad, weights))

Here is the training loop. Each inner iteration runs two batches of size
$2^{12-1}$ simultaneously.

Tensorboard can be used to visualize the training progress. See cell output for
command to start tensorboard.

In [6]:
start_time = time.time()

# Set up tensorboard logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
fast_str = "-fast" if use_fast_rotation_protocol else ""
logdir = os.path.abspath("") + f"/tflogs/dp-sgd{fast_str}-{stamp}"
print(f"To start tensorboard, run: tensorboard --logdir ./ --host 0.0.0.0")
print(f"\ttensorboard profiling requires: pip install tensorboard_plugin_profile")
writer = tf.summary.create_file_writer(logdir)

for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):
    print(
        f"Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}"
    )

    if step == 0:
        tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)

    train_step_wrapper(x_batch, y_batch)

    if step == 0:
        with writer.as_default():
            tf.summary.trace_export(name="label_dp_sgd", step=step)

    # Check the accuracy.
    average_loss = 0
    average_accuracy = 0
    for x, y in val_dataset:
        y_pred = output_layer(hidden_layer(x))
        loss = tf.reduce_mean(loss_fn(y, y_pred))
        accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32
            )
        )
        average_loss += loss
        average_accuracy += accuracy
    average_loss /= len(val_dataset)
    average_accuracy /= len(val_dataset)
    tf.print(f"\taccuracy: {accuracy}")

    with writer.as_default():
        tf.summary.scalar("loss", average_loss, step=step)
        tf.summary.scalar("accuracy", average_accuracy, step=step)


print(f"Total training time: {time.time() - start_time} seconds")

To start tensorboard, run: tensorboard --logdir ./ --host 0.0.0.0
	tensorboard profiling requires: pip install tensorboard_plugin_profile
Batch: 0 / 14, Time Stamp: 0.06904864311218262


2024-08-15 18:41:56.693395: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-08-15 18:41:56.693430: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.


	accuracy: 0.12451171875


2024-08-15 18:42:05.604902: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.
2024-08-15 18:42:05.618155: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
2024-08-15 18:42:05.618925: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /workspaces/tf-shell/examples/tflogs/dp-sgd-fast-20240815-184156/plugins/profile/2024_08_15_18_42_05/e64b0b6b3843.xplane.pb
2024-08-15 18:42:05.656513: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 1 / 14, Time Stamp: 9.061949491500854
	accuracy: 0.14697265625


2024-08-15 18:42:12.393712: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 2 / 14, Time Stamp: 15.772825002670288
	accuracy: 0.147216796875


2024-08-15 18:42:18.987119: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 3 / 14, Time Stamp: 22.36575150489807
	accuracy: 0.160888671875


2024-08-15 18:42:25.594964: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 4 / 14, Time Stamp: 28.97419834136963
	accuracy: 0.161865234375


2024-08-15 18:42:32.087415: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 5 / 14, Time Stamp: 35.46611022949219
	accuracy: 0.16357421875


2024-08-15 18:42:38.869085: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 6 / 14, Time Stamp: 42.248823404312134
	accuracy: 0.17333984375


2024-08-15 18:42:45.390588: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 7 / 14, Time Stamp: 48.76966071128845
	accuracy: 0.197509765625


2024-08-15 18:42:51.944617: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 8 / 14, Time Stamp: 55.32348322868347
	accuracy: 0.20751953125


2024-08-15 18:42:58.800163: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 9 / 14, Time Stamp: 62.178874254226685
	accuracy: 0.2138671875


2024-08-15 18:43:05.479284: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 10 / 14, Time Stamp: 68.8585274219513
	accuracy: 0.22412109375


2024-08-15 18:43:12.208747: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 11 / 14, Time Stamp: 75.58803963661194
	accuracy: 0.232421875


2024-08-15 18:43:18.942619: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 12 / 14, Time Stamp: 82.32158493995667
	accuracy: 0.248046875


2024-08-15 18:43:25.422888: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 13 / 14, Time Stamp: 88.801762342453
	accuracy: 0.262451171875
Total training time: 95.44570875167847 seconds


2024-08-15 18:43:32.066184: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-08-15 18:43:32.069279: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
