# Label DP SGD (Post Scale)

This notebook walks through how to train a model to recognize hand written
digits using label differentially private gradient decent and the MNIST dataset.
In this setting, one party has the images and the other party has the labels.
They would like to collaborate to train a model without revealing their data.

This colab uses the post-scale approach to training.

Before starting, install the tf-shell package.

```bash
pip install tf-shell
```

First, import some modules and set up tf-shell. The parameters are for the SHELL
encryption library, which tf-shell uses, and mostly depend on the multiplicative
depth of the computation to be performed. This example performs back
propagation, thus the multiplicative depth is determined by the number of
layers. For more information, see [SHELL](https://github.com/google/shell).

In [1]:
import time
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import tf_shell
import tf_shell_ml

2024-04-26 15:20:18.850048: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-26 15:20:18.873633: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Num plaintext bits: 19, noise bits: 40
# Max representable value: 61895
context = tf_shell.create_context64(
    log_n=11,
    main_moduli=[576460752303439873],
    plaintext_modulus=557057,
    scaling_factor=3,
    mul_depth_supported=1,
)
# 121 bits of security according to lattice estimator primal_bdd.

# Create the secret key for encryption and a rotation key (rotation key is
# an auxilary key required for operations like roll or matmul).
secret_key = tf_shell.create_key64(context)
public_rotation_key = tf_shell.create_rotation_key64(context, secret_key)

# The batch size is determined by the ciphertext parameters, specifically the
# schemes polynomial's ring degree because tf-shell uses batch axis packing.
# Furthermore, two micro-batches to run in parallel.
batch_size = context.num_slots

Setup MNIST dataset.

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784))
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)

Create a simple model with a hidden layer of size 64 and an output layer
of size 10 (for each of the 10 digits).

In [4]:
mnist_layers = [
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(10, activation="sigmoid"),
]

model = keras.Sequential(mnist_layers)
model.compile(
    optimizer="adam",
    metrics=["accuracy"],
)

Next, define the `train_step` function which will be called for each batch on an
encrypted batch of labels, y. The function first does a forward on the plaintext
image x to compute a predicted label, then does backpropagation using the
encrypted label y.

In [5]:
def train_step(x, y):
    """One step of training with using the "post scale" approach.

    High level idea:
    For each output class, backprop to compute the gradient but exclude the loss
    function. Now we have a _vector_ of model updates for one sample. The real
    gradient update for the sample is a linear combination of the vector of
    weight updates whose scale is determined by dJ_dyhat (the derivative of the
    loss with respect to the predicted output yhat). Effectively, we have
    factored out dJ_dyhat from the gradient. Separating out dJ_dyhat allows us
    to scale the weight updates easily when the label is secret and the gradient
    must be computed under encryption / multiparty computation because the
    multiplicative depth of the computation is 1, however the number of
    multiplications required now depends on the model size AND the number of
    output classes. In contrast, standard backpropagation only requires
    multiplications proportional to the model size, howver the multiplicative
    depth is proportional to the model depth.
    """

    # Unset the activation function for the last layer so it is not used in
    # computing the gradient. The effect of the last layer activation function
    # is factored out of the gradient computation and accounted for below.
    model.layers[-1].activation = tf.keras.activations.linear

    with tf.GradientTape() as tape:
        y_pred = model(x, training=True)  # forward pass
    grads = tape.jacobian(y_pred, model.trainable_weights)
    # ^  layers list x (batch size x num output classes x weights) matrix
    # dy_pred_j/dW_sample_class


    # Reset the activation function for the last layer and compute the real
    # prediction.
    model.layers[-1].activation = tf.keras.activations.sigmoid
    y_pred = model(x, training=False)

    # Compute y_pred - y (where y is encrypted).
    scalars = y_pred - y  # dJ/dy_pred
    # ^  batch_size x num output classes.

    # Scale each gradient. Since 'scalars' may be a vector of ciphertexts, this
    # requires multiplying plaintext gradient for the specific layer (2d) by the
    # ciphertext (scalar). To do so efficiently under encryption requires
    # flattening and packing the weights, as shown below.
    ps_grads = []
    for layer_grad_full in grads:
        # Remember the original shape of the gradient in order to unpack them
        # after the multiplication so they can be applied to the model.
        batch_sz = layer_grad_full.shape[0]
        num_output_classes = layer_grad_full.shape[1]
        grad_shape = layer_grad_full.shape[2:]

        packable_grad = tf.reshape(layer_grad_full, [batch_sz, num_output_classes, -1])
        # ^  batch_size x num output classes x flattened weights

        # Expand the last dim so that the subsequent multiplication is
        # broadcasted.
        expanded_scalars = tf_shell.expand_dims(scalars, axis=-1)
        # ^ batch_size x num output classes x 1

        # Scale the gradient precursors.
        scaled_grad = packable_grad * expanded_scalars
        # ^ dy_pred/dW * dJ/dy_pred = dJ/dW

        # Sum over the output classes.
        scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1)
        # ^  batch_size x 1 x flattened weights

        # In the real world, this approach would also likely require clipping
        # the gradient, aggregation, and adding DP noise.

        # Reshape to remove the '1' dimension in the middle.
        scaled_grad = tf_shell.reshape(scaled_grad, [batch_sz] + grad_shape)
        # ^  batch_size x weights

        # Sum over the batch.
        scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=0, rotation_key=public_rotation_key)
        # ^  batch_size x flattened weights
        # Every [i, ...] is the same, the sum over the batching dim axis=0.

        ps_grads.append(scaled_grad)

    return ps_grads


