In [1]:
import tensorflow as tf
import tensorflow_quantum as tfq
import flwr as fl
from typing import Optional, Tuple

import cirq
import sympy
import numpy as np
import seaborn as sns
import collections

# visualization tools
%matplotlib inline
import matplotlib.pyplot as plt
from cirq.contrib.svg import SVGCircuit

tf.config.run_functions_eagerly(True)

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test[6000:10000]
y_test = y_test[6000:10000]


x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0

print("Number of original training examples:", len(x_train))
print("Number of original test examples:", len(x_test))

Number of original training examples: 60000
Number of original test examples: 4000


In [3]:
def filter_36(x, y):
    keep = (y == 3) | (y == 6)
    x, y = x[keep], y[keep]
    y = y == 3
    return x,y

In [4]:
x_train, y_train = filter_36(x_train, y_train)
x_test, y_test = filter_36(x_test, y_test)

print("Number of filtered training examples:", len(x_train))
print("Number of filtered test examples:", len(x_test))

Number of filtered training examples: 12049
Number of filtered test examples: 811


In [6]:
x_train_small = tf.image.resize(x_train, (4,4)).numpy()
x_test_small = tf.image.resize(x_test, (4,4)).numpy()

In [8]:
def remove_contradicting(xs, ys):
    mapping = collections.defaultdict(set)
    orig_x = {}
    # Determine the set of labels for each unique image:
    for x,y in zip(xs,ys):
       orig_x[tuple(x.flatten())] = x
       mapping[tuple(x.flatten())].add(y)

    new_x = []
    new_y = []
    for flatten_x in mapping:
      x = orig_x[flatten_x]
      labels = mapping[flatten_x]
      if len(labels) == 1:
          new_x.append(x)
          new_y.append(next(iter(labels)))
      else:
          # Throw out images that match more than one label.
          pass

    num_uniq_3 = sum(1 for value in mapping.values() if len(value) == 1 and True in value)
    num_uniq_6 = sum(1 for value in mapping.values() if len(value) == 1 and False in value)
    num_uniq_both = sum(1 for value in mapping.values() if len(value) == 2)

    print("Number of unique images:", len(mapping.values()))
    print("Number of unique 3s: ", num_uniq_3)
    print("Number of unique 6s: ", num_uniq_6)
    print("Number of unique contradicting labels (both 3 and 6): ", num_uniq_both)
    print()
    print("Initial number of images: ", len(xs))
    print("Remaining non-contradicting unique images: ", len(new_x))

    return np.array(new_x), np.array(new_y)

In [9]:
x_train_nocon, y_train_nocon = remove_contradicting(x_train_small, y_train)

Number of unique images: 10387
Number of unique 3s:  4912
Number of unique 6s:  5426
Number of unique contradicting labels (both 3 and 6):  49

Initial number of images:  12049
Remaining non-contradicting unique images:  10338


In [10]:
THRESHOLD = 0.5

x_train_bin = np.array(x_train_nocon > THRESHOLD, dtype=np.float32)
x_test_bin = np.array(x_test_small > THRESHOLD, dtype=np.float32)

In [12]:
def convert_to_circuit(image):
    """Encode truncated classical image into quantum datapoint."""
    values = np.ndarray.flatten(image)
    qubits = cirq.GridQubit.rect(4, 4)
    circuit = cirq.Circuit()
    for i, value in enumerate(values):
        if value:
            circuit.append(cirq.X(qubits[i]))
    return circuit


x_train_circ = [convert_to_circuit(x) for x in x_train_bin]
x_test_circ = [convert_to_circuit(x) for x in x_test_bin]

In [15]:
x_train_tfcirc = tfq.convert_to_tensor(x_train_circ)
x_test_tfcirc = tfq.convert_to_tensor(x_test_circ)

In [52]:
class CircuitLayerBuilder():
    def __init__(self, data_qubits, readout):
        self.data_qubits = data_qubits
        self.readout = readout

    def add_layer(self, circuit, gate, prefix):
        for i, qubit in enumerate(self.data_qubits):
            symbol = sympy.Symbol(prefix + '-' + str(i))
            circuit.append(gate(qubit, self.readout)**symbol)

In [54]:
def create_quantum_model():
    """Create a QNN model circuit and readout operation to go along with it."""
    data_qubits = cirq.GridQubit.rect(4, 4)  # a 4x4 grid.
    readout = cirq.GridQubit(-1, -1)         # a single qubit at [-1,-1]
    circuit = cirq.Circuit()

    # Prepare the readout qubit.
    circuit.append(cirq.X(readout))
    circuit.append(cirq.H(readout))

    builder = CircuitLayerBuilder(
        data_qubits = data_qubits,
        readout=readout)

    # Then add layers (experiment by adding more).
    builder.add_layer(circuit, cirq.XX, "xx1")
    builder.add_layer(circuit, cirq.ZZ, "zz1")

    # Finally, prepare the readout qubit.
    circuit.append(cirq.H(readout))

    return circuit, cirq.Z(readout)

