# Label DP SGD (Post Scale)

This notebook walks through how to train a model to recognize hand written
digits using label differentially private gradient decent and the MNIST dataset.
In this setting, one party has the images and the other party has the labels.
They would like to collaborate to train a model without revealing their data.

This colab uses the post-scale approach to training.

Before starting, install the tf-shell package.

```bash
pip install tf-shell
```

First, import some modules and set up tf-shell. The parameters are for the SHELL
encryption library, which tf-shell uses, and mostly depend on the multiplicative
depth of the computation to be performed. This example performs back
propagation, thus the multiplicative depth is determined by the number of
layers. For more information, see [SHELL](https://github.com/google/shell).

In [1]:
import time
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import tf_shell
import tf_shell_ml
import os

2024-08-15 17:59:47.764859: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-15 17:59:47.908829: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
use_fast_rotation_protocol = True

if use_fast_rotation_protocol:
    # Num plaintext bits: 19, noise bits: 39
    # Max representable value: 61895
    context = tf_shell.create_context64(
        log_n=11,
        main_moduli=[288230376151748609],
        plaintext_modulus=557057,
        scaling_factor=3,
        mul_depth_supported=1,
    )
    # accuracy: 0.83642578125
    # Total training time: 572.2677536010742 seconds
else:
    # Num plaintext bits: 19, noise bits: 39
    # Max representable value: 61895
    context = tf_shell.create_context64(
        log_n=11,
        main_moduli=[288230376151748609],
        plaintext_modulus=557057,
        scaling_factor=3,
        mul_depth_supported=1,
    )
    # accuracy: 0.82861328125
    # Total training time: 2218.1095881462097 seconds
    

# Create the secret key for encryption and a rotation key (rotation key is
# an auxilary key required for operations like roll or matmul).
secret_key = tf_shell.create_key64(context)
public_rotation_key = tf_shell.create_rotation_key64(context, secret_key)
secret_fast_rotation_key = tf_shell.create_fast_rotation_key64(context, secret_key)

# The batch size is determined by the ciphertext parameters, specifically the
# schemes polynomial's ring degree because tf-shell uses batch axis packing.
# Furthermore, two micro-batches to run in parallel.
batch_size = context.num_slots

Setup MNIST dataset.

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784))
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

epochs = 1
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = (
    train_dataset.shuffle(buffer_size=2048)
    .batch(batch_size, drop_remainder=True)
    .repeat(count=epochs)
)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)

Create a simple model with a hidden layer of size 64 and an output layer
of size 10 (for each of the 10 digits).

In [4]:
mnist_layers = [
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(10, activation="sigmoid"),
]

model = keras.Sequential(mnist_layers)
model.compile(
    optimizer="adam",
    metrics=["accuracy"],
)

Next, define the `train_step` function which will be called for each batch on an
encrypted batch of labels, y. The function first does a forward on the plaintext
image x to compute a predicted label, then does backpropagation using the
encrypted label y.