def train_step_wrapper(x_batch, y_batch):
    # Encrypt
    enc_y_batch = tf_shell.to_encrypted(y_batch, secret_key, context)

    # Train
    ps_grads = train_step(x_batch, enc_y_batch)

    # Decrypt
    grads = []
    for enc_g in ps_grads:
        grads.append(tf_shell.to_tensorflow(enc_g, secret_key)[0])
        # ^ take the first element because the grad sum is repeated over the batching dim.

    model.optimizer.apply_gradients(
        zip(
            grads,
            model.trainable_weights
        )
    )

Here is the training loop. Each inner iteration runs a batch of size 2^(11),
then meaures the model accuracy.

Tensorboard can be used to visualize the training progress. See cell output for
command to start tensorboard.

In [6]:
epochs = 1
start_time = time.time()

# Set up tensorboard logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = "/tmp/tflogs/pt-%s" % stamp
print(f"To start tensorboard, run: tensorboard --logdir /tmp/tflogs")
writer = tf.summary.create_file_writer(logdir)

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):
        print(
            f"Epoch: {epoch}, Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}"
        )

        # Skip the last batch if it is not full for performance.
        if x_batch.shape[0] != batch_size:
            break

        # If using deferred execution, one can trace and profile the training.
        # tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)

        train_step_wrapper(x_batch, y_batch)

        # with writer.as_default():
        #     tf.summary.trace_export(
        #         name="tf_shell_example_label_dp_sgd", step=(epoch + 1) * step
        #     )

        # Check the accuracy.
        average_loss = 0
        average_accuracy = 0
        for x, y in val_dataset:
            y_pred = model(x, training=False)
            accuracy = tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32
                )
            )
            average_accuracy += accuracy
        average_loss /= len(val_dataset)
        average_accuracy /= len(val_dataset)
        tf.print(f"\taccuracy: {accuracy}")

        with writer.as_default():
            tf.summary.scalar("loss", average_loss, step=(epoch + 1) * batch_size - 1)
            tf.summary.scalar(
                "accuracy", average_accuracy, step=(epoch + 1) * batch_size - 1
            )


print(f"Total plaintext training time: {time.time() - start_time} seconds")

To start tensorboard, run: tensorboard --logdir /tmp/tflogs

Start of epoch 0
Epoch: 0, Batch: 0 / 30, Time Stamp: 0.06940650939941406
	accuracy: 0.06139380484819412


2024-04-26 15:21:43.872918: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 1 / 30, Time Stamp: 82.75300288200378
	accuracy: 0.08683628588914871