In [55]:
model_circuit, model_readout = create_quantum_model()

In [56]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(), dtype=tf.string),
    tfq.layers.PQC(model_circuit, model_readout),
])

In [57]:
y_train_hinge = 2.0*y_train_nocon-1.0
y_test_hinge = 2.0*y_test-1.0

In [58]:
def hinge_accuracy(y_true, y_pred):
    y_true = tf.squeeze(y_true) > 0.0
    y_pred = tf.squeeze(y_pred) > 0.0
    result = tf.cast(y_true == y_pred, tf.float32)

    return tf.reduce_mean(result)

In [59]:
model.compile(
    loss=tf.keras.losses.Hinge(),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=[hinge_accuracy])

In [60]:
print(model.summary())

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
pqc_1 (PQC)                  (None, 1)                 32        
Total params: 32
Trainable params: 32
Non-trainable params: 0
_________________________________________________________________
None


In [63]:
quantum_model_history = {'loss': [], 'accuracy': []}


def get_eval_fn(model, x_test_data, y_test_data, model_history):
    """Return an evaluation function for server-side evaluation."""
    def evaluate(weights: fl.common.Weights) -> Optional[Tuple[float, float]]:
        model.set_weights(weights)  # Update model with the latest parameters
        loss, accuracy = model.evaluate(x_test_data, y_test_data)
        model_history['loss'].append(loss)
        model_history['accuracy'].append(accuracy)
        return loss, {"accuracy": accuracy}

    return evaluate


strategy = fl.server.strategy.FedAvg(
    eval_fn=get_eval_fn(model, x_test_tfcirc, y_test),
    fraction_eval=0.2,
    min_eval_clients=2,
    min_available_clients=2,
)

In [65]:
fl.server.start_server(config={"num_rounds": 5}, strategy=strategy)

INFO flower 2021-12-24 15:23:32,457 | app.py:80 | Flower server running (insecure, 5 rounds)
INFO flower 2021-12-24 15:23:32,457 | server.py:118 | Initializing global parameters
INFO flower 2021-12-24 15:23:32,457 | server.py:304 | Requesting initial parameters from one random client
INFO flower 2021-12-24 15:23:48,579 | server.py:307 | Received initial parameters from one random client
INFO flower 2021-12-24 15:23:48,580 | server.py:120 | Evaluating initial parameters




INFO flower 2021-12-24 15:23:50,530 | server.py:127 | initial parameters (loss, other metrics): 1.0157755613327026, {'accuracy': 0.3838505446910858}
INFO flower 2021-12-24 15:23:50,531 | server.py:133 | FL starting
DEBUG flower 2021-12-24 15:24:01,882 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 15:26:06,979 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 15:26:08,865 | server.py:154 | fit progress: (1, 0.9996570348739624, {'accuracy': 0.5460008978843689}, 138.3343628579987)
INFO flower 2021-12-24 15:26:08,866 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 15:26:08,866 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 15:28:10,919 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 15:28:13,060 | server.py:154 | fit progress: (2, 0.9697707891464233, {'accuracy': 0.5219624042510986}, 262.5286173199984)
INFO flower 2021-12-24 15:28:13,060 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 15:28:13,061 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 15:30:10,761 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 15:30:12,691 | server.py:154 | fit progress: (3, 0.68275386095047, {'accuracy': 0.7117570042610168}, 382.1601249859996)
INFO flower 2021-12-24 15:30:12,692 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 15:30:12,692 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 15:32:08,386 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 15:32:10,335 | server.py:154 | fit progress: (4, 0.36644405126571655, {'accuracy': 0.8812281489372253}, 499.80389293199914)
INFO flower 2021-12-24 15:32:10,335 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 15:32:10,336 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 15:34:06,589 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 15:34:09,160 | server.py:154 | fit progress: (5, 0.3065353333950043, {'accuracy': 0.8800262212753296}, 618.6294312949994)
INFO flower 2021-12-24 15:34:09,161 | server.py:199 | evaluate_round: no clients selected, cancel
INFO flower 2021-12-24 15:34:09,161 | server.py:172 | FL finished in 618.6303053599986
INFO flower 2021-12-24 15:34:09,162 | app.py:119 | app_fit: losses_distributed []
INFO flower 2021-12-24 15:34:09,162 | app.py:120 | app_fit: metrics_distributed {}
INFO flower 2021-12-24 15:34:09,162 | app.py:121 | app_fit: losses_centralized [(0, 1.0157755613327026), (1, 0.9996570348739624), (2, 0.9697707891464233), (3, 0.68275386095047), (4, 0.36644405126571655), (5, 0.3065353333950043)]
INFO flower 2021-12-24 15:34:09,163 | app.py:122 | app_fit: metrics_centralized {'accuracy': [(0, 0.3838505446910858), (1, 0.5460008978843689), (2, 0.5219624042510986), (3, 0.7117570042610168), (4, 0.8812281489372253), (5, 0.8800262212753296)]}