In [5]:
@tf.function
def train_step(x, y):
    """One step of training with using the "post scale" approach.

    High level idea:
    For each output class, backprop to compute the gradient but exclude the loss
    function. Now we have a _vector_ of model updates for one sample. The real
    gradient update for the sample is a linear combination of the vector of
    weight updates whose scale is determined by dJ_dyhat (the derivative of the
    loss with respect to the predicted output yhat). Effectively, we have
    factored out dJ_dyhat from the gradient. Separating out dJ_dyhat allows us
    to scale the weight updates easily when the label is secret and the gradient
    must be computed under encryption / multiparty computation because the
    multiplicative depth of the computation is 1, however the number of
    multiplications required now depends on the model size AND the number of
    output classes. In contrast, standard backpropagation only requires
    multiplications proportional to the model size, howver the multiplicative
    depth is proportional to the model depth.
    """

    # Unset the activation function for the last layer so it is not used in
    # computing the gradient. The effect of the last layer activation function
    # is factored out of the gradient computation and accounted for below.
    model.layers[-1].activation = tf.keras.activations.linear

    with tf.GradientTape() as tape:
        y_pred = model(x, training=True)  # forward pass
    grads = tape.jacobian(y_pred, model.trainable_weights)
    # ^  layers list x (batch size x num output classes x weights) matrix
    # dy_pred_j/dW_sample_class


    # Reset the activation function for the last layer and compute the real
    # prediction.
    model.layers[-1].activation = tf.keras.activations.sigmoid
    y_pred = model(x, training=False)

    # Compute y_pred - y (where y is encrypted).
    scalars = y.__rsub__(y_pred)  # dJ/dy_pred
    # ^  batch_size x num output classes.

    # Expand the last dim so that the subsequent multiplication is
    # broadcasted.
    scalars = tf_shell.expand_dims(scalars, axis=-1)
    # ^ batch_size x num output classes x 1

    # Scale each gradient. Since 'scalars' may be a vector of ciphertexts, this
    # requires multiplying plaintext gradient for the specific layer (2d) by the
    # ciphertext (scalar). To do so efficiently under encryption requires
    # flattening and packing the weights, as shown below.
    ps_grads = []
    for layer_grad_full in grads:
        # Remember the original shape of the gradient in order to unpack them
        # after the multiplication so they can be applied to the model.
        batch_sz = layer_grad_full.shape[0]
        num_output_classes = layer_grad_full.shape[1]
        grad_shape = layer_grad_full.shape[2:]

        packable_grad = tf.reshape(layer_grad_full, [batch_sz, num_output_classes, -1])
        # ^  batch_size x num output classes x flattened weights

        # Scale the gradient precursors.
        scaled_grad = scalars * packable_grad
        # ^ dJ/dW = dJ/dy_pred * dy_pred/dW 

        # Sum over the output classes.
        scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1)
        # ^  batch_size x 1 x flattened weights

        # In the real world, this approach would also likely require clipping
        # the gradient, aggregation, and adding DP noise.

        # Reshape to remove the '1' dimension in the middle.
        scaled_grad = tf_shell.reshape(scaled_grad, [batch_sz] + grad_shape)
        # ^  batch_size x weights

        # Sum over the batch.
        if use_fast_rotation_protocol:
            scaled_grad = tf_shell.fast_reduce_sum(scaled_grad)
        else:
            scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=0, rotation_key=public_rotation_key)
        # ^  batch_size x flattened weights
        # Every [i, ...] is the same, the sum over the batching dim axis=0.

        ps_grads.append(scaled_grad)

    return ps_grads


@tf.function
def train_step_wrapper(x_batch, y_batch):
    # Encrypt
    enc_y_batch = tf_shell.to_encrypted(y_batch, secret_key, context)

    # Train
    ps_grads = train_step(x_batch, enc_y_batch)

    # Decrypt
    if use_fast_rotation_protocol:
        decrypt_key = secret_fast_rotation_key
    else:
        decrypt_key = secret_key
    batch_sz = context.num_slots
    top_grads = [tf_shell.to_tensorflow(enc_g, decrypt_key)[0] for enc_g in ps_grads]
    bottom_grads = [tf_shell.to_tensorflow(enc_g, decrypt_key)[batch_sz // 2] for enc_g in ps_grads]
    # ^ take the first element of each batch because the grad sum is repeated over the batching dim.

    model.optimizer.apply_gradients(
        zip(
            top_grads,
            model.trainable_weights
        )
    )
    model.optimizer.apply_gradients(
        zip(
            bottom_grads,
            model.trainable_weights
        )
    )

Here is the training loop. Each inner iteration runs a batch of size 2^(11),
then meaures the model accuracy.

Tensorboard can be used to visualize the training progress. See cell output for
command to start tensorboard.

In [6]:
start_time = time.time()

# Set up tensorboard logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
fast_str = "-fast" if use_fast_rotation_protocol else ""
logdir = os.path.abspath("") + f"/tflogs/post-scale{fast_str}-{stamp}"
print(f"To start tensorboard, run: tensorboard --logdir ./ --host 0.0.0.0")
print(f"\ttensorboard profiling requires: pip install tensorboard_plugin_profile")
writer = tf.summary.create_file_writer(logdir)

# Iterate over the batches of the dataset.
for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):
    print(
        f"Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}"
    )

    if step == 0:
        tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)

    train_step_wrapper(x_batch, y_batch)

    if step == 0:
        with writer.as_default():
            tf.summary.trace_export(name="label_dp_sgd_post_scale", step=step)

    # Check the accuracy.
    average_loss = 0
    average_accuracy = 0
    for x, y in val_dataset:
        y_pred = model(x, training=False)
        loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y, y_pred))
        accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32
            )
        )
        average_accuracy += accuracy
        average_loss += loss
    average_loss /= len(val_dataset)
    average_accuracy /= len(val_dataset)
    tf.print(f"\taccuracy: {accuracy}")

    with writer.as_default():
        tf.summary.scalar("loss", average_loss, step=step)
        tf.summary.scalar("accuracy", average_accuracy, step=step)


print(f"Total training time: {time.time() - start_time} seconds")