2024-04-26 15:22:52.390485: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 2 / 30, Time Stamp: 151.25495791435242
	accuracy: 0.12721239030361176


2024-04-26 15:24:00.741010: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 3 / 30, Time Stamp: 219.60518217086792
	accuracy: 0.1548672616481781


2024-04-26 15:25:14.001617: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 4 / 30, Time Stamp: 292.86561703681946
	accuracy: 0.17643804848194122


2024-04-26 15:26:22.228981: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 5 / 30, Time Stamp: 361.09285974502563
	accuracy: 0.19081857800483704


2024-04-26 15:27:31.338013: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 6 / 30, Time Stamp: 430.2023301124573
	accuracy: 0.2101769894361496


2024-04-26 15:28:42.044969: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 7 / 30, Time Stamp: 500.9089617729187
	accuracy: 0.21902655065059662


2024-04-26 15:29:54.669820: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 8 / 30, Time Stamp: 573.5338339805603
	accuracy: 0.22400441765785217


2024-04-26 15:31:03.402643: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 9 / 30, Time Stamp: 642.2663412094116
	accuracy: 0.2317477911710739


2024-04-26 15:32:11.294701: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 10 / 30, Time Stamp: 710.1584296226501
	accuracy: 0.24668142199516296


2024-04-26 15:33:19.957185: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 11 / 30, Time Stamp: 778.8208358287811
	accuracy: 0.26493361592292786


2024-04-26 15:34:32.724597: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 12 / 30, Time Stamp: 851.5886828899384
	accuracy: 0.2887168228626251


2024-04-26 15:35:40.807055: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 13 / 30, Time Stamp: 919.6708896160126
	accuracy: 0.3163716793060303


2024-04-26 15:36:49.307414: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 14 / 30, Time Stamp: 988.1711373329163
	accuracy: 0.35011062026023865


2024-04-26 15:37:56.881641: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 15 / 30, Time Stamp: 1055.7455956935883
	accuracy: 0.3794247806072235


2024-04-26 15:39:10.054974: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 16 / 30, Time Stamp: 1128.918863773346
	accuracy: 0.41261062026023865


2024-04-26 15:40:17.709488: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 17 / 30, Time Stamp: 1196.5733196735382
	accuracy: 0.451880544424057


2024-04-26 15:41:25.546967: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 18 / 30, Time Stamp: 1264.4107003211975
	accuracy: 0.4933628439903259


2024-04-26 15:42:33.810712: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 19 / 30, Time Stamp: 1332.6749150753021
	accuracy: 0.5221238732337952


2024-04-26 15:43:46.661953: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 20 / 30, Time Stamp: 1405.525713443756
	accuracy: 0.5365044474601746


2024-04-26 15:44:54.080379: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 21 / 30, Time Stamp: 1472.944087266922
	accuracy: 0.5553097128868103


2024-04-26 15:46:01.651246: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 22 / 30, Time Stamp: 1540.5149257183075
	accuracy: 0.5636062026023865


2024-04-26 15:47:10.968213: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 23 / 30, Time Stamp: 1609.8318963050842
	accuracy: 0.571349561214447


2024-04-26 15:48:23.901182: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 24 / 30, Time Stamp: 1682.7653839588165
	accuracy: 0.5818583965301514


2024-04-26 15:49:33.029080: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 25 / 30, Time Stamp: 1751.8927001953125
	accuracy: 0.5923672318458557


2024-04-26 15:50:41.049063: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 26 / 30, Time Stamp: 1819.9127042293549
	accuracy: 0.6100663542747498


2024-04-26 15:51:49.092682: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 27 / 30, Time Stamp: 1887.9565889835358
	accuracy: 0.6299778819084167


2024-04-26 15:53:02.865500: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch: 0, Batch: 28 / 30, Time Stamp: 1961.7292737960815
	accuracy: 0.6493362784385681
Epoch: 0, Batch: 29 / 30, Time Stamp: 2030.1340026855469
Total plaintext training time: 2030.1345376968384 seconds


2024-04-26 15:54:11.269970: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
