# Label DP SGD

This notebook walks through how to train a model to recognize hand written
digits using label differentially private gradient decent and the MNIST dataset.

Before starting, install the tf-shell package.

```bash
pip install tf-shell
```

First, import some modules and set up tf-shell. These parameters are for the
SHELL encryption library, which tf-shell uses. The parameters mostly depend on
the multiplicative depth of the computation to be performed, which in this
example is back propagation, and thus is mostly set by the number of layers. For
more information, see [SHELL](https://github.com/google/shell).

In [1]:
import time
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import shell_tensor
import shell_ml

# First set up parameters for the SHELL encryption library.
log_slots = 11
slots = 2**log_slots

# Num plaintext bits: 27, noise bits: 65, num rns moduli: 2
context = shell_tensor.create_context64(
    log_n=11,
    main_moduli=[140737488486401, 140737488498689],
    aux_moduli=[],
    plaintext_modulus=134246401,
    noise_variance=8,
    seed="",
)

# Create the secret key and a rotation key for certain operations.
key = shell_tensor.create_key64(context)
rotation_key = shell_tensor.create_rotation_key64(context, key)

# The most efficient batch size is determined by the ciphertext parameters.
# The batch_size is set to the ciphertext polynomial's ring degree allowing two
# mini-batches to run in parallel.
batch_size = slots
fxp_num_bits = 8  # number of fractional bits to use in fixed-point encoding.
plaintext_dtype = tf.float32

2024-02-02 18:26:55.026140: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-02 18:26:55.047894: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Setup MNIST dataset.

In [2]:
# Prepare the dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.reshape(x_train, (-1, 784)), np.reshape(x_test, (-1, 784))
x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)

Create a simple model with a hidden layer of size 64 and an output layer
of size 10 (for each of the 10 digits).

In [3]:
# Create the layers
hidden_layer = shell_ml.ShellDense(
    64,
    activation=shell_ml.relu,
    activation_deriv=shell_ml.relu_deriv,
    fxp_fractional_bits=fxp_num_bits,
    weight_dtype=plaintext_dtype,
)
output_layer = shell_ml.ShellDense(
    10,
    fxp_fractional_bits=fxp_num_bits,
    weight_dtype=plaintext_dtype,
)

# Call the layers once to create the weights.
y1 = hidden_layer(tf.zeros((batch_size, 784)))
y2 = output_layer(y1)

loss_fn = shell_ml.CategoricalCrossentropy()
optimizer = shell_ml.Adam()
optimizer.compile([hidden_layer.weights, output_layer.weights])

Here is the custom training loop. The `train_step` function is called for each
batch which first encrypts the input y, does a forward pass on the model in
plaintext to set up gradient precursors, then does backpropagation under
encryption.

In [10]:
stop_after_n_batches = 2
epochs = 1
start_time = time.time()


def train_step(x, enc_y):
    # Forward pass always in plaintext
    y_1 = hidden_layer(x)
    y_pred = output_layer(y_1)

    # Backward pass.
    dJ_dy_pred = loss_fn.grad(enc_y, y_pred)
    (dJ_dw1, dJ_dx1) = output_layer.backward(
        dJ_dy_pred, rotation_key, is_first_layer=False
    )
    (dJ_dw0, _) = hidden_layer.backward(dJ_dx1, rotation_key, is_first_layer=True)

    # dJ_dw1, the output layer gradient, would usually have shape [10] for the
    # 10 classes. tf-shell instead back propagates in two mini-batches per batch
    # resulting in two gradients of shape [10]. Furthermore, the gradients are
    # in an "expanded" form where the gradient is repeated by the size of the
    # mini-batch. Said another way, if real_grad_top/bottom is the "real"
    # gradient of shape [10] from the top/bottom halves of the batch:
    #
    # dJ_dw = tf.concat([
    #   tf.repeat(
    #       tf.expand_dims(real_grad_top, 0), repeats=[batch_sz // 2], axis=0
    #   ),
    #   tf.repeat(
    #       tf.expand_dims(real_grad_bottom, 0), repeats=[batch_sz // 2], axis=0
    #   )
    # ])
    #
    # This repetition is result of the SHELL library using a "packed"
    # representation of ciphertexts for efficiency. As such, if the ciphertexts
    # need to be sent over the network, they may be masked and packed together
    # before being transmitted to the party with the key.
    #
    # Only return the weight gradients at [0], not the bias gradients at [1].
    # The bias is not used in this test.
    return dJ_dw1[0], dJ_dw0[0]


# Set up tensorboard logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = "/tmp/tflogs/pt-%s" % stamp
print(f"tensorboard --logdir /tmp/tflogs")
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Check the accuracy.
    average_loss = 0
    average_accuracy = 0
    for x, y in val_dataset:
        y_pred = output_layer(hidden_layer(x))
        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32))
        average_accuracy += accuracy
    average_loss /= len(val_dataset)
    average_accuracy /= len(val_dataset)
    tf.print(f"Before Training: Epoch {epoch} - Accuracy: {accuracy}")

    # Iterate over the batches of the dataset.
    for step, (x_batch, y_batch) in enumerate(train_dataset.take(stop_after_n_batches)):
        print(
            f"Epoch: {epoch}, Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}"
        )

        x_batch = tf.cast(x_batch, tf.float32)
        y_batch = tf.cast(y_batch, tf.float32)

        # Encrypt the batch of secret labels y.
        enc_y_batch = shell_tensor.to_shell_tensor(
            context, y_batch, fxp_fractional_bits=fxp_num_bits
        ).get_encrypted(key)

        # Run the training step. The top and bottom halves of the batch are
        # treated as two separate mini-batches run in parallel to maximize
        # efficiency.
        enc_output_layer_grad, enc_hidden_layer_grad = train_step(x_batch, enc_y_batch)

        # Decrypt the weight gradients. In practice, the gradients should be
        # noised before decrypting.
        repeated_output_layer_grad = enc_output_layer_grad.get_decrypted(key)
        repeated_hidden_layer_grad = enc_hidden_layer_grad.get_decrypted(key)

        # Apply the gradients to the model. We choose the first dimension at
        # index 0 arbitrarily. The weight gradients are repeated across the
        # first dimension.
        optimizer.grad_to_weight(output_layer.weights, repeated_output_layer_grad[0])
        optimizer.grad_to_weight(hidden_layer.weights, repeated_hidden_layer_grad[0])
        optimizer.grad_to_weight(
            output_layer.weights, repeated_output_layer_grad[batch_size // 2]
        )
        optimizer.grad_to_weight(
            hidden_layer.weights, repeated_hidden_layer_grad[batch_size // 2]
        )
    
    # Check the accuracy.
    average_loss = 0
    average_accuracy = 0
    for x, y in val_dataset:
        y_pred = output_layer(hidden_layer(x))
        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32))
        average_accuracy += accuracy
    average_loss /= len(val_dataset)
    average_accuracy /= len(val_dataset)
    tf.print(f"After Training: Epoch {epoch} - Accuracy: {accuracy}")


print(f"Total plaintext training time: {time.time() - start_time} seconds")

with writer.as_default():
    tf.summary.trace_export(name="mnist_shell_example", step=0, profiler_outdir=logdir)

tensorboard --logdir /tmp/tflogs

Start of epoch 0
Epoch 0 - Accuracy: 0.07300885021686554
Total plaintext training time: 0.0222165584564209 seconds
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Instructions for updating:
`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.
Instructions for updating:
`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.


2024-02-02 21:45:08.784033: I tensorflow/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.
2024-02-02 21:45:08.886853: I tensorflow/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