To start tensorboard, run: tensorboard --logdir ./ --host 0.0.0.0
	tensorboard profiling requires: pip install tensorboard_plugin_profile
Batch: 0 / 29, Time Stamp: 0.07215428352355957


2024-08-15 17:59:50.696763: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-08-15 17:59:50.696786: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.


	accuracy: 0.09521484375


2024-08-15 18:00:36.145366: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.
2024-08-15 18:00:36.156101: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
2024-08-15 18:00:36.157071: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /workspaces/tf-shell/examples/tflogs/post-scale-fast-20240815-175950/plugins/profile/2024_08_15_18_00_36/e64b0b6b3843.xplane.pb
2024-08-15 18:00:36.219544: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 1 / 29, Time Stamp: 45.64277410507202
	accuracy: 0.10546875


2024-08-15 18:01:00.531878: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 2 / 29, Time Stamp: 69.91091227531433
	accuracy: 0.123046875


2024-08-15 18:01:24.659200: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 3 / 29, Time Stamp: 94.03831195831299
	accuracy: 0.13818359375


2024-08-15 18:01:48.374830: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 4 / 29, Time Stamp: 117.75612258911133
	accuracy: 0.16552734375


2024-08-15 18:02:11.962214: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 5 / 29, Time Stamp: 141.34227561950684
	accuracy: 0.212890625


2024-08-15 18:02:36.097529: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 6 / 29, Time Stamp: 165.4764814376831
	accuracy: 0.24365234375


2024-08-15 18:03:00.205054: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 7 / 29, Time Stamp: 189.58532667160034
	accuracy: 0.2646484375


2024-08-15 18:03:24.466542: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 8 / 29, Time Stamp: 213.84793329238892
	accuracy: 0.306640625


2024-08-15 18:03:48.593181: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 9 / 29, Time Stamp: 237.97351455688477
	accuracy: 0.3798828125


2024-08-15 18:04:08.998930: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 10 / 29, Time Stamp: 258.37798953056335
	accuracy: 0.462890625


2024-08-15 18:04:28.731907: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 11 / 29, Time Stamp: 278.1109175682068
	accuracy: 0.5400390625


2024-08-15 18:04:48.511234: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 12 / 29, Time Stamp: 297.8900656700134
	accuracy: 0.58837890625


2024-08-15 18:05:08.275190: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 13 / 29, Time Stamp: 317.65425848960876
	accuracy: 0.60302734375


2024-08-15 18:05:28.013728: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 14 / 29, Time Stamp: 337.3930377960205
	accuracy: 0.603515625


2024-08-15 18:05:47.544108: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 15 / 29, Time Stamp: 356.9234027862549
	accuracy: 0.61474609375


2024-08-15 18:06:07.041602: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 16 / 29, Time Stamp: 376.4208209514618
	accuracy: 0.65283203125


2024-08-15 18:06:26.811293: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 17 / 29, Time Stamp: 396.1906876564026
	accuracy: 0.68896484375


2024-08-15 18:06:46.524649: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 18 / 29, Time Stamp: 415.9036076068878
	accuracy: 0.708984375


2024-08-15 18:07:06.422710: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 19 / 29, Time Stamp: 435.80162358283997
	accuracy: 0.71337890625


2024-08-15 18:07:26.066425: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 20 / 29, Time Stamp: 455.4452106952667
	accuracy: 0.70849609375


2024-08-15 18:07:45.676431: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 21 / 29, Time Stamp: 475.05652117729187
	accuracy: 0.71240234375


2024-08-15 18:08:05.388632: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 22 / 29, Time Stamp: 494.76737999916077
	accuracy: 0.72119140625


2024-08-15 18:08:24.969979: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 23 / 29, Time Stamp: 514.3491895198822
	accuracy: 0.73876953125


2024-08-15 18:08:44.472569: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 24 / 29, Time Stamp: 533.8514201641083
	accuracy: 0.7578125


2024-08-15 18:09:03.994114: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 25 / 29, Time Stamp: 553.3729646205902
	accuracy: 0.77587890625


2024-08-15 18:09:23.844049: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 26 / 29, Time Stamp: 573.2234089374542
	accuracy: 0.79736328125


2024-08-15 18:09:43.488776: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 27 / 29, Time Stamp: 592.8677163124084
	accuracy: 0.8154296875


2024-08-15 18:10:03.621227: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Batch: 28 / 29, Time Stamp: 613.0001292228699
	accuracy: 0.8271484375
Total training time: 632.7654550075531 seconds


2024-08-15 18:10:23.386406: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-08-15 18:10:23.389255: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
