# Label DP SGD

This notebook walks through how to train a model to recognize hand written
digits using label differentially private gradient decent and the MNIST dataset.

Before starting, install the tfshell package.

First, import some modules and set up the training dataset.

In [None]:
import time
from datetime import datetime
import tensorflow as tf
import keras
import numpy as np
import shell_tensor
import shell_ml

batch_size = 1024 # must match SHELL polynomial degree

# Prepare the training dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

Next, set up the SHELL encryption library. The choice of security parameters is
for testing purposes only and is not secure. The parameters heavily depend on
the multiplicative depth of the computation to be performed, which in this case
is backpropagation, thus the number of layers has a high impact.

Here we create a simple model with a hidden layer of size 64 and an output layer
of size 10 (for each of the 10 digits).

In [None]:
ct_params = shell_tensor.shell.ContextParams64(
    modulus=shell_tensor.shell.kModulus59,
    log_n=10,
    log_t=16,
    variance=0,  # Too low for prod. Okay for test.
)
context = shell_tensor.create_context64(ct_params)
prng = shell_tensor.create_prng()
key = shell_tensor.create_key64(context, prng)

hidden_layer = label_dp_sgd.ShellDense(64, activation=label_dp_sgd.relu, activation_deriv=label_dp_sgd.relu_deriv)
output_layer = label_dp_sgd.ShellDense(10, activation=label_dp_sgd.sigmoid, activation_deriv=label_dp_sgd.sigmoid_deriv)

# Call the layers to create the weights
y1 = hidden_layer(tf.zeros((batch_size, 784)))
y2 = output_layer(y1)

loss_fn = label_dp_sgd.CategoricalCrossentropy()
optimizer = label_dp_sgd.Adam()
optimizer.compile([hidden_layer.weights, output_layer.weights])

Here is the custom training loop. The `train_step` function is called for each
batch which first encrypts the input y, does a forward pass on the model in
plaintext to set up gradient precursors, then does backpropagation under
encryption.

In [None]:
stop_after_n_batches = 4
epochs = 1
start_time = time.time()

@tf.function
def train_step(x, y):
    # In practice, input y would be quantized to fixed point before encryption.
    # This is not done here to reduce dependencies on external libraries.
    y = tf.cast(y, tf.int32)

    # Encrypt y
    y = shell_tensor.to_shell_tensor(context, y).get_encrypted(prng, key)

    # Forward pass in plaintext
    y_1 = hidden_layer(x)
    y_pred = output_layer(y_1)

    # Backward pass under encryption
    dJ_dy_pred = loss_fn.grad(y, y_pred)
    (dJ_dw1, dJ_dx1) = output_layer.backward(dJ_dy_pred, False, prng, key)
    (dJ_dw0, dJ_dx0_unused) = hidden_layer.backward(dJ_dx1, True, prng, key)

    # In practice, the gradients are likely secret and should be aggregated and
    # noised before decrypting and applying to the weights. Furthermore, weight
    # gradients are in an "expanded" form where each element of a the ciphertext
    # polynomial holds the same value, the gradient. What this means is dJ_dw1
    # is a tensor of shape [1024, 10] the the real gradient is of shape [10].
    # Said another way,
    # 
    # dJ_dw1 = tf.repeat(real_grad, repeats=[1024], axis=0)
    #
    # This repition may seem wasteful, and it is, but it is product of the
    # polynomial representation of ciphertexts. As such, decryption may be more
    # efficient if ciphertexts are packed together before being transmitted to
    # the party with the key.
    dJ_dw1 = dJ_dw1[0].get_decrypted(key)
    dJ_dw0 = dJ_dw0[0].get_decrypted(key)

    # Decrypt and apply the weight gradients. dJ_dw[1] is bias.
    dJ_dw1 = tf.cast(dJ_dw1, tf.float32)
    optimizer.grad_to_weight(output_layer.weights, dJ_dw1)

    dJ_dw0 = tf.cast(dJ_dw0, tf.float32)
    optimizer.grad_to_weight(hidden_layer.weights, dJ_dw0)


# Set up tensorboard logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = "/tmp/tflogs/pt-%s" % stamp
print(f"tensorboard --logdir /tmp/tflogs")
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # Log every 2 batches.
        if step % 2 == 0:
            print(f"Epoch: {epoch}, Batch: {step} / {len(train_dataset)}, Time: {time.time() - start_time}")

        train_step(x_batch_train, y_batch_train)

        if step == stop_after_n_batches:
            break

print(f"Total plaintext training time: {time.time() - start_time} seconds")

with writer.as_default():
    tf.summary.trace_export(name="my_func_trace", step=0, profiler_outdir=logdir)