In [16]:
def create_fair_classical_model():
    # A simple model based off LeNet from https://keras.io/examples/mnist_cnn/
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten(input_shape=(4,4,1)))
    model.add(tf.keras.layers.Dense(2, activation='relu'))
    model.add(tf.keras.layers.Dense(1))
    return model


model = create_fair_classical_model()
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])

simple_model_history = {'loss': [], 'accuracy': []}


strategy = fl.server.strategy.FedAvg(
    eval_fn=get_eval_fn(model, x_test_bin, y_test, simple_model_history),
    fraction_eval=0.2,
    min_eval_clients=2,
    min_available_clients=2,
)

In [18]:
fl.server.start_server(config={"num_rounds": 5}, strategy=strategy)

INFO flower 2021-12-24 16:27:02,484 | app.py:80 | Flower server running (insecure, 5 rounds)
INFO flower 2021-12-24 16:27:02,485 | server.py:118 | Initializing global parameters
INFO flower 2021-12-24 16:27:02,486 | server.py:304 | Requesting initial parameters from one random client
INFO flower 2021-12-24 16:27:28,976 | server.py:307 | Received initial parameters from one random client
INFO flower 2021-12-24 16:27:28,976 | server.py:120 | Evaluating initial parameters




  "Even though the tf.config.experimental_run_functions_eagerly "
INFO flower 2021-12-24 16:27:29,161 | server.py:127 | initial parameters (loss, other metrics): 0.6911640167236328, {'accuracy': 0.4882860779762268}
INFO flower 2021-12-24 16:27:29,162 | server.py:133 | FL starting
DEBUG flower 2021-12-24 16:27:55,693 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 16:27:56,282 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 16:27:56,388 | server.py:154 | fit progress: (1, 0.6886425614356995, {'accuracy': 0.4882860779762268}, 27.22556137799984)
INFO flower 2021-12-24 16:27:56,388 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 16:27:56,388 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 16:27:56,652 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 16:27:56,783 | server.py:154 | fit progress: (2, 0.6855823993682861, {'accuracy': 0.4882860779762268}, 27.62114687600115)
INFO flower 2021-12-24 16:27:56,784 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 16:27:56,784 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 16:27:57,032 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 16:27:57,138 | server.py:154 | fit progress: (3, 0.6807022094726562, {'accuracy': 0.4882860779762268}, 27.9763766140004)
INFO flower 2021-12-24 16:27:57,139 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 16:27:57,139 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 16:27:57,381 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 16:27:57,481 | server.py:154 | fit progress: (4, 0.6746236085891724, {'accuracy': 0.4882860779762268}, 28.319192856000882)
INFO flower 2021-12-24 16:27:57,482 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-12-24 16:27:57,482 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-12-24 16:27:57,720 | server.py:264 | fit_round received 2 results and 0 failures




INFO flower 2021-12-24 16:27:57,822 | server.py:154 | fit progress: (5, 0.667334794998169, {'accuracy': 0.4882860779762268}, 28.659764197000186)
INFO flower 2021-12-24 16:27:57,822 | server.py:199 | evaluate_round: no clients selected, cancel
INFO flower 2021-12-24 16:27:57,823 | server.py:172 | FL finished in 28.66055929200047
INFO flower 2021-12-24 16:27:57,823 | app.py:119 | app_fit: losses_distributed []
INFO flower 2021-12-24 16:27:57,823 | app.py:120 | app_fit: metrics_distributed {}
INFO flower 2021-12-24 16:27:57,823 | app.py:121 | app_fit: losses_centralized [(0, 0.6911640167236328), (1, 0.6886425614356995), (2, 0.6855823993682861), (3, 0.6807022094726562), (4, 0.6746236085891724), (5, 0.667334794998169)]
INFO flower 2021-12-24 16:27:57,824 | app.py:122 | app_fit: metrics_centralized {'accuracy': [(0, 0.4882860779762268), (1, 0.4882860779762268), (2, 0.4882860779762268), (3, 0.4882860779762268), (4, 0.4882860779762268), (5, 0.4882860779762268)]}
